Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions coreai_torch/_aten_to_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2649,6 +2649,96 @@ def replace_prod_dim_int(
return result if keepdim else coreai.shrink_dims(result, [axis])


def _stable_softplus(x: Value) -> Value:
"""Numerically stable softplus: max(x, 0) + log(1 + exp(-|x|)).

Since -|x| <= 0, exp(-|x|) is always in (0, 1], so no overflow occurs
in any precision. This prevents the fp16 discontinuity at x ≈ 10.4
on Apple Neural Engine.

See apple/coremltools#2687 and apple/coreai-torch#21.
"""
abs_x = coreai.abs_(x)
neg_abs_x = coreai.neg(abs_x)
exp_val = coreai.exp(neg_abs_x)
one = coreai.constant(1.0, dtype=x.type.element_type)
log_val = coreai.log(coreai.broadcasting_add(one, exp_val))
zero = coreai.constant(0.0, dtype=x.type.element_type)
max_val = coreai.maximum(x, zero)
return coreai.broadcasting_add(max_val, log_val)


def replace_softplus(
values_map: dict[str, Value], node: fx.Node, loc: Location
) -> Value:
"""Numerically stable softplus with threshold support.

softplus(x) = max(x, 0) + log(1 + exp(-|x|))

For beta != 1: softplus(x) = (1/beta) * softplus(beta * x)
For beta * x > threshold: return x directly (matching PyTorch semantics).
"""
x = _get_operand(values_map, node, 0)
beta = node.args[1] if len(node.args) > 1 else 1
threshold = node.args[2] if len(node.args) > 2 else 20.0

if beta == 1:
sp = _stable_softplus(x)
threshold_val = coreai.constant(float(threshold), dtype=x.type.element_type)
cond = coreai.greater(x, threshold_val)
else:
beta_val = coreai.constant(float(beta), dtype=x.type.element_type)
beta_x = coreai.broadcasting_mul(beta_val, x)
inv_beta = coreai.constant(1.0 / float(beta), dtype=x.type.element_type)
sp = coreai.broadcasting_mul(inv_beta, _stable_softplus(beta_x))
threshold_val = coreai.constant(float(threshold), dtype=x.type.element_type)
cond = coreai.greater(beta_x, threshold_val)

return coreai.select(cond, x, sp)


def replace_mish(
values_map: dict[str, Value], node: fx.Node, loc: Location
) -> Value:
"""Numerically stable mish: x * tanh(softplus_stable(x)).

Uses _stable_softplus to avoid fp16 overflow in the softplus component.
"""
x = _get_operand(values_map, node, 0)
sp = _stable_softplus(x)
return coreai.broadcasting_mul(x, coreai.tanh(sp))


def replace_logsumexp(
values_map: dict[str, Value], node: fx.Node, loc: Location
) -> Value:
"""Numerically stable logsumexp via max-shift.

logsumexp(x) = max(x) + log(sum(exp(x - max(x))))

Since x_i - max(x) <= 0, exp(x_i - max(x)) is in (0, 1], preventing
overflow.
"""
x = _get_operand(values_map, node, 0)
dim = node.args[1]
keepdim = node.args[2] if len(node.args) > 2 else False

if isinstance(dim, int):
dim = [dim]
dim = [d + x.type.rank if d < 0 else d for d in dim]

max_x = coreai.reduce_max(x, dim)
x_shifted = coreai.broadcasting_sub(x, max_x)
exp_shifted = coreai.exp(x_shifted)
sum_exp = coreai.reduce_sum(exp_shifted, dim)
log_sum_exp = coreai.log(sum_exp)
result = coreai.broadcasting_add(max_x, log_sum_exp)

if not keepdim:
result = coreai.shrink_dims(result, dim)
return result


def replace_log_softmax(
values_map: dict[str, Value], node: fx.Node, loc: Location
) -> Value:
Expand Down Expand Up @@ -3419,6 +3509,7 @@ def sdpa_maskless(q: Value, k: Value, v: Value) -> Value:
_aten_to_core_resolver: dict[str, Callable[..., Any]] = {
"_local_scalar_dense.default": replace_local_scalar_dense,
"_log_softmax.default": replace_log_softmax,
"mish.default": replace_mish,
"_native_batch_norm_legit_no_training.default": replace_batch_norm,
"_softmax.default": replace_softmax,
"_to_copy.default": replace_to_copy,
Expand Down Expand Up @@ -3506,6 +3597,7 @@ def sdpa_maskless(q: Value, k: Value, v: Value) -> Value:
"leaky_relu.default": replace_leaky_relu,
"lift_fresh_copy.default": replace_lift_fresh_copy,
"linalg_vector_norm.default": replace_linalg_vector_norm,
"logsumexp.default": replace_logsumexp,
"log.default": replace_unary_ops,
"log10.default": replace_log10,
"log1p.default": replace_log1p,
Expand Down Expand Up @@ -3567,6 +3659,7 @@ def sdpa_maskless(q: Value, k: Value, v: Value) -> Value:
"select.int": replace_select_int,
"sigmoid.default": replace_unary_ops,
"silu.default": replace_unary_ops,
"softplus.default": replace_softplus,
"sign.default": replace_sign,
"sin.default": replace_unary_ops,
"sinh.default": replace_unary_ops,
Expand Down
9 changes: 9 additions & 0 deletions coreai_torch/_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
torch.ops.aten.instance_norm.default,
torch.ops.aten.logsumexp.default,
torch.ops.aten.mish.default,
torch.ops.aten.pixel_shuffle.default,
torch.ops.aten.scaled_dot_product_attention.default,
torch.ops.aten.silu.default,
torch.ops.aten.softplus.default,
]


Expand All @@ -42,6 +45,12 @@ def get_decomp_table() -> dict:
* ``torch.ops.aten.hardswish.default``
* ``torch.ops.aten.silu.default``

*Numerically stable lowerings (fp16 safety):*

* ``torch.ops.aten.logsumexp.default``
* ``torch.ops.aten.mish.default``
* ``torch.ops.aten.softplus.default``

**Usage with** ``add_exported_program`` (caller handles decomposition)::

import torch
Expand Down
142 changes: 142 additions & 0 deletions tests/ops/test_fp16_stable_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause

"""Tests for numerically stable softplus, mish, and logsumexp converters (#21).

These tests verify that the stable decompositions produce correct results
for inputs in the fp16 overflow range (x > ~10.4), where the naive forms
(log(1+exp(x)), etc.) would overflow to inf.
"""

import numpy as np
import pytest
import torch
import torch.nn as nn
from torch import Tensor

from ..utils import (
_all_dims_dynamic,
validate_numerical_output,
)


# --- Softplus ---


class SoftplusModel(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return torch.nn.functional.softplus(x)


class SoftplusBetaModel(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return torch.nn.functional.softplus(x, beta=2.0)


@pytest.mark.parametrize("dynamic", [False, True])
@pytest.mark.parametrize(
"x",
[
torch.tensor([0.0, 1.0, -1.0, 5.0, -5.0]),
# Inputs in the fp16 overflow range — naive softplus fails here
torch.tensor([10.0, 11.0, 15.0, 20.0, 50.0]),
torch.randn(3, 4),
],
)
async def test_softplus(x: Tensor, dynamic: bool) -> None:
"""Test softplus with inputs spanning the fp16 overflow threshold."""
model = SoftplusModel().eval()
dynamic_shapes = {"x": _all_dims_dynamic(x)} if dynamic else None
await validate_numerical_output(model=model, x=x, dynamic_shapes=dynamic_shapes)


@pytest.mark.parametrize("dynamic", [False])
@pytest.mark.parametrize(
"x",
[
torch.tensor([0.0, 5.0, 10.0, 15.0]),
],
)
async def test_softplus_beta(x: Tensor, dynamic: bool) -> None:
"""Test softplus with beta != 1."""
model = SoftplusBetaModel().eval()
dynamic_shapes = {"x": _all_dims_dynamic(x)} if dynamic else None
await validate_numerical_output(model=model, x=x, dynamic_shapes=dynamic_shapes)


# --- Mish ---


class MishModel(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return torch.nn.functional.mish(x)


@pytest.mark.parametrize("dynamic", [False, True])
@pytest.mark.parametrize(
"x",
[
torch.tensor([0.0, 1.0, -1.0, 5.0, -5.0]),
# Inputs in the fp16 overflow range — naive mish fails here
torch.tensor([10.0, 11.0, 15.0, 20.0, 50.0]),
torch.randn(3, 4),
],
)
async def test_mish(x: Tensor, dynamic: bool) -> None:
"""Test mish with inputs spanning the fp16 overflow threshold."""
model = MishModel().eval()
dynamic_shapes = {"x": _all_dims_dynamic(x)} if dynamic else None
await validate_numerical_output(model=model, x=x, dynamic_shapes=dynamic_shapes)


# --- Logsumexp ---


class LogsumexpModel(nn.Module):
def __init__(self, dim: int, keepdim: bool = False) -> None:
super().__init__()
self.dim = dim
self.keepdim = keepdim

def forward(self, x: Tensor) -> Tensor:
return torch.logsumexp(x, dim=self.dim, keepdim=self.keepdim)


@pytest.mark.parametrize("dynamic", [False, True])
@pytest.mark.parametrize("keepdim", [False, True])
@pytest.mark.parametrize(
"x",
[
torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
# Inputs in the fp16 overflow range — naive logsumexp fails here
torch.tensor([[8.0, 9.0, 10.0], [11.0, 12.0, 15.0]]),
torch.randn(3, 4),
],
)
async def test_logsumexp(x: Tensor, keepdim: bool, dynamic: bool) -> None:
"""Test logsumexp with inputs spanning the fp16 overflow threshold."""
model = LogsumexpModel(dim=1, keepdim=keepdim).eval()
dynamic_shapes = {"x": _all_dims_dynamic(x)} if dynamic else None
await validate_numerical_output(model=model, x=x, dynamic_shapes=dynamic_shapes)


# --- Numeric fp16 overflow proof ---


def test_fp16_overflow_proof() -> None:
"""Verify naive fp16 softplus overflows while stable form does not."""
x_val = np.float16(15.0)

# Naive: log(1 + exp(fp16(15))) overflows
naive = np.float16(np.log(np.float16(1.0) + np.exp(x_val)))
assert not np.isfinite(naive), f"Naive fp16 softplus should overflow, got {naive}"

# Stable: max(15,0) + log(1 + exp(-15)) ≈ 15.0
stable = np.float16(
np.maximum(x_val, np.float16(0))
+ np.log(np.float16(1.0) + np.exp(-np.abs(x_val)))
)
assert np.isfinite(stable), f"Stable fp16 softplus should not overflow, got {stable}"
assert abs(float(stable) - 15.0) < 0.5, f"Expected ~15.0, got {stable}"