From 27517761b5cf4081448318f15ec51344a5b1e462 Mon Sep 17 00:00:00 2001 From: Josu Date: Sat, 9 May 2026 21:46:52 +0000 Subject: [PATCH] Add triangle multiplicative update problem --- .../triangle-multiplicative-update/def.py | 203 ++++++++++++++++++ .../triangle-multiplicative-update/problem.md | 67 ++++++ .../triangle-multiplicative-update/torch.py | 12 ++ 3 files changed, 282 insertions(+) create mode 100644 problems/triangle-multiplicative-update/def.py create mode 100644 problems/triangle-multiplicative-update/problem.md create mode 100644 problems/triangle-multiplicative-update/torch.py diff --git a/problems/triangle-multiplicative-update/def.py b/problems/triangle-multiplicative-update/def.py new file mode 100644 index 0000000..a3fafc2 --- /dev/null +++ b/problems/triangle-multiplicative-update/def.py @@ -0,0 +1,203 @@ +import math + +import torch +from typing import Any, Dict, List, Tuple + +from problem import Problem + + +class triangle_multiplicative_update(Problem): + """Masked outgoing triangle multiplicative update.""" + + is_exact = False + + parameters = [ + {"name": "left", "type": "float", "pointer": True, "const": True}, + {"name": "right", "type": "float", "pointer": True, "const": True}, + {"name": "mask", "type": "float", "pointer": True, "const": True}, + {"name": "output", "type": "float", "pointer": True, "const": False}, + {"name": "B", "type": "size_t", "pointer": False, "const": False}, + {"name": "N", "type": "size_t", "pointer": False, "const": False}, + {"name": "H", "type": "size_t", "pointer": False, "const": False}, + ] + + def __init__(self): + super().__init__(name="triangle-multiplicative-update") + + def reference_solution( + self, + left: torch.Tensor, + right: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + """ + Reference masked outgoing triangle multiplicative update. + + Args: + left: Tensor of shape (B, N, N, H), indexed as left[b, i, k, h]. + right: Tensor of shape (B, N, N, H), indexed as right[b, j, k, h]. + mask: Tensor of shape (B, N, N), with 0/1 entries applied to both paths. + + Returns: + Tensor of shape (B, N, N, H), where + output[b, i, j, h] = sum_k left[b, i, k, h] * mask[b, i, k] + * right[b, j, k, h] * mask[b, j, k]. + """ + with torch.no_grad(), torch.autocast("cuda", enabled=False, dtype=left.dtype): + mask_expanded = mask.unsqueeze(-1) + left_masked = left * mask_expanded + right_masked = right * mask_expanded + + b, n, _, h = left.shape + left_bh = left_masked.permute(0, 3, 1, 2).reshape(b * h, n, n) + right_bh_t = right_masked.permute(0, 3, 2, 1).reshape(b * h, n, n) + out_bh = torch.bmm(left_bh, right_bh_t) + return out_bh.reshape(b, h, n, n).permute(0, 2, 3, 1).contiguous() + + def _make_case( + self, + *, + b: int, + n: int, + h: int, + distribution: str, + mask_mode: str, + ) -> Dict[str, Any]: + name = f"B={b}, N={n}, H={h}, {distribution}, {mask_mode}" + seed = Problem.get_seed(f"{self.name}_{name}") + dtype = self.param_dtype(0) + + def create_inputs( + b: int = b, + n: int = n, + h: int = h, + distribution: str = distribution, + mask_mode: str = mask_mode, + seed: int = seed, + dtype: torch.dtype = dtype, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + generator = torch.Generator(device="cuda").manual_seed(seed) + shape = (b, n, n, h) + + if distribution == "cauchy": + # Inverse-CDF sampling keeps the case deterministic under the + # per-test CUDA generator. + left_u = torch.rand(shape, device="cuda", dtype=torch.float32, generator=generator) + right_u = torch.rand(shape, device="cuda", dtype=torch.float32, generator=generator) + # Clamp and scale to keep the reference finite while preserving + # heavy-tailed inputs. + left = (2.0 * torch.tan(math.pi * (left_u - 0.5))).clamp(-16.0, 16.0).to(dtype) * 0.125 + right = (2.0 * torch.tan(math.pi * (right_u - 0.5))).clamp(-16.0, 16.0).to(dtype) * 0.125 + else: + left = torch.randn(shape, device="cuda", dtype=dtype, generator=generator) * 0.5 + right = torch.randn(shape, device="cuda", dtype=dtype, generator=generator) * 0.5 + + if mask_mode == "nomask": + mask = torch.ones((b, n, n), device="cuda", dtype=dtype) + else: + mask = torch.randint( + 0, + 2, + (b, n, n), + device="cuda", + dtype=torch.int32, + generator=generator, + ).to(dtype) + + return left.contiguous(), right.contiguous(), mask.contiguous() + + return { + "name": name, + "B": b, + "N": n, + "H": h, + "distribution": distribution, + "mask_mode": mask_mode, + "create_inputs": create_inputs, + } + + def generate_test_cases(self) -> List[Dict[str, Any]]: + return [ + self._make_case(b=1, n=128, h=64, distribution="normal", mask_mode="nomask"), + self._make_case(b=1, n=256, h=128, distribution="normal", mask_mode="random-mask"), + self._make_case(b=2, n=256, h=128, distribution="cauchy", mask_mode="random-mask"), + self._make_case(b=1, n=512, h=128, distribution="normal", mask_mode="nomask"), + self._make_case(b=1, n=1024, h=128, distribution="normal", mask_mode="random-mask"), + ] + + def generate_sample(self) -> Dict[str, Any]: + dtype = self.param_dtype(0) + b, n, h = 1, 4, 2 + + return { + "name": "Sample B=1, N=4, H=2", + "B": b, + "N": n, + "H": h, + "distribution": "deterministic", + "mask_mode": "random-mask", + "create_inputs": lambda: ( + torch.arange(1, b * n * n * h + 1, device="cuda", dtype=dtype).view(b, n, n, h) / 16.0, + torch.arange(1, b * n * n * h + 1, device="cuda", dtype=dtype).flip(0).view(b, n, n, h) / 32.0, + torch.tensor( + [[[1, 1, 0, 1], [1, 0, 1, 1], [1, 1, 1, 0], [0, 1, 1, 1]]], + device="cuda", + dtype=dtype, + ), + ), + } + + def verify_result( + self, + expected_output: torch.Tensor, + actual_output: torch.Tensor, + ) -> Tuple[bool, Dict[str, Any]]: + is_close = torch.allclose(actual_output, expected_output, rtol=2e-2, atol=2e-2) + + debug_info: Dict[str, Any] = {} + if not is_close: + diff = actual_output.float() - expected_output.float() + abs_diff = torch.abs(diff) + flat_diff = abs_diff.flatten() + _, top_indices = torch.topk(flat_diff, min(5, flat_diff.numel())) + shape = expected_output.shape + + sample_differences = {} + for flat_idx in top_indices: + idx = [] + value = flat_idx.item() + for dim_size in reversed(shape): + idx.insert(0, value % dim_size) + value //= dim_size + idx_tuple = tuple(idx) + sample_differences[str(idx_tuple)] = { + "expected": expected_output[idx_tuple].item(), + "actual": actual_output[idx_tuple].item(), + "diff": diff[idx_tuple].item(), + } + + debug_info = { + "max_difference": flat_diff.max().item(), + "mean_difference": abs_diff.mean().item(), + "sample_differences": sample_differences, + } + + return is_close, debug_info + + def get_flops(self, test_case: Dict[str, Any]) -> int: + """ + FLOPs for the masked outgoing triangle update. + + For each output element output[b, i, j, h], the reduction over k uses + one multiply for left*right and one add into the accumulator. There are + B*N*N*H output elements and N reduction terms, so the contraction costs + approximately 2*B*N^3*H FLOPs. The mask multiplications add + 2*B*N^2*H FLOPs, one for each path element. + """ + b = test_case["B"] + n = test_case["N"] + h = test_case["H"] + return int((2 * b * n * n * n * h) + (2 * b * n * n * h)) + + def get_extra_params(self, test_case: Dict[str, Any]) -> List[Any]: + return [test_case["B"], test_case["N"], test_case["H"]] diff --git a/problems/triangle-multiplicative-update/problem.md b/problems/triangle-multiplicative-update/problem.md new file mode 100644 index 0000000..4dee3a2 --- /dev/null +++ b/problems/triangle-multiplicative-update/problem.md @@ -0,0 +1,67 @@ +--- +slug: "triangle-multiplicative-update" +title: "Triangle Multiplicative Update" +difficulty: "HARD" +author: "josusanmartin" +tags: ["alphafold", "matmul", "tensor-contraction", "bio-ml"] +--- + +Implement the core outgoing Triangle Multiplicative Update contraction used by AlphaFold-style pair representations. + +This problem is based on GPU MODE's [(Mini) Competition #3: AlphaFold's Triangle Multiplicative Update](https://stormy-sailor-96a.notion.site/GPU-MODE-Mini-Competition-3-AlphaFold-s-Triangle-Multiplicative-Update-207221cc2ffa8034b3eddff1d898dc14). The Tensara version isolates the cubic contraction from the full AlphaFold block so CUDA submissions can focus on the memory-layout and tensor-contraction challenge. + +Given two transformed pair tensors: + +$$ +L, R \in \mathbb{R}^{B \times N \times N \times H} +$$ + +and a pair mask: + +$$ +M \in \mathbb{R}^{B \times N \times N}, +$$ + +compute: + +$$ +O[b, i, j, h] = +\sum_{k=0}^{N-1} +\left(L[b, i, k, h] \cdot M[b, i, k]\right) +\left(R[b, j, k, h] \cdot M[b, j, k]\right). +$$ + +This is equivalent to the outgoing TriMul einsum: + +```python +output = einsum("bik h,bjk h->bij h", left * mask[..., None], right * mask[..., None]) +``` + +where spacing is added only for readability. + +## Input + +- Tensor `left` of shape $B \times N \times N \times H$ +- Tensor `right` of shape $B \times N \times N \times H$ +- Tensor `mask` of shape $B \times N \times N$, with entries equal to `0` or `1` + +## Output + +- Tensor `output` of shape $B \times N \times N \times H$ + +## Notes + +- All tensors are stored in row-major order. +- The last dimension `H` is contiguous in memory. +- The contraction reduces over the middle sequence index `k`. +- The mask is always provided. No-mask cases use an all-ones mask. +- Projection, sigmoid gates, and layer normalization from the full GPU MODE problem are intentionally excluded here so the Tensara task stays focused on the hard memory-layout and tensor-contraction part. +- Using Tensor Cores directly is non-trivial because the matrices for each hidden channel are strided by `H` in the input layout. + +## Test Case Sizes + +- B=1, N=128, H=64, normal input, no mask +- B=1, N=256, H=128, normal input, random mask +- B=2, N=256, H=128, clamped Cauchy input, random mask +- B=1, N=512, H=128, normal input, no mask +- B=1, N=1024, H=128, normal input, random mask diff --git a/problems/triangle-multiplicative-update/torch.py b/problems/triangle-multiplicative-update/torch.py new file mode 100644 index 0000000..4ba209a --- /dev/null +++ b/problems/triangle-multiplicative-update/torch.py @@ -0,0 +1,12 @@ +import torch + + +def solution(left, right, mask, output, B, N, H): + mask_expanded = mask.unsqueeze(-1) + left_masked = left * mask_expanded + right_masked = right * mask_expanded + + left_bh = left_masked.permute(0, 3, 1, 2).reshape(B * H, N, N) + right_bh_t = right_masked.permute(0, 3, 2, 1).reshape(B * H, N, N) + out_bh = torch.bmm(left_bh, right_bh_t) + output[:] = out_bh.reshape(B, H, N, N).permute(0, 2, 3, 1)