diff --git a/src/tilegym/ops/cutile/softmax.py b/src/tilegym/ops/cutile/softmax.py index 4593bbbd..1b04ce5b 100644 --- a/src/tilegym/ops/cutile/softmax.py +++ b/src/tilegym/ops/cutile/softmax.py @@ -50,6 +50,32 @@ def _softmax_kernel( ct.scatter(output, (row_idx, offsets), softmax_output, check_bounds=True) +@ct.kernel +def _softmax_kernel_multi_wave_full_row_reg_cached_ldg( + output, + input, + N: ConstInt, + TILE_SIZE: ConstInt, +): + """scheduling=multi_wave | coverage=full_row | load=ldg | caching=reg_cached + + Multi-wave softmax: one block per row, no grid-stride loop. + Uses conditional bounds check — skipped when TILE_SIZE == N (power-of-2 N).""" + row_idx = ct.bid(0) + offsets = ct.arange(TILE_SIZE, dtype=ct.int32) + check_bound = TILE_SIZE != N + + row = ct.gather(input, (row_idx, offsets), check_bounds=check_bound, padding_value=-math.inf) + row = ct.astype(row, ct.float32) + + row_max = ct.max(row, 0, keepdims=True) + numerator = ct.exp(row - row_max) + denominator = ct.sum(numerator, 0, keepdims=True) + + softmax_output = ct.astype(numerator / denominator, input.dtype) + ct.scatter(output, (row_idx, offsets), softmax_output, check_bounds=check_bound) + + # TMA version with static persistent scheduling @ct.kernel(occupancy=2) def _softmax_kernel_tma( @@ -178,6 +204,20 @@ def _launch_softmax_kernel(input, output, TILE_SIZE=1024): ) +def _launch_softmax_kernel_multi_wave_full_row_reg_cached_ldg(input, output, TILE_SIZE): + n_rows, n_cols = input.shape + input = input.contiguous() + output = output.contiguous() + + grid = (n_rows, 1, 1) + ct.launch( + torch.cuda.current_stream(), + grid, + _softmax_kernel_multi_wave_full_row_reg_cached_ldg, + (output, input, n_cols, TILE_SIZE), + ) + + def _launch_softmax_kernel_tma( input, output, @@ -276,6 +316,7 @@ def forward( x, use_tma=False, use_chunked=False, + use_multi_wave=False, ): assert not (use_tma and use_chunked), "Cannot use both TMA and chunked softmax at the same time" # TMA may be emulated on this arch; redirect to non-TMA path with a warning. @@ -296,7 +337,9 @@ def forward( # Create output tensor y = torch.empty_like(x) - if use_chunked: + if use_multi_wave: + _launch_softmax_kernel_multi_wave_full_row_reg_cached_ldg(x, y, TILE_SIZE=TILE_SIZE) + elif use_chunked: # Use chunked kernel (3-pass algorithm for large tensors) # Cap TILE_SIZE at 8192 to enable chunking for very large n_cols # For smaller n_cols, use next_power_of_2(n_cols) to match data size @@ -324,14 +367,17 @@ def softmax( use_tma: Whether to use TMA (Tensor Memory Accelerator) implementation. Requires H100+ GPU (compute capability >= 9.0) **kwargs: Additional arguments for backend-specific configurations - (e.g., use_chunked: whether to use chunked softmax implementation) + use_chunked: whether to use chunked softmax implementation + use_multi_wave: whether to use multi-wave (one block per row) Returns: Softmax output tensor with gradient support """ use_chunked = kwargs.get("use_chunked", False) + use_multi_wave = kwargs.get("use_multi_wave", False) return _Softmax.apply( x, use_tma, use_chunked, + use_multi_wave, ) diff --git a/tests/benchmark/bench_softmax.py b/tests/benchmark/bench_softmax.py index 25b0b6f8..50815bab 100644 --- a/tests/benchmark/bench_softmax.py +++ b/tests/benchmark/bench_softmax.py @@ -17,6 +17,7 @@ def reference_softmax( x: torch.Tensor, use_tma: bool = False, # Unused - kept for interface compatibility use_chunked: bool = None, # Unused - kept for interface compatibility + use_multi_wave: bool = False, # Unused - kept for interface compatibility ): """Reference implementation of softmax using PyTorch""" return torch.nn.functional.softmax(x, dim=-1) @@ -38,7 +39,7 @@ def get_supported_backends(): return [p for p in ALL_BACKENDS if p is not None] -def create_benchmark_config(M, use_tma=True, use_chunked=False): +def create_benchmark_config(M, use_tma=True, use_chunked=False, use_multi_wave=False): """Create a benchmark configuration for given parameters""" available_backends = get_supported_backends() if not available_backends: @@ -54,27 +55,34 @@ def create_benchmark_config(M, use_tma=True, use_chunked=False): line_names=list(names), styles=list(styles), ylabel="GB/s", - plot_name=f"softmax-performance-tma-{use_tma}-chunked-{use_chunked}-GBps", - args={"M": M, "use_tma": use_tma, "use_chunked": use_chunked}, + plot_name=f"softmax-performance-tma-{use_tma}-chunked-{use_chunked}-multi-wave-{use_multi_wave}-GBps", + args={"M": M, "use_tma": use_tma, "use_chunked": use_chunked, "use_multi_wave": use_multi_wave}, ) @triton.testing.perf_report( [ - create_benchmark_config(M, use_tma, use_chunked) + create_benchmark_config(M, use_tma, use_chunked, use_multi_wave) for M in [4096] - for use_tma, use_chunked in [ - (False, False), # baseline - (True, False), # TMA softmax - (False, True), # chunked softmax + for use_tma, use_chunked, use_multi_wave in [ + (False, False, False), # baseline + (True, False, False), # TMA softmax + (False, True, False), # chunked softmax + (False, False, True), # multi-wave softmax ] ] ) -def bench_softmax(M, N, backend, use_tma, use_chunked, dtype=torch.float32, device=DEVICE): +def bench_softmax(M, N, backend, use_tma, use_chunked, use_multi_wave, dtype=torch.float32, device=DEVICE): # Create data x = torch.randn(M, N, dtype=dtype, device=device) - fn = lambda: tilegym.ops.softmax(x, use_tma=use_tma, use_chunked=use_chunked, backend=backend) + fn = lambda: tilegym.ops.softmax( + x, + use_tma=use_tma, + use_chunked=use_chunked, + use_multi_wave=use_multi_wave, + backend=backend, + ) ref = lambda: reference_softmax(x) torch.testing.assert_close(fn(), ref(), atol=1e-2, rtol=1e-2) diff --git a/tests/ops/test_softmax.py b/tests/ops/test_softmax.py index 89ff1644..c0cd74fa 100644 --- a/tests/ops/test_softmax.py +++ b/tests/ops/test_softmax.py @@ -34,11 +34,17 @@ def reference(x): ], ) @pytest.mark.parametrize("backend", _backends) - @pytest.mark.parametrize("use_tma", [True, False], ids=["use_tma=True", "use_tma=False"]) - @pytest.mark.parametrize("use_chunked", [True, False], ids=["use_chunked=True", "use_chunked=False"]) - def test_op(self, m, n, dtype, arch, backend, use_tma, use_chunked): - if use_chunked and use_tma: - pytest.skip("Cannot use both TMA and chunked softmax at the same time") + @pytest.mark.parametrize( + "use_tma,use_chunked,use_multi_wave", + [ + (False, False, False), + (True, False, False), + (False, True, False), + (False, False, True), + ], + ids=["baseline", "use_tma", "use_chunked", "use_multi_wave"], + ) + def test_op(self, m, n, dtype, arch, backend, use_tma, use_chunked, use_multi_wave): if tilegym.is_backend_available(backend): tilegym.set_backend(backend) self.setUp() @@ -64,7 +70,7 @@ def test_op(self, m, n, dtype, arch, backend, use_tma, use_chunked): tilegym.ops.softmax, self.reference, {"x": x}, - extra_test_kwargs={"use_tma": use_tma, "use_chunked": use_chunked}, + extra_test_kwargs={"use_tma": use_tma, "use_chunked": use_chunked, "use_multi_wave": use_multi_wave}, gradient=dout, rtol=rtol, atol=atol,