Skip to content

Hello, I would like to ask the HMLA implementation is geniune? There is no KV Cache implementation #5

@aaababaaz

Description

@aaababaaz

helm/helm/modules/hmla.py

Lines 126 to 176 in e8b4821

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], attn_impl='naive'):
"""
Forward pass for the Multi-Headed Attention Layer (MLA).
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
start_pos (int): Starting position in the sequence for caching.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
Returns:
torch.Tensor: Output tensor with the same shape as the input.
"""
bsz, seqlen, embed_dim = x.size()
end_pos = start_pos + seqlen
if self.q_lora_rank == 0:
q = self.wq(x, return_space=True)
else:
q = self.wq_b(self.q_norm(self.wq_a(x, return_space=True), space_only=True), return_space=True)
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim - 1) #space-like
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim - 1], dim=-1) #space-like
q_pe = apply_rotary_emb(q_pe, freqs_cis) #space-like
kv = self.wkv_a(x, return_space=True)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim - 1], dim=-1) #space-like
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) #space-like
q = torch.cat([q_nope, q_pe], dim=-1) #space-like
kv = self.wkv_b(self.kv_norm(kv, space_only=True), return_space=True) #space-like
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim - 1)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim - 1], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
# self.k_cache[:bsz, start_pos:end_pos] = k
# self.v_cache[:bsz, start_pos:end_pos] = v
# MLA based on hyperbolic distance
qs = self.project(q)
ks = self.project(k)
scores = 2 * self.manifold.c + 2 * self.manifold.cinner(qs.transpose(1, 2), ks.transpose(1, 2)) # [B, S, N, N]
scores = scores / self.softmax_scale + self.bias
if mask is not None:
mask = self.shape_mask(mask, bsz, self.n_local_heads, seqlen)
scores = scores.masked_fill(mask, -1e18)
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
# vs = self.project(self.v_cache[:bsz, :end_pos])
vs = self.project(v)
x = self.manifold.lorentzian_centroid(vs.transpose(1, 2), scores).transpose(1, 2) #[B, S, H, N]
x = self.wo(x.flatten(2))
return x

There, I didnt see something like down below:

Example NoPE MLA Code
class MLALayerOptimized(nn.Module):
    """
    一个纯粹的、无位置编码 (NoPE) 且完全向量化的
    Multi-head Latent Attention (MLA) 的优化实现。
    - 训练/Prefill模式: 使用 F.scaled_dot_product_attention 以获得最佳性能 (支持 Flash Attention)。
    - 推理模式: 实现论文中描述的、通过恒等变换达成的 MQA 式计算优化。
    - 支持 Prefill 和单步解码。
    - 解决了 c_norm 在不同路径下的逻辑一致性问题。
    """
    def __init__(self, d_model: int, num_heads: int, d_latent: int, d_head: int = None, output_dim: int = None, **kwargs):
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_latent = d_latent
        self.d_head = d_head if d_head is not None else d_model // num_heads

        # 确保 d_head * num_heads 不会出错
        self.inner_dim = self.num_heads * self.d_head
        # 投影矩阵
        self.W_q = nn.Linear(d_model, self.inner_dim, bias=False)
        self.W_c = nn.Linear(d_model, d_latent, bias=False)
        self.W_k = nn.Linear(d_latent, self.inner_dim, bias=False)
        self.W_v = nn.Linear(d_latent, self.inner_dim, bias=False)
        self.W_o = nn.Linear(self.inner_dim, d_model if not output_dim else output_dim, bias=False)
        self.c_norm = nn.RMSNorm(d_latent)
        self.q_norm = nn.RMSNorm(self.inner_dim)

    def forward(self, x: torch.Tensor, use_cache: bool = False, cache: torch.Tensor = None, attn_mask=None, **kwargs):
        batch_size, seq_len, _ = x.shape

        
        # ------------------------------------------------------------------
        # 路径 1: 单步解码 (Decoding) - 当且仅当 use_cache=True 且 cache 已存在
        # ------------------------------------------------------------------
 
        if use_cache and cache is not None:
            if seq_len != 1:
                raise ValueError(f"Decoding with cache requires seq_len=1. cache: {cache}")
            

            # 1. 计算当前 token 的 c 并更新 cache
            c = self.W_c(x) # x shape: (B, 1, d_model) -> c shape: (B, 1, d_latent)
            # if hasattr(self, 'c_norm'):
            c = self.c_norm(c)

            c_full = torch.cat([cache, c], dim=1) if cache is not None else c

            # 2. 计算当前 token 的 Q
            q = self.W_q(x) # (B, 1, inner_dim)
            q = self.q_norm(q) # Apply q_norm: 这一步是安全的
            q_current = q.view(batch_size, 1, self.num_heads, self.d_head)

            # 3. 核心优化:实现 q' = q @ Wk.T
            # q_current: (B, 1, H, D_h)
            # W_k.weight: (H*D_h, D_l) -> (H, D_h, D_l)
            # q_prime: (B, 1, H, D_l)
            W_k_reshaped = self.W_k.weight.view(self.num_heads, self.d_head, self.d_latent)
            q_prime = torch.einsum('bqhd,hdl->bqhl', q_current, W_k_reshaped)

            # 4. 计算注意力分数 q' @ c.T
            # q_prime: (B, 1, H, D_l)
            # c_full: (B, L, D_l)
            # attn_scores: (B, 1, H, L)
            attn_scores = torch.einsum('bqhl,bkl->bqhk', q_prime, c_full) / math.sqrt(self.d_head)

            # 5. 计算权重并对 c 进行加权求和 ("先求和")
            # attn_weights: (B, 1, H, L)
            # intermediate: (B, 1, H, D_l)
            attn_weights = F.softmax(attn_scores - attn_scores.max(dim=-1, keepdim=True)[0] , dim=-1)
            intermediate = torch.einsum('bqhk,bkl->bqhl', attn_weights, c_full)

            # 6. 用 Wv 对中间结果进行变换 ("后变换")
            # W_v.weight: (H*D_h, D_l) -> (H, D_h, D_l)
            # head_output: (B, 1, H, D_h)
            W_v_reshaped = self.W_v.weight.view(self.num_heads, self.d_head, self.d_latent)
            head_output = torch.einsum('bqhl,hdl->bqhd', intermediate, W_v_reshaped)
            # 7. 合并 head 并输出
            combined_heads = head_output.contiguous().view(batch_size, 1, -1)
            output = self.W_o(combined_heads)
            return output, c_full
        # ------------------------------------------------------------------
        # 路径 2: 并行处理 (训练 或 推理的 Prefill 阶段)
        # ------------------------------------------------------------------
        c = self.W_c(x) # (B, L, d_latent)

        q = self.W_q(x)
        q = q.view(batch_size, seq_len, self.num_heads, self.d_head)
        k = self.W_k(c).view(batch_size, seq_len, self.num_heads, self.d_head)
        v = self.W_v(c).view(batch_size, seq_len, self.num_heads, self.d_head)

        # 使用 PyTorch 2.0+ 的高效实现,is_causal=True 会自动应用因果掩码
        # (B, L, H, D) -> (B, H, L, D) for SDPA
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        # F.scaled_dot_product_attention 内部处理 softmax 和缩放
        head_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=True)
        combined_heads = head_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        output = self.W_o(combined_heads)
        cache_to_return = c if use_cache else None
        return output, cache_to_return

I mean, HMLA doesn't looks like there is any to about KV Cache, It's more looking like MHA without KV-Cache inference running path.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions