diff --git a/coreai_torch/_aten_to_core.py b/coreai_torch/_aten_to_core.py index 433e27a..20a2f83 100644 --- a/coreai_torch/_aten_to_core.py +++ b/coreai_torch/_aten_to_core.py @@ -644,20 +644,27 @@ def replace_arange_start_step( else coreai.constant(1, dtype=start.type.element_type) ) - # coreai.range_ requires scalar (rank-0) operands that share an element - # type. aten.arange promotes mixed-type scalars internally; we replicate - # that here by squeezing each operand to rank-0 and casting to the FX - # node's output dtype before the op. + # When operands are integer-typed, keep them as si32 so coreai.range_ can + # infer a static output shape, then cast the result to the requested dtype. + # For float operands the count isn't statically determinable anyway, so + # cast everything to target_type and let range_ return a dynamic shape. target_type = get_output_element_type_from_node(node) + si32 = IntegerType.get_signed(32) + range_type = ( + si32 if isinstance(start.type.element_type, IntegerType) else target_type + ) def to_scalar(v: Value) -> Value: if v.type.rank > 0: v = coreai.shrink_dims(v, list(range(v.type.rank))) - if v.type.element_type != target_type: - v = coreai.cast(v, target_type) + if v.type.element_type != range_type: + v = coreai.cast(v, range_type) return v - return coreai.range_(to_scalar(start), to_scalar(end), to_scalar(step)) + result = coreai.range_(to_scalar(start), to_scalar(end), to_scalar(step)) + if result.type.element_type != target_type: + result = coreai.cast(result, target_type) + return result def replace_batch_norm( diff --git a/tests/ops/test_ops_ir.py b/tests/ops/test_ops_ir.py index 50c0786..59005b4 100644 --- a/tests/ops/test_ops_ir.py +++ b/tests/ops/test_ops_ir.py @@ -995,11 +995,12 @@ def forward(self, x: Tensor) -> Tensor: dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1)}}, ) # ``end`` is a SymInt (rank-1 si32) sliced from x.shape[0]; - # ``start``/``step`` are scalar (rank-0) si32 constants. The - # lowering casts each operand to the FX node's output dtype (f32) - # before ``coreai.range_`` so the op sees uniform-typed scalars; - # the optimizer then constant-folds the casts on start/step into - # f32 constants directly. + # ``start``/``step`` are scalar (rank-0) si32 constants. + # The lowering keeps all operands as si32 so that coreai.range_ + # can infer a static shape from constant int operands, then casts + # the result to the requested output dtype. (Casting operands to + # float *before* range_ causes it to return tensor even when the + # bounds are compile-time constants.) filecheck_pattern( ir, check_file=""" @@ -1007,24 +1008,50 @@ def forward(self, x: Tensor) -> Tensor: // CHECK-SAME: %[[ARG0:.*]]: tensor // CHECK-SAME: -> (tensor // - // start and step land as f32 constants directly (the si32 - // constant + cast-to-f32 pair gets folded by the optimizer): - // CHECK-DAG: %[[STEP:.+]] = coreai.constant dense<1.000000e+00> : tensor - // CHECK-DAG: %[[START:.+]] = coreai.constant dense<0.000000e+00> : tensor + // start and step remain as si32 constants; no pre-range cast: + // CHECK-DAG: %[[STEP:.+]] = coreai.constant dense<{{.*}}> : tensor + // CHECK-DAG: %[[START:.+]] = coreai.constant dense<{{.*}}> : tensor // - // end: get_shape -> slice -> cast(ui32->si32) -> reshape(rank-1 to rank-0) -> cast(si32->f32): + // end: get_shape -> slice -> cast(ui32->si32) -> reshape(rank-1 to rank-0); + // stays si32, no cast to f32 before range_: // CHECK: %[[END_RANK1:.+]] = coreai.cast {{.*}} : tensor<1xui32> to tensor<1xsi32> // CHECK: %[[END_RANK0:.+]] = coreai.reshape %[[END_RANK1]], {{.*}} : (tensor<1xsi32>, tensor<0xui32>) -> tensor - // CHECK: %[[END_F32:.+]] = coreai.cast %[[END_RANK0]] : tensor to tensor // - // range called with all-f32 scalars; result is f32 directly, - // no post-range cast on the result: - // CHECK: %[[OUT:.+]] = coreai.range %[[START]], %[[END_F32]], %[[STEP]] : (tensor, tensor, tensor) -> tensor - // CHECK-NOT: coreai.cast %[[OUT]] + // range_ called with all-si32 scalars; result is si32 with dynamic shape: + // CHECK: %[[RANGE_OUT:.+]] = coreai.range %[[START]], %[[END_RANK0]], %[[STEP]] : (tensor, tensor, tensor) -> tensor + // + // post-range cast to the requested f32 output dtype: + // CHECK: %[[OUT:.+]] = coreai.cast %[[RANGE_OUT]] : tensor to tensor // CHECK: coreai.output %[[OUT]] : tensor """, ) + def test_static_float_dtype_preserves_shape(self) -> None: + """Regression: arange with a float dtype must keep a static output shape. + + Casting operands to float *before* coreai.range_ causes the op to + return tensor even when all bounds are compile-time constants. + The fix is to run range_ on int operands and cast the result. + """ + + class ArangeFloat(nn.Module): + def forward(self, x: Tensor) -> Tensor: + # Constant int bounds, float output dtype — the problematic case. + return torch.arange(0, 8, 2, dtype=torch.float32, device=x.device) + + ir = get_ir(ArangeFloat().eval(), x=torch.rand(4)) + # The output shape must be static (4 elements: 0,2,4,6). + filecheck_pattern( + ir, + check_file=""" + // CHECK-LABEL: module { + // CHECK-NEXT: coreai.graph @main(%[[ARG0:.*]]: tensor<4xf32> + // CHECK-SAME: -> (tensor<4xf32> + // The optimizer constant-folds the whole thing to a dense literal: + // CHECK: coreai.constant dense<{{.*}}> : tensor<4xf32> + """, + ) + class TestArgmaxIR: def test_static(self) -> None: