Skip to content

hunter-heidenreich/vae

Repository files navigation

vae

A clean, fully-configurable PyTorch Variational Autoencoder trained on MNIST: a from-the-ground-up reimplementation of Kingma & Welling's Auto-Encoding Variational Bayes (2013) with a thorough analysis and diagnostics suite.

What it is

An MLP (not convolutional) VAE on MNIST, deliberately simple in architecture but heavily instrumented. The encoder maps a flattened image to latent parameters; the reparameterization trick samples z = mu + std * eps; the decoder reconstructs pixel logits, trained on BCE + KL (the ELBO).

  • model.py the VAE and its VAEConfig, with three ways to parameterize the latent standard deviation.
  • trainer.py the training/validation loop with checkpointing, TensorBoard logging, optional per-step gradient diagnostics, and per-epoch parameter-change tracking.
  • plotting/ the analysis suite (training curves, latent space, generations/interpolations, KL diagnostics, gradient diagnostics, parameter diagnostics).
  • main.py the CLI tying it together.

Standard-deviation parameterizations

The encoder emits a raw sigma; the model converts it to a positive std three ways:

Mode Conversion Notes
log-variance (default) std = exp(0.5 * sigma) sigma is log(sigma^2); analytical KL 0.5 * sum(mu^2 + exp(logvar) - 1 - logvar)
softplus (--use-softplus-std) std = softplus(sigma) + eps strictly positive; KL uses std directly
sigmoid-bounded (--bound-std X) std = sigmoid(sigma) * X + eps caps std at X; useful against posterior collapse

Quickstart

uv sync
python main.py \
    --latent-dim 2 --hidden-dim 512 --activation tanh \
    --learning-rate 1e-3 --weight-decay 0.01 --num-epochs 100 \
    --batch-size 100 --warmup-steps 100 \
    --analyze-gradients \
    --run-dir runs/vae_mnist --device auto --seed 42

Useful flags: --use-softplus-std / --bound-std X (std parameterization), --n-latent-samples (multi-sample ELBO), --analyze-gradients (per-step gradient diagnostics; slower), --interp-method slerp|lerp, and --device auto|cpu|cuda|mps (auto-detects CUDA or Apple-Silicon MPS).

Figures, checkpoints, and a metrics JSON are written under --run-dir.

License

MIT. See LICENSE.

About

A clean, fully-configurable PyTorch VAE on MNIST (Kingma & Welling 2013) with three std parameterizations and a thorough gradient/latent/KL analysis suite.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages