diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 20678b84..c1e9896e 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -204,8 +204,13 @@ def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any * CLS_TOKEN ) model_kwargs["output_attentions"] = True + + x = torch.cat((cls_tokens, batch.x), dim=1) + # cut off to max length of max_position_embeddings + x = x[:, : self.config.max_position_embeddings] + return dict( - features=torch.cat((cls_tokens, batch.x), dim=1), + features=x, labels=batch.y, model_kwargs=model_kwargs, loss_kwargs=loss_kwargs, diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index b485cbbd..1225ce7f 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -4,7 +4,7 @@ import shutil import tempfile from datetime import datetime -from typing import Generator, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union import pandas as pd import requests @@ -284,10 +284,10 @@ def __init__(self, train_batch_size=1_000_000, *args, **kwargs): self.test_batch_size = 100_000 @property - def processed_file_names_dict(self) -> List[str]: + def processed_file_names_dict(self) -> Dict[str, str]: """ Returns: - List[str]: List of processed data file names. + Dict[str, str]: Dictionary of processed data file names. """ train_samples = ( self._n_samples @@ -404,6 +404,27 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader **kwargs, ) + def load_processed_data( + self, kind: Optional[str] = None, filename: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Loads processed data from a specified dataset type or file. Loads data directly from file instead of + using the dynamic_splits_df property. This ensures that a new training batch is loaded for each epoch. + """ + if kind is None and filename is None: + raise ValueError( + "Either kind or filename is required to load the correct dataset, both are None" + ) + + # If both kind and filename are given, use filename + if kind is not None and filename is None: + return self.load_processed_data_from_file( + self.processed_file_names_dict[kind] + ) + + # If filename is provided + return self.load_processed_data_from_file(filename) + class LabeledUnlabeledMixed(XYBaseDataModule): """