-
Notifications
You must be signed in to change notification settings - Fork 1
KNN edge construction fix + align residue indexing with ESM sanitization #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
24e220d
25a745e
a001f50
d6e0578
6344ad4
545cd93
0cfbfb1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,12 +34,37 @@ def build_knn_edges( | |
| batch_dst: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| KNN edges from src -> dst (source indices in row 0, dest in row 1). | ||
| Build KNN edges from src -> dst (source indices in row 0, dest in row 1). | ||
|
|
||
| The KNN query is performed *per destination*: for each point in ``dst_pos`` | ||
| we look up its ``k`` nearest neighbors in ``src_pos`` (``knn(x=src_pos, | ||
| y=dst_pos, ...)``) and emit them as incoming edges. As a consequence every | ||
| destination node is guaranteed to have up to ``k`` incoming edges (and so | ||
| appears in row 1), whereas a source node that is no destination's nearest | ||
| neighbor may not appear in row 0 at all. Coverage checks ("every node has an | ||
| edge") must therefore be made against the destination row (row 1). | ||
|
|
||
| For a homogeneous graph (``src_pos is dst_pos``) self-edges are dropped. | ||
|
|
||
| Args: | ||
| src_pos: (N_src, 3) source node positions. | ||
| dst_pos: (N_dst, 3) destination node positions. | ||
| k: Number of nearest source neighbors to find per destination node. | ||
| batch_src: (N_src,) batch assignment for source nodes, or None. | ||
| batch_dst: (N_dst,) batch assignment for destination nodes, or None. | ||
|
|
||
| Returns: | ||
| (2, E) edge index tensor with source indices in row 0, destination in | ||
| row 1. | ||
| """ | ||
| if src_pos.numel() == 0 or dst_pos.numel() == 0: | ||
| return torch.empty(2, 0, dtype=torch.long, device=src_pos.device) | ||
|
|
||
| idx = knn(x=dst_pos, y=src_pos, k=k, batch_x=batch_dst, batch_y=batch_src) | ||
| # knn(x=src_pos, y=dst_pos) returns row 0 = dst (query) indices, row 1 = src | ||
| # (neighbor) indices; swap so the result follows the src(row 0)->dst(row 1) | ||
| # edge_index convention. | ||
| idx = knn(x=src_pos, y=dst_pos, k=k, batch_x=batch_src, batch_y=batch_dst) | ||
| idx = torch.stack((idx[1], idx[0]), dim=0) | ||
|
Comment on lines
+66
to
+67
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add an exact set asymmetric test with toy cases where you know true neighbors to test on the directionalities, both within the |
||
|
|
||
| # remove self-edges if homogeneous | ||
| if src_pos.data_ptr() == dst_pos.data_ptr(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
|
|
||
| """ | ||
| Utility functions organized by category: | ||
| 1. Feature encoding (rbf, atom37_to_atoms, normalize_ins_code) | ||
| 1. Feature encoding (rbf, normalize_ins_code) | ||
| 2. Optimal transport (ot_coupling) | ||
| 3. Metrics (recall_precision, compute_rmsd, compute_placement_metrics) | ||
| 4. Visualization (plot_3d_frame, create_trajectory_gif, save_protein_plot) | ||
|
|
@@ -24,44 +24,9 @@ | |
| from PIL import Image | ||
| from scipy.optimize import linear_sum_assignment | ||
| from torch import Tensor | ||
| from torch_geometric.nn import knn | ||
| from tqdm import tqdm | ||
|
|
||
| from src.constants import NUM_RBF, RBF_CUTOFF | ||
|
|
||
|
|
||
| def build_knn_edges( | ||
| src_pos: torch.Tensor, | ||
| dst_pos: torch.Tensor, | ||
| k: int, | ||
| batch_src: torch.Tensor | None = None, | ||
| batch_dst: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Build KNN edges from source to destination nodes. | ||
|
|
||
| Args: | ||
| src_pos: (N_src, 3) source node positions | ||
| dst_pos: (N_dst, 3) destination node positions | ||
| k: Number of nearest neighbors per source node | ||
| batch_src: (N_src,) batch indices for source nodes, or None if single graph | ||
| batch_dst: (N_dst,) batch indices for destination nodes, or None if single graph | ||
|
|
||
| Returns: | ||
| (2, E) edge index tensor with source indices in row 0, destination in row 1. | ||
| Self-edges are removed for homogeneous graphs (src_pos is dst_pos). | ||
| """ | ||
| if src_pos.numel() == 0 or dst_pos.numel() == 0: | ||
| return torch.empty(2, 0, dtype=torch.long, device=src_pos.device) | ||
|
|
||
| idx = knn(x=dst_pos, y=src_pos, k=k, batch_x=batch_dst, batch_y=batch_src) | ||
|
|
||
| # remove self-edges if homogeneous | ||
| if src_pos.data_ptr() == dst_pos.data_ptr(): | ||
| mask = idx[0] != idx[1] | ||
| idx = idx[:, mask] | ||
|
|
||
| return idx.unique(dim=1) | ||
| from src.constants import NUM_RBF, ONE_TO_THREE, RBF_CUTOFF, THREE_TO_ONE | ||
|
|
||
|
|
||
| def setup_logging_for_tqdm( | ||
|
|
@@ -117,6 +82,36 @@ def normalize_ins_code(value) -> str: | |
| return ins | ||
|
|
||
|
|
||
| def sanitize_res_names_for_esm(atoms): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add type hint on input and output |
||
| """ | ||
| Return a copy of an AtomArray with residue names canonicalized to match the | ||
| ESM embedding pipeline. | ||
|
|
||
| Each residue name is mapped to its one-letter code and back | ||
| (``THREE_TO_ONE`` -> ``ONE_TO_THREE``), with anything unrecognized collapsed | ||
| to ``"UNK"``. This merges non-canonical names that share a residue position | ||
| (e.g. modified residues -> their canonical parent, unknowns -> ``UNK``) so | ||
| that biotite's ``get_residue_starts`` does not split them apart. | ||
|
|
||
| This is the single source of truth for residue-name sanitization shared by | ||
| ``scripts/generate_esm_embeddings.py`` (which feeds the sanitized structure | ||
| to ESM3) and ``src/dataset.py`` (which derives residue indices that must line | ||
| up with the stored ESM embeddings). Insertion codes are normalized | ||
| separately via :func:`normalize_ins_code`. | ||
|
|
||
| Args: | ||
| atoms: A biotite ``AtomArray`` with a ``res_name`` annotation. | ||
|
|
||
| Returns: | ||
| A copy of ``atoms`` with ``res_name`` canonicalized. | ||
| """ | ||
| sanitized = atoms.copy() | ||
| for i in range(len(sanitized)): | ||
| aa1 = THREE_TO_ONE.get(sanitized.res_name[i], "X") | ||
| sanitized.res_name[i] = ONE_TO_THREE.get(aa1, "UNK") | ||
| return sanitized | ||
|
|
||
|
|
||
| def parse_split_file(split_file: Path, base_pdb_dir: Path) -> list[dict]: | ||
| """ | ||
| Parse split file and construct entries with paths. | ||
|
|
@@ -164,9 +159,6 @@ def parse_split_file(split_file: Path, base_pdb_dir: Path) -> list[dict]: | |
| return entries | ||
|
|
||
|
|
||
| ATOM37_FILL = 1e-5 | ||
|
|
||
|
|
||
| def rbf(r: Tensor, num_gaussians: int = NUM_RBF, cutoff: float = RBF_CUTOFF) -> Tensor: | ||
| """ | ||
| Compute radial basis function encoding of distances. | ||
|
|
@@ -264,32 +256,6 @@ def compute_edge_features( | |
| return unit_vectors, rbf_features | ||
|
|
||
|
|
||
| def atom37_to_atoms( | ||
| atom_tensor: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Convert atom37 representation to flat atom list. | ||
|
|
||
| Args: | ||
| atom_tensor: (N_res, 37, 3) atom37 coordinates | ||
|
|
||
| Returns: | ||
| coords: (N_atoms, 3) coordinates of present atoms | ||
| residue_index: (N_atoms,) which residue each atom belongs to | ||
| atom_type: (N_atoms,) atom type index (0-36) | ||
| """ | ||
| present = (atom_tensor != ATOM37_FILL).any(dim=-1) # (N_res, 37) | ||
| nz = present.nonzero(as_tuple=False) # (N_atoms, 2) | ||
| residue_index = nz[:, 0] | ||
| atom_type = nz[:, 1].long() | ||
|
|
||
| flat = atom_tensor.reshape(-1, 3) | ||
| flat_mask = present.reshape(-1) | ||
| coords = flat[flat_mask] | ||
|
|
||
| return coords, residue_index, atom_type | ||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def ot_coupling( | ||
| x1: torch.Tensor, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe name the variables clearer or add a comment explaining why you are not using atomarray.res_id. The same sanitization can probably be achieved by
bts.spread_residue_wise(sanitized_for_idx, np.arange(num_res))but as long as it's correct I don't particularly care how it is implemented. With that, probably worth adding a test checking you are assigning the res_id correctly.