Latent diffusion model with classifier-free guidance (CFG) for conditional image generation on the Galaxy10 DECaLS dataset — 17 736 RGB galaxy images from the Dark Energy Camera Legacy Survey, annotated with 10 morphological classes.
Generative model producing realistic galaxy images conditioned on morphology class. The workflow consists of two-stage pipeline:
- VAE — compress 256×256 RGB images into a compact 4×32×32 latent representation.
- Latent diffusion — train a denoising U-Net in latent space. At inference, reverse diffusion with CFG produces class-conditioned latents, decoded back to images by the VAE.
This follows the core idea of Latent Diffusion Models (Rombach et al., 2022), scaled down for research use.
Sweeping the CFG guidance scale w and comparing generated vs. real latents on the
real-data manifold shows the typicality–diversity trade-off directly: classifier
recall peaks at intermediate w while manifold coverage falls monotonically — rising
w over-extrapolates samples past their own class manifold rather than collapsing
them onto prototypes. See the coverage analysis notebook
(GitHub renders it inline, with figures). For end-to-end class-conditioned sampling, see the
generation demo.
Coverage decreases and classifier recall peaks at intermediate guidance — best
all-round operating point ≈ 2.5, which maintains high recall and increases total coverage.
Galaxy10 DECaLS
- 17 736 galaxy images, 256×256 px, 3-channel RGB
- 10 morphological classes: Disturbed, Merging, Round Smooth, In-between Round Smooth, Cigar Shaped Smooth, Barred Spiral, Unbarred Tight Spiral, Unbarred Loose Spiral, Edge-on without Bulge, Edge-on with Bulge
- Class distribution is reasonably balanced except for Cigar Shaped Smooth (334 samples, 1.9%)
- Download: https://astronn.readthedocs.io/en/latest/galaxy10.html (HDF5, ~2.5 GB)
galaxy_diffusion/
├── src/
│ └── galaxy_diffusion/
│ ├── models/
│ │ ├── vae.py # VAE — image ↔ latent space
│ │ ├── vae_supcon.py # VAESupCon — VAE + SupCon projection head
│ │ ├── unet.py # LatentUNet (AdaGN) and LatentUNetCA (cross-attn)
│ │ └── classifier.py # GalaxyCNN — pixel-space classifier for evaluation
│ ├── diffusion/
│ │ └── ddpm.py # cosine schedule, forward diffusion, CFG sampling
│ ├── data/
│ │ └── datasets.py # Galaxy10Dataset (h5py), stratified_split, encode_dataset
│ └── evaluation.py # generate_images, classify_images, detect_solid_color
├── examples/
│ ├── train_galaxy10_adagn.py # VAE + AdaGN latent diffusion training
│ ├── train_galaxy10_xattn.py # VAE + cross-attention latent diffusion training
│ ├── train_galaxy10_supcon.py # SupCon VAE + AdaGN latent diffusion training
│ ├── train_classifier.py # CNN classifier training
│ ├── eval_diffusion.py # recall/precision vs guidance scale (sweeps w, φ)
│ └── generate_latents.py # sample and save latents to .npz (per w, φ)
├── scripts/
│ └── convert_to_safetensors.py # convert .pth checkpoints to .safetensors for release
├── tests/
│ ├── test_vae.py # VAE shape and loss tests
│ └── test_unet.py # UNet shape and diffusion schedule tests
├── notebooks/
│ ├── galaxy10_samples.ipynb # dataset overview — class distribution, visual samples
│ ├── galaxy10_vae_analysis.ipynb # baseline VAE — reconstruction, UMAP, Silhouette
│ ├── galaxy10_vae_supcon_analysis.ipynb # SupCon VAE — reconstruction, UMAP, Silhouette
│ ├── classify_images_adagn_v6.ipynb # evaluation — AdaGN diffusion v6
│ ├── classify_images_adagn_v7.ipynb # evaluation — AdaGN diffusion v7
│ ├── classify_images_supcon_v1.ipynb # evaluation — SupCon diffusion v1
│ ├── classify_images_supcon_v2.ipynb # evaluation — SupCon diffusion v2
│ ├── classify_images_xattn_v1.ipynb # evaluation — cross-attention diffusion v1
│ ├── eval_galaxy10_classifier.ipynb # classifier validation on real images
│ ├── eval_xattn_v1_grid.ipynb # recall/precision grid over (w, φ)
│ ├── eval_latents_xattn_v1_grid.ipynb # same grid, recomputed from saved latents
│ ├── decode_latents_xattn_v1.ipynb # decode saved latents back to images
│ ├── coverage_latents_xattn_v1.ipynb # latent-space coverage/density vs guidance scale
│ ├── demo_galaxy10_adagn.ipynb # generation demo — AdaGN pipeline
│ ├── demo_galaxy10_xattn.ipynb # generation demo — cross-attention pipeline
│ └── demo_galaxy10_supcon.ipynb # generation demo — SupCon pipeline
├── requirements.txt # pinned dependencies (reproducible env)
└── pyproject.toml
Convolutional encoder–decoder with reparameterisation trick. Compresses 3×256×256 images to 4×32×32 latents (8× spatial compression).
- Encoder: initial conv + 3× stride-2 conv (32→64→128 channels), outputs μ and log σ²
- Decoder: mirror of encoder with transposed convolutions, Tanh output
- Loss: MSE reconstruction + KL divergence (
kl_weight=0.001) - ~1.09M parameters
Extension of VAE that adds a projection head trained with Supervised Contrastive loss (Khosla et al., 2020). The goal is to enforce class-separated structure in the latent space so the diffusion model can more reliably learn class-conditional distributions.
μ: (B, 4, 32, 32)
→ flatten → (B, 4096)
→ Linear(4096, 256) → ReLU
→ Linear(256, 128)
→ L2-normalise → z̃: (B, 128)
The projection head is used only during VAE training and discarded at inference. Total parameters: ~2.17M.
Combined loss:
L = L_recon + β · L_KL + λ · L_SupCon (β=0.001, λ=0.01, τ=0.1)
U-Net denoiser operating on 4×32×32 latents with classifier-free guidance. Time and class embeddings are summed and injected into each ResBlock via AdaGN scale-shift modulation.
LatentUNet(
latent_channels=4,
base_channels=128,
channel_mult=(1, 2, 4), # 128ch @ 32×32 / 256ch @ 16×16 / 512ch @ 8×8
num_res_blocks=2,
time_emb_dim=256,
num_classes=10,
attn_levels=(1,), # self-attention @ 16×16 + always in bottleneck
)- ~25M parameters
- Conditioning:
h = norm(h) * (1 + scale) + shiftwherescale, shift = Linear(time_emb + class_emb, 2×C) - CFG label dropout: 10% during training
- Noise schedule: cosine (T = 1000)
- Training: AdamW (lr=2e-4, weight_decay=0.01), cosine annealing to 1e-6, gradient clipping (max norm 1.0)
- Loss: Min-SNR weighted MSE —
weight = min(SNR_t, 5) / SNR_t
Variant of LatentUNet with separate conditioning pathways for time and class signals. Time embedding modulates each ResBlock via AdaGN (as above). Class embedding enters via a dedicated cross-attention block after each encoder/decoder level and in the bottleneck.
LatentUNetCA(
latent_channels=4,
base_channels=128,
channel_mult=(1, 2, 4),
num_res_blocks=2,
time_emb_dim=256,
class_emb_dim=256,
num_classes=10,
attn_levels=(1,),
cross_attn_heads=4,
)- ~27.9M parameters
- Cross-attention: Q = feature map
(B, H×W, C), K/V = class embedding(B, 1, 256)— gives the class signal a spatially-aware, independent pathway - CFG rescaling (Lin et al., 2023) at sampling time (
φ=0.7) — rescales the guided prediction to match the per-sample std of the conditional prediction, reducing saturation at high guidance scales - All other training hyperparameters match
LatentUNet
5-stage CNN trained from scratch for evaluation of generated images. Each stage is a ConvBlock with a residual shortcut (1×1 Conv projection on the skip path) for better gradient flow.
ConvBlock(3→32, k=7) → 128×128 # skip: 1×1 Conv(3→32) + pool
ConvBlock(32→64, k=3) → 64×64 # skip: 1×1 Conv(32→64) + pool
ConvBlock(64→128, k=3) → 32×32 # skip: 1×1 Conv(64→128) + pool
ConvBlock(128→256,k=3) → 16×16 # skip: 1×1 Conv(128→256)+ pool
ConvBlock(256→512,k=3) → 8×8 # skip: 1×1 Conv(256→512)+ pool
GAP → Dropout(0.3) → Linear(512, 10)
- ~1.75M parameters
- Training data: the full Galaxy10 dataset is first passed through the VAE (encode → decode) before training — the classifier trains on VAE-reconstructed images to match the distribution of diffusion-generated outputs (VAE smoothing, no background noise)
- Training: weighted cross-entropy (inverse class frequency) +
label_smoothing=0.1, 120 epochs, cosine annealing 1e-3→1e-6 - Augmentation: RandomResizedCrop (scale 0.8–1.0), hflip + vflip (p=0.5), rotation 0–360°, brightness/contrast jitter ±20%, GaussianBlur p=0.3
During training, class labels are replaced with a null token with probability 0.1. At sampling:
eps = eps_uncond + w * (eps_cond - eps_uncond)
Higher guidance scale w produces images more strongly tied to the requested class at the cost of diversity. LatentUNetCA additionally applies CFG rescaling to prevent saturation at high w.
Generated images are evaluated using the trained GalaxyCNN classifier. Degenerate images (CFG-saturated or blank outputs) are detected and excluded before classification using a combined criterion:
degenerate = (mean > 0.0) | (std < 0.05) # images in [-1, 1]; galaxies have dark backgroundMetrics reported per run:
- Top-1 accuracy on generated images (degenerate excluded)
- Degenerate rate — fraction of generated images flagged per class
| Model | Checkpoint | Top-1 acc (generated) | Degenerate (200 images) |
|---|---|---|---|
| VAE + AdaGN v6 | latent_diffusion_galaxy10_adagn_v6.pth |
93.5% | 0 / 200 (0%) |
| VAE + AdaGN v7 | latent_diffusion_galaxy10_adagn_v7.pth |
— | — |
| SupCon VAE + AdaGN v1 | latent_diffusion_galaxy10_supcon_v1.pth |
— | 0 / 200 (0%) |
| SupCon VAE + AdaGN v2 | latent_diffusion_galaxy10_supcon_v2.pth |
79.5% | — |
| VAE + XAttn v1 | latent_diffusion_galaxy10_xattn_v1.pth |
93.6% | n/a † |
† XAttn v1 is evaluated over a guidance sweep (N = 200 images/class) rather than a single degenerate-filtered run — see notebooks/eval_xattn_v1_grid.ipynb. 93.6% is the peak top-1 accuracy, at w = 3.0 without CFG rescaling (φ = 0.0). At the recommended operating point with rescaling (w ≈ 2.5, φ = 0.7) top-1 is 91.6%, trading ~2 pp accuracy for higher diversity / coverage. CFG rescaling suppresses saturated outputs by construction, so the degenerate filter is not applied (hence n/a).
Developed on Python 3.11, PyTorch 2.9 (CUDA 12.x), NVIDIA RTX 4070. PyTorch is pinned because diffusion sampling is sensitive to the framework version.
# editable install of the package (loose deps, for development)
pip install -e ".[dev]"
# or a fully pinned, reproducible environment
pip install -r requirements.txtIf the default torch wheel does not match your CUDA version, install the correct
build from https://pytorch.org first, then the rest of requirements.txt.
Trained weights are published on the Hugging Face Hub
(llapsus/galaxy-diffusion), in
.safetensors format. Each checkpoint ships as a .model / .vae weight file plus a
small .config.json holding the constructor arguments and latent normalisation stats.
Published:
latent_diffusion_galaxy10_xattn_v1— VAE + cross-attention latent diffusion (the model used in the latent-coverage analysis)galaxy10_classifier— GalaxyCNN evaluation classifier
import json
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from galaxy_diffusion.models.vae import VAE
from galaxy_diffusion.models.unet import LatentUNetCA
path = snapshot_download("llapsus/galaxy-diffusion") # downloads all weight files
cfg = json.load(open(f"{path}/latent_diffusion_galaxy10_xattn_v1.config.json"))
vae = VAE(**cfg["vae_config"])
vae.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.vae.safetensors"))
vae.eval()
unet = LatentUNetCA(**cfg["unet_config"])
unet.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.model.safetensors"))
unet.eval()
# latent normalisation stats for sampling:
latents_mean, latents_std = cfg["latents_mean"], cfg["latents_std"]The published weights are compatible with the code at tag v1.0. Convert your own checkpoints with
python scripts/convert_to_safetensors.py <ckpt.pth>.
Train AdaGN latent diffusion model:
python examples/train_galaxy10_adagn.py --data path/to/Galaxy10_DECals.h5
# resume with existing VAE weights:
python examples/train_galaxy10_adagn.py --vae_ckpt models/latent_diffusion_galaxy10_adagn_v6.pthTrain cross-attention latent diffusion model:
python examples/train_galaxy10_xattn.py \
--data path/to/Galaxy10_DECals.h5 \
--vae_ckpt models/latent_diffusion_galaxy10_adagn_v6.pthTrain SupCon VAE + latent diffusion:
python examples/train_galaxy10_supcon.py --data path/to/Galaxy10_DECals.h5Train classifier (on VAE-reconstructed images):
python examples/train_classifier.py \
--data path/to/Galaxy10_DECals.h5 \
--vae_ckpt models/latent_diffusion_galaxy10_adagn_v6.pthRun tests:
pytest tests/ -vThese are intentionally kept out of the repository:
- Dataset — Galaxy10 DECaLS (~2.5 GB HDF5). Download from
https://astronn.readthedocs.io/en/latest/galaxy10.html and pass its path via
--data. - Trained weights — published as GitHub Release assets (see Pretrained weights), not committed to git.
- Intermediate checkpoints — only the final
xattn_v1diffusion model and the evaluation classifier are released; earlier AdaGN / SupCon checkpoints are not. - Generated latents — the
.npzfiles produced byexamples/generate_latents.pyare large and regenerable, so they are not tracked; regenerate them locally as needed.
Released under CC0 1.0 Universal (public domain dedication) — see LICENSE.
- Rombach et al., High-Resolution Image Synthesis with Latent Diffusion Models, CVPR 2022
- Ho & Salimans, Classifier-Free Diffusion Guidance, NeurIPS 2021 Workshop
- Hang et al., Efficient Diffusion Training via Min-SNR Weighting Strategy, ICCV 2023
- Khosla et al., Supervised Contrastive Learning, NeurIPS 2020
- Lin et al., Common Diffusion Noise Schedules and Sample Steps are Flawed, WACV 2024
- Leung & Bovy, Galaxy10 DECaLS, https://astronn.readthedocs.io/en/latest/galaxy10.html

