Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ conda activate go1
> ⚡️ Our environment has been tested with **CUDA 12.4**.
```bash
pip install -e .
```

Flash Attention is loaded through the [`kernels`](https://github.com/huggingface/kernels) library (installed by the command above), which fetches pre-built Flash Attention 2 binaries on first use — no local compilation required. The binaries are bit-exact with the upstream source build.

If your machine is not covered by the pre-built binaries, the code automatically falls back to a source build of `flash-attn`, which you can install with:
```bash
pip install --no-build-isolation flash-attn==2.4.2
```

Expand Down
87 changes: 87 additions & 0 deletions go1/internvl/model/flash_attn_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Loader for Flash Attention symbols used across the GO-1 model code.

Building ``flash-attn`` from source is slow. To avoid that, we first try to load
pre-built Flash Attention 2 binaries served through the Hugging Face ``kernels``
library. When pre-built binaries are not available for the current platform
(unsupported GPU/arch, no network, ...), we transparently fall back to a source
build of the ``flash-attn`` package, so the dependency stays optional.

See https://github.com/OpenDriveLab/AgiBot-World/issues/158 for context.
"""

import logging
from functools import lru_cache

logger = logging.getLogger(__name__)

# Kernel served by the `kernels` library that mirrors the `flash-attn` v2 API.
_FLASH_ATTN_KERNEL = "kernels-community/flash-attn2"
_FLASH_ATTN_KERNEL_VERSION = 1


@lru_cache(maxsize=None)
def load_flash_attn():
"""Return a dict of Flash Attention symbols, or an empty dict if unavailable.

The returned dict exposes the same callables the model code relies on:
``flash_attn_func``, ``flash_attn_varlen_func``,
``flash_attn_varlen_qkvpacked_func``, ``pad_input``, ``unpad_input`` and
``index_first_axis``.

Resolution order:
1. Pre-built FA2 binaries via the ``kernels`` library (no compilation).
2. A source build of the ``flash-attn`` package (the original behaviour).
"""
symbols = _load_from_kernels()
if symbols is not None:
return symbols

symbols = _load_from_source()
if symbols is not None:
return symbols

return {}


def _load_from_kernels():
try:
from kernels import get_kernel

module = get_kernel(_FLASH_ATTN_KERNEL, version=_FLASH_ATTN_KERNEL_VERSION)
interface = module.flash_attention_interface
bert_padding = module.bert_padding
logger.info("Loaded pre-built Flash Attention 2 binaries via `kernels` (%s).", _FLASH_ATTN_KERNEL)
return {
"flash_attn_func": interface.flash_attn_func,
"flash_attn_varlen_func": interface.flash_attn_varlen_func,
"flash_attn_varlen_qkvpacked_func": interface.flash_attn_varlen_qkvpacked_func,
"pad_input": bert_padding.pad_input,
"unpad_input": bert_padding.unpad_input,
"index_first_axis": bert_padding.index_first_axis,
}
except Exception as exc: # noqa: BLE001 - any failure should trigger the source fallback
logger.info(
"Pre-built Flash Attention via `kernels` unavailable (%s); falling back to a source build.", exc
)
return None


def _load_from_source():
try:
from flash_attn import (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_varlen_qkvpacked_func,
)
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input

return {
"flash_attn_func": flash_attn_func,
"flash_attn_varlen_func": flash_attn_varlen_func,
"flash_attn_varlen_qkvpacked_func": flash_attn_varlen_qkvpacked_func,
"pad_input": pad_input,
"unpad_input": unpad_input,
"index_first_axis": index_first_axis,
}
except ImportError:
return None
23 changes: 10 additions & 13 deletions go1/internvl/model/go1/modeling_action_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,18 @@

from .configuration_action_expert import ActionExpertConfig

from go1.internvl.model.flash_attn_utils import load_flash_attn

flash_attn_func, flash_attn_varlen_func = None, None
pad_input, index_first_axis, unpad_input = None, None, None
try:
from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
from flash_attn.bert_padding import (
index_first_axis as _index_first_axis,
pad_input as _pad_input,
unpad_input as _unpad_input,
)

flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
has_flash_attn = True
except:
has_flash_attn = False
_flash_attn_symbols = load_flash_attn()
has_flash_attn = bool(_flash_attn_symbols)
if has_flash_attn:
flash_attn_func = _flash_attn_symbols["flash_attn_func"]
flash_attn_varlen_func = _flash_attn_symbols["flash_attn_varlen_func"]
pad_input = _flash_attn_symbols["pad_input"]
index_first_axis = _flash_attn_symbols["index_first_axis"]
unpad_input = _flash_attn_symbols["unpad_input"]
import logging

from go1.internvl.model.internlm2.modeling_internlm2 import (
Expand Down
23 changes: 10 additions & 13 deletions go1/internvl/model/go1/modeling_internlm2_go1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,18 @@

logger = logging.get_logger(__name__)

from go1.internvl.model.flash_attn_utils import load_flash_attn

flash_attn_func, flash_attn_varlen_func = None, None
pad_input, index_first_axis, unpad_input = None, None, None
try:
from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
from flash_attn.bert_padding import (
index_first_axis as _index_first_axis,
pad_input as _pad_input,
unpad_input as _unpad_input,
)

flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
has_flash_attn = True
except:
has_flash_attn = False
_flash_attn_symbols = load_flash_attn()
has_flash_attn = bool(_flash_attn_symbols)
if has_flash_attn:
flash_attn_func = _flash_attn_symbols["flash_attn_func"]
flash_attn_varlen_func = _flash_attn_symbols["flash_attn_varlen_func"]
pad_input = _flash_attn_symbols["pad_input"]
index_first_axis = _flash_attn_symbols["index_first_axis"]
unpad_input = _flash_attn_symbols["unpad_input"]


class InternLM2AttentionGO1(InternLM2Attention):
Expand Down
41 changes: 17 additions & 24 deletions go1/internvl/model/internlm2/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,38 +51,31 @@

_CONFIG_FOR_DOC = "InternLM2Config"

from go1.internvl.model.flash_attn_utils import load_flash_attn

flash_attn_func, flash_attn_varlen_func = None, None
pad_input, index_first_axis, unpad_input = None, None, None
try:
from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
from flash_attn.bert_padding import (
index_first_axis as _index_first_axis,
pad_input as _pad_input,
unpad_input as _unpad_input,
)

flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
has_flash_attn = True
except:
has_flash_attn = False
_flash_attn_symbols = load_flash_attn()
has_flash_attn = bool(_flash_attn_symbols)
if has_flash_attn:
flash_attn_func = _flash_attn_symbols["flash_attn_func"]
flash_attn_varlen_func = _flash_attn_symbols["flash_attn_varlen_func"]
pad_input = _flash_attn_symbols["pad_input"]
index_first_axis = _flash_attn_symbols["index_first_axis"]
unpad_input = _flash_attn_symbols["unpad_input"]


def _import_flash_attn():
global flash_attn_func, flash_attn_varlen_func
global pad_input, index_first_axis, unpad_input
try:
from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
from flash_attn.bert_padding import (
index_first_axis as _index_first_axis,
pad_input as _pad_input,
unpad_input as _unpad_input,
)

flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
except ImportError:
symbols = load_flash_attn()
if not symbols:
raise ImportError("flash_attn is not installed.")
flash_attn_func = symbols["flash_attn_func"]
flash_attn_varlen_func = symbols["flash_attn_varlen_func"]
pad_input = symbols["pad_input"]
index_first_axis = symbols["index_first_axis"]
unpad_input = symbols["unpad_input"]


# Copied from transformers.models.llama.modeling_llama._get_unpad_data
Expand Down
12 changes: 7 additions & 5 deletions go1/internvl/model/internvl_chat/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import torch.nn as nn
from einops import rearrange

try: # v1
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except: # v2
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from go1.internvl.model.flash_attn_utils import load_flash_attn

from flash_attn.bert_padding import pad_input, unpad_input
_flash_attn_symbols = load_flash_attn()
if not _flash_attn_symbols:
raise ImportError("flash_attn is not installed.")
flash_attn_unpadded_qkvpacked_func = _flash_attn_symbols["flash_attn_varlen_qkvpacked_func"]
pad_input = _flash_attn_symbols["pad_input"]
unpad_input = _flash_attn_symbols["unpad_input"]


class FlashAttention(nn.Module):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"huggingface_hub",
"imageio",
"json_numpy",
"kernels",
"lerobot @ git+https://github.com/huggingface/lerobot.git@2b71789e15c35418b1ccecbceb81f4a598bfd883",
"matplotlib",
"numpy<2",
Expand Down