Skip to content

[ops] feat: add triton invariant attention backend#35

Merged
Luosuu merged 28 commits into
verl-project:mainfrom
Luosuu:feat/triton-invariant-attn
Jun 11, 2026
Merged

[ops] feat: add triton invariant attention backend#35
Luosuu merged 28 commits into
verl-project:mainfrom
Luosuu:feat/triton-invariant-attn

Conversation

@Luosuu

@Luosuu Luosuu commented Jun 6, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • Add a triton-invariant attention backend based on packed Triton varlen attention kernels.
  • Expose a FlashAttention-compatible flash_attn_varlen_func wrapper with packed non-paged autograd support and paged forward support.
  • Support GQA, packed-only non-paged model inputs, and register the backend for inference plus VeOmni/VeRL actor-side use.
  • Extend packed logits/logprob verification with an optional VeOmni model backend for fused LCE validation.
  • Add CUDA tests for non-paged backward, arbitrary head dims, q_len > kv_len, paged forward aliases, GQA, and prefill/decode bitwise invariance.

Tests

  • python -m py_compile vexact/batch_invariant_ops/triton_invariant_attention.py tests/batch_invariant_ops/test_triton_attention_varlen.py tests/scripts/verify_logits_vs_native_hf.py
  • ruff check vexact/batch_invariant_ops/triton_invariant_attention.py tests/batch_invariant_ops/test_triton_attention_varlen.py tests/scripts/verify_logits_vs_native_hf.py
  • ruff format --check vexact/batch_invariant_ops/triton_invariant_attention.py tests/batch_invariant_ops/test_triton_attention_varlen.py tests/scripts/verify_logits_vs_native_hf.py
  • git diff --check
  • mlx worker 950994: pytest -q submodules/open-vexact/tests/batch_invariant_ops/test_triton_attention_varlen.py -> 18 passed
  • mlx worker 950994 Qwen3-1.7B verifier, --model_backend hf --attn_impl triton-invariant --use_remove_padding -> logits matched 64/64 tokens, max abs diff 0
  • mlx worker 950994 Qwen3-1.7B verifier, --model_backend veomni --attn_impl triton-invariant --use_remove_padding --use_fused_lce -> logprobs matched 64/64 tokens, max abs diff 0, backward produced nonzero grad norm
  • mlx worker 950994 Qwen3-1.7B vexact rollout smoke using examples/getting_started/run_qwen3_1b7.sh with INFER_FA_IMPL=triton-invariant and VEOMNI_ATTN_IMPLEMENTATION=triton-invariant -> training/rollout_probs_diff_max:0.0, training/rollout_probs_diff_mean:0.0

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Triton-based batch-invariant attention implementation (triton-invariant) to support A100 (SM80) GPUs, including integration with the configuration, inferencer, and verification scripts, along with comprehensive unit tests. The code review feedback highlights several critical improvement opportunities in vexact/batch_invariant_ops/oai_fused_attn.py. These include adding an early exit condition in _paged_attn_fwd_block_kernel to optimize decode attention, dynamically computing max_seqlen_k and max_seqlen_q when they are not provided to avoid massive redundant overhead, and clamping q_local, logical_page, and page_offset to prevent negative or out-of-bounds pointer arithmetic that could lead to undefined behavior or hardware exceptions.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread vexact/batch_invariant_ops/triton_invariant_attention.py
Comment thread vexact/batch_invariant_ops/oai_fused_attn.py Outdated
Comment thread vexact/batch_invariant_ops/oai_fused_attn.py Outdated
Comment thread vexact/batch_invariant_ops/oai_fused_attn.py Outdated
@Luosuu Luosuu changed the title feat: add triton invariant attention backend [ops] feat: add triton invariant attention backend Jun 6, 2026
@Luosuu Luosuu force-pushed the feat/triton-invariant-attn branch 2 times, most recently from d9acb1c to 46dd656 Compare June 6, 2026 18:48
@Luosuu Luosuu force-pushed the feat/triton-invariant-attn branch from 6abc4ac to f074376 Compare June 10, 2026 17:05
@Luosuu Luosuu force-pushed the feat/triton-invariant-attn branch 3 times, most recently from 7248405 to f46c8b4 Compare June 10, 2026 21:53
@Luosuu Luosuu force-pushed the feat/triton-invariant-attn branch from f46c8b4 to 65c6b61 Compare June 10, 2026 22:04
@Luosuu Luosuu force-pushed the feat/triton-invariant-attn branch from 65c6b61 to 5354b4d Compare June 10, 2026 22:28
@Luosuu Luosuu force-pushed the feat/triton-invariant-attn branch from 0295e54 to 45823f4 Compare June 10, 2026 22:58
@Luosuu Luosuu merged commit 88ad2a9 into verl-project:main Jun 11, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant