This repository provides a full training and inference pipeline for an image-classification model built in PyTorch.
It includes:
- A custom dataset loader with image caching
- A modular CNN architecture with BatchNorm + ReLU blocks
- Training loop with augmentation, class balancing, LR scheduling
- Automatic model visualization and learning-curve plotting
- Inference script for batch prediction on PNG images
.
├── dataset.py # Custom dataset classes and device-aware dataloader
├── network.py # StandardCNN architecture
├── training.py # Training pipeline with evaluation and plotting
├── inference.py # Inference on arbitrary PNG folders
├── model.pt # Example trained model
├── model_architecture.png # Auto-generated architecture visualization
├── learning_curves.png # Auto-generated learning-curve plots
├── final_files/ # Final selected outputs
├── output_predictions/ # (generated) inference results
└── pyproject.toml
All dataset logic is defined in dataset.py.
- Wraps
torchvision.datasets.ImageFolder. - Loads all images into memory once using a multithreaded
ThreadPoolExecutor. - Eliminates repeated disk I/O during training → improves throughput.
- Returns
(PIL.Image, label).
- Wraps a raw dataset and applies Albumentations transforms.
- Supports synthetic dataset expansion via
times=k. - Converts PIL → NumPy → Augmentation → Tensor.
- Wraps a standard PyTorch dataloader.
- Moves batches to the assigned device (
cuda/cpu) inside the iterator. - Keeps the training loop clean and readable.
Defined in network.py.
The model (StandardCNN) follows a classic but robust CNN pattern:
[Conv → BN → ReLU] × 2
MaxPool
Dropout
repeat with increasing channels: 32 → 64 → 128 → 256
→ AdaptiveAvgPool2d(1)
→ Fully connected classifier
Important architectural elements:
-
ConvBNReLU blocks: normalize feature maps + accelerate convergence.
-
Dropout between blocks: reduces overfitting.
-
AdaptiveAvgPool2d(1): ensures the CNN can handle fixed 64×64 inputs without needing flattening of large feature maps.
-
Final classifier:
Linear(256 → 128 → num_classes)- ReLU + Dropout + Linear
The architecture diagram is automatically generated via torchview.
Defined in training.py, the process includes:
-
Load raw dataset into memory.
-
Split into train/validation/test (80/10/10).
-
Apply:
- Heavy augmentations on training set (flips, jittering, affine, noise, resize, normalization)
- Light transforms on validation/test sets.
-
Compute class-balanced weights (manually provided counts).
-
Initialize:
StandardCNN()Adamoptimizer (LR=3e-4, weight decay)ReduceLROnPlateauscheduler (monitors validation loss)- Weighted cross-entropy loss
-
Training loop:
- Forward → Loss → Backward → Update
- Track loss & accuracy per epoch.
-
Validation loop each epoch.
-
After training:
-
Evaluate on test set.
-
Save:
model.ptlearning_curves.pngmodel_architecture.png
-
plot_learning_curves() generates a plot of training & validation losses, allowing you to inspect:
- underfitting/overfitting patterns
- effect of LR scheduling
- convergence stability
Defined in inference.py.
-
Load trained model with
state_dict. -
Apply the same preprocessing used in validation:
- Resize to 64×64
- Normalize (ImageNet means/std)
- Convert to tensor
-
Recursively scan the dataset directory for
*.png. -
Run the forward pass (no gradients).
-
Select class with highest logit.
-
Save all predictions to:
output_predictions/predictions.csv
filename, class_id
Inference automatically chooses GPU if available.
python -m venv .venv
source .venv/bin/activate
pip install -e .python training.py /path/to/dataset_rootThis will create:
model.ptlearning_curves.pngmodel_architecture.png
python inference.py /path/to/images /path/to/model.ptResults saved to:
output_predictions/predictions.csv
The training and inference rely on ImageFolder format:
dataset_root/
├── class_0/
│ ├── image1.png
│ ├── ...
├── class_1/
│ ├── ...
└── ...
- Images must be PNG for inference (training can use other formats).
- Labels map to directory names.
In network.py:
class StandardCNN(nn.Module):
def __init__(self, num_classes=6):In training.py, edit:
train_aug = A.Compose([...])
val_tf = A.Compose([...])Change epoch count, LR, weight decay, etc., in:
opt = optim.Adam(...)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(...)
fit(epochs=30, ...)This project demonstrates a clean, extensible pipeline for image classification using PyTorch. It showcases:
- Efficient data loading with caching
- Clean modular CNN design
- Strong augmentation strategy
- Full training/validation/testing workflow
- Automated model visualization
- Ready-to-use inference pipeline