diff --git a/src/fastaf/kernels/fused_lora_linear/bwd_b.py b/src/fastaf/kernels/fused_lora_linear/bwd_b.py index dc3b744..350c694 100644 --- a/src/fastaf/kernels/fused_lora_linear/bwd_b.py +++ b/src/fastaf/kernels/fused_lora_linear/bwd_b.py @@ -61,6 +61,8 @@ def _kernel_lora_fused_linear_backward_b( _total_acc = tl.zeros([BLOCK_H, BLOCK_R], dtype=tl.float32) # dL/d(out).T @ inner + # TODO: We have to do the full interm. computation for each BLOCK_S tile, causing redundant computation AND (much worse) redundant global memory loads of A and X tiles. + # Same fix as in bwd_a: compute once and then compute partial results for all values and save via atomic add. for s in range(0, tl.cdiv(dL_dout_S, BLOCK_S)): offs_s = s * BLOCK_S + tl.arange(0, BLOCK_S)