Skip to content

Attention Rollout Computation and Heatmap Visualization#79

Open
philipxunwu wants to merge 20 commits into
AI2Science:mainfrom
philipxunwu:attention_rollout
Open

Attention Rollout Computation and Heatmap Visualization#79
philipxunwu wants to merge 20 commits into
AI2Science:mainfrom
philipxunwu:attention_rollout

Conversation

@philipxunwu

@philipxunwu philipxunwu commented Apr 30, 2026

Copy link
Copy Markdown

Attention Rollout Computation and Basic Visualization

Extends #16 (Attention Heatmap) and #12 (Web App)
Addresses Issue #7: - Alternative Visualizations of Attention Heads

Overview

This contribution aims to extend the existing attention heatmap visualization with attention rollout, a method for computing the cumulative, cross-layer attention flow through the Evoformer transformer stack. It is integrated directly into the existing Flask web application, so users can toggle between single-layer head views and aggregated rollout views within the same UI without any additional tooling.

Why Attention Rollout?

Standard per-layer, per-head attention maps show local residue-to-residue interactions at a single depth of the model. Because transformer models pass information through residual connections, a raw attention matrix from layer k only reflects that layer's routing decisions, while ignoring all context accumulated by earlier layers.

Attention rollout (Abnar & Zuidema, 2020; see also this walkthrough) addresses this by:

  1. Averaging across attention heads at each layer to get a single per-layer routing matrix.
  2. Adding the identity matrix at each layer (to model the residual/skip connection).
  3. Row-normalizing each matrix so it represents a probability distribution.
  4. Chaining the normalized matrices via sequential matrix multiplication across the selected layer range.

The result is a single matrix that approximates how much information from residue j influenced residue i after passing through the full depth of the network.

In the context of protein folding, this captures the cumulative effect of evolutionary co-variation signals integrated across all Evoformer layers, making it a useful complement to the existing per-layer heatmaps for identifying globally important residue relationships.

Changes Made

1. Dense Attention Saving (openfold/model/evoformer.py, run_pretrained_openfold.py)

The original attention saving code only logged top-k attention edges as text. To support rollout across the full attention matrix, two functions were added to evoformer.py:

  • save_attention_full() — saves a dense (heads × seq_len × seq_len) attention array as a .npz file at each layer.
  • save_all_full_from_recent_attention() — batch-saves the last-computed dense attention for all layers.

In run_pretrained_openfold.py, a new CLI flag --save_full_attn was added. When set, inference populates a dense_attention/ subdirectory alongside the existing text-format attention files.

Edge case: If --save_full_attn is not set, the rollout endpoint gracefully falls back and returns an error instructing the user to re-run with the flag enabled.

2. Rollout Computation (visualize_attention_heatmap_grid.py)

The following functions were added to the existing visualization library:

  • load_dense_attention(path): Loads a .npz attention file; validates array shape.
  • load_attention_array(dir, layer, type) : Unified loader that tries .npz first, falls back to parsing the text format.
  • discover_attention_layers(dir, type): Scans the attention directory and returns the sorted list of available layer indices.
  • compute_attention_rollout(dir, start, end, type, add_identity) : ore rollout: loads layers [start, end], averages heads, adds identity, normalizes, and chains multiplications. Returns a (seq_len × seq_len) NumPy array.
  • create_rollout_heatmap(matrix, threshold) : Generates a static Plotly HTML heatmap from a rollout matrix.
  • visualize_attention_rollout(dir, out, start, end, type, threshold) : High-level driver: computes rollout and saves the static heatmap.
  • create_interactive_rollout_html(dir, type, seq_len) : Generates a fully client-side interactive HTML with layer-range sliders that compute rollout in JavaScript (no server needed).

GPU acceleration: compute_attention_rollout() detects whether PyTorch with CUDA is available. If so, the chain of matrix multiplications runs on GPU (significantly faster for long sequences); otherwise it falls back to NumPy. This is transparent to the caller.

CLI usage:

# Static rollout heatmap for layers 0-47, MSA row attention
python visualize_attention_heatmap_grid.py \
  --attention_dir ./outputs/attention_files_6KWC_demo_tri_18 \
  --output_dir ./outputs/rollout_out \
  --seq_len 100 \
  --rollout \
  --rollout_layer_start 0 \
  --rollout_layer_end 47 \
  --attention_type msa_row \
  --threshold 0.01

# Interactive HTML (client-side, no server required)
python visualize_attention_heatmap_grid.py \
  --attention_dir ./outputs/attention_files_6KWC_demo_tri_18 \
  --output_dir ./outputs/rollout_interactive \
  --seq_len 100 \
  --interactive_rollout \
  --attention_type msa_row

3. Flask Backend Endpoint (web_app/app.py)

A new endpoint GET /viz/rollout was added to the existing Flask server:

Query parameters:

  • protein_id: protein identifier
  • attention_type: "msa_row" or "triangle_start"
  • seq_len: sequence length
  • layer_start: first layer to include in rollout
  • layer_end: last layer to include in rollout
  • threshold: values below this are zeroed in response

Response:

{
  "rollout_matrix": [[0.12, 0.03, ...], [0.01, 0.45, ...], ...]
}

Returns a 2D array (list of lists) suitable for direct use by the frontend canvas renderer.

4. Frontend Integration (web_app/templates/index.html)

The existing attention panel was extended with rollout controls below the per-head head selector:

New UI elements:

  • Layer range start / end — numeric inputs (layer_start, layer_end) defaulting to 0 and the maximum discovered layer.
  • Threshold — numeric input for zeroing low-weight entries before rendering.
  • Compute Rollout button — triggers the backend call.

JavaScript function computeRolloutBackend():

  1. Reads the layer range, threshold, and currently selected attention type from the UI.
  2. POSTs to /viz/rollout with the parameters.
  3. On success, passes the returned rollout_matrix directly to the existing renderHeatmap() function, which reuses the same canvas element already used for single-layer head views.
  4. The arc diagram is not updated on rollout (rollout is a global matrix, not naturally visualizable as an arc per the existing SVG arc layout), and a small label "Rollout view" is shown below the heatmap to indicate the mode.

Running the Web App with Rollout

cd web_app
python app.py
# Server starts on http://localhost:9000

To enable rollout, predictions must be run with dense attention saving. The web app passes this flag automatically when the prediction is launched through the UI. For manual inference:

python run_pretrained_openfold.py \
  --fasta_paths ./my_protein.fasta \
  --template_mmcif_dir ./templates \
  --output_dir ./outputs \
  --model_device cuda:0 \
  --save_full_attn          # <-- required for rollout

@philipxunwu philipxunwu changed the title Attention Rollout Attention Rollout Computation and Heatmap Visualization Apr 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants