From 743e1d3b9d11d8996338366a8374cf9cb172641b Mon Sep 17 00:00:00 2001 From: Timo Imhof Date: Sun, 12 Apr 2026 08:49:19 +0000 Subject: [PATCH] problem statement --- src/fastaf/kernels/fused_lora_linear/bwd_b.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/fastaf/kernels/fused_lora_linear/bwd_b.py b/src/fastaf/kernels/fused_lora_linear/bwd_b.py index dc3b744..350c694 100644 --- a/src/fastaf/kernels/fused_lora_linear/bwd_b.py +++ b/src/fastaf/kernels/fused_lora_linear/bwd_b.py @@ -61,6 +61,8 @@ def _kernel_lora_fused_linear_backward_b( _total_acc = tl.zeros([BLOCK_H, BLOCK_R], dtype=tl.float32) # dL/d(out).T @ inner + # TODO: We have to do the full interm. computation for each BLOCK_S tile, causing redundant computation AND (much worse) redundant global memory loads of A and X tiles. + # Same fix as in bwd_a: compute once and then compute partial results for all values and save via atomic add. for s in range(0, tl.cdiv(dL_dout_S, BLOCK_S)): offs_s = s * BLOCK_S + tl.arange(0, BLOCK_S)