Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions skills/aimx-hydra-lightning-builder/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,23 @@ Never edit, format, sync dependencies, generate files, or run mutation/codegen c

Read `references/architecture.md` before scaffold or migration work.

- `configs/<task>.yaml` composes `data`, `datamodule`, `model`, `plmodule`, `trainer`, `callbacks`, `logger`, `paths`, `accelerate`, and optional `experiment`.
- `configs/<task>.yaml` composes `datamodule`, `model`, `plmodule`, `trainer`, `callbacks`, `logger`, `paths`, `accelerate`, `opt`, and `experiment`.
- Keep `configs/<task>.yaml` as the baseline and select experiment deltas with `experiment=<name>`, where `configs/experiment/<name>.yaml` uses Hydra `override` defaults and parameter overrides.
- `src/train.py` seeds, instantiates configured objects, logs hyperparameters, and calls `trainer.fit/validate/test`.
- `BaseLitModule` owns `cfg`, `cfg.model` instantiation, optimizer/scheduler, compile/SDPA options, and shared trace helpers.
- Task modules own batch parsing, loss, metrics, and prediction/evaluation outputs.
- DataModules own splits, dataloaders, sampler/collate policy, and data preparation boundaries.
- DataModules own splits, dataloaders, sampler/collate policy, and data preparation boundaries. Prefer dataset samples and batches as pytrees so task modules can evolve without positional tuple churn.
- Aim trace uses Lightning loggers for scalars and explicit `experiment.track(...)` for images/distributions.

## Design Principles

- Keep high cohesion inside modules and low coupling across modules.
- Let config define how an experiment runs; let code define what the domain operation means.
- Keep inheritance trees shallow and explicit. Prefer composition through Hydra-configured modules when behavior varies.
- Keep baseline defaults separate from experiment deltas. Experiments should override choices and parameters, not copy whole config trees.
- Keep optimizer and scheduler policy in `opt`; experiments override `opt` values instead of hiding optimizer settings under `model` or `trainer`.
- Use domain adapters for domain-specific behavior. Shared bases define contracts and common mechanics; child adapters implement radar, satellite, vision-frame, tabular, sequence, or other domain semantics.

## References

- `references/architecture.md`: core relationships and file layout.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@ Hydra + Lightning + Aim template for Aimx AutoResearch.
```bash
uv sync
uv run python src/train.py trainer.fast_dev_run=true trainer.logger=false
uv run python src/train.py experiment=exp trainer.fast_dev_run=true trainer.logger=false
uv run pytest
```

Use `experiment=<name>` to apply a file from `configs/experiment/<name>.yaml`.
Experiment yaml files should override config groups and values such as
`model`, `datamodule`, `trainer`, `opt`, `accelerate`, and `logger`.

Enable Aim logging by leaving `trainer.logger=true` and using `logger=aim`.

```bash
uv run python src/train.py
uv run python src/train.py experiment=exp
aimx query params "run.hash != ''" --repo .
aimx query metrics "metric.name != ''" --repo .
aimx query metrics "metric.name == 'acc'" --repo . --json
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
compile: false
precision: "32-true"
fp32_matmul_precision: "highest"
sdpa: ["efficient", "flash", "math"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# @package _global_

# Run with:
# uv run python src/train.py experiment=exp

defaults:
- override /datamodule: dummy
- override /model: mlp
- override /plmodule: classifier
- override /callbacks: default
- override /trainer: default
- override /opt: default
- override /accelerate: default
- override /logger: aim

task_name: train_exp
tags: ["exp", "{{ preset }}"]

seed: 42

autoresearch:
experiment_name: "{{ project_name }}-exp"

trainer:
max_epochs: 2
gradient_clip_val: 0.5

datamodule:
batch_size: 32

model:
hidden_dim: 32

opt:
optimizer:
lr: 0.002

accelerate:
compile: false
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,9 @@ optimizer:
_target_: torch.optim.AdamW
lr: 0.001
weight_decay: 0.0

scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
_partial_: True
T_max: 10
eta_min: 0
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ defaults:
- logger: aim
- opt: default
- accelerate: default
- experiment: null

task_name: train
tags: ["dev", "{{ preset }}"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,22 @@

import torch
from lightning import LightningDataModule
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.utils.data import DataLoader, Dataset, random_split


class PytreeClassificationDataset(Dataset):
def __init__(self, x: torch.Tensor, y: torch.Tensor) -> None:
self.x = x
self.y = y

def __len__(self) -> int:
return int(self.x.shape[0])

def __getitem__(self, index: int) -> dict[str, dict[str, torch.Tensor]]:
return {
"input": {"x": self.x[index]},
"target": {"label": self.y[index]},
}


class RandomClassificationDataModule(LightningDataModule):
Expand All @@ -25,7 +40,7 @@ def setup(self, stage: str | None = None) -> None:
x = torch.randn(int(self.hparams.num_samples), int(self.hparams.num_features), generator=generator)
weights = torch.randn(int(self.hparams.num_features), int(self.hparams.num_classes), generator=generator)
y = torch.argmax(x @ weights, dim=1)
dataset = TensorDataset(x, y)
dataset = PytreeClassificationDataset(x, y)
train_len = max(1, int(0.8 * len(dataset)))
val_len = len(dataset) - train_len
self.train_dataset, self.val_dataset = random_split(dataset, [train_len, val_len], generator=generator)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,93 @@
from __future__ import annotations

import hydra
import lightning as L
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Raise the torch floor for the SDPA import

The generated template still declares torch>=2.1 in pyproject.toml, but this new top-level torch.nn.attention import is not available in allowed 2.1.x environments. In a repo that resolves or pins to that declared floor, simply importing <pkg>.plmodules—including the new export test and any training run—fails before Lightning starts, so either bump the template dependency floor or keep a fallback for older supported torch versions.

Useful? React with 👍 / 👎.

from omegaconf import DictConfig


class BaseLitModule(L.LightningModule):
def __init__(self, cfg: DictConfig) -> None:
super().__init__()

self.save_hyperparameters(logger=False)
self.cfg = cfg
self.net = hydra.utils.instantiate(cfg.model)
self._net_compiled = False

sdpa_map = {
"cudnn": SDPBackend.CUDNN_ATTENTION,
"math": SDPBackend.MATH,
"efficient": SDPBackend.EFFICIENT_ATTENTION,
"flash": SDPBackend.FLASH_ATTENTION,
}

self.sdpa_backends = [sdpa_map[backend] for backend in self.cfg.accelerate.get("sdpa", ["math"])]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation directly accesses self.cfg.accelerate and assumes that the keys are always present. If accelerate is missing or set to null in the configuration, this will raise an AttributeError. Additionally, if an invalid backend name is provided, it will raise a KeyError. We should make this more robust by using defensive .get() calls and filtering out invalid backends.

        accelerate_cfg = self.cfg.get("accelerate")
        sdpa_names = accelerate_cfg.get("sdpa", ["math"]) if accelerate_cfg else ["math"]
        self.sdpa_backends = [sdpa_map[backend] for backend in sdpa_names if backend in sdpa_map]


def forward(self, *args, **kwargs):
return self._model_forward(*args, **kwargs)

def _model_forward(self, *args, **kwargs):
with sdpa_kernel(self.sdpa_backends):
return self.net(*args, **kwargs)

def setup(self, stage: str) -> None:
if self.cfg.accelerate.compile and stage == "fit" and hasattr(torch, "compile") and not self._net_compiled:
self.net = torch.compile(self.net)
self._net_compiled = True
Comment on lines +35 to +38

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Directly accessing self.cfg.accelerate.compile can raise an AttributeError if accelerate is missing or None in the configuration. It is safer to use defensive .get() calls to retrieve the compile setting.

    def setup(self, stage: str) -> None:
        accelerate_cfg = self.cfg.get("accelerate")
        compile_enabled = accelerate_cfg.get("compile", False) if accelerate_cfg else False
        if compile_enabled and stage == "fit" and hasattr(torch, "compile") and not self._net_compiled:
            self.net = torch.compile(self.net)
            self._net_compiled = True


def get_lr_scheduler(self, optimizer):
scheduler = hydra.utils.instantiate(self.cfg.opt.scheduler)(optimizer=optimizer)
kwargs = {
key: value for key, value in self.cfg.opt.items() if key not in ["optimizer", "scheduler"]
}
return {
"scheduler": scheduler,
**kwargs,
}
Comment on lines +40 to +48

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Directly accessing self.cfg.opt can raise an AttributeError if opt is missing or None in the configuration. We should use defensive .get() calls to safely retrieve the optimizer and scheduler configurations.

    def get_lr_scheduler(self, optimizer):
        opt_cfg = self.cfg.get("opt", {})
        scheduler_cfg = opt_cfg.get("scheduler")
        scheduler = hydra.utils.instantiate(scheduler_cfg)(optimizer=optimizer)
        kwargs = {
            key: value for key, value in opt_cfg.items() if key not in ["optimizer", "scheduler"]
        }
        return {
            "scheduler": scheduler,
            **kwargs,
        }


def get_optimizer(self):
if self.cfg.opt.optimizer._target_ == "torch.optim.AdamW":
optimizer = hydra.utils.instantiate(
self.cfg.opt.optimizer,
params=filter(lambda p: p.requires_grad, self.net.parameters()),
)
elif self.cfg.opt.optimizer._target_ == "colossalai.nn.optimizer.HybridAdam":
optimizer = hydra.utils.instantiate(
self.cfg.opt.optimizer,
model_params=filter(lambda p: p.requires_grad, self.net.parameters()),
)
else:
optimizer = hydra.utils.instantiate(
self.cfg.opt.optimizer,
params=filter(lambda p: p.requires_grad, self.net.parameters()),
)
return optimizer
Comment on lines +50 to +66

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of get_optimizer has redundant code branches (the AdamW and else branches are identical) and directly accesses nested attributes on self.cfg.opt which can raise an AttributeError if opt or optimizer is missing. Additionally, passing a single-use filter iterator directly to the optimizer can cause issues with some custom optimizers or multiple parameter groups. We should refactor this to be cleaner, safer, and convert the filtered parameters to a list.

    def get_optimizer(self):
        opt_cfg = self.cfg.get("opt")
        optimizer_cfg = opt_cfg.get("optimizer") if opt_cfg else None
        if not optimizer_cfg:
            raise ValueError("Optimizer configuration is missing in 'cfg.opt.optimizer'")

        target = optimizer_cfg.get("_target_")
        params = list(filter(lambda p: p.requires_grad, self.net.parameters()))

        if target == "colossalai.nn.optimizer.HybridAdam":
            return hydra.utils.instantiate(optimizer_cfg, model_params=params)

        return hydra.utils.instantiate(optimizer_cfg, params=params)


def configure_optimizers(self):
optimizer = self.get_optimizer()
if not self.cfg.opt.get("scheduler"):
return optimizer

lr_scheduler = self.get_lr_scheduler(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
}
Comment on lines +68 to +77

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Directly calling self.cfg.opt.get("scheduler") can raise an AttributeError if opt is missing or None in the configuration. We should use defensive .get() calls to safely check for the scheduler.

    def configure_optimizers(self):
        optimizer = self.get_optimizer()
        opt_cfg = self.cfg.get("opt")
        if not opt_cfg or not opt_cfg.get("scheduler"):
            return optimizer

        lr_scheduler = self.get_lr_scheduler(optimizer)
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler,
        }


def _aim_experiments(self):
for logger in self.loggers:
experiment = getattr(logger, "experiment", None)
if experiment is not None and hasattr(experiment, "track"):
yield experiment

def _instantiate_metric(self, name: str, defaults: dict[str, dict[str, object]]):
metrics_cfg = self.cfg.get("metrics", {})
metric_cfg = metrics_cfg[name] if name in metrics_cfg else defaults[name]
return hydra.utils.instantiate(metric_cfg)
Comment on lines +85 to +88

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If metrics is configured as None (e.g., metrics: null), self.cfg.get("metrics", {}) will return None, and the subsequent name in metrics_cfg check will raise a TypeError. We should use defensive programming to handle this case gracefully.

    def _instantiate_metric(self, name: str, defaults: dict[str, dict[str, object]]):
        metrics_cfg = self.cfg.get("metrics") or {}
        metric_cfg = metrics_cfg.get(name) if name in metrics_cfg else defaults.get(name)
        if metric_cfg is None:
            raise ValueError(f"Metric '{name}' is not defined in config or defaults.")
        return hydra.utils.instantiate(metric_cfg)



from {{ package_name }}.plmodules.classifier import ClassificationModule

__all__ = ["ClassificationModule"]
__all__ = ["BaseLitModule", "ClassificationModule"]
Original file line number Diff line number Diff line change
@@ -1,38 +1,39 @@
from __future__ import annotations

import hydra
import lightning as L
import torch
import torch.nn.functional as F
from omegaconf import DictConfig

from {{ package_name }}.plmodules import BaseLitModule

class ClassificationModule(L.LightningModule):

class ClassificationModule(BaseLitModule):
def __init__(self, cfg: DictConfig) -> None:
super().__init__()
self.save_hyperparameters(logger=False)
self.cfg = cfg
self.net = hydra.utils.instantiate(cfg.model)
super().__init__(cfg)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def _parse_batch(self, batch: dict[str, dict[str, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
return batch["input"]["x"], batch["target"]["label"]
Comment on lines +14 to +15

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of _parse_batch assumes that the batch is always a pytree dictionary. To make the module more robust and backward-compatible with standard tuple/list batches (e.g., during migration or when using standard PyTorch datasets), we should support a fallback for tuple/list batches.

    def _parse_batch(self, batch: dict[str, dict[str, torch.Tensor]] | tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
        if isinstance(batch, (tuple, list)):
            return batch[0], batch[1]
        return batch["input"]["x"], batch["target"]["label"]


def _shared_step(self, batch, mode: str) -> torch.Tensor:
x, y = batch
def _shared_step(self, batch, mode: str) -> dict[str, torch.Tensor]:
x, y = self._parse_batch(batch)
logits = self(x)
loss = F.cross_entropy(logits, y)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
on_step = mode == "train"
self.log(f"{mode}/loss", loss, on_step=on_step, on_epoch=True, prog_bar=True)
self.log(f"{mode}/acc", acc, on_step=on_step, on_epoch=True, prog_bar=True)
return loss
res = {
"y_hat": preds,
"y": y,
}
if mode in ["train", "val"]:
loss = F.cross_entropy(logits, y)
acc = (preds == y).float().mean()
on_step = mode == "train"
self.log(f"{mode}/loss", loss, on_step=on_step, on_epoch=True, prog_bar=True)
self.log(f"{mode}/acc", acc, on_step=on_step, on_epoch=True, prog_bar=True)
res["loss"] = loss
return res

def training_step(self, batch, batch_idx: int) -> torch.Tensor:
return self._shared_step(batch, "train")

def validation_step(self, batch, batch_idx: int) -> None:
self._shared_step(batch, "val")
res = self._shared_step(batch, "train")
return res["loss"]

def configure_optimizers(self):
return hydra.utils.instantiate(self.cfg.opt.optimizer, params=self.parameters())
def validation_step(self, batch, batch_idx: int) -> dict[str, torch.Tensor]:
return self._shared_step(batch, "val")
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
import subprocess
import sys

import pytest

def test_fast_dev_run() -> None:

@pytest.mark.parametrize("overrides", [(), ("experiment=exp",)])
def test_fast_dev_run(overrides: tuple[str, ...]) -> None:
result = subprocess.run(
[
sys.executable,
"src/train.py",
*overrides,
"trainer.fast_dev_run=true",
"trainer.logger=false",
"trainer.enable_progress_bar=false",
Expand All @@ -18,3 +22,10 @@ def test_fast_dev_run() -> None:
text=True,
)
assert result.returncode == 0, result.stderr


def test_plmodule_exports() -> None:
from {{ package_name }}.plmodules import BaseLitModule, ClassificationModule

assert BaseLitModule.__name__ == "BaseLitModule"
assert ClassificationModule.__name__ == "ClassificationModule"
Loading
Loading