Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from hydra.utils import instantiate
import torch
from muonly.callbacks import PredictionWriter
from muonly.utils.resolvers import register_resolvers

register_resolvers()


def run(ckpt_file_path: Path, gpu_id: int):
Expand Down
26 changes: 26 additions & 0 deletions src/muonly/utils/resolvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import secrets

from coolname.impl import generate_slug
from omegaconf import OmegaConf


def register_resolvers() -> None:
"""Register custom OmegaConf resolvers used across scripts."""
OmegaConf.register_new_resolver(
"slug",
lambda pattern=2: generate_slug(pattern=pattern),
use_cache=True,
replace=True,
)

OmegaConf.register_new_resolver(
name="len",
resolver=len,
replace=True,
)

OmegaConf.register_new_resolver(
name="randbits",
resolver=secrets.randbits,
replace=True,
)
23 changes: 3 additions & 20 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@
import logging
from pathlib import Path
import warnings
import secrets
import os
from aim.pytorch_lightning import AimLogger
import hydra
from hydra.utils import instantiate
import torch
from lightning.pytorch import LightningDataModule, Trainer
from lightning import seed_everything
from omegaconf import DictConfig
from omegaconf import OmegaConf
from coolname.impl import generate_slug
from omegaconf import DictConfig, OmegaConf
from tqdm import TqdmExperimentalWarning
import mplhep as mh
from muonly.nn.utils import init_params
from muonly.utils.logging import log_everything
from muonly.utils.resolvers import register_resolvers

mh.style.use("CMS")

Expand All @@ -32,22 +30,7 @@

warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)

OmegaConf.register_new_resolver(
"slug",
lambda pattern=2: generate_slug(pattern=pattern),
use_cache=True,
replace=True,
)

OmegaConf.register_new_resolver(
name="len",
resolver=len,
)

OmegaConf.register_new_resolver(
name="randbits",
resolver=secrets.randbits,
)
register_resolvers()


@hydra.main(
Expand Down