From da24fe0c01640654132a3269d4c82df266766ce7 Mon Sep 17 00:00:00 2001 From: Jack Wu Date: Wed, 10 Dec 2025 19:29:05 -0800 Subject: [PATCH 1/2] Minimal changes to make it run in 2025 on a macbook --- grok/data.py | 3 ++- grok/training.py | 18 ++++++++++++------ setup.py | 2 +- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/grok/data.py b/grok/data.py index e7d3827..9e53eb6 100644 --- a/grok/data.py +++ b/grok/data.py @@ -451,6 +451,7 @@ def __iter__(self): """ :returns: this iterator """ + self.reset_iteration() #jmod return self def __next__(self) -> Dict[str, Tensor]: @@ -463,7 +464,7 @@ def __next__(self) -> Dict[str, Tensor]: batch_begin = self.index * self.batchsize if batch_begin > len(self.dataset) - 1: - self.reset_iteration() + #self.reset_iteration() #jmod raise StopIteration indices = self.permutation[batch_begin : batch_begin + self.batchsize] text = self.dataset.data[indices, :-1] diff --git a/grok/training.py b/grok/training.py index 43b43df..42657bb 100755 --- a/grok/training.py +++ b/grok/training.py @@ -47,7 +47,8 @@ def __init__(self, hparams: Namespace) -> None: self.add_model_specific_args(). """ super().__init__() - self.hparams = hparams # type: ignore + # jmod. self.hparams = hparams # type: ignore + self.save_hyperparameters(hparams) self.prepare_data() self.transformer = Transformer( @@ -450,10 +451,12 @@ def training_step(self, batch, batch_idx): ) self.fwd_time_in_epoch += time.time() - start - schedulers = self.trainer.lr_schedulers[0] + # jmod schedulers = self.trainer.lr_schedulers[0] + schedulers = self.lr_schedulers() if self.current_epoch != self.next_train_epoch_to_log: return {"loss": loss} - lr = schedulers["scheduler"].optimizer.param_groups[0]["lr"] + # jmod lr = schedulers["scheduler"].optimizer.param_groups[0]["lr"] + lr = schedulers.optimizer.param_groups[0]["lr"] output = { "loss": loss, "partial_train_loss": coeff * loss, @@ -585,7 +588,8 @@ def validation_epoch_end(self, outputs): # get the l2 norm of the parameter logs["paramnorm_" + name] = torch.norm( param, 2 - ).detach().cpu().numpy() / np.sqrt(n_params) + #jomod ).detach().cpu().numpy() / np.sqrt(n_params) + ).detach().cpu().numpy().astype(np.float32) / np.sqrt(n_params,dtype=np.float32) # train accuracy device = self.transformer.embedding.weight.device @@ -718,12 +722,14 @@ def train(hparams: Namespace) -> None: "max_steps": hparams.max_steps, "min_steps": hparams.max_steps, "max_epochs": int(1e8), - "val_check_interval": 1, + "val_check_interval": 1.0, #jmod changed from 1 "profiler": False, # "checkpoint_callback": checkpointer, "logger": logger, "log_every_n_steps": 1, - "flush_logs_every_n_steps": 1000, + #jmod. "flush_logs_every_n_steps": 1000, + #jmod + "accelerator": 'mps', "devices": 1 } if torch.cuda.is_available() and hparams.gpu >= 0: trainer_args["gpus"] = [hparams.gpu] diff --git a/setup.py b/setup.py index c31f060..356987e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ packages=find_packages(), version="0.0.1", install_requires=[ - "pytorch_lightning", + "pytorch_lightning<2.0.0", "blobfile", "numpy", "torch", From 8227017a0c800dab3f58c920d707994fdc9928e8 Mon Sep 17 00:00:00 2001 From: Jack Wu Date: Thu, 11 Dec 2025 14:57:00 -0800 Subject: [PATCH 2/2] test if MPS is available automatically. Set gpu to -1 you don't want to use it. --- grok/data.py | 4 ++-- grok/training.py | 16 +++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/grok/data.py b/grok/data.py index 9e53eb6..d75b16b 100644 --- a/grok/data.py +++ b/grok/data.py @@ -451,7 +451,7 @@ def __iter__(self): """ :returns: this iterator """ - self.reset_iteration() #jmod + self.reset_iteration() # return self def __next__(self) -> Dict[str, Tensor]: @@ -464,7 +464,7 @@ def __next__(self) -> Dict[str, Tensor]: batch_begin = self.index * self.batchsize if batch_begin > len(self.dataset) - 1: - #self.reset_iteration() #jmod + #self.reset_iteration() raise StopIteration indices = self.permutation[batch_begin : batch_begin + self.batchsize] text = self.dataset.data[indices, :-1] diff --git a/grok/training.py b/grok/training.py index 42657bb..74f1b53 100755 --- a/grok/training.py +++ b/grok/training.py @@ -47,7 +47,7 @@ def __init__(self, hparams: Namespace) -> None: self.add_model_specific_args(). """ super().__init__() - # jmod. self.hparams = hparams # type: ignore + # self.hparams = hparams # type: ignore self.save_hyperparameters(hparams) self.prepare_data() @@ -451,11 +451,11 @@ def training_step(self, batch, batch_idx): ) self.fwd_time_in_epoch += time.time() - start - # jmod schedulers = self.trainer.lr_schedulers[0] + # schedulers = self.trainer.lr_schedulers[0] schedulers = self.lr_schedulers() if self.current_epoch != self.next_train_epoch_to_log: return {"loss": loss} - # jmod lr = schedulers["scheduler"].optimizer.param_groups[0]["lr"] + # lr = schedulers["scheduler"].optimizer.param_groups[0]["lr"] lr = schedulers.optimizer.param_groups[0]["lr"] output = { "loss": loss, @@ -722,18 +722,20 @@ def train(hparams: Namespace) -> None: "max_steps": hparams.max_steps, "min_steps": hparams.max_steps, "max_epochs": int(1e8), - "val_check_interval": 1.0, #jmod changed from 1 + "val_check_interval": 1.0, # changed from 1 "profiler": False, # "checkpoint_callback": checkpointer, "logger": logger, "log_every_n_steps": 1, - #jmod. "flush_logs_every_n_steps": 1000, - #jmod - "accelerator": 'mps', "devices": 1 + # "flush_logs_every_n_steps": 1000, } if torch.cuda.is_available() and hparams.gpu >= 0: trainer_args["gpus"] = [hparams.gpu] + if torch.backends.mps.is_available() and hparams.gpu >= 0: + trainer_args["accelerator"] = 'mps' + trainer_args["devices"] = 1 + trainer = Trainer(**trainer_args) trainer.fit(model=model) # type: ignore