Skip to content

ahkhan03/snn-normalization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

snn-normalization

A homeostatic framework for batch-free normalization in artificial and spiking neural networks.

License: Apache 2.0 Python 3.9+ Cite


What this is

This repository provides the two normalization layers and the reference architectures described in the accompanying research paper "Why does normalization matter more for spiking neural networks than for conventional ones?". It is intended as a concept-demonstration companion to the paper: readers and reviewers can import a single module, drop it into a PyTorch model, and reproduce the headline finding in a handful of lines.

Two layers are exposed:

Layer Domain Drop-in replacement for
HomeostaticNorm Artificial (static) networks nn.BatchNorm2d, nn.BatchNorm1d
TemporalHomeostaticNorm2d Spiking networks, per-timestep BatchNorm2d replicated over time

Both layers are derived from the same quadratic-programming formulation of homeostatic plasticity — gain and bias play the role of synaptic scaling and intrinsic excitability, and their EMA updates drive the layer toward zero-mean, unit-variance activations without consuming mini-batch statistics at inference time.

The headline finding

Removing normalization costs about three percentage points of accuracy in a conventional CNN. Removing it from the same architecture trained with LIF neurons causes a 76-point collapse to chance. The asymmetry is roughly 22x; the homeostatic layers recover full trainability in the spiking regime.

Architecture No norm Standard norm Homeostatic
Artificial CNN, CIFAR-10 88.78 % 92.27 % (BatchNorm) 92.51 %
Spiking CNN, CIFAR-10 10.00 % 86.20 % (BNTT) 78.69 %
Dependency gap 3.49 pp -- 76.20 pp

Numbers are best test accuracy over three seeds, 30 epochs, default hyperparameters; see examples/ for the training scripts used here.

Install

git clone https://github.com/ahkhan03/snn-normalization
cd snn-normalization
pip install -e .                 # core layers only
pip install -e '.[examples]'     # adds torchvision + snnTorch + tqdm for the example scripts

Python >= 3.9, PyTorch >= 2.0. GPU recommended for the training examples; equivalence_demo.py runs on CPU in seconds.

Quickstart

Equivalence demo (Proposition 1)

python examples/equivalence_demo.py

Runs the forward-output comparison between nn.BatchNorm2d and HomeostaticNorm on a shared synthetic feature-map distribution and prints the maximum elementwise discrepancy along with the output first- and second-order moments. Both layers converge to indistinguishable outputs after a short warm-up.

Drop HomeostaticNorm into your own model

import torch.nn as nn
from snn_norm import HomeostaticNorm

model = nn.Sequential(
    nn.Conv2d(3, 64, 3, padding=1),
    HomeostaticNorm(64),              # <- drop-in for BatchNorm2d(64)
    nn.ReLU(inplace=True),
    nn.Conv2d(64, 64, 3, padding=1),
    HomeostaticNorm(64),
    nn.ReLU(inplace=True),
    # ...
)

The layer tracks running_mean and running_var exactly like BatchNorm2d does, and honours model.train() / model.eval() in the usual way.

Drop TemporalHomeostaticNorm2d into a spiking stack

import snntorch as snn
from snntorch import surrogate
import torch.nn as nn
from snn_norm import TemporalHomeostaticNorm2d

num_steps = 25
norm = TemporalHomeostaticNorm2d(64, num_steps)
lif = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid(25), init_hidden=True)

# Inside the time loop:
for t in range(num_steps):
    h = norm(conv(x), t)
    s = lif(h)

Reproduce the headline ANN / SNN gap

# ANN: each run ~15-20 min on a single mid-range GPU
python examples/train_ann_cifar10.py --norm none         --epochs 30
python examples/train_ann_cifar10.py --norm batchnorm    --epochs 30
python examples/train_ann_cifar10.py --norm homeostatic  --epochs 30

# SNN: each run ~1-2 h on a single mid-range GPU (T=25)
python examples/train_snn_cifar10.py --norm none         --epochs 5   # collapses to 10% by epoch 1
python examples/train_snn_cifar10.py --norm bntt         --epochs 30
python examples/train_snn_cifar10.py --norm homeostatic  --epochs 30

Sweep --seed over 42 123 456 and take the best test accuracy per run to reproduce the numbers in the table above.

Layout

snn-normalization/
├── src/snn_norm/
│   ├── layers.py      # HomeostaticNorm, TemporalHomeostaticNorm2d
│   └── models.py      # ANNCNN, SNNCNN reference architectures
├── examples/
│   ├── equivalence_demo.py
│   ├── train_ann_cifar10.py
│   └── train_snn_cifar10.py
├── pyproject.toml
├── requirements.txt
├── CITATION.cff
├── LICENSE            # Apache 2.0
└── README.md

Figure-generation, LaTeX-table generation, and the full multi-seed / multi-run analysis pipeline used to produce the paper's figures and tables are kept in a separate internal repository and are not included here. This repository is deliberately scoped to the concept and the minimum code needed to reproduce the headline result.

Citation

If you use the code or the ideas from this repository, please cite the accompanying paper (full reference will be added once the paper is assigned a DOI; for now, please cite the preprint). The CITATION.cff file in this repository is tracked by GitHub and by reference managers that support the Citation File Format.

License

Apache License 2.0. See LICENSE.

Authors

  • Ameer H. Khan — School of Artificial Intelligence, Taizhou University, China. ORCID: 0000-0002-5367-5277
  • Ahsan Khan — Academy of Wellness and Human Development, Hong Kong Baptist University, Hong Kong. ORCID: 0000-0001-8133-9839 (corresponding author)

About

Homeostatic normalization layers for artificial and spiking neural networks — companion code for the SNN-Norm paper.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages