KNN edge construction fix + align residue indexing with ESM sanitization#82
KNN edge construction fix + align residue indexing with ESM sanitization#82vratins wants to merge 7 commits into
Conversation
📝 WalkthroughWalkthroughThe PR adds shared ESM residue-name sanitization, uses it in embedding generation and dataset residue indexing, and updates KNN edge construction to emit source-to-destination directed edges. Related tests now assert the updated residue and edge semantics. ChangesESM Residue Sanitization Alignment
KNN Edge Direction Correction
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR fixes KNN-based edge construction directionality in the flow model, hardens dataset preprocessing for edge cases, and aligns protein residue indexing with the same residue-name sanitization used when generating cached ESM embeddings.
Changes:
- Fix
build_knn_edgesto query per-destination point (and swap returned index rows) so edges align with intended src→dst semantics. - Update dataset preprocessing to handle empty
atomsinputs and to compute residue indices using ESM-style residue-name canonicalization before residue-boundary detection. - Adjust tests by adding an
xfailmarker for a batched edge-connectivity test (though the WW coverage assertions likely need updating instead).
Reviewed changes
Copilot reviewed 3 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
uv.lock |
Updates locked dependencies (adds jaxtyping, removes mypy, adjusts some wheels/metadata). |
tests/test_flow.py |
Marks the batched water-edge connectivity test as xfail(strict=True) and updates rationale text. |
src/flow.py |
Fixes KNN query argument order and explicitly swaps index rows to preserve src→dst edge_index layout. |
src/dataset.py |
Handles empty atoms in coordinate matching; updates residue indexing to mirror ESM sanitization behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/test_flow.py (1)
582-617:⚠️ Potential issue | 🟠 Major | ⚡ Quick winNarrow the
xfailscope so it doesn’t hidepwregressions.
xfailis currently applied to the whole test, so failures in the protein-water assertions are also treated as expected. That drops useful coverage beyond the knownwwissue.Suggested split to preserve `pw` coverage
- `@pytest.mark.xfail`( + def test_batched_waters_have_protein_edges(self, batched_hetero_data): + """Ensure all waters in a batched graph have protein-water edges.""" + updater = ProteinWaterUpdate(hidden_dims=(128, 16), layers=1) + edge_dict = updater.build_edges(batched_hetero_data, k_pw=4, k_ww=3) + pw_edges = edge_dict[("protein", "pw", "water")] + n_water = batched_hetero_data["water"].num_nodes + water_nodes_with_pw_edges = torch.unique(pw_edges[1]) + assert len(water_nodes_with_pw_edges) == n_water, ( + f"Only {len(water_nodes_with_pw_edges)}/{n_water} waters have protein edges in batched data" + ) + + `@pytest.mark.xfail`( reason=( "build_knn_edges' src/dst argument-order fix changes self-graph (ww) " "edge direction: row 0 now holds discovered neighbors rather than query " "points, so a point that is nobody's k-nearest neighbor can be dropped " "from coverage. The fixed-degree k_pw/k_ww KNN approach is replaced by " "radius-based edges + KNN-fallback-for-isolated-nodes in a future PR " "(edge type flags & dynamic edge construction), which removes the " "k_pw/k_ww params and fixes this guarantee structurally. will remove this " "marker when that PR is created." ), strict=True, ) - def test_batched_waters_have_edges(self, batched_hetero_data): - """Ensure all waters in a batched graph have edges.""" + def test_batched_waters_have_water_edges(self, batched_hetero_data): + """Ensure all waters in a batched graph have water-water edges.""" updater = ProteinWaterUpdate(hidden_dims=(128, 16), layers=1) - edge_dict = updater.build_edges(batched_hetero_data, k_pw=4, k_ww=3) - pw_edges = edge_dict[("protein", "pw", "water")] ww_edges = edge_dict[("water", "ww", "water")] - n_water = batched_hetero_data["water"].num_nodes - - # Check protein-water edges - water_nodes_with_pw_edges = torch.unique(pw_edges[1]) - assert len(water_nodes_with_pw_edges) == n_water, ( - f"Only {len(water_nodes_with_pw_edges)}/{n_water} waters have protein edges in batched data" - ) - - # Check water-water edges if n_water > 1: water_nodes_with_ww_edges = torch.unique(ww_edges[0]) assert len(water_nodes_with_ww_edges) == n_water, ( f"Only {len(water_nodes_with_ww_edges)}/{n_water} waters have water-water edges in batched data" )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_flow.py` around lines 582 - 617, The xfail marker is currently applied to the entire test_batched_waters_have_edges function, which hides failures in the protein-water edge assertions that should not be expected to fail. Remove the xfail decorator from the function and instead apply it only to the water-water edge checking section (the assertions checking water_nodes_with_ww_edges). This can be done by either splitting the test into two separate test functions with xfail only on the water-water test, or by wrapping just the water-water edge assertion block with pytest.xfail() to preserve protein-water edge coverage while still allowing the known water-water edge failure.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/dataset.py`:
- Around line 993-1000: The insertion code normalization is missing before
calculating residue starts, which causes misalignment with the cached ESM
embeddings. After the loop that sanitizes res_name for the sanitized_for_idx
object (which converts three-letter codes to one-letter and back), add code to
normalize the ins_code field by setting blank or non-standard insertion codes to
a consistent placeholder value (similar to how "X" is used for unknown
residues). This normalization must occur before calling
bts.get_residue_starts(sanitized_for_idx) to ensure the residue count and
protein_res_idx indices match what was computed in generate_esm_embeddings.py.
---
Outside diff comments:
In `@tests/test_flow.py`:
- Around line 582-617: The xfail marker is currently applied to the entire
test_batched_waters_have_edges function, which hides failures in the
protein-water edge assertions that should not be expected to fail. Remove the
xfail decorator from the function and instead apply it only to the water-water
edge checking section (the assertions checking water_nodes_with_ww_edges). This
can be done by either splitting the test into two separate test functions with
xfail only on the water-water test, or by wrapping just the water-water edge
assertion block with pytest.xfail() to preserve protein-water edge coverage
while still allowing the known water-water edge failure.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 65109962-a7f2-481b-b0ef-737262e6f23a
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (3)
src/dataset.pysrc/flow.pytests/test_flow.py
DorisMai
left a comment
There was a problem hiding this comment.
The only blocking request I have is add a test for the knn fix to loudly fail the mix up (previous bug) and guards any additional mix up the ordering. All other comments are nitpick.
| } | ||
|
|
||
|
|
||
| def match_atoms_to_coords( |
There was a problem hiding this comment.
This is a comment outside the diff as I look at this function. Should probably add a check (and log warning if failed) that the length of atoms and target_coords should roughly match, and majority atoms matched.
| return ins | ||
|
|
||
|
|
||
| def sanitize_res_names_for_esm(atoms): |
There was a problem hiding this comment.
Add type hint on input and output
| res_starts = bts.get_residue_starts(sanitized_for_idx) | ||
| num_residues = len(res_starts) | ||
| atom_res_idx = ( | ||
| np.searchsorted(res_starts, np.arange(len(protein_atoms)), side="right") - 1 |
There was a problem hiding this comment.
maybe name the variables clearer or add a comment explaining why you are not using atomarray.res_id. The same sanitization can probably be achieved by bts.spread_residue_wise(sanitized_for_idx, np.arange(num_res)) but as long as it's correct I don't particularly care how it is implemented. With that, probably worth adding a test checking you are assigning the res_id correctly.
| assert torch.allclose(coords[0], expected_coord) | ||
| @staticmethod | ||
| def _esm_key_count(atoms): | ||
| """Replicate the ESM script's residue-key counting.""" |
There was a problem hiding this comment.
if truly a replicate then why not extract this part of ESM script into a helper function to import here?
| idx = knn(x=src_pos, y=dst_pos, k=k, batch_x=batch_src, batch_y=batch_dst) | ||
| idx = torch.stack((idx[1], idx[0]), dim=0) |
There was a problem hiding this comment.
Please add an exact set asymmetric test with toy cases where you know true neighbors to test on the directionalities, both within the knn function (to explicitly fail the x vs y mix up) and for interpreting the output tensor. specifically for the output, I couldn't find anywhere officially documented that the knn output tensor row 0 is destination indices and row 1 is source indices. moreover, it looks like the latests knn function in torch_geometry is different from the 2.7.0 version this repo currently uses, so it might worth making a note here in addition to explicitly test the behavior stays the same.
| updater = ProteinWaterUpdate(hidden_dims=(128, 16), layers=1) | ||
|
|
||
| edge_dict = updater.build_edges(simple_hetero_data, k_pw=4, k_ww=3) | ||
| pw_edges = edge_dict[("protein", "pw", "water")] |
There was a problem hiding this comment.
this is just a curiosity question outside the diff. It looks like PW is the only type with symmetrized knn edges built. Is there a particular motivation for this?
build_knn_edgeswas callingknn(x=dst_pos, y=src_pos), which queries eachsrcpoint's nearestdstpoints instead of eachdstpoint's nearestsrcpoints; swapped the call and the resulting index rows to fix this. This was previously masked by taking the union of edges on both sides. A future PR will use KNN as a fallback to radius graphs. Expanded the docstring to document that the query is per-destination, so every destination is guaranteed incoming edges (row 1) while a source that is nobody's nearest neighbor may be absent fromrow 0— and updated the water-water coverage tests to assert on the destination row (row 1) accordingly, removing the now-unnecessaryxfailmarkers.match_atoms_to_coordsnow also handles an emptyatomsarray instead of only an emptytarget_coordsarray.THREE_TO_ONE->ONE_TO_THREE, unknowns ->UNK) before counting residue boundaries with biotite'sget_residue_starts. Without this, two residues that share (chain,resid,ins_code) but had different originalres_namescould get merged into one under ESM's sanitization but stay separate here, desyncing residue counts/indices from the stored ESM embeddings. Insertion codes are now also normalized (normalize_ins_code) before counting, sinceget_residue_startssplits onins_codetoo — a blank vs. placeholder code (''/'.'/'?') would otherwise split or merge residues differently from the ESM script's residue keys. Both the canonicalization (sanitize_res_names_for_esm) and the insertion-code normalization are now shared betweensrc/dataset.pyandscripts/generate_esm_embeddings.pyto prevent the two paths from drifting apart.build_knn_edgesinsrc/utils.py(the canonical one lives insrc/flow.py) that still carried the old per-source semantics, and removed the orphanedatom37_to_atoms/ATOM37_FILLhelper (unused outside its own tests; the SLAE pipeline uses its own implementation). Added regression tests covering residue-count alignment between the dataset and the ESM residue keys.Summary by CodeRabbit
Bug Fixes
Tests
Documentation