Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,13 @@ def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any
* CLS_TOKEN
)
model_kwargs["output_attentions"] = True

x = torch.cat((cls_tokens, batch.x), dim=1)
# cut off to max length of max_position_embeddings
x = x[:, : self.config.max_position_embeddings]

return dict(
features=torch.cat((cls_tokens, batch.x), dim=1),
features=x,
labels=batch.y,
model_kwargs=model_kwargs,
loss_kwargs=loss_kwargs,
Expand Down
27 changes: 24 additions & 3 deletions chebai/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import shutil
import tempfile
from datetime import datetime
from typing import Generator, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union

import pandas as pd
import requests
Expand Down Expand Up @@ -284,10 +284,10 @@ def __init__(self, train_batch_size=1_000_000, *args, **kwargs):
self.test_batch_size = 100_000

@property
def processed_file_names_dict(self) -> List[str]:
def processed_file_names_dict(self) -> Dict[str, str]:
"""
Returns:
List[str]: List of processed data file names.
Dict[str, str]: Dictionary of processed data file names.
"""
train_samples = (
self._n_samples
Expand Down Expand Up @@ -404,6 +404,27 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader
**kwargs,
)

def load_processed_data(
self, kind: Optional[str] = None, filename: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Loads processed data from a specified dataset type or file. Loads data directly from file instead of
using the dynamic_splits_df property. This ensures that a new training batch is loaded for each epoch.
"""
if kind is None and filename is None:
raise ValueError(
"Either kind or filename is required to load the correct dataset, both are None"
)

# If both kind and filename are given, use filename
if kind is not None and filename is None:
return self.load_processed_data_from_file(
self.processed_file_names_dict[kind]
)

# If filename is provided
return self.load_processed_data_from_file(filename)


class LabeledUnlabeledMixed(XYBaseDataModule):
"""
Expand Down
Loading