diff --git a/predict.py b/predict.py index 55a1ae3..bca343a 100755 --- a/predict.py +++ b/predict.py @@ -5,6 +5,9 @@ from hydra.utils import instantiate import torch from muonly.callbacks import PredictionWriter +from muonly.utils.resolvers import register_resolvers + +register_resolvers() def run(ckpt_file_path: Path, gpu_id: int): diff --git a/src/muonly/utils/resolvers.py b/src/muonly/utils/resolvers.py new file mode 100644 index 0000000..e443530 --- /dev/null +++ b/src/muonly/utils/resolvers.py @@ -0,0 +1,26 @@ +import secrets + +from coolname.impl import generate_slug +from omegaconf import OmegaConf + + +def register_resolvers() -> None: + """Register custom OmegaConf resolvers used across scripts.""" + OmegaConf.register_new_resolver( + "slug", + lambda pattern=2: generate_slug(pattern=pattern), + use_cache=True, + replace=True, + ) + + OmegaConf.register_new_resolver( + name="len", + resolver=len, + replace=True, + ) + + OmegaConf.register_new_resolver( + name="randbits", + resolver=secrets.randbits, + replace=True, + ) diff --git a/train.py b/train.py index 56f6634..a272c62 100755 --- a/train.py +++ b/train.py @@ -2,7 +2,6 @@ import logging from pathlib import Path import warnings -import secrets import os from aim.pytorch_lightning import AimLogger import hydra @@ -10,13 +9,12 @@ import torch from lightning.pytorch import LightningDataModule, Trainer from lightning import seed_everything -from omegaconf import DictConfig -from omegaconf import OmegaConf -from coolname.impl import generate_slug +from omegaconf import DictConfig, OmegaConf from tqdm import TqdmExperimentalWarning import mplhep as mh from muonly.nn.utils import init_params from muonly.utils.logging import log_everything +from muonly.utils.resolvers import register_resolvers mh.style.use("CMS") @@ -32,22 +30,7 @@ warnings.filterwarnings("ignore", category=TqdmExperimentalWarning) -OmegaConf.register_new_resolver( - "slug", - lambda pattern=2: generate_slug(pattern=pattern), - use_cache=True, - replace=True, -) - -OmegaConf.register_new_resolver( - name="len", - resolver=len, -) - -OmegaConf.register_new_resolver( - name="randbits", - resolver=secrets.randbits, -) +register_resolvers() @hydra.main(