Skip to content

botmahn/mmfa

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Multimodal Flash-Attention (MMFA)

CS2.502 – Hardware for AI | IIIT-Hyderabad
Naman Mishra · Abnp Sirisha · Dinesh Nagalla


Overview

MMFA is a memory-efficient attention algorithm that accelerates inference in Vision-Language Models (VLMs). It addresses two dominant bottlenecks in VLM inference:

  1. Quadratic attention cost — self-attention over long multimodal sequences causes repeated high-bandwidth memory (HBM) transfers when materializing full N×N attention matrices.
  2. Visual token redundancy — visual encoders produce many tokens, but only a fraction contribute meaningfully to generation.

MMFA extends FlashAttention to Turing-generation GPUs (which lack official FlashAttention support) via a from-scratch Triton kernel, and combines it with SparseVLM-style importance-based visual token pruning in a fused operation. Together these eliminate full-matrix materialization, minimize HBM traffic, and reduce the effective sequence length fed to attention.

Empirical results show speedups of up to 7× over native PyTorch attention.

MMFA Flow


Repository Structure

mmfa/
├── profiling_scripts/          # Baseline latency profiling for SOTA VLMs
│   ├── gemma3_benchmark.py
│   ├── llava_benchmark.py
│   ├── llama3.2_benchmark.py
│   ├── smolVLM_benchmark.py
│   └── qwen2.5vl_benchmark.py
│
├── mmfa/
│   ├── turing_fa/              # Triton FlashAttention kernel for Turing GPUs
│   │   └── profile_turing_fa.py
│   ├── turing_mmfa/            # Full MMFA: pruning + attention benchmark
│   │   ├── run_mmfa_benchmark.py
│   │   ├── run_mmfa_benchmark_videos.py
│   │   └── mmfa_educational.py
│   ├── monkey_patch/           # Retro-fit MMFA into real SOTA VLMs
│   │   ├── mmfa_cls.py         # TritonAttentionRunner class
│   │   ├── run_gemma3_benchmark.py
│   │   ├── run_llava_benchmark.py
│   │   └── run_qwen3vl_benchmark.py
│   └── vqa/                    # VQA evaluation on pruned vs original images
│       ├── generate_pruned_images/
│       ├── run_vqa/
│       └── save_pruned_tensors/
│
├── results/
│   ├── model_benchmarks/       # Per-model latency CSVs (TTFT, E2E, ITL, etc.)
│   └── qualitative_analysis.png
│
├── assets/
│   └── mmfa_flow.png
└── environment.yml

Installation

git clone https://github.com/botmahn/mmfa.git
cd mmfa/
# Requires conda
conda env create -f environment.yml
conda activate mmfa

Key dependencies: Python 3.10, PyTorch 2.6 (CUDA 12.4), Triton 3.2, transformers, accelerate, xformers.

Hardware: Turing GPU recommended (RTX 2080Ti tested). The Triton kernels are designed specifically to work around Turing's lack of official FlashAttention support.


Experiments

1. Baseline VLM Profiling

Profile TTFT, end-to-end latency, inter-token latency, decode throughput, and FLOP/byte counts for SOTA VLMs across a grid of image sizes and sequence lengths. Results are logged to CSV.

python profiling_scripts/gemma3_benchmark.py \
    --images_dir /path/to/images \
    --prompt "Describe this image in detail." \
    --image_sizes "448,560,672" \
    --max_seq_lens "1024,2048" \
    --max_new_tokens 128 \
    --dtype float16 \
    --repeats 3 \
    --csv gemma3_vl_grid.csv

Supported models: Gemma-3, LLaVA, LLaMA-3.2, SmolVLM, Qwen2.5-VL. Swap the script name for the target model.


2. Triton FlashAttention on Turing GPUs

Benchmark the from-scratch Triton kernel against PyTorch manual attention and the official flash-attn library. Supports MHA, GQA, MQA, causal masking, and sliding-window attention.

python mmfa/turing_fa/profile_turing_fa.py \
    --b 1 \        # batch size
    --s 2048 \     # sequence length
    --n_q 12 \     # query heads
    --n_kv 12 \    # key/value heads (set < n_q for GQA, set to 1 for MQA)
    --d 64 \       # per-head dimension
    --w -1 \       # window size (-1 = full attention)
    --causal \     # enable causal masking
    --layers 32    # number of layers to simulate

