Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions coreai_torch/_aten_to_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,111 @@ def replace_argmax(values_map: dict[str, Value], node: fx.Node, loc: Location) -
return result if keepdim else coreai.shrink_dims(result, [dim])


def replace_atan2(values_map: dict[str, Value], node: fx.Node, loc: Location) -> Value:
"""Lower atan2(y, x) using atan(y/x) with quadrant correction.

CoreAI has no native atan2, so it is decomposed as:
- x != 0, finite: atan(y/x) adjusted by ±π for the correct quadrant.
- x == +0: ±π/2 for non-zero y, 0 for y = 0.
- x == -0: ±π for all y (including ±0 → ±π per IEEE-754).
- both infinite: ±π/4 or ±3π/4 per IEEE-754.

Signed-zero handling: IEEE-754 treats -0.0 as distinct from +0.0 for atan2
(e.g. atan2(-0, -1) = -π, not +π). The 1/v trick — 1/-0.0 = -inf — is used
to detect the sign bit of zero inputs so that y_neg and x_neg are correct
for -0.0 inputs without misclassifying ±inf (which use the strict > path).

When x=0, x is replaced with 1 before the divide solely to avoid NaN/inf; that
intermediate result is discarded by the final where-select.
atan2(0, 0) = 0 by convention.
"""
y, x = _get_operands(values_map, node, [0, 1])
ele_type = x.type.element_type

zero = coreai.constant(0.0, dtype=ele_type)
one = coreai.constant(1.0, dtype=ele_type)
pi = coreai.constant(np.pi, dtype=ele_type)
neg_pi = coreai.constant(-np.pi, dtype=ele_type)
half_pi = coreai.constant(np.pi / 2.0, dtype=ele_type)
neg_half_pi = coreai.constant(-np.pi / 2.0, dtype=ele_type)
quarter_pi = coreai.constant(np.pi / 4.0, dtype=ele_type)
neg_quarter_pi = coreai.constant(-np.pi / 4.0, dtype=ele_type)
three_quarter_pi = coreai.constant(3.0 * np.pi / 4.0, dtype=ele_type)
neg_three_quarter_pi = coreai.constant(-3.0 * np.pi / 4.0, dtype=ele_type)

# ── signed-zero-aware sign predicates ─────────────────────────────────────
# 1 / -0.0 = -inf (IEEE-754), so (0 > 1/v) is True iff v = -0.0. Combine with
# the strict > predicate (handles ±inf and non-zero finites) via OR.
y_is_zero = coreai.broadcasting_equal(y, zero)
x_is_zero = coreai.broadcasting_equal(x, zero)
y_neg = coreai.broadcasting_or(
coreai.broadcasting_greater(zero, y),
coreai.broadcasting_and(
y_is_zero,
coreai.broadcasting_greater(zero, coreai.broadcasting_divide(one, y)),
),
)
x_neg = coreai.broadcasting_or(
coreai.broadcasting_greater(zero, x),
coreai.broadcasting_and(
x_is_zero,
coreai.broadcasting_greater(zero, coreai.broadcasting_divide(one, x)),
),
)
x_is_neg_zero = coreai.broadcasting_and(
x_is_zero,
coreai.broadcasting_greater(zero, coreai.broadcasting_divide(one, x)),
)

# ── both-infinite branch ──────────────────────────────────────────────────
# atan(inf/inf) = atan(NaN) = NaN; handle before the divide.
pos_inf = coreai.constant(float("inf"), dtype=ele_type)
neg_inf = coreai.constant(float("-inf"), dtype=ele_type)
x_is_inf = coreai.broadcasting_or(
coreai.broadcasting_equal(x, pos_inf), coreai.broadcasting_equal(x, neg_inf)
)
y_is_inf = coreai.broadcasting_or(
coreai.broadcasting_equal(y, pos_inf), coreai.broadcasting_equal(y, neg_inf)
)
both_inf = coreai.broadcasting_and(x_is_inf, y_is_inf)
inf_result = coreai.broadcasting_where(
y_neg,
coreai.broadcasting_where(x_neg, neg_three_quarter_pi, neg_quarter_pi),
coreai.broadcasting_where(x_neg, three_quarter_pi, quarter_pi),
)

# ── x = 0 branch ──────────────────────────────────────────────────────────
# x = +0: ±π/2 for strictly ±y, 0 when y = 0.
# x = -0: ±π for all y (y_neg covers y = -0.0 via the 1/y trick above).
y_pos_strict = coreai.broadcasting_greater(y, zero)
y_neg_strict = coreai.broadcasting_greater(zero, y)
pos_x_zero_result = coreai.broadcasting_where(
y_pos_strict,
half_pi,
coreai.broadcasting_where(y_neg_strict, neg_half_pi, zero),
)
neg_x_zero_result = coreai.broadcasting_where(y_neg, neg_pi, pi)
zero_result = coreai.broadcasting_where(
x_is_neg_zero, neg_x_zero_result, pos_x_zero_result
)

# ── finite nonzero x branch ────────────────────────────────────────────────
# Avoid division by zero: substitute x = 1 when x = 0; result discarded by
# the outer where-select.
x_safe = coreai.broadcasting_where(x_is_zero, one, x)
base = coreai.atan(coreai.broadcasting_divide(y, x_safe))
correction = coreai.broadcasting_where(
y_neg,
coreai.broadcasting_sub(base, pi),
coreai.broadcasting_add(base, pi),
)
nonzero_result = coreai.broadcasting_where(x_neg, correction, base)

# ── combine ────────────────────────────────────────────────────────────────
result = coreai.broadcasting_where(x_is_zero, zero_result, nonzero_result)
return coreai.broadcasting_where(both_inf, inf_result, result)


def replace_gather(values_map: dict[str, Value], node: fx.Node, loc: Location) -> Value:
"""Converts aten.gather to coreai.gather_along_axis."""
x, index = _get_operands(values_map, node, [0, 2])
Expand Down Expand Up @@ -3440,6 +3545,7 @@ def sdpa_maskless(q: Value, k: Value, v: Value) -> Value:
"asin.default": replace_unary_ops,
"asinh.default": replace_unary_ops,
"atan.default": replace_unary_ops,
"atan2.default": replace_atan2,
"atanh.default": replace_unary_ops,
"_adaptive_avg_pool2d.default": replace_adaptive_avg_pool2d,
"_unsafe_view.default": replace_view,
Expand Down
Loading