From be19b319a860cc6dd97c36bbb3ac41b567ad8d40 Mon Sep 17 00:00:00 2001 From: gokulkrishna98 Date: Mon, 29 Jun 2026 12:54:45 -0700 Subject: [PATCH 1/2] _aten_to_core: fix arange lowering to preserve static shape for float dtypes Cast range_ operands to si32 instead of the output dtype, so the compiler can infer a static element count from constant int bounds. Cast the result to the requested dtype afterward. Previously casting operands to float caused range_ to return tensor even for compile-time constants, breaking composite signatures that expected a static dimension. --- coreai_torch/_aten_to_core.py | 16 +++++----- tests/ops/test_ops_ir.py | 57 ++++++++++++++++++++++++++--------- 2 files changed, 51 insertions(+), 22 deletions(-) diff --git a/coreai_torch/_aten_to_core.py b/coreai_torch/_aten_to_core.py index 433e27a..418e2af 100644 --- a/coreai_torch/_aten_to_core.py +++ b/coreai_torch/_aten_to_core.py @@ -644,20 +644,22 @@ def replace_arange_start_step( else coreai.constant(1, dtype=start.type.element_type) ) - # coreai.range_ requires scalar (rank-0) operands that share an element - # type. aten.arange promotes mixed-type scalars internally; we replicate - # that here by squeezing each operand to rank-0 and casting to the FX - # node's output dtype before the op. + # Keep operands as integers so coreai.range_ can infer a static shape; + # cast the result to the requested dtype afterward. target_type = get_output_element_type_from_node(node) + si32 = IntegerType.get_signed(32) def to_scalar(v: Value) -> Value: if v.type.rank > 0: v = coreai.shrink_dims(v, list(range(v.type.rank))) - if v.type.element_type != target_type: - v = coreai.cast(v, target_type) + if v.type.element_type != si32: + v = coreai.cast(v, si32) return v - return coreai.range_(to_scalar(start), to_scalar(end), to_scalar(step)) + result = coreai.range_(to_scalar(start), to_scalar(end), to_scalar(step)) + if result.type.element_type != target_type: + result = coreai.cast(result, target_type) + return result def replace_batch_norm( diff --git a/tests/ops/test_ops_ir.py b/tests/ops/test_ops_ir.py index 50c0786..59005b4 100644 --- a/tests/ops/test_ops_ir.py +++ b/tests/ops/test_ops_ir.py @@ -995,11 +995,12 @@ def forward(self, x: Tensor) -> Tensor: dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1)}}, ) # ``end`` is a SymInt (rank-1 si32) sliced from x.shape[0]; - # ``start``/``step`` are scalar (rank-0) si32 constants. The - # lowering casts each operand to the FX node's output dtype (f32) - # before ``coreai.range_`` so the op sees uniform-typed scalars; - # the optimizer then constant-folds the casts on start/step into - # f32 constants directly. + # ``start``/``step`` are scalar (rank-0) si32 constants. + # The lowering keeps all operands as si32 so that coreai.range_ + # can infer a static shape from constant int operands, then casts + # the result to the requested output dtype. (Casting operands to + # float *before* range_ causes it to return tensor even when the + # bounds are compile-time constants.) filecheck_pattern( ir, check_file=""" @@ -1007,24 +1008,50 @@ def forward(self, x: Tensor) -> Tensor: // CHECK-SAME: %[[ARG0:.*]]: tensor // CHECK-SAME: -> (tensor // - // start and step land as f32 constants directly (the si32 - // constant + cast-to-f32 pair gets folded by the optimizer): - // CHECK-DAG: %[[STEP:.+]] = coreai.constant dense<1.000000e+00> : tensor - // CHECK-DAG: %[[START:.+]] = coreai.constant dense<0.000000e+00> : tensor + // start and step remain as si32 constants; no pre-range cast: + // CHECK-DAG: %[[STEP:.+]] = coreai.constant dense<{{.*}}> : tensor + // CHECK-DAG: %[[START:.+]] = coreai.constant dense<{{.*}}> : tensor // - // end: get_shape -> slice -> cast(ui32->si32) -> reshape(rank-1 to rank-0) -> cast(si32->f32): + // end: get_shape -> slice -> cast(ui32->si32) -> reshape(rank-1 to rank-0); + // stays si32, no cast to f32 before range_: // CHECK: %[[END_RANK1:.+]] = coreai.cast {{.*}} : tensor<1xui32> to tensor<1xsi32> // CHECK: %[[END_RANK0:.+]] = coreai.reshape %[[END_RANK1]], {{.*}} : (tensor<1xsi32>, tensor<0xui32>) -> tensor - // CHECK: %[[END_F32:.+]] = coreai.cast %[[END_RANK0]] : tensor to tensor // - // range called with all-f32 scalars; result is f32 directly, - // no post-range cast on the result: - // CHECK: %[[OUT:.+]] = coreai.range %[[START]], %[[END_F32]], %[[STEP]] : (tensor, tensor, tensor) -> tensor - // CHECK-NOT: coreai.cast %[[OUT]] + // range_ called with all-si32 scalars; result is si32 with dynamic shape: + // CHECK: %[[RANGE_OUT:.+]] = coreai.range %[[START]], %[[END_RANK0]], %[[STEP]] : (tensor, tensor, tensor) -> tensor + // + // post-range cast to the requested f32 output dtype: + // CHECK: %[[OUT:.+]] = coreai.cast %[[RANGE_OUT]] : tensor to tensor // CHECK: coreai.output %[[OUT]] : tensor """, ) + def test_static_float_dtype_preserves_shape(self) -> None: + """Regression: arange with a float dtype must keep a static output shape. + + Casting operands to float *before* coreai.range_ causes the op to + return tensor even when all bounds are compile-time constants. + The fix is to run range_ on int operands and cast the result. + """ + + class ArangeFloat(nn.Module): + def forward(self, x: Tensor) -> Tensor: + # Constant int bounds, float output dtype — the problematic case. + return torch.arange(0, 8, 2, dtype=torch.float32, device=x.device) + + ir = get_ir(ArangeFloat().eval(), x=torch.rand(4)) + # The output shape must be static (4 elements: 0,2,4,6). + filecheck_pattern( + ir, + check_file=""" + // CHECK-LABEL: module { + // CHECK-NEXT: coreai.graph @main(%[[ARG0:.*]]: tensor<4xf32> + // CHECK-SAME: -> (tensor<4xf32> + // The optimizer constant-folds the whole thing to a dense literal: + // CHECK: coreai.constant dense<{{.*}}> : tensor<4xf32> + """, + ) + class TestArgmaxIR: def test_static(self) -> None: From 0d1b595f0b1c173109fe2715ba51bb4a371b3099 Mon Sep 17 00:00:00 2001 From: gokulkrishna98 Date: Mon, 29 Jun 2026 13:04:54 -0700 Subject: [PATCH 2/2] =?UTF-8?q?=5Faten=5Fto=5Fcore:=20fix=20arange=20lower?= =?UTF-8?q?ing=20=E2=80=94=20use=20si32=20only=20for=20integer=20operands?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Float operands (e.g. arange(0.5, 5.0, 0.5)) must not be cast to si32 as that truncates the values. Only apply the integer-path optimisation when start has an integer element type; fall back to target_type for float operands where a dynamic shape is correct anyway. --- coreai_torch/_aten_to_core.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/coreai_torch/_aten_to_core.py b/coreai_torch/_aten_to_core.py index 418e2af..20a2f83 100644 --- a/coreai_torch/_aten_to_core.py +++ b/coreai_torch/_aten_to_core.py @@ -644,16 +644,21 @@ def replace_arange_start_step( else coreai.constant(1, dtype=start.type.element_type) ) - # Keep operands as integers so coreai.range_ can infer a static shape; - # cast the result to the requested dtype afterward. + # When operands are integer-typed, keep them as si32 so coreai.range_ can + # infer a static output shape, then cast the result to the requested dtype. + # For float operands the count isn't statically determinable anyway, so + # cast everything to target_type and let range_ return a dynamic shape. target_type = get_output_element_type_from_node(node) si32 = IntegerType.get_signed(32) + range_type = ( + si32 if isinstance(start.type.element_type, IntegerType) else target_type + ) def to_scalar(v: Value) -> Value: if v.type.rank > 0: v = coreai.shrink_dims(v, list(range(v.type.rank))) - if v.type.element_type != si32: - v = coreai.cast(v, si32) + if v.type.element_type != range_type: + v = coreai.cast(v, range_type) return v result = coreai.range_(to_scalar(start), to_scalar(end), to_scalar(step))