3. MMFA: Pruning + Attention Benchmark

Builds a toy multimodal transformer, extracts Q/K/V for interleaved text and image tokens, scores visual token importance (Triton kernel), prunes low-importance patches, and benchmarks attention before and after pruning for both PyTorch and Triton backends. Saves a visualization of kept patches.

Moderate benchmark:

python mmfa/turing_mmfa/run_mmfa_benchmark.py \
    --image sample.jpg \
    --keep 0.4 \
    --num_heads 12 \
    --head_dim 64 \
    --img_size 224 \
    --batch_size 1

Heavy benchmark:

python mmfa/turing_mmfa/run_mmfa_benchmark.py \
    --image sample.jpg \
    --keep 0.5 \
    --num_heads 16 \
    --head_dim 128 \
    --img_size 336 \
    --batch_size 2

Pass --ignore_overhead to measure theoretical max speedup (excluding the importance scoring step).


4. Monkey-Patching SOTA VLMs

These scripts load real model weights (Gemma-3, LLaVA, Qwen3-VL), generate dummy multimodal hidden states matching the model's actual tensor shapes, and benchmark the model's default attention against the Triton kernel as a drop-in replacement.

# Gemma-3
python mmfa/monkey_patch/run_gemma3_benchmark.py

# LLaVA
python mmfa/monkey_patch/run_llava_benchmark.py

# Qwen3-VL
python mmfa/monkey_patch/run_qwen3vl_benchmark.py

Models are loaded from Hugging Face. Set HF_TOKEN or HUGGINGFACE_TOKEN in your environment for gated models.


5. VQA Evaluation

Measures whether visual token pruning degrades task accuracy on VQAv2. Download the dataset from visualqa.org/download.html first.

Step 1 – Generate pruned images:

python mmfa/vqa/generate_pruned_images/gemma3_generate_pruned_images.py \
    --image_dir /datasets/COCO/val2014 \
    --questions_json v2_OpenEnded_mscoco_val2014_questions.json \
    --out_image_dir results/viz \
    --out_json_dir results/json \
    --keep 0.5 \
    --max_new_tokens 128 \
    --max_samples 200 \
    --ignore_overhead

Step 2 – Run VQA evaluation:

python mmfa/vqa/run_vqa/run_gemma3_vqa.py \
    --pruned_dir results/pruned \
    --normal_dir /datasets/COCO/val2014 \
    --questions_json data/v2_OpenEnded_mscoco_val2014_questions.json \
    --annotations_json data/v2_mscoco_val2014_annotations.json \
    --model_id google/gemma-3-12b-it \
    --split val2014 \
    --max_samples 300

Results

Benchmark CSVs for all evaluated models are in results/model_benchmarks/. Each CSV records per-run metrics including TTFT, E2E latency, ITL (mean/P50/P90), decode throughput, FLOP counts, and GPU memory usage across a grid of image sizes and context lengths.

Model Size
Gemma-3 4B, 12B
LLaVA-OneVision 4B, 8B
LLaMA-3.2 11B
SmolVLM 2B
Qwen2.5-VL 3B, 7B

How It Works

Triton FlashAttention Kernel

The kernel implements the online softmax + tiled attention computation from the FlashAttention paper. Key design choices for Turing compatibility:

  • Tiled computation: Q/K/V are loaded in SRAM-sized blocks (BLOCK_M=64, BLOCK_N=64) to avoid materializing the full N×N matrix in HBM.
  • Online softmax: Running max (m_i) and normalizer (l_i) are updated incrementally, enabling numerically stable attention without a second pass.
  • GQA/MQA support: Query head count (H_q) is decoupled from KV head count (H_kv); grouped-query and multi-query attention are handled natively.
  • Causal and windowed masking: Both are applied inside the kernel loop as compile-time constants (tl.constexpr) to avoid runtime branching.

Visual Token Pruning

Importance scores are computed via a fused Triton kernel (_importance_kernel) that measures how strongly each visual token's key vector attends to all text query vectors — a proxy for the visual patch's relevance to the query. The top-k patches (controlled by --keep) are retained; the rest are discarded before the attention call. This directly reduces sequence length and attention FLOP count.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages