Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 30 additions & 40 deletions deepethogram/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.tuner import Tuner

try:
from ray.tune import CLIReporter, get_trial_dir # noqa: F401
Expand Down Expand Up @@ -35,6 +37,12 @@
log = logging.getLogger(__name__)


def trainer_device_kwargs(cfg: DictConfig) -> dict:
if torch.cuda.is_available():
return {"accelerator": "gpu", "devices": [cfg.compute.gpu_id]}
return {"accelerator": "cpu", "devices": 1}


class BaseLightningModule(pl.LightningModule):
"""Base class for all Lightning modules for training"""

Expand Down Expand Up @@ -279,7 +287,7 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s

if cfg.compute.batch_size == "auto" or cfg.train.lr == "auto":
trainer = pl.Trainer(
gpus=[cfg.compute.gpu_id],
**trainer_device_kwargs(cfg),
precision=16 if cfg.compute.fp16 else 32,
limit_train_batches=1.0,
limit_val_batches=1.0,
Expand Down Expand Up @@ -310,7 +318,7 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s
lightning_module.gpu_transforms = gpu_transforms
log.debug("new: {}".format(lightning_module.gpu_transforms))

tuner = pl.tuner.tuning.Tuner(trainer)
tuner = Tuner(trainer)
# hack for lightning to find the batch size
cfg.batch_size = 2 # to start

Expand All @@ -326,7 +334,7 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s
if cfg.compute.batch_size == "auto":
max_trials = int(math.log2(cfg.compute.max_batch_size)) - int(math.log2(cfg.compute.min_batch_size))
log.info("max trials: {}".format(max_trials))
new_batch_size = trainer.tuner.scale_batch_size(
new_batch_size = tuner.scale_batch_size(
lightning_module,
mode="power",
steps_per_trial=30,
Expand All @@ -336,7 +344,7 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s
cfg.compute.batch_size = new_batch_size
log.info("auto-tuned batch size: {}".format(new_batch_size))
if cfg.train.lr == "auto":
lr_finder = trainer.tuner.lr_find(lightning_module, early_stop_threshold=None, min_lr=1e-6, max_lr=10.0)
lr_finder = tuner.lr_find(lightning_module, early_stop_threshold=None, min_lr=1e-6, max_lr=10.0)
# log.info(lr_finder.results)
plt.style.use("seaborn")
fig = lr_finder.plot(suggest=True, show=False)
Expand Down Expand Up @@ -381,41 +389,23 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s
tensorboard_logger = pl.loggers.tensorboard.TensorBoardLogger(os.getcwd())
refresh_rate = 1

# tuning messes with the callbacks
try:
# will be deprecated in the future; pytorch lightning updated their kwargs for this function
# don't like how they keep updating the api without proper deprecation warnings, etc.
trainer = pl.Trainer(
gpus=[cfg.compute.gpu_id],
precision=16 if cfg.compute.fp16 else 32,
limit_train_batches=steps_per_epoch["train"],
limit_val_batches=steps_per_epoch["val"],
limit_test_batches=steps_per_epoch["test"],
logger=tensorboard_logger,
max_epochs=cfg.train.num_epochs,
num_sanity_val_steps=0,
callbacks=callback_list,
reload_dataloaders_every_epoch=True,
progress_bar_refresh_rate=refresh_rate,
profiler=profiler,
log_every_n_steps=1,
)

except TypeError:
trainer = pl.Trainer(
gpus=[cfg.compute.gpu_id],
precision=16 if cfg.compute.fp16 else 32,
limit_train_batches=steps_per_epoch["train"],
limit_val_batches=steps_per_epoch["val"],
limit_test_batches=steps_per_epoch["test"],
logger=tensorboard_logger,
max_epochs=cfg.train.num_epochs,
num_sanity_val_steps=0,
callbacks=callback_list,
reload_dataloaders_every_n_epochs=1,
progress_bar_refresh_rate=refresh_rate,
profiler=profiler,
log_every_n_steps=1,
)
if refresh_rate > 0:
callback_list.append(TQDMProgressBar(refresh_rate=refresh_rate))

trainer = pl.Trainer(
**trainer_device_kwargs(cfg),
precision=16 if cfg.compute.fp16 else 32,
limit_train_batches=steps_per_epoch["train"],
limit_val_batches=steps_per_epoch["val"],
limit_test_batches=steps_per_epoch["test"],
logger=tensorboard_logger,
max_epochs=cfg.train.num_epochs,
num_sanity_val_steps=0,
callbacks=callback_list,
reload_dataloaders_every_n_epochs=1,
enable_progress_bar=refresh_rate > 0,
profiler=profiler,
log_every_n_steps=1,
)
torch.cuda.empty_cache()
return trainer
46 changes: 20 additions & 26 deletions deepethogram/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@ def __init__(self):
super().__init__()
log.info("callback initialized")

def on_init_end(self, trainer):
log.info("on init start")

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
log.debug("on train batch start")

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
log.debug("on train batch end")

def on_train_epoch_start(self, trainer, pl_module):
Expand All @@ -43,12 +40,6 @@ def on_test_epoch_start(self, trainer, pl_module):
def on_test_epoch_end(self, trainer, pl_module):
log.info("on test epoch end")

def on_epoch_start(self, trainer, pl_module):
log.info("on epoch start")

def on_epoch_end(self, trainer, pl_module):
log.info("on epoch end")

def on_train_start(self, trainer, pl_module):
log.info("on train start")

Expand All @@ -61,8 +52,9 @@ def on_validation_start(self, trainer, pl_module):
def on_validation_end(self, trainer, pl_module):
log.info("on validation end")

def on_keyboard_interrupt(self, trainer, pl_module):
log.info("on keyboard interrupt")
def on_exception(self, trainer, pl_module, exception):
if isinstance(exception, KeyboardInterrupt):
log.info("on keyboard interrupt")


class FPSCallback(Callback):
Expand Down Expand Up @@ -92,22 +84,22 @@ def end_batch(self, split, batch, pl_module, eps: float = 1e-7):

pl_module.metrics.buffer.append(split, {"fps": fps})

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
self.start_timer("train")

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.end_batch("train", batch, pl_module)

def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx=0):
self.start_timer("val")

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
self.end_batch("val", batch, pl_module)

def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx=0):
self.start_timer("speedtest")

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
self.end_batch("speedtest", batch, pl_module)


Expand Down Expand Up @@ -162,8 +154,9 @@ def on_test_epoch_end(self, trainer, pl_module):
log_metrics(pl_module, "test")
# pl_module.metrics.end_epoch('speedtest')

def on_keyboard_interrupt(self, trainer, pl_module):
pl_module.metrics.buffer.clear()
def on_exception(self, trainer, pl_module, exception):
if isinstance(exception, KeyboardInterrupt):
pl_module.metrics.buffer.clear()


class ExampleImagesCallback(Callback):
Expand All @@ -188,13 +181,13 @@ def on_validation_epoch_end(self, trainer, pl_module):
def on_test_epoch_end(self, trainer, pl_module):
self.reset_cnt(pl_module, "test")

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
pl_module.viz_cnt["train"] += 1

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
pl_module.viz_cnt["val"] += 1

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
pl_module.viz_cnt["test"] += 1


Expand All @@ -208,8 +201,9 @@ def checkpoint(self, pl_module):
def on_train_epoch_end(self, trainer, pl_module):
self.checkpoint(pl_module)

def on_keyboard_interrupt(self, trainer, pl_module):
self.checkpoint(pl_module)
def on_exception(self, trainer, pl_module, exception):
if isinstance(exception, KeyboardInterrupt):
self.checkpoint(pl_module)


class StopperCallback(Callback):
Expand Down
3 changes: 2 additions & 1 deletion deepethogram/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def __init__(self, model: nn.Module, path_to_pretrained_weights, alpha: float, b
self.beta = beta

assert os.path.isfile(path_to_pretrained_weights)
state = torch.load(path_to_pretrained_weights, map_location="cpu")
# DeepEthogram checkpoints include training metadata, so they need the legacy full checkpoint loader.
state = torch.load(path_to_pretrained_weights, map_location="cpu", weights_only=False)

pretrained_state = state["state_dict"]

Expand Down
1 change: 0 additions & 1 deletion deepethogram/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def initialize_scheduler(optimizer, cfg: DictConfig, mode: str = "max", reductio
mode=mode,
factor=reduction_factor,
patience=cfg.train.patience,
verbose=True,
min_lr=cfg.train.min_lr,
)
scheduler.name = "plateau"
Expand Down
3 changes: 2 additions & 1 deletion deepethogram/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ def load_state_from_dict(model, state_dict):


def load_state_dict_from_file(weights_file, distributed: bool = False):
state = torch.load(weights_file, map_location="cpu")
# DeepEthogram checkpoints include training metadata, so they need the legacy full checkpoint loader.
state = torch.load(weights_file, map_location="cpu", weights_only=False)

is_pure_weights = "epoch" not in list(state.keys())
# load params
Expand Down
2 changes: 1 addition & 1 deletion modal_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

APP_NAME = "deepethogram-gpu-tests"
VOLUME_NAME = "deepethogram-test-data"
CUDA_IMAGE = "nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04"
CUDA_IMAGE = "nvidia/cuda:12.6.3-cudnn-runtime-ubuntu22.04"
UV_VERSION = "0.6.14"
DEFAULT_GPU = os.environ.get("DEG_MODAL_GPU", "T4")
REMOTE_WORKDIR = Path("/workspace")
Expand Down
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ dependencies = [
"PySide6>=6.6.0",
"scikit-learn>=1.2,<1.4",
"scipy>=1.9,<1.11",
"torch==2.4.1",
"torch==2.8.0",
"tensorboard",
"tqdm",
"vidio",
"pytorch_lightning==1.6.5",
"pytorch_lightning==2.6.0",
]

[dependency-groups]
Expand All @@ -52,7 +53,7 @@ deepethogram = "deepethogram.gui.main:entry"
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
{ index = "pytorch-cu124", marker = "sys_platform == 'linux'" },
{ index = "pytorch-cu126", marker = "sys_platform == 'linux'" },
]

[[tool.uv.index]]
Expand All @@ -61,8 +62,8 @@ url = "https://download.pytorch.org/whl/cpu"
explicit = true

[[tool.uv.index]]
name = "pytorch-cu124"
url = "https://download.pytorch.org/whl/cu124"
name = "pytorch-cu126"
url = "https://download.pytorch.org/whl/cu126"
explicit = true

[tool.setuptools.packages.find]
Expand Down
60 changes: 60 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
from types import SimpleNamespace

from deepethogram import callbacks
from deepethogram.callbacks import CheckpointCallback, MetricsCallback


class Buffer:
def __init__(self):
self.cleared = False

def clear(self):
self.cleared = True


def test_metrics_callback_clears_buffer_on_keyboard_interrupt():
buffer = Buffer()
pl_module = SimpleNamespace(metrics=SimpleNamespace(buffer=buffer))

MetricsCallback().on_exception(None, pl_module, KeyboardInterrupt())

assert buffer.cleared


def test_metrics_callback_ignores_other_exceptions():
buffer = Buffer()
pl_module = SimpleNamespace(metrics=SimpleNamespace(buffer=buffer))

MetricsCallback().on_exception(None, pl_module, RuntimeError("boom"))

assert not buffer.cleared


def test_checkpoint_callback_saves_on_keyboard_interrupt(monkeypatch):
calls = []
model = object()
pl_module = SimpleNamespace(model=model, current_epoch=7)

def checkpoint(model, directory, epoch):
calls.append((model, directory, epoch))

monkeypatch.setattr(callbacks.utils, "checkpoint", checkpoint)

CheckpointCallback().on_exception(None, pl_module, KeyboardInterrupt())

assert len(calls) == 1
assert calls[0][0] is model
assert calls[0][1] == os.getcwd()
assert calls[0][2] == 7


def test_checkpoint_callback_ignores_other_exceptions(monkeypatch):
calls = []
pl_module = SimpleNamespace(model=object(), current_epoch=7)

monkeypatch.setattr(callbacks.utils, "checkpoint", lambda *args: calls.append(args))

CheckpointCallback().on_exception(None, pl_module, RuntimeError("boom"))

assert calls == []
Loading
Loading