diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1639,11 +1639,11 @@ CollapseMixedReshapeOps, CollapseShapeOpMemRefCastFolder>(context); } + OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { - if (succeeded(foldMemRefCast(*this))) - return getResult(); return foldReshapeOp(*this, operands); } + OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { return foldReshapeOp(*this, operands); } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -614,6 +614,8 @@ return %collapsed : memref } +// ----- + // CHECK-LABEL: func @collapse_after_memref_cast( // CHECK-SAME: %[[INPUT:.*]]: memref) -> memref { // CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]] @@ -624,3 +626,28 @@ %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref into memref return %collapsed : memref } + +// ----- + +// CHECK-LABEL: func @expand_after_memref_cast_type_change( +// CHECK-SAME: %[[INPUT:.*]]: memref<1x?xf32>) -> memref { +// CHECK: %[[CAST:.*]] = memref.cast %[[INPUT]] : memref<1x?xf32> to memref +// CHECK: %[[EXPANDED:.*]] = memref.expand_shape %[[CAST]] {{\[\[}}0], [1, 2, 3]] : memref into memref +// CHECK: return %[[EXPANDED]] : memref +func @expand_after_memref_cast_type_change(%arg0 : memref<1x?xf32>) -> memref { + %dynamic = memref.cast %arg0: memref<1x?xf32> to memref + %expanded = memref.expand_shape %dynamic [[0], [1, 2, 3]] : memref into memref + return %expanded: memref +} + +// ----- + +// CHECK-LABEL: func @expand_after_memref_cast( +// CHECK-SAME: %[[INPUT:.*]]: memref) -> memref { +// CHECK: %[[EXPANDED:.*]] = memref.expand_shape %[[INPUT]] {{\[\[}}0], [1, 2, 3]] : memref into memref +// CHECK: return %[[EXPANDED]] : memref +func @expand_after_memref_cast(%arg0 : memref) -> memref { + %dynamic = memref.cast %arg0: memref to memref + %expanded = memref.expand_shape %dynamic [[0], [1, 2, 3]] : memref into memref + return %expanded: memref +}