diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -767,6 +767,11 @@ if (!reshape) return failure(); + // We need to make sure that the index dominates the new load. Since we + // can't check that here, restrict the transformation to constant indices. + if (!matchPattern(dim.index(), matchConstantIndex())) + return failure(); + // Place the load directly after the reshape to ensure that the shape memref // was not mutated. rewriter.setInsertionPointAfter(reshape); diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -142,6 +142,21 @@ // ----- +// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape_not_dominating( +// CHECK: memref.dim +func @dim_of_memref_reshape_not_dominating(%arg0: memref<*xf32>, %arg1: memref, %arg2: memref) + -> index { + %c3 = constant 3 : index + %0 = memref.reshape %arg0(%arg1) + : (memref<*xf32>, memref) -> memref<*xf32> + %idx = memref.load %arg2[%c3] : memref + %1 = memref.dim %0, %idx : memref<*xf32> + return %1 : index +} + +// ----- + // Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx // CHECK-LABEL: func @fold_dim_of_tensor.cast // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>