diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1640,8 +1640,6 @@ CollapseShapeOpMemRefCastFolder>(context); } OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { - if (succeeded(foldMemRefCast(*this))) - return getResult(); return foldReshapeOp(*this, operands); } OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -600,6 +600,18 @@ // ----- +func @fold_memref_expand_cast(%arg0 : memref) -> memref<2x4x4xf32> { + %0 = memref.cast %arg0 : memref to memref<8x4xf32> + %1 = memref.expand_shape %0 [[0, 1], [2]] + : memref<8x4xf32> into memref<2x4x4xf32> + return %1 : memref<2x4x4xf32> +} + +// CHECK-LABEL: @fold_memref_expand_cast +// CHECK: memref.expand_shape + +// ----- + // CHECK-LABEL: func @collapse_after_memref_cast_type_change( // CHECK-SAME: %[[INPUT:.*]]: memref) -> memref { // CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]