Skip to content

fix(arange): preserve static output shape for float-dtype arange with integer bounds#25

Merged
gokulkrishna98 merged 2 commits into
apple:mainfrom
gokulkrishna98:dev/gokul/fix-casting-arange
Jun 29, 2026
Merged

fix(arange): preserve static output shape for float-dtype arange with integer bounds#25
gokulkrishna98 merged 2 commits into
apple:mainfrom
gokulkrishna98:dev/gokul/fix-casting-arange

Conversation

@gokulkrishna98

Copy link
Copy Markdown
Contributor

Problem

torch.arange(start, end, step, dtype=torch.float32) was producing a dynamic-shape output (tensor<?xf32>) even when all three bounds were compile-time integer constants.

Root cause: replace_arange_start_step was casting each operand to the FX node's output dtype (e.g. f32) before calling coreai.range_. When range_ receives float operands it cannot statically determine the element count, so it returns tensor<?xf32> regardless of whether the values are constants.

This propagated through composites that use arange to build position indices (e.g. RoPE), corrupting their output type signature from a static to a dynamic dimension and breaking downstream composites that expected a static shape.

Fix

For integer operands (the common case — e.g. arange(0, head_dim, 2, dtype=float32)): keep all three operands as si32 for the range_ call. The compiler can statically count integer steps, so the result is tensor<NxSI32> with a known N. Cast the result to the requested output dtype afterward.

For float operands (e.g. arange(0.5, 5.0, 0.5)): keep the original behaviour — cast to target_type before range_. A dynamic shape is correct here since floating-point step sizes don't have a statically determinable count, and truncating values like 0.5 → 0 via si32 would produce wrong results.

si32 = IntegerType.get_signed(32)
range_type = si32 if isinstance(start.type.element_type, IntegerType) else target_type

result = coreai.range_(to_scalar(start), to_scalar(end), to_scalar(step))
if result.type.element_type != target_type:
    result = coreai.cast(result, target_type)

Tests

  • Updated TestArangeIR::test_dynamic FileCheck patterns to reflect the new IR (si32 operands → range_ → si32 result → cast to f32).
  • Added TestArangeIR::test_static_float_dtype_preserves_shape as a direct regression test: arange(0, 8, 2, dtype=float32) must lower to a static tensor<4xf32>, not tensor<?xf32>.
  • All existing arange numerical tests continue to pass, including float-operand cases (arange(0.0, 5.0, 0.5), arange(0.5, 5.0, 0.5)).

… dtypes

Cast range_ operands to si32 instead of the output dtype, so the compiler
can infer a static element count from constant int bounds. Cast the result
to the requested dtype afterward. Previously casting operands to float
caused range_ to return tensor<?xT> even for compile-time constants,
breaking composite signatures that expected a static dimension.
Float operands (e.g. arange(0.5, 5.0, 0.5)) must not be cast to si32 as
that truncates the values. Only apply the integer-path optimisation when
start has an integer element type; fall back to target_type for float
operands where a dynamic shape is correct anyway.
@gokulkrishna98 gokulkrishna98 merged commit e8dac1d into apple:main Jun 29, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants