From d2f499457c7f6d37f53f2c59e3870af6a71d01e5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 15 Apr 2026 08:06:19 +0000 Subject: [PATCH 1/2] Initial plan From d06793f72005af4a48ea85a97c7ad3b8c841b1fa Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 15 Apr 2026 08:09:23 +0000 Subject: [PATCH 2/2] Extract OmegaConf resolver registration into shared utility Move custom resolver registration (slug, len, randbits) from train.py into src/muonly/utils/resolvers.py and call it from both train.py and predict.py to fix unknown resolver errors when loading saved configs. Agent-Logs-Url: https://github.com/cms-kr/DeepMuonReco/sessions/e352cbc8-c6f4-404a-9a7c-179355774a13 Co-authored-by: slowmoyang <20718100+slowmoyang@users.noreply.github.com> --- predict.py | 3 +++ src/muonly/utils/resolvers.py | 26 ++++++++++++++++++++++++++ train.py | 23 +++-------------------- 3 files changed, 32 insertions(+), 20 deletions(-) create mode 100644 src/muonly/utils/resolvers.py diff --git a/predict.py b/predict.py index 55a1ae3..bca343a 100755 --- a/predict.py +++ b/predict.py @@ -5,6 +5,9 @@ from hydra.utils import instantiate import torch from muonly.callbacks import PredictionWriter +from muonly.utils.resolvers import register_resolvers + +register_resolvers() def run(ckpt_file_path: Path, gpu_id: int): diff --git a/src/muonly/utils/resolvers.py b/src/muonly/utils/resolvers.py new file mode 100644 index 0000000..e443530 --- /dev/null +++ b/src/muonly/utils/resolvers.py @@ -0,0 +1,26 @@ +import secrets + +from coolname.impl import generate_slug +from omegaconf import OmegaConf + + +def register_resolvers() -> None: + """Register custom OmegaConf resolvers used across scripts.""" + OmegaConf.register_new_resolver( + "slug", + lambda pattern=2: generate_slug(pattern=pattern), + use_cache=True, + replace=True, + ) + + OmegaConf.register_new_resolver( + name="len", + resolver=len, + replace=True, + ) + + OmegaConf.register_new_resolver( + name="randbits", + resolver=secrets.randbits, + replace=True, + ) diff --git a/train.py b/train.py index 56f6634..a272c62 100755 --- a/train.py +++ b/train.py @@ -2,7 +2,6 @@ import logging from pathlib import Path import warnings -import secrets import os from aim.pytorch_lightning import AimLogger import hydra @@ -10,13 +9,12 @@ import torch from lightning.pytorch import LightningDataModule, Trainer from lightning import seed_everything -from omegaconf import DictConfig -from omegaconf import OmegaConf -from coolname.impl import generate_slug +from omegaconf import DictConfig, OmegaConf from tqdm import TqdmExperimentalWarning import mplhep as mh from muonly.nn.utils import init_params from muonly.utils.logging import log_everything +from muonly.utils.resolvers import register_resolvers mh.style.use("CMS") @@ -32,22 +30,7 @@ warnings.filterwarnings("ignore", category=TqdmExperimentalWarning) -OmegaConf.register_new_resolver( - "slug", - lambda pattern=2: generate_slug(pattern=pattern), - use_cache=True, - replace=True, -) - -OmegaConf.register_new_resolver( - name="len", - resolver=len, -) - -OmegaConf.register_new_resolver( - name="randbits", - resolver=secrets.randbits, -) +register_resolvers() @hydra.main(