refactor: enhance configuration and modularity in aimx-hydra-lightnin…#20
Conversation
…g-builder - Updated SKILL.md to clarify the use of `experiment` in configuration and introduced design principles for module cohesion and configuration management. - Modified README.md to include instructions for using experiment configurations. - Enhanced training configuration to support explicit experiment overrides. - Introduced a new dataset class for better data handling and updated the data module to utilize it. - Refactored the BaseLitModule to streamline optimizer and scheduler management. - Improved the ClassificationModule to utilize the new base class and enhance batch parsing. - Added tests for fast development runs and module exports to ensure functionality. - Expanded architecture documentation to emphasize configuration best practices and domain adapter patterns.
There was a problem hiding this comment.
Code Review
This pull request updates the Aimx AutoResearch template and documentation to introduce a centralized BaseLitModule, separate baseline configurations from experiment overrides via a new opt and experiment config structure, and transition datasets and task modules to use pytree-structured batches. Feedback focuses on improving the robustness of the newly added BaseLitModule and ClassificationModule by using defensive configuration parsing to prevent potential AttributeError or TypeError exceptions when config keys are missing or null, and adding backward compatibility for tuple/list batches in batch parsing.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| "flash": SDPBackend.FLASH_ATTENTION, | ||
| } | ||
|
|
||
| self.sdpa_backends = [sdpa_map[backend] for backend in self.cfg.accelerate.get("sdpa", ["math"])] |
There was a problem hiding this comment.
The current implementation directly accesses self.cfg.accelerate and assumes that the keys are always present. If accelerate is missing or set to null in the configuration, this will raise an AttributeError. Additionally, if an invalid backend name is provided, it will raise a KeyError. We should make this more robust by using defensive .get() calls and filtering out invalid backends.
accelerate_cfg = self.cfg.get("accelerate")
sdpa_names = accelerate_cfg.get("sdpa", ["math"]) if accelerate_cfg else ["math"]
self.sdpa_backends = [sdpa_map[backend] for backend in sdpa_names if backend in sdpa_map]| def setup(self, stage: str) -> None: | ||
| if self.cfg.accelerate.compile and stage == "fit" and hasattr(torch, "compile") and not self._net_compiled: | ||
| self.net = torch.compile(self.net) | ||
| self._net_compiled = True |
There was a problem hiding this comment.
Directly accessing self.cfg.accelerate.compile can raise an AttributeError if accelerate is missing or None in the configuration. It is safer to use defensive .get() calls to retrieve the compile setting.
def setup(self, stage: str) -> None:
accelerate_cfg = self.cfg.get("accelerate")
compile_enabled = accelerate_cfg.get("compile", False) if accelerate_cfg else False
if compile_enabled and stage == "fit" and hasattr(torch, "compile") and not self._net_compiled:
self.net = torch.compile(self.net)
self._net_compiled = True| def get_lr_scheduler(self, optimizer): | ||
| scheduler = hydra.utils.instantiate(self.cfg.opt.scheduler)(optimizer=optimizer) | ||
| kwargs = { | ||
| key: value for key, value in self.cfg.opt.items() if key not in ["optimizer", "scheduler"] | ||
| } | ||
| return { | ||
| "scheduler": scheduler, | ||
| **kwargs, | ||
| } |
There was a problem hiding this comment.
Directly accessing self.cfg.opt can raise an AttributeError if opt is missing or None in the configuration. We should use defensive .get() calls to safely retrieve the optimizer and scheduler configurations.
def get_lr_scheduler(self, optimizer):
opt_cfg = self.cfg.get("opt", {})
scheduler_cfg = opt_cfg.get("scheduler")
scheduler = hydra.utils.instantiate(scheduler_cfg)(optimizer=optimizer)
kwargs = {
key: value for key, value in opt_cfg.items() if key not in ["optimizer", "scheduler"]
}
return {
"scheduler": scheduler,
**kwargs,
}| def get_optimizer(self): | ||
| if self.cfg.opt.optimizer._target_ == "torch.optim.AdamW": | ||
| optimizer = hydra.utils.instantiate( | ||
| self.cfg.opt.optimizer, | ||
| params=filter(lambda p: p.requires_grad, self.net.parameters()), | ||
| ) | ||
| elif self.cfg.opt.optimizer._target_ == "colossalai.nn.optimizer.HybridAdam": | ||
| optimizer = hydra.utils.instantiate( | ||
| self.cfg.opt.optimizer, | ||
| model_params=filter(lambda p: p.requires_grad, self.net.parameters()), | ||
| ) | ||
| else: | ||
| optimizer = hydra.utils.instantiate( | ||
| self.cfg.opt.optimizer, | ||
| params=filter(lambda p: p.requires_grad, self.net.parameters()), | ||
| ) | ||
| return optimizer |
There was a problem hiding this comment.
The current implementation of get_optimizer has redundant code branches (the AdamW and else branches are identical) and directly accesses nested attributes on self.cfg.opt which can raise an AttributeError if opt or optimizer is missing. Additionally, passing a single-use filter iterator directly to the optimizer can cause issues with some custom optimizers or multiple parameter groups. We should refactor this to be cleaner, safer, and convert the filtered parameters to a list.
def get_optimizer(self):
opt_cfg = self.cfg.get("opt")
optimizer_cfg = opt_cfg.get("optimizer") if opt_cfg else None
if not optimizer_cfg:
raise ValueError("Optimizer configuration is missing in 'cfg.opt.optimizer'")
target = optimizer_cfg.get("_target_")
params = list(filter(lambda p: p.requires_grad, self.net.parameters()))
if target == "colossalai.nn.optimizer.HybridAdam":
return hydra.utils.instantiate(optimizer_cfg, model_params=params)
return hydra.utils.instantiate(optimizer_cfg, params=params)| def configure_optimizers(self): | ||
| optimizer = self.get_optimizer() | ||
| if not self.cfg.opt.get("scheduler"): | ||
| return optimizer | ||
|
|
||
| lr_scheduler = self.get_lr_scheduler(optimizer) | ||
| return { | ||
| "optimizer": optimizer, | ||
| "lr_scheduler": lr_scheduler, | ||
| } |
There was a problem hiding this comment.
Directly calling self.cfg.opt.get("scheduler") can raise an AttributeError if opt is missing or None in the configuration. We should use defensive .get() calls to safely check for the scheduler.
def configure_optimizers(self):
optimizer = self.get_optimizer()
opt_cfg = self.cfg.get("opt")
if not opt_cfg or not opt_cfg.get("scheduler"):
return optimizer
lr_scheduler = self.get_lr_scheduler(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
}| def _instantiate_metric(self, name: str, defaults: dict[str, dict[str, object]]): | ||
| metrics_cfg = self.cfg.get("metrics", {}) | ||
| metric_cfg = metrics_cfg[name] if name in metrics_cfg else defaults[name] | ||
| return hydra.utils.instantiate(metric_cfg) |
There was a problem hiding this comment.
If metrics is configured as None (e.g., metrics: null), self.cfg.get("metrics", {}) will return None, and the subsequent name in metrics_cfg check will raise a TypeError. We should use defensive programming to handle this case gracefully.
def _instantiate_metric(self, name: str, defaults: dict[str, dict[str, object]]):
metrics_cfg = self.cfg.get("metrics") or {}
metric_cfg = metrics_cfg.get(name) if name in metrics_cfg else defaults.get(name)
if metric_cfg is None:
raise ValueError(f"Metric '{name}' is not defined in config or defaults.")
return hydra.utils.instantiate(metric_cfg)| def _parse_batch(self, batch: dict[str, dict[str, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]: | ||
| return batch["input"]["x"], batch["target"]["label"] |
There was a problem hiding this comment.
The current implementation of _parse_batch assumes that the batch is always a pytree dictionary. To make the module more robust and backward-compatible with standard tuple/list batches (e.g., during migration or when using standard PyTorch datasets), we should support a fallback for tuple/list batches.
def _parse_batch(self, batch: dict[str, dict[str, torch.Tensor]] | tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
if isinstance(batch, (tuple, list)):
return batch[0], batch[1]
return batch["input"]["x"], batch["target"]["label"]There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1aab5cd035
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| import hydra | ||
| import lightning as L | ||
| import torch | ||
| from torch.nn.attention import SDPBackend, sdpa_kernel |
There was a problem hiding this comment.
Raise the torch floor for the SDPA import
The generated template still declares torch>=2.1 in pyproject.toml, but this new top-level torch.nn.attention import is not available in allowed 2.1.x environments. In a repo that resolves or pins to that declared floor, simply importing <pkg>.plmodules—including the new export test and any training run—fails before Lightning starts, so either bump the template dependency floor or keep a fallback for older supported torch versions.
Useful? React with 👍 / 👎.
…g-builder
experimentin configuration and introduced design principles for module cohesion and configuration management.