Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions safe/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,23 @@ def randomize(mol: dm.Mol, rng: Optional[int] = None):
atom_indices = rng.permutation(atom_indices).tolist()
return Chem.RenumberAtoms(mol, atom_indices)

@staticmethod
def _format_ring_closure(num: int) -> str:
"""Format a ring-closure (branch) number into its SMILES token.

Single digits stay bare (e.g. ``5``), two-digit numbers use the ``%NN``
form, and numbers >= 100 use RDKit's extended ``%(nnn)`` ring-closure
notation (see https://www.rdkit.org/docs/RDKit_Book.html#ring-closures).

Args:
num: ring-closure number to format
"""
if num < 10:
return str(num)
if num < 100:
return f"%{num}"
return f"%({num})"

@classmethod
def _find_branch_number(cls, inp: str):
"""Find the branch number and ring closure in the SMILES representation using regexp
Expand All @@ -113,16 +130,18 @@ def _find_branch_number(cls, inp: str):
inp: input smiles
"""
inp = re.sub(r"\[.*?\]", "", inp) # noqa
matching_groups = re.findall(r"((?<=%)\d{2})|((?<!%)\d+)(?![^\[]*\])", inp)
# first match is for multiple connection as multiple digits
# second match is for single connections requiring 2 digits
# SMILES does not support triple digits
matching_groups = re.findall(r"(?<=%)\((\d+)\)|(?<=%)(\d{2})|((?<!%)\d+)(?![^\[]*\])", inp)
# first match is the extended '%(nnn)' ring closure (a single label)
# second match is for two-digit '%NN' ring closures
# third match is for plain digits, each one a single-digit ring closure
branch_numbers = []
for m in matching_groups:
if m[0] == "":
branch_numbers.extend(int(mm) for mm in m[1])
elif m[1] == "":
branch_numbers.append(int(m[0].replace("%", "")))
for extended, double, single in matching_groups:
if extended:
branch_numbers.append(int(extended))
elif double:
branch_numbers.append(int(double))
elif single:
branch_numbers.extend(int(d) for d in single)
return branch_numbers

def _ensure_valid(self, inp: str):
Expand All @@ -139,11 +158,10 @@ def _ensure_valid(self, inp: str):
branch_numbers = Counter(branch_numbers)
for i, (bnum, bcount) in enumerate(branch_numbers.items()):
if bcount % 2 != 0:
bnum_str = str(bnum) if bnum < 10 else f"%{bnum}"
bnum_str = self._format_ring_closure(bnum)
_tk = f"[*:{i+1}]{bnum_str}"
if self.use_original_opener_for_attach:
bnum_digit = bnum_str.strip("%") # strip out the % sign
_tk = f"[*:{bnum_digit}]{bnum_str}"
_tk = f"[*:{bnum}]{bnum_str}"
missing_tokens.append(_tk)
return ".".join(missing_tokens)

Expand Down Expand Up @@ -261,8 +279,9 @@ def encoder(
# EN: we first normalize the attachment if the molecule is a query:
# inp = dm.reactions.convert_attach_to_isotope(inp, as_smiles=True)

# TODO(maclandrol): RDKit supports some extended form of ring closure, up to 5 digits
# https://www.rdkit.org/docs/RDKit_Book.html#ring-closures and I should try to include them
# RDKit's extended ring-closure form ('%(nnn)', up to 5 digits) is used for
# labels >= 100; see `_format_ring_closure`.
# https://www.rdkit.org/docs/RDKit_Book.html#ring-closures
branch_numbers = self._find_branch_number(inp)

mol = dm.to_mol(inp, remove_hs=False)
Expand Down Expand Up @@ -345,14 +364,14 @@ def encoder(
attach_pos = sorted(attach_pos)
starting_num = 1 if len(scf_branch_num) == 0 else max(scf_branch_num) + 1
for attach in attach_pos:
val = str(starting_num) if starting_num < 10 else f"%{starting_num}"
val = self._format_ring_closure(starting_num)
# we cannot have anything of the form "\([@=-#-$/\]*\d+\)"
attach_regexp = re.compile(r"(" + re.escape(attach) + r")")
scaffold_str = attach_regexp.sub(val, scaffold_str)
starting_num += 1

# now we need to remove all the parenthesis around digit only number
wrong_attach = re.compile(r"\(([\%\d]*)\)")
wrong_attach = re.compile(r"(?<!%)\((%\(\d+\)|[\%\d]*)\)")
scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str)
# furthermore, we autoapply rdkit-compatible digit standardization.
if rdkit_safe:
Expand Down
37 changes: 37 additions & 0 deletions tests/test_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,40 @@ def test_stereochemistry_issue():
# check if we ignore the stereo
output = safe.encode(STEREO_MOL_LIST[0], ignore_stereo=True, slicer="brics")
assert dm.same_mol(dm.remove_stereochemistry(dm.to_mol(STEREO_MOL_LIST[0])), output)


def test_large_molecule_ring_closures():
# A long peptide produces > 99 SAFE fragments, whose attachment bonds need
# the %(nnn) extended ring-closure form to be valid SMILES.
from rdkit import Chem

seq = "LVYTDCTESGQNLCLCEGSNVCGQGNKCILGSDGEKNQCVTGEGTPKPQSHNDGDFEEIPEEYLQ"
smiles = Chem.MolToSmiles(Chem.MolFromSequence(seq))
encoded = safe.encode(smiles, canonical=True)
assert "%(" in encoded # uses extended ring closures
assert dm.same_mol(smiles, safe.decode(encoded))


def test_extended_ring_closure_decoding():
# The decoder must understand RDKit's extended '%(nnn)' ring-closure form,
# both when reading branch numbers and when completing unpaired attachment
# points.
from rdkit import Chem

conv = safe.SAFEConverter()

# '%(nnn)' is a single ring-closure label, not three separate digits
assert conv._find_branch_number("C%(100)") == [100]
assert conv._find_branch_number("c1ccccc1%(123)") == [1, 1, 123]
# plain single-digit and two-digit forms keep their existing behaviour
assert conv._find_branch_number("C1CC%23C") == [1, 23]

# an unpaired extended label must be completed into a valid molecule
assert dm.to_mol(conv._ensure_valid("C%(100)")) is not None

# fragment-level decoding of a >99 ring-closure molecule must not silently fail
seq = "LVYTDCTESGQNLCLCEGSNVCGQGNKCILGSDGEKNQCVTGEGTPKPQSHNDGDFEEIPEEYLQ"
encoded = safe.encode(Chem.MolToSmiles(Chem.MolFromSequence(seq)), canonical=True)
decoded_fragments = [safe.decode(fragment, fix=True) for fragment in encoded.split(".")]
assert all(x is not None for x in decoded_fragments)
assert safe.decode(encoded, as_mol=True) is not None