Skip to content

Useless integer casts in Alloc/Reshape shape arguments are not canonicalized away #2135

@ricardoV94

Description

@ricardoV94

Description

Integer-to-integer casts on shape arguments of Alloc (and other shape-consuming ops like Reshape) are not removed during canonicalization, even though these ops accept any integer dtype for shape inputs.

For example, Alloc(val, int32_var.astype("int64")) keeps the Cast{int64} node in the optimized graph, despite Alloc working correctly with int32 shape inputs directly.

Reproducer

import pytensor.tensor as pt
from pytensor import function, dprint
from pytensor.compile.mode import get_default_mode

x = pt.scalar('x')
s = pt.scalar('s', dtype='int32')

g1 = pt.alloc(x, s)
g2 = pt.alloc(x, s.astype('int64'))

mode = get_default_mode().including('canonicalize')

f1 = function([x, s], g1, mode=mode)
f2 = function([x, s], g2, mode=mode)

print('Without cast:')
dprint(f1)
# Alloc [id A]
#  |-- x [id B]
#  `-- s [id C]

print('With useless int32->int64 cast:')
dprint(f2)
# Alloc [id A]
#  |-- x [id B]
#  `-- Cast{int64} [id C]
#     `-- s [id D]

The same issue applies to all integer-to-integer casts on shape arguments: int16, int32, uint8, uint16, uint32 to int64. These casts are functionally unnecessary since Alloc accepts any integer dtype for shape inputs.

This also affects Reshape:

x = pt.matrix('x')
s = pt.scalar('s', dtype='int32')

g = x.reshape((s.astype('int64'), pt.constant(2)))
f = function([x, s], g, mode=mode)
dprint(f)
# Reshape{2} [id A]
#  |-- x [id B]
#  `-- MakeVector{dtype='int64'} [id C]
#     |-- Cast{int64} [id D]
#     |  `-- s [id E]
#     `-- 2 [id F]

Expected behavior

A canonicalization rewrite should strip integer casts from shape arguments of shape-consuming ops (Alloc, Reshape, etc.), since the cast has no semantic effect.

This is a minor inefficiency but can clutter graphs and interfere with pattern matching in other rewrites that expect clean shape inputs.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions