Description
Integer-to-integer casts on shape arguments of Alloc (and other shape-consuming ops like Reshape) are not removed during canonicalization, even though these ops accept any integer dtype for shape inputs.
For example, Alloc(val, int32_var.astype("int64")) keeps the Cast{int64} node in the optimized graph, despite Alloc working correctly with int32 shape inputs directly.
Reproducer
import pytensor.tensor as pt
from pytensor import function, dprint
from pytensor.compile.mode import get_default_mode
x = pt.scalar('x')
s = pt.scalar('s', dtype='int32')
g1 = pt.alloc(x, s)
g2 = pt.alloc(x, s.astype('int64'))
mode = get_default_mode().including('canonicalize')
f1 = function([x, s], g1, mode=mode)
f2 = function([x, s], g2, mode=mode)
print('Without cast:')
dprint(f1)
# Alloc [id A]
# |-- x [id B]
# `-- s [id C]
print('With useless int32->int64 cast:')
dprint(f2)
# Alloc [id A]
# |-- x [id B]
# `-- Cast{int64} [id C]
# `-- s [id D]
The same issue applies to all integer-to-integer casts on shape arguments: int16, int32, uint8, uint16, uint32 to int64. These casts are functionally unnecessary since Alloc accepts any integer dtype for shape inputs.
This also affects Reshape:
x = pt.matrix('x')
s = pt.scalar('s', dtype='int32')
g = x.reshape((s.astype('int64'), pt.constant(2)))
f = function([x, s], g, mode=mode)
dprint(f)
# Reshape{2} [id A]
# |-- x [id B]
# `-- MakeVector{dtype='int64'} [id C]
# |-- Cast{int64} [id D]
# | `-- s [id E]
# `-- 2 [id F]
Expected behavior
A canonicalization rewrite should strip integer casts from shape arguments of shape-consuming ops (Alloc, Reshape, etc.), since the cast has no semantic effect.
This is a minor inefficiency but can clutter graphs and interfere with pattern matching in other rewrites that expect clean shape inputs.
Description
Integer-to-integer casts on shape arguments of
Alloc(and other shape-consuming ops likeReshape) are not removed during canonicalization, even though these ops accept any integer dtype for shape inputs.For example,
Alloc(val, int32_var.astype("int64"))keeps theCast{int64}node in the optimized graph, despiteAllocworking correctly withint32shape inputs directly.Reproducer
The same issue applies to all integer-to-integer casts on shape arguments:
int16,int32,uint8,uint16,uint32toint64. These casts are functionally unnecessary sinceAllocaccepts any integer dtype for shape inputs.This also affects
Reshape:Expected behavior
A canonicalization rewrite should strip integer casts from shape arguments of shape-consuming ops (
Alloc,Reshape, etc.), since the cast has no semantic effect.This is a minor inefficiency but can clutter graphs and interfere with pattern matching in other rewrites that expect clean shape inputs.