Skip to content

[--flatten-memref] Flattened Indices are incorrect #10326

@Abhilekhgautam

Description

@Abhilekhgautam

The --flatten-memref pass generates incorrect indices for the flattened memref.

For eg (Check in Compiler Explorer) :
If we run the same pass for the following function:

func.func @relu4d_0(%arg0: memref<1x3x10x10xf32>) -> memref<1x3x10x10xf32> attributes {itypes = "_", otypes = "_"} {
    %c10 = arith.constant 10 : index
    %c3 = arith.constant 3 : index
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %alloc = memref.alloc() {name = "Z"} : memref<1x3x10x10xf32>
    scf.for %arg1 = %c0 to %c3 step %c1 {
      scf.for %arg2 = %c0 to %c10 step %c1 {
        scf.for %arg3 = %c0 to %c10 step %c1 {
          %0 = memref.load %arg0[%c0, %arg1, %arg2, %arg3] : memref<1x3x10x10xf32>
          %1 = arith.maximumf %0, %cst : f32
          memref.store %1, %alloc[%c0, %arg1, %arg2, %arg3] : memref<1x3x10x10xf32>
        }
      }
    }
    return %alloc : memref<1x3x10x10xf32>
  }

It generates:

func.func @relu4d_0(%arg0: memref<300xf32>) -> memref<300xf32> attributes {itypes = "_", otypes = "_"} {
    %c10 = arith.constant 10 : index
    %c3 = arith.constant 3 : index
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %alloc = memref.alloc() : memref<300xf32>
    scf.for %arg1 = %c0 to %c3 step %c1 {
      scf.for %arg2 = %c0 to %c10 step %c1 {
        scf.for %arg3 = %c0 to %c10 step %c1 {
          %c300 = arith.constant 300 : index
          %0 = arith.muli %c0, %c300 : index
          %1 = arith.addi %0, %arg1 : index
          %c100 = arith.constant 100 : index
          %2 = arith.muli %1, %c100 : index
          %3 = arith.addi %2, %arg2 : index
          %c10_0 = arith.constant 10 : index
          %4 = arith.muli %3, %c10_0 : index
          %5 = arith.addi %4, %arg3 : index
          %6 = memref.load %arg0[%5] : memref<300xf32>
          %7 = arith.maximumf %6, %cst : f32
          %c300_1 = arith.constant 300 : index
          %8 = arith.muli %c0, %c300_1 : index
          %9 = arith.addi %8, %arg1 : index
          %c100_2 = arith.constant 100 : index
          %10 = arith.muli %9, %c100_2 : index
          %11 = arith.addi %10, %arg2 : index
          %c10_3 = arith.constant 10 : index
          %12 = arith.muli %11, %c10_3 : index
          %13 = arith.addi %12, %arg3 : index
          memref.store %7, %alloc[%13] : memref<300xf32>
        }
      }
    }
    return %alloc : memref<300xf32>
  }

If we trace this for arg1 = 1, arg2 = 1 and arg3 = 1 we get:

%0 = 0
%1 = 1
%2 = 100
%3 = 101
%4 = 1010
%5 = 1011

This way, we ultimately try to access arg0[1011] which is out of bounds for a memref<300xf32>

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions