Add ESMFold backend smoke test and reproducibility documentation#58
Add ESMFold backend smoke test and reproducibility documentation#58Mose-Kim02 wants to merge 19 commits into
Conversation
…d build dependency
…tates; use hooks for traces
…pipeline - Rewrite hooks.py: target HF encoder.layer[i].attention.self directly instead of broad name-matching; monkey-patch forward to force output_attentions=True so real attention weights [B,H,N,N] are captured (not hidden states); separate attention/activation hooks with correct layer indices; slice out <cls>/<eos> tokens from attention maps - Fix _coords_to_minimal_pdb: use 3-letter residue codes (valid PDB) - Remove dead code: try_use_outputs() path, shared mutable counter - Extract structure logic into _extract_structure() method - Unify FASTA reading into single read_fasta() returning (seq, id, hash) - Wire --dtype through CLI (float32/float16 model loading) - Log runner.run() result (attention/activation layer counts) - Fix trace_adapter: correct head-slicing axis for 3D vs 4D tensors; fix entropy calculation (per-row, not per-matrix)
|
|
||
|
|
||
| def test_esmfold_backend_smoke(): | ||
| output_dir = "outputs/test_trace_ci" |
There was a problem hiding this comment.
looks like the output_dir is hardcoded, where if you run this twice it will pass on stale data from the first run. what i would suggest is using tmp_path (pytest fixture) or tempfile.mkdtemp() so each run starts clean
| output_dir = "outputs/test_trace_ci" | ||
|
|
||
| cmd = [ | ||
| "python", |
There was a problem hiding this comment.
for this line, use sys.executable instead of "python" here. otherwise it might pick up the wrong python on some machines
| for root, _, files in os.walk(f"{output_dir}/trace"): | ||
| trace_files += [f for f in files if f.endswith(".pt")] | ||
|
|
||
| assert len(trace_files) >= 36 |
There was a problem hiding this comment.
checking file count is good but it would also be nice to load one of the .pt files and assert the shape is [1, H, N, N] for attention so we know the tensors are actually correct and not just empty files. specifically, we want to know that the hook actually captured real attention weights with the right dimensions through this additional shape check.
| @@ -0,0 +1,35 @@ | |||
| import os | |||
There was a problem hiding this comment.
general comment for this file: there's already a test_esmf_smoke.py with a few test cases, might make sense to add these into that file instead of a separate one so we dont end up with two test files for the same thing
| ```bash | ||
| python3 -m venv .venv | ||
| source .venv/bin/activate | ||
|
|
There was a problem hiding this comment.
the markdown is broken here since the code fence never gets closed. as a result, everything after this renders as one big block. this just needs the closing triple backticks after each code section
…re, trace relpaths, summary logging, layer_count
…ence#2) Co-authored-by: Rohan Singhal <rsinghal49@atl1-1-03-013-19-0.pace.gatech.edu>
…I2Science#1) * Add VizFold text-file attention export compatible with existing visualization tools * Bug fix: override the positional arg in-place instead of adding to kwargs * Fix: trace_formats missing from meta.json * Robust attention saving & forward signature handling hooks.py: - make the EsmSelfAttention forward patch resilient to signature changes by finding the position of output_attentions by name instead of assuming a fixed positional index trace_adapter.py: - reuse OpenFold's save_attention_topk if available, and falls back to a self-contained NumPy implementation (no OpenFold dependency) that writes msa_row_attn text files - layer-index extraction via regex - compute produced trace_formats dynamically in build_and_write_meta instead of hardcoding ["pt","txt"]
Co-authored-by: Mose Kim <kimmose2002@gmail.com>
…cience#3) * Extract s_s folding trunk activations and enforce safetensors * Update backend pipeline * Capture s_s and s_z at every recycling iteration via trunk hook * Remove test output artifacts --------- Co-authored-by: Rohan Singhal <rsinghal49@atl1-1-03-013-19-0.pace.gatech.edu>
* Add VizFold text-file attention export compatible with existing visualization tools
* Bug fix: override the positional arg in-place instead of adding to kwargs
* Fix: trace_formats missing from meta.json
* Robust attention saving & forward signature handling
hooks.py:
- make the EsmSelfAttention forward patch resilient to signature changes by finding the position of output_attentions by name instead of assuming a fixed positional index
trace_adapter.py:
- reuse OpenFold's save_attention_topk if available, and falls back to a self-contained NumPy implementation (no OpenFold dependency) that writes msa_row_attn text files
- layer-index extraction via regex
- compute produced trace_formats dynamically in build_and_write_meta instead of hardcoding ["pt","txt"]
* Capture and save evoformer trunk intermediates
Add per-block evoformer tracing and output saving for ESMFold.
- hooks.py: introduce register_trunk_hooks and _make_trunk_block_hook to register forward hooks on model.trunk.blocks (EsmFoldTriangularSelfAttentionBlock). Captured per-block sequence_state and pairwise_state are stored in collector.trunk_blocks; clear() updated and warnings added when trunk/blocks are missing.
- inference.py: register the new trunk hooks in ESMFoldRunner, extract and save final folding trunk pair representations (out.s_z), and write per-block evoformer intermediates to trace/trunk/*.pt while recording shapes. Logging messages adjusted.
- trace_adapter.py: update trace layout to include trunk/ files (block_{idx}_seq/pair, s_s, s_z).
* ESMFold: save trunk tensors, CPU attention
Ensure attention tensors are moved to CPU in hooks (detach().cpu()) to avoid GPU tensor serialization. Stop extracting final trunk outputs from model.out and instead collect final s_s/s_z from collector.recycled_s_s/recycled_s_z (avoids redundant copies) and save per-block trunk tensors plus final s_s/s_z into trace/trunk/.
* Squeeze batch dim in hooks; drop recycling archive
Fix tensor shape handling in ESMFoldTraceCollector hooks by squeezing the leading batch dimension before detaching and moving seq and pair states to CPU, preventing stored activations from containing an extra batch axis. Also remove the prior archival of recycled s_s/s_z tensors in the ESMFoldRunner inference flow to avoid redundant/memory-heavy activation copies and logging related to those recycled tensors.
44ab130 to
26ae087
Compare
…recycle output paths
Summary
Validated locally
Purpose
Supports integration testing and reproducibility for Issue #43 shared backend branch.