diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -770,8 +770,11 @@ // Place the load directly after the reshape to ensure that the shape memref // was not mutated. rewriter.setInsertionPointAfter(reshape); - rewriter.replaceOpWithNewOp(dim, reshape.shape(), - llvm::makeArrayRef({dim.index()})); + Location loc = dim.getLoc(); + Value load = rewriter.create(loc, reshape.shape(), dim.index()); + if (load.getType() != dim.getType()) + load = rewriter.create(loc, dim.getType(), load); + rewriter.replaceOp(dim, load); return success(); } }; diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -122,6 +122,26 @@ // ----- +// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape_i32( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, +// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref +// CHECK-NEXT: %[[IDX:.*]] = constant 3 +// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] +// CHECK-NEXT: %[[CAST:.*]] = index_cast %[[DIM]] +// CHECK-NOT: memref.dim +// CHECK: return %[[CAST]] : index +func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref) + -> index { + %c3 = constant 3 : index + %0 = memref.reshape %arg0(%arg1) + : (memref<*xf32>, memref) -> memref<*xf32> + %1 = memref.dim %0, %c3 : memref<*xf32> + return %1 : index +} + +// ----- + // Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx // CHECK-LABEL: func @fold_dim_of_tensor.cast // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>