The --flatten-memref pass generates incorrect indices for the flattened memref.
For eg (Check in Compiler Explorer) :
If we run the same pass for the following function:
func.func @relu4d_0(%arg0: memref<1x3x10x10xf32>) -> memref<1x3x10x10xf32> attributes {itypes = "_", otypes = "_"} {
%c10 = arith.constant 10 : index
%c3 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%alloc = memref.alloc() {name = "Z"} : memref<1x3x10x10xf32>
scf.for %arg1 = %c0 to %c3 step %c1 {
scf.for %arg2 = %c0 to %c10 step %c1 {
scf.for %arg3 = %c0 to %c10 step %c1 {
%0 = memref.load %arg0[%c0, %arg1, %arg2, %arg3] : memref<1x3x10x10xf32>
%1 = arith.maximumf %0, %cst : f32
memref.store %1, %alloc[%c0, %arg1, %arg2, %arg3] : memref<1x3x10x10xf32>
}
}
}
return %alloc : memref<1x3x10x10xf32>
}
It generates:
func.func @relu4d_0(%arg0: memref<300xf32>) -> memref<300xf32> attributes {itypes = "_", otypes = "_"} {
%c10 = arith.constant 10 : index
%c3 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%alloc = memref.alloc() : memref<300xf32>
scf.for %arg1 = %c0 to %c3 step %c1 {
scf.for %arg2 = %c0 to %c10 step %c1 {
scf.for %arg3 = %c0 to %c10 step %c1 {
%c300 = arith.constant 300 : index
%0 = arith.muli %c0, %c300 : index
%1 = arith.addi %0, %arg1 : index
%c100 = arith.constant 100 : index
%2 = arith.muli %1, %c100 : index
%3 = arith.addi %2, %arg2 : index
%c10_0 = arith.constant 10 : index
%4 = arith.muli %3, %c10_0 : index
%5 = arith.addi %4, %arg3 : index
%6 = memref.load %arg0[%5] : memref<300xf32>
%7 = arith.maximumf %6, %cst : f32
%c300_1 = arith.constant 300 : index
%8 = arith.muli %c0, %c300_1 : index
%9 = arith.addi %8, %arg1 : index
%c100_2 = arith.constant 100 : index
%10 = arith.muli %9, %c100_2 : index
%11 = arith.addi %10, %arg2 : index
%c10_3 = arith.constant 10 : index
%12 = arith.muli %11, %c10_3 : index
%13 = arith.addi %12, %arg3 : index
memref.store %7, %alloc[%13] : memref<300xf32>
}
}
}
return %alloc : memref<300xf32>
}
If we trace this for arg1 = 1, arg2 = 1 and arg3 = 1 we get:
%0 = 0
%1 = 1
%2 = 100
%3 = 101
%4 = 1010
%5 = 1011
This way, we ultimately try to access arg0[1011] which is out of bounds for a memref<300xf32>
The
--flatten-memrefpass generates incorrect indices for the flattened memref.For eg (Check in Compiler Explorer) :
If we run the same pass for the following function:
It generates:
If we trace this for
arg1=1,arg2=1andarg3=1we get:This way, we ultimately try to access
arg0[1011]which is out of bounds for amemref<300xf32>