Skip to content
11 changes: 4 additions & 7 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ class ChEBIFromList(_ChEBIDataExtractor):

"""

READER = dr.ChemDataReader
READER = dr.StaticSMILESReader

def __init__(
self,
Expand Down Expand Up @@ -585,7 +585,7 @@ class ChEBIOverX(_ChEBIDataExtractor):
THRESHOLD (None): The threshold for selecting classes.
"""

READER: dr.ChemDataReader = dr.ChemDataReader
READER = dr.StaticSMILESReader

@property
def _name(self) -> str:
Expand Down Expand Up @@ -804,11 +804,8 @@ class ChEBIOver100Fingerprints(ChEBIOverXFingerprints, ChEBIOver100):


if __name__ == "__main__":
dataset = ChEBIOver50Partial(
chebi_version=247,
subset="3_STAR",
top_class_id="36700",
external_data_ratio=0.5,
dataset = ChEBIOver50(
chebi_version=251,
)
dataset.prepare_data()
dataset.setup()
10 changes: 6 additions & 4 deletions chebai/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PubChem(_DynamicDataset):
LABEL_INDEX = 1
FULL = 0
UNLABELED = True
READER = dr.ChemDataReader
READER = dr.StaticSMILESReader

# Column indices in data.pkl
_ID_IDX: int = 0
Expand Down Expand Up @@ -208,7 +208,9 @@ def _perform_data_preparation(self, *args, **kwargs):
"""
Checks for raw data, downloads if necessary, then builds data.pkl.
"""
print("Check for raw data in", self.raw_dir)
print(
f"Check for raw data ({', '.join(self.raw_file_names)}) in {self.raw_dir}..."
)
if any(
not os.path.isfile(os.path.join(self.raw_dir, f))
for f in self.raw_file_names
Expand Down Expand Up @@ -260,7 +262,7 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
class PubChemBatched(PubChem):
"""Store train data as batches of 10m, validation and test should each be 100k max"""

READER: Type[dr.ChemDataReader] = dr.ChemDataReader
READER: Type[dr.DataReader] = dr.StaticSMILESReader

def __init__(self, train_batch_size=1_000_000, *args, **kwargs):
super(PubChemBatched, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -566,6 +568,6 @@ class PubChemSELFIES(PubChem):


if __name__ == "__main__":
dataset = PubChem(k=10000)
dataset = PubChem(n_samples=10_000)
dataset.prepare_data()
dataset.setup()
46 changes: 46 additions & 0 deletions chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,52 @@ def _back_to_smiles(self, smiles_encoded):
return smiles_decoded


class StaticSMILESReader(DataReader):
"""
Data reader for SMILES tokens with a static token set. Atoms are split into 5 components: isotope, element, charge, hydrogens, stereo.
New tokens are not added to the token file, and unknown tokens are mapped to a special index.
"""

COLLATOR = RaggedCollator

def __init__(self, *args, **kwargs) -> None:
from chebai.preprocessing.smiles_tokenizer import BasicSmilesTokenizer

super().__init__(*args, **kwargs)
self.tokenizer = BasicSmilesTokenizer()

@classmethod
def name(cls) -> str:
"""Returns the name of the data reader."""
return "static_smiles"

def _read_data(self, raw_data: str | Chem.Mol) -> Optional[List[int]]:
"""Tokenize raw SMILES data using BasicSmilesTokenizer with static vocabulary."""
try:
if isinstance(raw_data, str):
mol = Chem.MolFromSmiles(raw_data.strip())
else:
mol = raw_data
except ValueError as e:
print(f"could not process {raw_data}")
print(f"\tError: {e}")
return None

try:
smiles = Chem.MolToSmiles(mol, canonical=True)
except Exception as e:
print(f"RDKit failed to canonicalize the SMILES: {raw_data}")
print(f"\t{e}")
return None

try:
return self.tokenizer.encode(smiles)
except Exception as e:
print(f"could not tokenize {raw_data}")
print(f"\tError: {e}")
return None


class DeepChemDataReader(ChemDataReader):
"""
Data reader for chemical data using DeepSMILES tokens.
Expand Down
Loading
Loading