A PyTorch implementation of a self-pruning Convolutional Neural Network (CNN) that learns to prune its own connections during training using learnable gate mechanisms.
This project implements a CNN with prunable fully-connected layers that automatically learn to prune unnecessary connections during training. The model uses learnable gate scores that are multiplied with the weights, allowing the network to identify and remove redundant connections while maintaining accuracy on CIFAR-10 image classification.
- Learnable Gating Mechanism: PrunableLinear layers with sigmoid gates that learn to prune connections
- Sparsity Regularization: L1 penalty on gate values to encourage sparsity
- CIFAR-10 Classification: Trains and evaluates on the CIFAR-10 dataset
- Visualization: Generates histograms showing gate value distributions
- Multiple Lambda Values: Experiments with different sparsity regularization strengths
The model consists of:
- Convolutional Backbone: Two conv layers with BatchNorm, ReLU, and MaxPool
- Prunable FC Layers: Two fully-connected layers with learnable gates
- Dropout: 0.3 dropout rate for regularization
The PrunableLinear layer applies sigmoid gates to weights:
gates = sigmoid(gate_scores)
pruned_weight = weight * gates
output = linear(input, pruned_weight, bias)
pip install -r requirements.txtRequirements:
- torch
- torchvision
- matplotlib
- numpy
Run the training script:
python train.pyThis will:
- Download CIFAR-10 dataset to the
data/directory - Train models with different lambda values (1e-5, 5e-5, 1e-4)
- Evaluate accuracy and sparsity for each model
- Save gate distribution histograms to
results/ - Save final results to
results/results.txt
The script outputs:
- Test accuracy for each lambda value
- Sparsity percentage (fraction of gates < 0.01)
- Gate distribution histograms
- Summary table in
results/results.txt
The total loss combines cross-entropy with sparsity regularization:
loss = cross_entropy_loss + λ * sum(gates)
Where λ controls the strength of sparsity regularization.
.
├── train.py # Main training script
├── requirements.txt # Python dependencies
├── README.md # This file
├── data/ # CIFAR-10 dataset (auto-downloaded)
└── results/ # Output plots and results
MIT License