CS2.502 – Hardware for AI | IIIT-Hyderabad
Naman Mishra · Abnp Sirisha · Dinesh Nagalla
MMFA is a memory-efficient attention algorithm that accelerates inference in Vision-Language Models (VLMs). It addresses two dominant bottlenecks in VLM inference:
- Quadratic attention cost — self-attention over long multimodal sequences causes repeated high-bandwidth memory (HBM) transfers when materializing full N×N attention matrices.
- Visual token redundancy — visual encoders produce many tokens, but only a fraction contribute meaningfully to generation.
MMFA extends FlashAttention to Turing-generation GPUs (which lack official FlashAttention support) via a from-scratch Triton kernel, and combines it with SparseVLM-style importance-based visual token pruning in a fused operation. Together these eliminate full-matrix materialization, minimize HBM traffic, and reduce the effective sequence length fed to attention.
Empirical results show speedups of up to 7× over native PyTorch attention.
mmfa/
├── profiling_scripts/ # Baseline latency profiling for SOTA VLMs
│ ├── gemma3_benchmark.py
│ ├── llava_benchmark.py
│ ├── llama3.2_benchmark.py
│ ├── smolVLM_benchmark.py
│ └── qwen2.5vl_benchmark.py
│
├── mmfa/
│ ├── turing_fa/ # Triton FlashAttention kernel for Turing GPUs
│ │ └── profile_turing_fa.py
│ ├── turing_mmfa/ # Full MMFA: pruning + attention benchmark
│ │ ├── run_mmfa_benchmark.py
│ │ ├── run_mmfa_benchmark_videos.py
│ │ └── mmfa_educational.py
│ ├── monkey_patch/ # Retro-fit MMFA into real SOTA VLMs
│ │ ├── mmfa_cls.py # TritonAttentionRunner class
│ │ ├── run_gemma3_benchmark.py
│ │ ├── run_llava_benchmark.py
│ │ └── run_qwen3vl_benchmark.py
│ └── vqa/ # VQA evaluation on pruned vs original images
│ ├── generate_pruned_images/
│ ├── run_vqa/
│ └── save_pruned_tensors/
│
├── results/
│ ├── model_benchmarks/ # Per-model latency CSVs (TTFT, E2E, ITL, etc.)
│ └── qualitative_analysis.png
│
├── assets/
│ └── mmfa_flow.png
└── environment.yml
git clone https://github.com/botmahn/mmfa.git
cd mmfa/
# Requires conda
conda env create -f environment.yml
conda activate mmfaKey dependencies: Python 3.10, PyTorch 2.6 (CUDA 12.4), Triton 3.2, transformers, accelerate, xformers.
Hardware: Turing GPU recommended (RTX 2080Ti tested). The Triton kernels are designed specifically to work around Turing's lack of official FlashAttention support.
Profile TTFT, end-to-end latency, inter-token latency, decode throughput, and FLOP/byte counts for SOTA VLMs across a grid of image sizes and sequence lengths. Results are logged to CSV.
python profiling_scripts/gemma3_benchmark.py \
--images_dir /path/to/images \
--prompt "Describe this image in detail." \
--image_sizes "448,560,672" \
--max_seq_lens "1024,2048" \
--max_new_tokens 128 \
--dtype float16 \
--repeats 3 \
--csv gemma3_vl_grid.csvSupported models: Gemma-3, LLaVA, LLaMA-3.2, SmolVLM, Qwen2.5-VL. Swap the script name for the target model.
Benchmark the from-scratch Triton kernel against PyTorch manual attention and the official flash-attn library. Supports MHA, GQA, MQA, causal masking, and sliding-window attention.
python mmfa/turing_fa/profile_turing_fa.py \
--b 1 \ # batch size
--s 2048 \ # sequence length
--n_q 12 \ # query heads
--n_kv 12 \ # key/value heads (set < n_q for GQA, set to 1 for MQA)
--d 64 \ # per-head dimension
--w -1 \ # window size (-1 = full attention)
--causal \ # enable causal masking
--layers 32 # number of layers to simulateBuilds a toy multimodal transformer, extracts Q/K/V for interleaved text and image tokens, scores visual token importance (Triton kernel), prunes low-importance patches, and benchmarks attention before and after pruning for both PyTorch and Triton backends. Saves a visualization of kept patches.
Moderate benchmark:
python mmfa/turing_mmfa/run_mmfa_benchmark.py \
--image sample.jpg \
--keep 0.4 \
--num_heads 12 \
--head_dim 64 \
--img_size 224 \
--batch_size 1Heavy benchmark:
python mmfa/turing_mmfa/run_mmfa_benchmark.py \
--image sample.jpg \
--keep 0.5 \
--num_heads 16 \
--head_dim 128 \
--img_size 336 \
--batch_size 2Pass --ignore_overhead to measure theoretical max speedup (excluding the importance scoring step).
These scripts load real model weights (Gemma-3, LLaVA, Qwen3-VL), generate dummy multimodal hidden states matching the model's actual tensor shapes, and benchmark the model's default attention against the Triton kernel as a drop-in replacement.
# Gemma-3
python mmfa/monkey_patch/run_gemma3_benchmark.py
# LLaVA
python mmfa/monkey_patch/run_llava_benchmark.py
# Qwen3-VL
python mmfa/monkey_patch/run_qwen3vl_benchmark.pyModels are loaded from Hugging Face. Set HF_TOKEN or HUGGINGFACE_TOKEN in your environment for gated models.
Measures whether visual token pruning degrades task accuracy on VQAv2. Download the dataset from visualqa.org/download.html first.
Step 1 – Generate pruned images:
python mmfa/vqa/generate_pruned_images/gemma3_generate_pruned_images.py \
--image_dir /datasets/COCO/val2014 \
--questions_json v2_OpenEnded_mscoco_val2014_questions.json \
--out_image_dir results/viz \
--out_json_dir results/json \
--keep 0.5 \
--max_new_tokens 128 \
--max_samples 200 \
--ignore_overheadStep 2 – Run VQA evaluation:
python mmfa/vqa/run_vqa/run_gemma3_vqa.py \
--pruned_dir results/pruned \
--normal_dir /datasets/COCO/val2014 \
--questions_json data/v2_OpenEnded_mscoco_val2014_questions.json \
--annotations_json data/v2_mscoco_val2014_annotations.json \
--model_id google/gemma-3-12b-it \
--split val2014 \
--max_samples 300Benchmark CSVs for all evaluated models are in results/model_benchmarks/. Each CSV records per-run metrics including TTFT, E2E latency, ITL (mean/P50/P90), decode throughput, FLOP counts, and GPU memory usage across a grid of image sizes and context lengths.
| Model | Size |
|---|---|
| Gemma-3 | 4B, 12B |
| LLaVA-OneVision | 4B, 8B |
| LLaMA-3.2 | 11B |
| SmolVLM | 2B |
| Qwen2.5-VL | 3B, 7B |
The kernel implements the online softmax + tiled attention computation from the FlashAttention paper. Key design choices for Turing compatibility:
- Tiled computation: Q/K/V are loaded in SRAM-sized blocks (BLOCK_M=64, BLOCK_N=64) to avoid materializing the full N×N matrix in HBM.
- Online softmax: Running max (
m_i) and normalizer (l_i) are updated incrementally, enabling numerically stable attention without a second pass. - GQA/MQA support: Query head count (
H_q) is decoupled from KV head count (H_kv); grouped-query and multi-query attention are handled natively. - Causal and windowed masking: Both are applied inside the kernel loop as compile-time constants (
tl.constexpr) to avoid runtime branching.
Importance scores are computed via a fused Triton kernel (_importance_kernel) that measures how strongly each visual token's key vector attends to all text query vectors — a proxy for the visual patch's relevance to the query. The top-k patches (controlled by --keep) are retained; the rest are discarded before the attention call. This directly reduces sequence length and attention FLOP count.
