From e42c705e8610512618a364f33030807b4d3f560c Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 28 May 2026 15:06:09 +0200 Subject: [PATCH 1/3] fix get_data_splits for pubchem batched --- chebai/preprocessing/datasets/pubchem.py | 29 +++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index b485cbbd..a798c8aa 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -4,7 +4,7 @@ import shutil import tempfile from datetime import datetime -from typing import Generator, List, Optional, Tuple, Type, Union +from typing import Dict, Generator, List, Optional, Tuple, Type, Union import pandas as pd import requests @@ -284,10 +284,10 @@ def __init__(self, train_batch_size=1_000_000, *args, **kwargs): self.test_batch_size = 100_000 @property - def processed_file_names_dict(self) -> List[str]: + def processed_file_names_dict(self) -> Dict[str, str]: """ Returns: - List[str]: List of processed data file names. + Dict[str, str]: Dictionary of processed data file names. """ train_samples = ( self._n_samples @@ -404,6 +404,29 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader **kwargs, ) + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + The PubChemBatched dataset comes with pre-split data + """ + + train = self.load_processed_data_from_file( + self.processed_file_names_dict[ + "train" + if "train" in self.processed_file_names_dict + else f"train_{self.curr_epoch}" + ] + ) + train_df = pd.DataFrame(train) + val = self.load_processed_data_from_file( + self.processed_file_names_dict["validation"] + ) + val_df = pd.DataFrame(val) + test = self.load_processed_data_from_file( + self.processed_file_names_dict["test"] + ) + test_df = pd.DataFrame(test) + return train_df, val_df, test_df + class LabeledUnlabeledMixed(XYBaseDataModule): """ From ce5ffde6c4afb22ec73cda4f4a92d929ef57bbd2 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 28 May 2026 15:38:28 +0200 Subject: [PATCH 2/3] overwrite load_processed_data to reload batch for every epoch --- chebai/preprocessing/datasets/pubchem.py | 38 +++++++++++------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index a798c8aa..1225ce7f 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -4,7 +4,7 @@ import shutil import tempfile from datetime import datetime -from typing import Dict, Generator, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union import pandas as pd import requests @@ -404,28 +404,26 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader **kwargs, ) - def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + def load_processed_data( + self, kind: Optional[str] = None, filename: Optional[str] = None + ) -> List[Dict[str, Any]]: """ - The PubChemBatched dataset comes with pre-split data + Loads processed data from a specified dataset type or file. Loads data directly from file instead of + using the dynamic_splits_df property. This ensures that a new training batch is loaded for each epoch. """ + if kind is None and filename is None: + raise ValueError( + "Either kind or filename is required to load the correct dataset, both are None" + ) - train = self.load_processed_data_from_file( - self.processed_file_names_dict[ - "train" - if "train" in self.processed_file_names_dict - else f"train_{self.curr_epoch}" - ] - ) - train_df = pd.DataFrame(train) - val = self.load_processed_data_from_file( - self.processed_file_names_dict["validation"] - ) - val_df = pd.DataFrame(val) - test = self.load_processed_data_from_file( - self.processed_file_names_dict["test"] - ) - test_df = pd.DataFrame(test) - return train_df, val_df, test_df + # If both kind and filename are given, use filename + if kind is not None and filename is None: + return self.load_processed_data_from_file( + self.processed_file_names_dict[kind] + ) + + # If filename is provided + return self.load_processed_data_from_file(filename) class LabeledUnlabeledMixed(XYBaseDataModule): From f41319fbc37cc1b57b8582b4cca414c7b140fb38 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 29 May 2026 10:57:54 +0200 Subject: [PATCH 3/3] add hard cutoff for feature length --- chebai/models/electra.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 20678b84..c1e9896e 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -204,8 +204,13 @@ def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any * CLS_TOKEN ) model_kwargs["output_attentions"] = True + + x = torch.cat((cls_tokens, batch.x), dim=1) + # cut off to max length of max_position_embeddings + x = x[:, : self.config.max_position_embeddings] + return dict( - features=torch.cat((cls_tokens, batch.x), dim=1), + features=x, labels=batch.y, model_kwargs=model_kwargs, loss_kwargs=loss_kwargs,