Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions coreai_torch/_aten_to_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,20 +644,27 @@ def replace_arange_start_step(
else coreai.constant(1, dtype=start.type.element_type)
)

# coreai.range_ requires scalar (rank-0) operands that share an element
# type. aten.arange promotes mixed-type scalars internally; we replicate
# that here by squeezing each operand to rank-0 and casting to the FX
# node's output dtype before the op.
# When operands are integer-typed, keep them as si32 so coreai.range_ can
# infer a static output shape, then cast the result to the requested dtype.
# For float operands the count isn't statically determinable anyway, so
# cast everything to target_type and let range_ return a dynamic shape.
target_type = get_output_element_type_from_node(node)
si32 = IntegerType.get_signed(32)
range_type = (
si32 if isinstance(start.type.element_type, IntegerType) else target_type
)

def to_scalar(v: Value) -> Value:
if v.type.rank > 0:
v = coreai.shrink_dims(v, list(range(v.type.rank)))
if v.type.element_type != target_type:
v = coreai.cast(v, target_type)
if v.type.element_type != range_type:
v = coreai.cast(v, range_type)
return v

return coreai.range_(to_scalar(start), to_scalar(end), to_scalar(step))
result = coreai.range_(to_scalar(start), to_scalar(end), to_scalar(step))
if result.type.element_type != target_type:
result = coreai.cast(result, target_type)
return result


def replace_batch_norm(
Expand Down
57 changes: 42 additions & 15 deletions tests/ops/test_ops_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,36 +995,63 @@ def forward(self, x: Tensor) -> Tensor:
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1)}},
)
# ``end`` is a SymInt (rank-1 si32) sliced from x.shape[0];
# ``start``/``step`` are scalar (rank-0) si32 constants. The
# lowering casts each operand to the FX node's output dtype (f32)
# before ``coreai.range_`` so the op sees uniform-typed scalars;
# the optimizer then constant-folds the casts on start/step into
# f32 constants directly.
# ``start``/``step`` are scalar (rank-0) si32 constants.
# The lowering keeps all operands as si32 so that coreai.range_
# can infer a static shape from constant int operands, then casts
# the result to the requested output dtype. (Casting operands to
# float *before* range_ causes it to return tensor<?> even when the
# bounds are compile-time constants.)
filecheck_pattern(
ir,
check_file="""
// CHECK-LABEL: coreai.graph @main
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x3xf32>
// CHECK-SAME: -> (tensor<?xf32>
//
// start and step land as f32 constants directly (the si32
// constant + cast-to-f32 pair gets folded by the optimizer):
// CHECK-DAG: %[[STEP:.+]] = coreai.constant dense<1.000000e+00> : tensor<f32>
// CHECK-DAG: %[[START:.+]] = coreai.constant dense<0.000000e+00> : tensor<f32>
// start and step remain as si32 constants; no pre-range cast:
// CHECK-DAG: %[[STEP:.+]] = coreai.constant dense<{{.*}}> : tensor<si32>
// CHECK-DAG: %[[START:.+]] = coreai.constant dense<{{.*}}> : tensor<si32>
//
// end: get_shape -> slice -> cast(ui32->si32) -> reshape(rank-1 to rank-0) -> cast(si32->f32):
// end: get_shape -> slice -> cast(ui32->si32) -> reshape(rank-1 to rank-0);
// stays si32, no cast to f32 before range_:
// CHECK: %[[END_RANK1:.+]] = coreai.cast {{.*}} : tensor<1xui32> to tensor<1xsi32>
// CHECK: %[[END_RANK0:.+]] = coreai.reshape %[[END_RANK1]], {{.*}} : (tensor<1xsi32>, tensor<0xui32>) -> tensor<si32>
// CHECK: %[[END_F32:.+]] = coreai.cast %[[END_RANK0]] : tensor<si32> to tensor<f32>
//
// range called with all-f32 scalars; result is f32 directly,
// no post-range cast on the result:
// CHECK: %[[OUT:.+]] = coreai.range %[[START]], %[[END_F32]], %[[STEP]] : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
// CHECK-NOT: coreai.cast %[[OUT]]
// range_ called with all-si32 scalars; result is si32 with dynamic shape:
// CHECK: %[[RANGE_OUT:.+]] = coreai.range %[[START]], %[[END_RANK0]], %[[STEP]] : (tensor<si32>, tensor<si32>, tensor<si32>) -> tensor<?xsi32>
//
// post-range cast to the requested f32 output dtype:
// CHECK: %[[OUT:.+]] = coreai.cast %[[RANGE_OUT]] : tensor<?xsi32> to tensor<?xf32>
// CHECK: coreai.output %[[OUT]] : tensor<?xf32>
""",
)

def test_static_float_dtype_preserves_shape(self) -> None:
"""Regression: arange with a float dtype must keep a static output shape.

Casting operands to float *before* coreai.range_ causes the op to
return tensor<?xf32> even when all bounds are compile-time constants.
The fix is to run range_ on int operands and cast the result.
"""

class ArangeFloat(nn.Module):
def forward(self, x: Tensor) -> Tensor:
# Constant int bounds, float output dtype — the problematic case.
return torch.arange(0, 8, 2, dtype=torch.float32, device=x.device)

ir = get_ir(ArangeFloat().eval(), x=torch.rand(4))
# The output shape must be static (4 elements: 0,2,4,6).
filecheck_pattern(
ir,
check_file="""
// CHECK-LABEL: module {
// CHECK-NEXT: coreai.graph @main(%[[ARG0:.*]]: tensor<4xf32>
// CHECK-SAME: -> (tensor<4xf32>
// The optimizer constant-folds the whole thing to a dense literal:
// CHECK: coreai.constant dense<{{.*}}> : tensor<4xf32>
""",
)


class TestArgmaxIR:
def test_static(self) -> None:
Expand Down