-
Notifications
You must be signed in to change notification settings - Fork 1
refactor: enhance configuration and modularity in aimx-hydra-lightnin… #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| compile: false | ||
| precision: "32-true" | ||
| fp32_matmul_precision: "highest" | ||
| sdpa: ["efficient", "flash", "math"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # @package _global_ | ||
|
|
||
| # Run with: | ||
| # uv run python src/train.py experiment=exp | ||
|
|
||
| defaults: | ||
| - override /datamodule: dummy | ||
| - override /model: mlp | ||
| - override /plmodule: classifier | ||
| - override /callbacks: default | ||
| - override /trainer: default | ||
| - override /opt: default | ||
| - override /accelerate: default | ||
| - override /logger: aim | ||
|
|
||
| task_name: train_exp | ||
| tags: ["exp", "{{ preset }}"] | ||
|
|
||
| seed: 42 | ||
|
|
||
| autoresearch: | ||
| experiment_name: "{{ project_name }}-exp" | ||
|
|
||
| trainer: | ||
| max_epochs: 2 | ||
| gradient_clip_val: 0.5 | ||
|
|
||
| datamodule: | ||
| batch_size: 32 | ||
|
|
||
| model: | ||
| hidden_dim: 32 | ||
|
|
||
| opt: | ||
| optimizer: | ||
| lr: 0.002 | ||
|
|
||
| accelerate: | ||
| compile: false |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,93 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import hydra | ||
| import lightning as L | ||
| import torch | ||
| from torch.nn.attention import SDPBackend, sdpa_kernel | ||
| from omegaconf import DictConfig | ||
|
|
||
|
|
||
| class BaseLitModule(L.LightningModule): | ||
| def __init__(self, cfg: DictConfig) -> None: | ||
| super().__init__() | ||
|
|
||
| self.save_hyperparameters(logger=False) | ||
| self.cfg = cfg | ||
| self.net = hydra.utils.instantiate(cfg.model) | ||
| self._net_compiled = False | ||
|
|
||
| sdpa_map = { | ||
| "cudnn": SDPBackend.CUDNN_ATTENTION, | ||
| "math": SDPBackend.MATH, | ||
| "efficient": SDPBackend.EFFICIENT_ATTENTION, | ||
| "flash": SDPBackend.FLASH_ATTENTION, | ||
| } | ||
|
|
||
| self.sdpa_backends = [sdpa_map[backend] for backend in self.cfg.accelerate.get("sdpa", ["math"])] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation directly accesses accelerate_cfg = self.cfg.get("accelerate")
sdpa_names = accelerate_cfg.get("sdpa", ["math"]) if accelerate_cfg else ["math"]
self.sdpa_backends = [sdpa_map[backend] for backend in sdpa_names if backend in sdpa_map] |
||
|
|
||
| def forward(self, *args, **kwargs): | ||
| return self._model_forward(*args, **kwargs) | ||
|
|
||
| def _model_forward(self, *args, **kwargs): | ||
| with sdpa_kernel(self.sdpa_backends): | ||
| return self.net(*args, **kwargs) | ||
|
|
||
| def setup(self, stage: str) -> None: | ||
| if self.cfg.accelerate.compile and stage == "fit" and hasattr(torch, "compile") and not self._net_compiled: | ||
| self.net = torch.compile(self.net) | ||
| self._net_compiled = True | ||
|
Comment on lines
+35
to
+38
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Directly accessing def setup(self, stage: str) -> None:
accelerate_cfg = self.cfg.get("accelerate")
compile_enabled = accelerate_cfg.get("compile", False) if accelerate_cfg else False
if compile_enabled and stage == "fit" and hasattr(torch, "compile") and not self._net_compiled:
self.net = torch.compile(self.net)
self._net_compiled = True |
||
|
|
||
| def get_lr_scheduler(self, optimizer): | ||
| scheduler = hydra.utils.instantiate(self.cfg.opt.scheduler)(optimizer=optimizer) | ||
| kwargs = { | ||
| key: value for key, value in self.cfg.opt.items() if key not in ["optimizer", "scheduler"] | ||
| } | ||
| return { | ||
| "scheduler": scheduler, | ||
| **kwargs, | ||
| } | ||
|
Comment on lines
+40
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Directly accessing def get_lr_scheduler(self, optimizer):
opt_cfg = self.cfg.get("opt", {})
scheduler_cfg = opt_cfg.get("scheduler")
scheduler = hydra.utils.instantiate(scheduler_cfg)(optimizer=optimizer)
kwargs = {
key: value for key, value in opt_cfg.items() if key not in ["optimizer", "scheduler"]
}
return {
"scheduler": scheduler,
**kwargs,
} |
||
|
|
||
| def get_optimizer(self): | ||
| if self.cfg.opt.optimizer._target_ == "torch.optim.AdamW": | ||
| optimizer = hydra.utils.instantiate( | ||
| self.cfg.opt.optimizer, | ||
| params=filter(lambda p: p.requires_grad, self.net.parameters()), | ||
| ) | ||
| elif self.cfg.opt.optimizer._target_ == "colossalai.nn.optimizer.HybridAdam": | ||
| optimizer = hydra.utils.instantiate( | ||
| self.cfg.opt.optimizer, | ||
| model_params=filter(lambda p: p.requires_grad, self.net.parameters()), | ||
| ) | ||
| else: | ||
| optimizer = hydra.utils.instantiate( | ||
| self.cfg.opt.optimizer, | ||
| params=filter(lambda p: p.requires_grad, self.net.parameters()), | ||
| ) | ||
| return optimizer | ||
|
Comment on lines
+50
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation of def get_optimizer(self):
opt_cfg = self.cfg.get("opt")
optimizer_cfg = opt_cfg.get("optimizer") if opt_cfg else None
if not optimizer_cfg:
raise ValueError("Optimizer configuration is missing in 'cfg.opt.optimizer'")
target = optimizer_cfg.get("_target_")
params = list(filter(lambda p: p.requires_grad, self.net.parameters()))
if target == "colossalai.nn.optimizer.HybridAdam":
return hydra.utils.instantiate(optimizer_cfg, model_params=params)
return hydra.utils.instantiate(optimizer_cfg, params=params) |
||
|
|
||
| def configure_optimizers(self): | ||
| optimizer = self.get_optimizer() | ||
| if not self.cfg.opt.get("scheduler"): | ||
| return optimizer | ||
|
|
||
| lr_scheduler = self.get_lr_scheduler(optimizer) | ||
| return { | ||
| "optimizer": optimizer, | ||
| "lr_scheduler": lr_scheduler, | ||
| } | ||
|
Comment on lines
+68
to
+77
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Directly calling def configure_optimizers(self):
optimizer = self.get_optimizer()
opt_cfg = self.cfg.get("opt")
if not opt_cfg or not opt_cfg.get("scheduler"):
return optimizer
lr_scheduler = self.get_lr_scheduler(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
} |
||
|
|
||
| def _aim_experiments(self): | ||
| for logger in self.loggers: | ||
| experiment = getattr(logger, "experiment", None) | ||
| if experiment is not None and hasattr(experiment, "track"): | ||
| yield experiment | ||
|
|
||
| def _instantiate_metric(self, name: str, defaults: dict[str, dict[str, object]]): | ||
| metrics_cfg = self.cfg.get("metrics", {}) | ||
| metric_cfg = metrics_cfg[name] if name in metrics_cfg else defaults[name] | ||
| return hydra.utils.instantiate(metric_cfg) | ||
|
Comment on lines
+85
to
+88
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If def _instantiate_metric(self, name: str, defaults: dict[str, dict[str, object]]):
metrics_cfg = self.cfg.get("metrics") or {}
metric_cfg = metrics_cfg.get(name) if name in metrics_cfg else defaults.get(name)
if metric_cfg is None:
raise ValueError(f"Metric '{name}' is not defined in config or defaults.")
return hydra.utils.instantiate(metric_cfg) |
||
|
|
||
|
|
||
| from {{ package_name }}.plmodules.classifier import ClassificationModule | ||
|
|
||
| __all__ = ["ClassificationModule"] | ||
| __all__ = ["BaseLitModule", "ClassificationModule"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,38 +1,39 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import hydra | ||
| import lightning as L | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from omegaconf import DictConfig | ||
|
|
||
| from {{ package_name }}.plmodules import BaseLitModule | ||
|
|
||
| class ClassificationModule(L.LightningModule): | ||
|
|
||
| class ClassificationModule(BaseLitModule): | ||
| def __init__(self, cfg: DictConfig) -> None: | ||
| super().__init__() | ||
| self.save_hyperparameters(logger=False) | ||
| self.cfg = cfg | ||
| self.net = hydra.utils.instantiate(cfg.model) | ||
| super().__init__(cfg) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| return self.net(x) | ||
| def _parse_batch(self, batch: dict[str, dict[str, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]: | ||
| return batch["input"]["x"], batch["target"]["label"] | ||
|
Comment on lines
+14
to
+15
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation of def _parse_batch(self, batch: dict[str, dict[str, torch.Tensor]] | tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
if isinstance(batch, (tuple, list)):
return batch[0], batch[1]
return batch["input"]["x"], batch["target"]["label"] |
||
|
|
||
| def _shared_step(self, batch, mode: str) -> torch.Tensor: | ||
| x, y = batch | ||
| def _shared_step(self, batch, mode: str) -> dict[str, torch.Tensor]: | ||
| x, y = self._parse_batch(batch) | ||
| logits = self(x) | ||
| loss = F.cross_entropy(logits, y) | ||
| preds = torch.argmax(logits, dim=1) | ||
| acc = (preds == y).float().mean() | ||
| on_step = mode == "train" | ||
| self.log(f"{mode}/loss", loss, on_step=on_step, on_epoch=True, prog_bar=True) | ||
| self.log(f"{mode}/acc", acc, on_step=on_step, on_epoch=True, prog_bar=True) | ||
| return loss | ||
| res = { | ||
| "y_hat": preds, | ||
| "y": y, | ||
| } | ||
| if mode in ["train", "val"]: | ||
| loss = F.cross_entropy(logits, y) | ||
| acc = (preds == y).float().mean() | ||
| on_step = mode == "train" | ||
| self.log(f"{mode}/loss", loss, on_step=on_step, on_epoch=True, prog_bar=True) | ||
| self.log(f"{mode}/acc", acc, on_step=on_step, on_epoch=True, prog_bar=True) | ||
| res["loss"] = loss | ||
| return res | ||
|
|
||
| def training_step(self, batch, batch_idx: int) -> torch.Tensor: | ||
| return self._shared_step(batch, "train") | ||
|
|
||
| def validation_step(self, batch, batch_idx: int) -> None: | ||
| self._shared_step(batch, "val") | ||
| res = self._shared_step(batch, "train") | ||
| return res["loss"] | ||
|
|
||
| def configure_optimizers(self): | ||
| return hydra.utils.instantiate(self.cfg.opt.optimizer, params=self.parameters()) | ||
| def validation_step(self, batch, batch_idx: int) -> dict[str, torch.Tensor]: | ||
| return self._shared_step(batch, "val") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generated template still declares
torch>=2.1inpyproject.toml, but this new top-leveltorch.nn.attentionimport is not available in allowed 2.1.x environments. In a repo that resolves or pins to that declared floor, simply importing<pkg>.plmodules—including the new export test and any training run—fails before Lightning starts, so either bump the template dependency floor or keep a fallback for older supported torch versions.Useful? React with 👍 / 👎.