diff --git a/coreai_torch/_aten_to_core.py b/coreai_torch/_aten_to_core.py index 433e27a..621646f 100644 --- a/coreai_torch/_aten_to_core.py +++ b/coreai_torch/_aten_to_core.py @@ -2649,6 +2649,96 @@ def replace_prod_dim_int( return result if keepdim else coreai.shrink_dims(result, [axis]) +def _stable_softplus(x: Value) -> Value: + """Numerically stable softplus: max(x, 0) + log(1 + exp(-|x|)). + + Since -|x| <= 0, exp(-|x|) is always in (0, 1], so no overflow occurs + in any precision. This prevents the fp16 discontinuity at x ≈ 10.4 + on Apple Neural Engine. + + See apple/coremltools#2687 and apple/coreai-torch#21. + """ + abs_x = coreai.abs_(x) + neg_abs_x = coreai.neg(abs_x) + exp_val = coreai.exp(neg_abs_x) + one = coreai.constant(1.0, dtype=x.type.element_type) + log_val = coreai.log(coreai.broadcasting_add(one, exp_val)) + zero = coreai.constant(0.0, dtype=x.type.element_type) + max_val = coreai.maximum(x, zero) + return coreai.broadcasting_add(max_val, log_val) + + +def replace_softplus( + values_map: dict[str, Value], node: fx.Node, loc: Location +) -> Value: + """Numerically stable softplus with threshold support. + + softplus(x) = max(x, 0) + log(1 + exp(-|x|)) + + For beta != 1: softplus(x) = (1/beta) * softplus(beta * x) + For beta * x > threshold: return x directly (matching PyTorch semantics). + """ + x = _get_operand(values_map, node, 0) + beta = node.args[1] if len(node.args) > 1 else 1 + threshold = node.args[2] if len(node.args) > 2 else 20.0 + + if beta == 1: + sp = _stable_softplus(x) + threshold_val = coreai.constant(float(threshold), dtype=x.type.element_type) + cond = coreai.greater(x, threshold_val) + else: + beta_val = coreai.constant(float(beta), dtype=x.type.element_type) + beta_x = coreai.broadcasting_mul(beta_val, x) + inv_beta = coreai.constant(1.0 / float(beta), dtype=x.type.element_type) + sp = coreai.broadcasting_mul(inv_beta, _stable_softplus(beta_x)) + threshold_val = coreai.constant(float(threshold), dtype=x.type.element_type) + cond = coreai.greater(beta_x, threshold_val) + + return coreai.select(cond, x, sp) + + +def replace_mish( + values_map: dict[str, Value], node: fx.Node, loc: Location +) -> Value: + """Numerically stable mish: x * tanh(softplus_stable(x)). + + Uses _stable_softplus to avoid fp16 overflow in the softplus component. + """ + x = _get_operand(values_map, node, 0) + sp = _stable_softplus(x) + return coreai.broadcasting_mul(x, coreai.tanh(sp)) + + +def replace_logsumexp( + values_map: dict[str, Value], node: fx.Node, loc: Location +) -> Value: + """Numerically stable logsumexp via max-shift. + + logsumexp(x) = max(x) + log(sum(exp(x - max(x)))) + + Since x_i - max(x) <= 0, exp(x_i - max(x)) is in (0, 1], preventing + overflow. + """ + x = _get_operand(values_map, node, 0) + dim = node.args[1] + keepdim = node.args[2] if len(node.args) > 2 else False + + if isinstance(dim, int): + dim = [dim] + dim = [d + x.type.rank if d < 0 else d for d in dim] + + max_x = coreai.reduce_max(x, dim) + x_shifted = coreai.broadcasting_sub(x, max_x) + exp_shifted = coreai.exp(x_shifted) + sum_exp = coreai.reduce_sum(exp_shifted, dim) + log_sum_exp = coreai.log(sum_exp) + result = coreai.broadcasting_add(max_x, log_sum_exp) + + if not keepdim: + result = coreai.shrink_dims(result, dim) + return result + + def replace_log_softmax( values_map: dict[str, Value], node: fx.Node, loc: Location ) -> Value: @@ -3419,6 +3509,7 @@ def sdpa_maskless(q: Value, k: Value, v: Value) -> Value: _aten_to_core_resolver: dict[str, Callable[..., Any]] = { "_local_scalar_dense.default": replace_local_scalar_dense, "_log_softmax.default": replace_log_softmax, + "mish.default": replace_mish, "_native_batch_norm_legit_no_training.default": replace_batch_norm, "_softmax.default": replace_softmax, "_to_copy.default": replace_to_copy, @@ -3506,6 +3597,7 @@ def sdpa_maskless(q: Value, k: Value, v: Value) -> Value: "leaky_relu.default": replace_leaky_relu, "lift_fresh_copy.default": replace_lift_fresh_copy, "linalg_vector_norm.default": replace_linalg_vector_norm, + "logsumexp.default": replace_logsumexp, "log.default": replace_unary_ops, "log10.default": replace_log10, "log1p.default": replace_log1p, @@ -3567,6 +3659,7 @@ def sdpa_maskless(q: Value, k: Value, v: Value) -> Value: "select.int": replace_select_int, "sigmoid.default": replace_unary_ops, "silu.default": replace_unary_ops, + "softplus.default": replace_softplus, "sign.default": replace_sign, "sin.default": replace_unary_ops, "sinh.default": replace_unary_ops, diff --git a/coreai_torch/_decomp.py b/coreai_torch/_decomp.py index dd8c599..831f363 100644 --- a/coreai_torch/_decomp.py +++ b/coreai_torch/_decomp.py @@ -17,9 +17,12 @@ torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, torch.ops.aten.instance_norm.default, + torch.ops.aten.logsumexp.default, + torch.ops.aten.mish.default, torch.ops.aten.pixel_shuffle.default, torch.ops.aten.scaled_dot_product_attention.default, torch.ops.aten.silu.default, + torch.ops.aten.softplus.default, ] @@ -42,6 +45,12 @@ def get_decomp_table() -> dict: * ``torch.ops.aten.hardswish.default`` * ``torch.ops.aten.silu.default`` + *Numerically stable lowerings (fp16 safety):* + + * ``torch.ops.aten.logsumexp.default`` + * ``torch.ops.aten.mish.default`` + * ``torch.ops.aten.softplus.default`` + **Usage with** ``add_exported_program`` (caller handles decomposition):: import torch diff --git a/tests/ops/test_fp16_stable_ops.py b/tests/ops/test_fp16_stable_ops.py new file mode 100644 index 0000000..2208a13 --- /dev/null +++ b/tests/ops/test_fp16_stable_ops.py @@ -0,0 +1,142 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Tests for numerically stable softplus, mish, and logsumexp converters (#21). + +These tests verify that the stable decompositions produce correct results +for inputs in the fp16 overflow range (x > ~10.4), where the naive forms +(log(1+exp(x)), etc.) would overflow to inf. +""" + +import numpy as np +import pytest +import torch +import torch.nn as nn +from torch import Tensor + +from ..utils import ( + _all_dims_dynamic, + validate_numerical_output, +) + + +# --- Softplus --- + + +class SoftplusModel(nn.Module): + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.softplus(x) + + +class SoftplusBetaModel(nn.Module): + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.softplus(x, beta=2.0) + + +@pytest.mark.parametrize("dynamic", [False, True]) +@pytest.mark.parametrize( + "x", + [ + torch.tensor([0.0, 1.0, -1.0, 5.0, -5.0]), + # Inputs in the fp16 overflow range — naive softplus fails here + torch.tensor([10.0, 11.0, 15.0, 20.0, 50.0]), + torch.randn(3, 4), + ], +) +async def test_softplus(x: Tensor, dynamic: bool) -> None: + """Test softplus with inputs spanning the fp16 overflow threshold.""" + model = SoftplusModel().eval() + dynamic_shapes = {"x": _all_dims_dynamic(x)} if dynamic else None + await validate_numerical_output(model=model, x=x, dynamic_shapes=dynamic_shapes) + + +@pytest.mark.parametrize("dynamic", [False]) +@pytest.mark.parametrize( + "x", + [ + torch.tensor([0.0, 5.0, 10.0, 15.0]), + ], +) +async def test_softplus_beta(x: Tensor, dynamic: bool) -> None: + """Test softplus with beta != 1.""" + model = SoftplusBetaModel().eval() + dynamic_shapes = {"x": _all_dims_dynamic(x)} if dynamic else None + await validate_numerical_output(model=model, x=x, dynamic_shapes=dynamic_shapes) + + +# --- Mish --- + + +class MishModel(nn.Module): + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.mish(x) + + +@pytest.mark.parametrize("dynamic", [False, True]) +@pytest.mark.parametrize( + "x", + [ + torch.tensor([0.0, 1.0, -1.0, 5.0, -5.0]), + # Inputs in the fp16 overflow range — naive mish fails here + torch.tensor([10.0, 11.0, 15.0, 20.0, 50.0]), + torch.randn(3, 4), + ], +) +async def test_mish(x: Tensor, dynamic: bool) -> None: + """Test mish with inputs spanning the fp16 overflow threshold.""" + model = MishModel().eval() + dynamic_shapes = {"x": _all_dims_dynamic(x)} if dynamic else None + await validate_numerical_output(model=model, x=x, dynamic_shapes=dynamic_shapes) + + +# --- Logsumexp --- + + +class LogsumexpModel(nn.Module): + def __init__(self, dim: int, keepdim: bool = False) -> None: + super().__init__() + self.dim = dim + self.keepdim = keepdim + + def forward(self, x: Tensor) -> Tensor: + return torch.logsumexp(x, dim=self.dim, keepdim=self.keepdim) + + +@pytest.mark.parametrize("dynamic", [False, True]) +@pytest.mark.parametrize("keepdim", [False, True]) +@pytest.mark.parametrize( + "x", + [ + torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + # Inputs in the fp16 overflow range — naive logsumexp fails here + torch.tensor([[8.0, 9.0, 10.0], [11.0, 12.0, 15.0]]), + torch.randn(3, 4), + ], +) +async def test_logsumexp(x: Tensor, keepdim: bool, dynamic: bool) -> None: + """Test logsumexp with inputs spanning the fp16 overflow threshold.""" + model = LogsumexpModel(dim=1, keepdim=keepdim).eval() + dynamic_shapes = {"x": _all_dims_dynamic(x)} if dynamic else None + await validate_numerical_output(model=model, x=x, dynamic_shapes=dynamic_shapes) + + +# --- Numeric fp16 overflow proof --- + + +def test_fp16_overflow_proof() -> None: + """Verify naive fp16 softplus overflows while stable form does not.""" + x_val = np.float16(15.0) + + # Naive: log(1 + exp(fp16(15))) overflows + naive = np.float16(np.log(np.float16(1.0) + np.exp(x_val))) + assert not np.isfinite(naive), f"Naive fp16 softplus should overflow, got {naive}" + + # Stable: max(15,0) + log(1 + exp(-15)) ≈ 15.0 + stable = np.float16( + np.maximum(x_val, np.float16(0)) + + np.log(np.float16(1.0) + np.exp(-np.abs(x_val))) + ) + assert np.isfinite(stable), f"Stable fp16 softplus should not overflow, got {stable}" + assert abs(float(stable) - 15.0) < 0.5, f"Expected ~15.0, got {stable}"