Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions src/tilegym/ops/cutile/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,32 @@ def _softmax_kernel(
ct.scatter(output, (row_idx, offsets), softmax_output, check_bounds=True)


@ct.kernel
def _softmax_kernel_multi_wave_full_row_reg_cached_ldg(
output,
input,
N: ConstInt,
TILE_SIZE: ConstInt,
):
"""scheduling=multi_wave | coverage=full_row | load=ldg | caching=reg_cached

Multi-wave softmax: one block per row, no grid-stride loop.
Uses conditional bounds check — skipped when TILE_SIZE == N (power-of-2 N)."""
row_idx = ct.bid(0)
offsets = ct.arange(TILE_SIZE, dtype=ct.int32)
check_bound = TILE_SIZE != N

row = ct.gather(input, (row_idx, offsets), check_bounds=check_bound, padding_value=-math.inf)
row = ct.astype(row, ct.float32)

row_max = ct.max(row, 0, keepdims=True)
numerator = ct.exp(row - row_max)
denominator = ct.sum(numerator, 0, keepdims=True)

softmax_output = ct.astype(numerator / denominator, input.dtype)
ct.scatter(output, (row_idx, offsets), softmax_output, check_bounds=check_bound)


# TMA version with static persistent scheduling
@ct.kernel(occupancy=2)
def _softmax_kernel_tma(
Expand Down Expand Up @@ -178,6 +204,20 @@ def _launch_softmax_kernel(input, output, TILE_SIZE=1024):
)


def _launch_softmax_kernel_multi_wave_full_row_reg_cached_ldg(input, output, TILE_SIZE):
n_rows, n_cols = input.shape
input = input.contiguous()
output = output.contiguous()

grid = (n_rows, 1, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
_softmax_kernel_multi_wave_full_row_reg_cached_ldg,
(output, input, n_cols, TILE_SIZE),
)


def _launch_softmax_kernel_tma(
input,
output,
Expand Down Expand Up @@ -276,6 +316,7 @@ def forward(
x,
use_tma=False,
use_chunked=False,
use_multi_wave=False,
):
assert not (use_tma and use_chunked), "Cannot use both TMA and chunked softmax at the same time"
# TMA may be emulated on this arch; redirect to non-TMA path with a warning.
Expand All @@ -296,7 +337,9 @@ def forward(
# Create output tensor
y = torch.empty_like(x)

if use_chunked:
if use_multi_wave:
_launch_softmax_kernel_multi_wave_full_row_reg_cached_ldg(x, y, TILE_SIZE=TILE_SIZE)
elif use_chunked:
# Use chunked kernel (3-pass algorithm for large tensors)
# Cap TILE_SIZE at 8192 to enable chunking for very large n_cols
# For smaller n_cols, use next_power_of_2(n_cols) to match data size
Expand Down Expand Up @@ -324,14 +367,17 @@ def softmax(
use_tma: Whether to use TMA (Tensor Memory Accelerator) implementation.
Requires H100+ GPU (compute capability >= 9.0)
**kwargs: Additional arguments for backend-specific configurations
(e.g., use_chunked: whether to use chunked softmax implementation)
use_chunked: whether to use chunked softmax implementation
use_multi_wave: whether to use multi-wave (one block per row)

Returns:
Softmax output tensor with gradient support
"""
use_chunked = kwargs.get("use_chunked", False)
use_multi_wave = kwargs.get("use_multi_wave", False)
return _Softmax.apply(
x,
use_tma,
use_chunked,
use_multi_wave,
)
28 changes: 18 additions & 10 deletions tests/benchmark/bench_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def reference_softmax(
x: torch.Tensor,
use_tma: bool = False, # Unused - kept for interface compatibility
use_chunked: bool = None, # Unused - kept for interface compatibility
use_multi_wave: bool = False, # Unused - kept for interface compatibility
):
"""Reference implementation of softmax using PyTorch"""
return torch.nn.functional.softmax(x, dim=-1)
Expand All @@ -38,7 +39,7 @@ def get_supported_backends():
return [p for p in ALL_BACKENDS if p is not None]


def create_benchmark_config(M, use_tma=True, use_chunked=False):
def create_benchmark_config(M, use_tma=True, use_chunked=False, use_multi_wave=False):
"""Create a benchmark configuration for given parameters"""
available_backends = get_supported_backends()
if not available_backends:
Expand All @@ -54,27 +55,34 @@ def create_benchmark_config(M, use_tma=True, use_chunked=False):
line_names=list(names),
styles=list(styles),
ylabel="GB/s",
plot_name=f"softmax-performance-tma-{use_tma}-chunked-{use_chunked}-GBps",
args={"M": M, "use_tma": use_tma, "use_chunked": use_chunked},
plot_name=f"softmax-performance-tma-{use_tma}-chunked-{use_chunked}-multi-wave-{use_multi_wave}-GBps",
args={"M": M, "use_tma": use_tma, "use_chunked": use_chunked, "use_multi_wave": use_multi_wave},
)


@triton.testing.perf_report(
[
create_benchmark_config(M, use_tma, use_chunked)
create_benchmark_config(M, use_tma, use_chunked, use_multi_wave)
for M in [4096]
for use_tma, use_chunked in [
(False, False), # baseline
(True, False), # TMA softmax
(False, True), # chunked softmax
for use_tma, use_chunked, use_multi_wave in [
(False, False, False), # baseline
(True, False, False), # TMA softmax
(False, True, False), # chunked softmax
(False, False, True), # multi-wave softmax
]
]
)
def bench_softmax(M, N, backend, use_tma, use_chunked, dtype=torch.float32, device=DEVICE):
def bench_softmax(M, N, backend, use_tma, use_chunked, use_multi_wave, dtype=torch.float32, device=DEVICE):
# Create data
x = torch.randn(M, N, dtype=dtype, device=device)

fn = lambda: tilegym.ops.softmax(x, use_tma=use_tma, use_chunked=use_chunked, backend=backend)
fn = lambda: tilegym.ops.softmax(
x,
use_tma=use_tma,
use_chunked=use_chunked,
use_multi_wave=use_multi_wave,
backend=backend,
)
ref = lambda: reference_softmax(x)
torch.testing.assert_close(fn(), ref(), atol=1e-2, rtol=1e-2)

Expand Down
18 changes: 12 additions & 6 deletions tests/ops/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,17 @@ def reference(x):
],
)
@pytest.mark.parametrize("backend", _backends)
@pytest.mark.parametrize("use_tma", [True, False], ids=["use_tma=True", "use_tma=False"])
@pytest.mark.parametrize("use_chunked", [True, False], ids=["use_chunked=True", "use_chunked=False"])
def test_op(self, m, n, dtype, arch, backend, use_tma, use_chunked):
if use_chunked and use_tma:
pytest.skip("Cannot use both TMA and chunked softmax at the same time")
@pytest.mark.parametrize(
"use_tma,use_chunked,use_multi_wave",
[
(False, False, False),
(True, False, False),
(False, True, False),
(False, False, True),
],
ids=["baseline", "use_tma", "use_chunked", "use_multi_wave"],
)
def test_op(self, m, n, dtype, arch, backend, use_tma, use_chunked, use_multi_wave):
if tilegym.is_backend_available(backend):
tilegym.set_backend(backend)
self.setUp()
Expand All @@ -64,7 +70,7 @@ def test_op(self, m, n, dtype, arch, backend, use_tma, use_chunked):
tilegym.ops.softmax,
self.reference,
{"x": x},
extra_test_kwargs={"use_tma": use_tma, "use_chunked": use_chunked},
extra_test_kwargs={"use_tma": use_tma, "use_chunked": use_chunked, "use_multi_wave": use_multi_wave},
gradient=dout,
rtol=rtol,
atol=atol,
Expand Down
Loading