diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 36e85e6a..5f4962e0 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -539,7 +539,7 @@ class ChEBIFromList(_ChEBIDataExtractor): """ - READER = dr.ChemDataReader + READER = dr.StaticSMILESReader def __init__( self, @@ -585,7 +585,7 @@ class ChEBIOverX(_ChEBIDataExtractor): THRESHOLD (None): The threshold for selecting classes. """ - READER: dr.ChemDataReader = dr.ChemDataReader + READER = dr.StaticSMILESReader @property def _name(self) -> str: @@ -804,11 +804,8 @@ class ChEBIOver100Fingerprints(ChEBIOverXFingerprints, ChEBIOver100): if __name__ == "__main__": - dataset = ChEBIOver50Partial( - chebi_version=247, - subset="3_STAR", - top_class_id="36700", - external_data_ratio=0.5, + dataset = ChEBIOver50( + chebi_version=251, ) dataset.prepare_data() dataset.setup() diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 6df1ee51..b485cbbd 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -33,7 +33,7 @@ class PubChem(_DynamicDataset): LABEL_INDEX = 1 FULL = 0 UNLABELED = True - READER = dr.ChemDataReader + READER = dr.StaticSMILESReader # Column indices in data.pkl _ID_IDX: int = 0 @@ -208,7 +208,9 @@ def _perform_data_preparation(self, *args, **kwargs): """ Checks for raw data, downloads if necessary, then builds data.pkl. """ - print("Check for raw data in", self.raw_dir) + print( + f"Check for raw data ({', '.join(self.raw_file_names)}) in {self.raw_dir}..." + ) if any( not os.path.isfile(os.path.join(self.raw_dir, f)) for f in self.raw_file_names @@ -260,7 +262,7 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: class PubChemBatched(PubChem): """Store train data as batches of 10m, validation and test should each be 100k max""" - READER: Type[dr.ChemDataReader] = dr.ChemDataReader + READER: Type[dr.DataReader] = dr.StaticSMILESReader def __init__(self, train_batch_size=1_000_000, *args, **kwargs): super(PubChemBatched, self).__init__(*args, **kwargs) @@ -566,6 +568,6 @@ class PubChemSELFIES(PubChem): if __name__ == "__main__": - dataset = PubChem(k=10000) + dataset = PubChem(n_samples=10_000) dataset.prepare_data() dataset.setup() diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index cc70be7f..0dae39cd 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -257,6 +257,52 @@ def _back_to_smiles(self, smiles_encoded): return smiles_decoded +class StaticSMILESReader(DataReader): + """ + Data reader for SMILES tokens with a static token set. Atoms are split into 5 components: isotope, element, charge, hydrogens, stereo. + New tokens are not added to the token file, and unknown tokens are mapped to a special index. + """ + + COLLATOR = RaggedCollator + + def __init__(self, *args, **kwargs) -> None: + from chebai.preprocessing.smiles_tokenizer import BasicSmilesTokenizer + + super().__init__(*args, **kwargs) + self.tokenizer = BasicSmilesTokenizer() + + @classmethod + def name(cls) -> str: + """Returns the name of the data reader.""" + return "static_smiles" + + def _read_data(self, raw_data: str | Chem.Mol) -> Optional[List[int]]: + """Tokenize raw SMILES data using BasicSmilesTokenizer with static vocabulary.""" + try: + if isinstance(raw_data, str): + mol = Chem.MolFromSmiles(raw_data.strip()) + else: + mol = raw_data + except ValueError as e: + print(f"could not process {raw_data}") + print(f"\tError: {e}") + return None + + try: + smiles = Chem.MolToSmiles(mol, canonical=True) + except Exception as e: + print(f"RDKit failed to canonicalize the SMILES: {raw_data}") + print(f"\t{e}") + return None + + try: + return self.tokenizer.encode(smiles) + except Exception as e: + print(f"could not tokenize {raw_data}") + print(f"\tError: {e}") + return None + + class DeepChemDataReader(ChemDataReader): """ Data reader for chemical data using DeepSMILES tokens. diff --git a/chebai/preprocessing/smiles_tokenizer.py b/chebai/preprocessing/smiles_tokenizer.py new file mode 100644 index 00000000..c78fefa2 --- /dev/null +++ b/chebai/preprocessing/smiles_tokenizer.py @@ -0,0 +1,317 @@ +from __future__ import annotations +import re +from typing import List + +from rdkit import Chem + + +def _build_bracket_atoms() -> List[str]: + """Enumerate chemically meaningful bracketed atoms.""" + pt = Chem.GetPeriodicTable() + elements = [pt.GetElementSymbol(i) for i in range(1, 119)] + # aromatic forms used in SMILES + elements += ["c", "n", "o", "s", "p", "b", "te", "se", "si"] + charges = [i for i in range(-5, 9) if i != 0] + hydrogens = range(1, 9) + stereo = ["@", "@@"] + isotopes = range(1, 300) # [295Og] is the heaviest isotope in PubChem + + tokens = [] + for el in elements: + tokens.append(f"element_{el}") + for ch in charges: + tokens.append(f"charge_{ch}") + for h in hydrogens: + tokens.append(f"hydrogens_{h}") + for st in stereo: + tokens.append(f"stereo_{st}") + for iso in isotopes: + tokens.append(f"isotope_{iso}") + + return tokens + + +NON_BRACKET_TOKENS = [ + # bonds / structure + "(", + ")", + "=", + "#", + "->", + "<-", + ">>", + "-", + "+", + "/", + "\\", + ":", + ".", + "~", + "*", + "$", + "?", + "@", + "@@", + # ring closures: single-digit + *[str(d) for d in range(10)], + # ring closures: %10..%99 + *[f"%{n:02d}" for n in range(10, 100)], +] + + +def _build_default_vocab() -> List[str]: + """non-bracket symbols + bracketed atoms.""" + + brackets = _build_bracket_atoms() + + # de-duplicate while preserving order (specials first) + seen, vocab = set(), [] + for tok in NON_BRACKET_TOKENS + brackets: + if tok not in seen: + seen.add(tok) + vocab.append(tok) + return vocab + + +# change to original regex: added -> and <- for handling dative bonds (we use RDKit-normalised SMILES, this is not a standard SMILES feature) +SMI_REGEX_PATTERN = r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|->|<-|>>?|-|\+|\\|\/|:|~|@@|@|\?|\*|\$|\%[0-9]{2}|[0-9])""" +EMBEDDING_OFFSET = 10 +UNKNOWN_TOKEN_IDX = 3 +ORGANIC_SUBSET = frozenset( + {"B", "C", "N", "O", "P", "S", "F", "Cl", "Br", "I", "b", "c", "n", "o", "s", "p"} +) + + +class BasicSmilesTokenizer(object): + """ + References + ---------- + .. [1] Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee + ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction + 1572-1583 DOI: 10.1021/acscentsci.9b00576 + """ + + def __init__(self, regex_pattern: str = SMI_REGEX_PATTERN): + """Constructs a BasicSMILESTokenizer. + + Parameters + ---------- + regex: string + SMILES token regex + """ + self.regex_pattern = regex_pattern + self.regex = re.compile(self.regex_pattern) + + self.vocab = _build_default_vocab() + self.vocab_dict = { + tok: idx + EMBEDDING_OFFSET for idx, tok in enumerate(self.vocab) + } + self.idx_to_token = {idx: tok for tok, idx in self.vocab_dict.items()} + + def _parse_bracket_atom(self, bracket_token: str) -> List[str]: + """ + Parse a bracketed atom token into its components. + + When all attributes are at their defaults (charge=0, hydrogens=0, isotope=None, stereo=None), + only the element token is emitted. Otherwise all 5 tokens are emitted. + + E.g. "[N]" -> ["element_N"] + "[85Kr]" -> ["isotope_85", "element_Kr", "charge_0", "hydrogens_0", "stereo_None"] + """ + atom_str = bracket_token[1:-1] # Remove brackets + + # special case: any atom containing a * is treated as a wildcard (e.g. [3*:0],[1*]) + if "*" in atom_str: + return ["*"] + + isotope = None + element = None + charge = 0 + hydrogens = 0 + stereo = None + + pos = 0 + + # Parse isotope (leading digits) + iso_str = "" + while pos < len(atom_str) and atom_str[pos].isdigit(): + iso_str += atom_str[pos] + pos += 1 + if iso_str: + isotope = iso_str + + # Parse element (1-2 letters) + if pos < len(atom_str) and (atom_str[pos].isupper() or atom_str[pos].islower()): + element = atom_str[pos] + pos += 1 + if pos < len(atom_str) and atom_str[pos].islower(): + element += atom_str[pos] + pos += 1 + + # Parse stereo (@ or @@) + if pos < len(atom_str) and atom_str[pos] == "@": + if pos + 1 < len(atom_str) and atom_str[pos + 1] == "@": + stereo = "@@" + pos += 2 + else: + stereo = "@" + pos += 1 + + # Parse hydrogens (H, H2, H3, H4) + if pos < len(atom_str) and atom_str[pos] == "H": + hydrogens = 1 + pos += 1 + if pos < len(atom_str) and atom_str[pos].isdigit(): + hydrogens = int(atom_str[pos]) + pos += 1 + + # Parse charge (+, -, +2, -2, +3, -3) + if pos < len(atom_str): + if atom_str[pos] == "+": + pos += 1 + if pos < len(atom_str) and atom_str[pos].isdigit(): + charge = int(atom_str[pos]) + pos += 1 + else: + charge = 1 + elif atom_str[pos] == "-": + pos += 1 + if pos < len(atom_str) and atom_str[pos].isdigit(): + charge = -int(atom_str[pos]) + pos += 1 + else: + charge = -1 + + # return element token and optionally isotope, charge, hydrogens, stereo tokens + res = [f"element_{element}"] + if isotope is not None: + res.append(f"isotope_{isotope}") + if charge != 0: + res.append(f"charge_{charge}") + if hydrogens != 0: + res.append(f"hydrogens_{hydrogens}") + if stereo is not None: + res.append(f"stereo_{stereo}") + return res + + def tokenize(self, text): + """Tokenize a SMILES string, breaking bracketed atoms into 5 components. + + Non-bracketed tokens are returned as-is. + Bracketed atoms are decomposed into: isotope, element, charge, hydrogens, stereo. + """ + raw_tokens = [token for token in self.regex.findall(text)] + tokens = [] + + for token in raw_tokens: + if token in NON_BRACKET_TOKENS: + tokens.append(token) + continue + if not (token.startswith("[") and token.endswith("]")): + token = ( + f"[{token}]" # Wrap non-bracket tokens in brackets for uniformity + ) + # Parse bracketed atom into 5 components + components = self._parse_bracket_atom(token) + tokens.extend(components) + + return tokens + + def encode(self, text): + tokens = self.tokenize(text) + return [self.vocab_dict.get(token, UNKNOWN_TOKEN_IDX) for token in tokens] + + def _reassemble_bracket_atom( + self, + element: str, + isotope: str = "None", + charge: int = 0, + hydrogens: int = 0, + stereo: str = "None", + ) -> str: + if ( + isotope == "None" + and charge == 0 + and stereo == "None" + and hydrogens == 0 + and element in ORGANIC_SUBSET + ): + return element + inner = "" + if isotope != "None": + inner += isotope + inner += element + if stereo != "None": + inner += stereo + if hydrogens > 0: + inner += "H" + if hydrogens > 1: + inner += str(hydrogens) + if charge > 0: + inner += "+" + if charge > 1: + inner += str(charge) + elif charge < 0: + inner += "-" + if charge < -1: + inner += str(-charge) + return f"[{inner}]" + + def decode(self, token_ids, skip_special_tokens=False): + tokens = [self.idx_to_token.get(idx, "[UNK]") for idx in token_ids] + if skip_special_tokens: + tokens = [tok for tok in tokens if tok not in self.vocab_dict] + + result = [] + i = 0 + while i < len(tokens): + tok = tokens[i] + if tok.startswith("element_"): + i += 1 + add = ["None", 0, 0, "None"] + for idx, additional_token in enumerate( + ["isotope_", "charge_", "hydrogens_", "stereo_"] + ): + if i >= len(tokens): + break + if tokens[i].startswith(additional_token): + add[idx] = tokens[i][len(additional_token) :] + if additional_token in ["charge_", "hydrogens_"]: + add[idx] = int(add[idx]) + i += 1 + + result.append( + self._reassemble_bracket_atom( + tok[len("element_") :], + *add, + ) + ) + else: + result.append(tok) + i += 1 + + return "".join(result) + + +# ---- quick self-test ------------------------------------------------------ +if __name__ == "__main__": + # tok = SmilesTokenizer.build_default() + # print(f"Vocab size: {tok.vocab_size}") + tok = BasicSmilesTokenizer() + print(f"Vocab size: {len(tok.vocab)}") + examples = [ + "CC(=O)Oc1ccccc1C(=O)O", # aspirin + "C[C@H](N)C(=O)O", # L-alanine + "[13CH3]CO", # isotope + "C1CC2(CCCCC2)CC1", # spiro + "c1ccc2c(c1)[nH]cn2", # benzimidazole with [nH] + # "CC(=O)N[C@@H]1[C@H](O[C@H]2[C@H](O)[C@@H](NC(C)=O)[C@H](O)O[C@@H]2CO[C@@H]2O[C@@H](C)[C@@H](O)[C@@H](O)[C@@H]2O)O[C@H](CO)[C@@H](O[C@@H]2O[C@H](CO[C@H]3O[C@H](CO[C@@H]4O[C@H](CO)[C@@H](O[C@@H]5O[C@H](CO)[C@H](O)[C@H](O[C@@H]6O[C@H](CO)[C@@H](O[C@@H]7O[C@H](CO)[C@H](O)[C@H](O[C@]8(C(=O)O)C[C@H](O)[C@@H](NC(C)=O)[C@H]([C@H](O)[C@H](O)CO)O8)[C@H]7O)[C@H](O)[C@H]6NC(C)=O)[C@H]5O)[C@H](O)[C@H]4NC(C)=O)[C@@H](O)[C@H](O)[C@@H]3O[C@@H]3O[C@H](CO)[C@@H](O[C@@H]4O[C@H](CO)[C@H](O)[C@H](O[C@@H]5O[C@H](CO)[C@@H](O[C@@H]6O[C@H](CO)[C@H](O)[C@H](O[C@]7(C(=O)O)C[C@H](O)[C@@H](NC(C)=O)[C@H]([C@H](O)[C@H](O)CO)O7)[C@H]6O)[C@H](O)[C@H]5NC(C)=O)[C@H]4O)[C@H](O)[C@H]3NC(C)=O)[C@@H](O)[C@H](O[C@H]3O[C@H](CO)[C@@H](O[C@@H]4O[C@H](CO)[C@@H](O[C@@H]5O[C@H](CO)[C@H](O)[C@H](O[C@]6(C(=O)O)C[C@H](O)[C@@H](NC(C)=O)[C@H]([C@H](O)[C@H](O)CO)O6)[C@H]5O)[C@H](O)[C@H]4NC(C)=O)[C@H](O)[C@@H]3O[C@@H]3O[C@H](CO)[C@@H](O[C@@H]4O[C@H](CO)[C@H](O)[C@H](O)[C@H]4O)[C@H](O)[C@@H]3NC(C)=O)[C@@H]2O)[C@@H]1O", + ] + + for s in examples: + tokens = tok.tokenize(s) + ids = tok.encode(s) + print(f"\n{s}") + print(f" tokens: {tokens}") + print(f" ids: {ids}") + print(f" decode: {tok.decode(ids)}") diff --git a/configs/model/electra.yml b/configs/model/electra.yml index 4427715f..5240e3fb 100644 --- a/configs/model/electra.yml +++ b/configs/model/electra.yml @@ -4,7 +4,7 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - vocab_size: 4400 + vocab_size: 600 max_position_embeddings: 1800 num_attention_heads: 8 num_hidden_layers: 6 diff --git a/tests/unit/readers/testStaticSMILESReader.py b/tests/unit/readers/testStaticSMILESReader.py new file mode 100644 index 00000000..86a2d1e7 --- /dev/null +++ b/tests/unit/readers/testStaticSMILESReader.py @@ -0,0 +1,139 @@ +import unittest + +from rdkit import Chem + +from chebai.preprocessing.reader import StaticSMILESReader +from chebai.preprocessing.smiles_tokenizer import UNKNOWN_TOKEN_IDX + + +class TestStaticSMILESReader(unittest.TestCase): + """ + Unit tests for the StaticSMILESReader class. + + Focuses on two core properties: + - Determinism: the same SMILES always produces the same token sequence. + - Decode correctness: token sequences decode back to the canonical SMILES. + """ + + @classmethod + def setUpClass(cls) -> None: + cls.reader = StaticSMILESReader() + + @staticmethod + def _canonical(smiles: str) -> str: + return Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) + + # --- Determinism tests --- + + def test_same_input_produces_same_tokens(self) -> None: + """Calling _read_data with two instances on the same SMILES gives identical token sequences.""" + smiles = "CC(=O)Oc1ccccc1C(=O)O" # aspirin + result1 = self.reader._read_data(smiles) + new_reader = ( + StaticSMILESReader() + ) # Create a new reader instance to ensure no shared state + result2 = new_reader._read_data(smiles) + self.assertIsNotNone(result1) + self.assertEqual(result1, result2) + + def test_non_canonical_input_matches_canonical(self) -> None: + """Different SMILES representations of the same molecule produce identical tokens.""" + canonical = "CC(=O)Oc1ccccc1C(=O)O" + non_canonical = "OC(=O)c1ccccc1OC(C)=O" + result_canonical = self.reader._read_data(canonical) + result_non_canonical = self.reader._read_data(non_canonical) + self.assertIsNotNone(result_non_canonical) + self.assertEqual( + result_canonical, + result_non_canonical, + "Non-canonical and canonical SMILES of the same molecule produced different token sequences.", + ) + + # --- Static vocabulary tests --- + + def test_vocabulary_does_not_grow(self) -> None: + """Encoding multiple novel SMILES strings never increases the vocabulary size.""" + initial_vocab_size = len(self.reader.tokenizer.vocab) + smiles_list = [ + "CC(=O)Oc1ccccc1C(=O)O", # aspirin + "C[C@H](N)C(=O)O", # L-alanine + "[13CH3]CO", # isotope + "c1ccc2c(c1)[nH]cn2", # benzimidazole + "[NH4+]", # charged atom + ] + for smiles in smiles_list: + self.reader._read_data(smiles) + self.assertEqual( + len(self.reader.tokenizer.vocab), + initial_vocab_size, + "Vocabulary grew after encoding SMILES strings — StaticSMILESReader must not add new tokens.", + ) + + def test_unknown_tokens_use_unknown_idx(self) -> None: + """Tokens outside the static vocabulary are mapped to UNKNOWN_TOKEN_IDX, not added.""" + # [123] has no element symbol, so _parse_bracket_atom produces element_None + # which is not in the vocabulary. + token_ids = self.reader.tokenizer.encode("[123]") + self.assertIn( + UNKNOWN_TOKEN_IDX, + token_ids, + f"Expected UNKNOWN_TOKEN_IDX ({UNKNOWN_TOKEN_IDX}) in encoded output for out-of-vocabulary token.", + ) + + # --- Decode roundtrip tests --- + + def test_decode_roundtrip_simple_organic(self) -> None: + """Encoding then decoding recovers the canonical SMILES for a simple organic molecule.""" + smiles = "CC(=O)Oc1ccccc1C(=O)O" # aspirin + canonical = self._canonical(smiles) + token_ids = self.reader._read_data(smiles) + decoded = self.reader.tokenizer.decode(token_ids) + self.assertEqual(decoded, canonical) + + def test_decode_roundtrip_stereo(self) -> None: + """Encoding then decoding recovers canonical SMILES for a molecule with stereochemistry.""" + smiles = "C[C@H](N)C(=O)O" # L-alanine + canonical = self._canonical(smiles) + token_ids = self.reader._read_data(smiles) + decoded = self.reader.tokenizer.decode(token_ids) + self.assertEqual(decoded, canonical) + + def test_decode_roundtrip_isotope(self) -> None: + """Encoding then decoding recovers canonical SMILES for a molecule with an isotope label.""" + smiles = "[13CH3]CO" + canonical = self._canonical(smiles) + token_ids = self.reader._read_data(smiles) + decoded = self.reader.tokenizer.decode(token_ids) + self.assertEqual(decoded, canonical) + + def test_decode_roundtrip_charged_atom(self) -> None: + """Encoding then decoding recovers canonical SMILES for a molecule with a charged atom.""" + smiles = "[NH4+]" + canonical = self._canonical(smiles) + token_ids = self.reader._read_data(smiles) + decoded = self.reader.tokenizer.decode(token_ids) + self.assertEqual(decoded, canonical) + + def test_decode_roundtrip_bracket_aromatic(self) -> None: + """Encoding then decoding recovers canonical SMILES for an aromatic ring with a bracketed atom.""" + smiles = "c1ccc2c(c1)[nH]cn2" # benzimidazole + canonical = self._canonical(smiles) + token_ids = self.reader._read_data(smiles) + decoded = self.reader.tokenizer.decode(token_ids) + self.assertEqual(decoded, canonical) + + # --- Invalid input tests --- + + def test_invalid_smiles_returns_none(self) -> None: + """Invalid SMILES strings cause _read_data to return None.""" + invalid_smiles = ["%INVALID%", "ADADAD", "ADASDAD"] + for smiles in invalid_smiles: + result = self.reader._read_data(smiles) + self.assertIsNone( + result, + f"Expected None for invalid SMILES '{smiles}', got {result!r}.", + ) + + +if __name__ == "__main__": + unittest.main()