diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2386,7 +2386,9 @@ } OpFoldResult MemRefCastOp::fold(ArrayRef operands) { - return impl::foldCastOp(*this); + if (Value folded = impl::foldCastOp(*this)) + return folded; + return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -334,6 +334,29 @@ return %1, %2 : f32, f32 } +// CHECK-LABEL: @fold_memref_cast_in_memref_cast +// CHECK-SAME: (%[[ARG0:.*]]: memref<42x42xf64>) +func @fold_memref_cast_in_memref_cast(%0: memref<42x42xf64>) { + // CHECK: %[[folded:.*]] = memref_cast %[[ARG0]] : memref<42x42xf64> to memref + %4 = memref_cast %0 : memref<42x42xf64> to memref + // CHECK-NOT: memref_cast + %5 = memref_cast %4 : memref to memref + // CHECK: "test.user"(%[[folded]]) + "test.user"(%5) : (memref) -> () + return +} + +// CHECK-LABEL: @fold_memref_cast_chain +// CHECK-SAME: (%[[ARG0:.*]]: memref<42x42xf64>) +func @fold_memref_cast_chain(%0: memref<42x42xf64>) { + // CHECK-NOT: memref_cast + %4 = memref_cast %0 : memref<42x42xf64> to memref + %5 = memref_cast %4 : memref to memref<42x42xf64> + // CHECK: "test.user"(%[[ARG0]]) + "test.user"(%5) : (memref<42x42xf64>) -> () + return +} + // CHECK-LABEL: func @alloc_const_fold func @alloc_const_fold() -> memref { // CHECK-NEXT: %0 = alloc() : memref<4xf32>