Attention Rollout Computation and Heatmap Visualization#79
Open
philipxunwu wants to merge 20 commits into
Open
Attention Rollout Computation and Heatmap Visualization#79philipxunwu wants to merge 20 commits into
philipxunwu wants to merge 20 commits into
Conversation
…ed global norm and per head norm options, as well as filtering on attention weight threshold
Added a demo link for the Protein Structure Prediction web interface.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Attention Rollout Computation and Basic Visualization
Extends #16 (Attention Heatmap) and #12 (Web App)
Addresses Issue #7: - Alternative Visualizations of Attention Heads
Overview
This contribution aims to extend the existing attention heatmap visualization with attention rollout, a method for computing the cumulative, cross-layer attention flow through the Evoformer transformer stack. It is integrated directly into the existing Flask web application, so users can toggle between single-layer head views and aggregated rollout views within the same UI without any additional tooling.
Why Attention Rollout?
Standard per-layer, per-head attention maps show local residue-to-residue interactions at a single depth of the model. Because transformer models pass information through residual connections, a raw attention matrix from layer k only reflects that layer's routing decisions, while ignoring all context accumulated by earlier layers.
Attention rollout (Abnar & Zuidema, 2020; see also this walkthrough) addresses this by:
The result is a single matrix that approximates how much information from residue j influenced residue i after passing through the full depth of the network.
In the context of protein folding, this captures the cumulative effect of evolutionary co-variation signals integrated across all Evoformer layers, making it a useful complement to the existing per-layer heatmaps for identifying globally important residue relationships.
Changes Made
1. Dense Attention Saving (
openfold/model/evoformer.py,run_pretrained_openfold.py)The original attention saving code only logged top-k attention edges as text. To support rollout across the full attention matrix, two functions were added to
evoformer.py:save_attention_full()— saves a dense(heads × seq_len × seq_len)attention array as a.npzfile at each layer.save_all_full_from_recent_attention()— batch-saves the last-computed dense attention for all layers.In
run_pretrained_openfold.py, a new CLI flag--save_full_attnwas added. When set, inference populates adense_attention/subdirectory alongside the existing text-format attention files.Edge case: If
--save_full_attnis not set, the rollout endpoint gracefully falls back and returns an error instructing the user to re-run with the flag enabled.2. Rollout Computation (
visualize_attention_heatmap_grid.py)The following functions were added to the existing visualization library:
load_dense_attention(path): Loads a.npzattention file; validates array shape.load_attention_array(dir, layer, type): Unified loader that tries.npzfirst, falls back to parsing the text format.discover_attention_layers(dir, type): Scans the attention directory and returns the sorted list of available layer indices.compute_attention_rollout(dir, start, end, type, add_identity): ore rollout: loads layers[start, end], averages heads, adds identity, normalizes, and chains multiplications. Returns a(seq_len × seq_len)NumPy array.create_rollout_heatmap(matrix, threshold): Generates a static Plotly HTML heatmap from a rollout matrix.visualize_attention_rollout(dir, out, start, end, type, threshold): High-level driver: computes rollout and saves the static heatmap.create_interactive_rollout_html(dir, type, seq_len): Generates a fully client-side interactive HTML with layer-range sliders that compute rollout in JavaScript (no server needed).GPU acceleration:
compute_attention_rollout()detects whether PyTorch with CUDA is available. If so, the chain of matrix multiplications runs on GPU (significantly faster for long sequences); otherwise it falls back to NumPy. This is transparent to the caller.CLI usage:
3. Flask Backend Endpoint (
web_app/app.py)A new endpoint
GET /viz/rolloutwas added to the existing Flask server:Query parameters:
protein_id: protein identifierattention_type: "msa_row" or "triangle_start"seq_len: sequence lengthlayer_start: first layer to include in rolloutlayer_end: last layer to include in rolloutthreshold: values below this are zeroed in responseResponse:
{ "rollout_matrix": [[0.12, 0.03, ...], [0.01, 0.45, ...], ...] }Returns a 2D array (list of lists) suitable for direct use by the frontend canvas renderer.
4. Frontend Integration (
web_app/templates/index.html)The existing attention panel was extended with rollout controls below the per-head head selector:
New UI elements:
layer_start,layer_end) defaulting to 0 and the maximum discovered layer.JavaScript function
computeRolloutBackend():/viz/rolloutwith the parameters.rollout_matrixdirectly to the existingrenderHeatmap()function, which reuses the same canvas element already used for single-layer head views.Running the Web App with Rollout
To enable rollout, predictions must be run with dense attention saving. The web app passes this flag automatically when the prediction is launched through the UI. For manual inference:
python run_pretrained_openfold.py \ --fasta_paths ./my_protein.fasta \ --template_mmcif_dir ./templates \ --output_dir ./outputs \ --model_device cuda:0 \ --save_full_attn # <-- required for rollout