diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -225,7 +225,7 @@ // /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1` /// `reassociation_2` and produce `expand_shape`. -template +template struct ComposeCollapseOfExpandOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CollapseOpTy collapseOp, @@ -250,8 +250,7 @@ SmallVector higherRankReassociation, lowerRankReassociation; - bool isResultCollapsed = srcRank > resultRank; - if (isResultCollapsed) { + if (srcRank > resultRank) { higherRankReassociation = expandOp.getReassociationIndices(); lowerRankReassociation = collapseOp.getReassociationIndices(); } else { @@ -274,12 +273,20 @@ } composedReassociation.push_back(composedIndices); } - if (isResultCollapsed) + if (srcRank > resultRank) { rewriter.replaceOpWithNewOp( collapseOp, resultType, expandOp.getSrc(), composedReassociation); - else + } else if (srcRank < resultRank) { rewriter.replaceOpWithNewOp( collapseOp, resultType, expandOp.getSrc(), composedReassociation); + } else { + // Collapses/expansions that do not change the rank are not allowed. Use + // a cast instead. + assert(llvm::equal(srcType.getShape(), resultType.getShape()) && + "expected same shape"); + rewriter.replaceOpWithNewOp(collapseOp, resultType, + expandOp.getSrc()); + } return success(); } }; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2447,7 +2447,7 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, - ComposeCollapseOfExpandOp, + ComposeCollapseOfExpandOp, CollapseShapeOpMemRefCastFolder>(context); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1586,7 +1586,7 @@ MLIRContext *context) { results .add, - ComposeCollapseOfExpandOp, + ComposeCollapseOfExpandOp, FoldReshapeWithConstant, FoldReshapeWithFromElements, FoldCollapseOfCastOp>( context); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -859,3 +859,19 @@ memref.store %v, %0[%i2] : memref<4xf32> return %src : memref<2xf32> } + +// ----- + +// CHECK-LABEL: func @collapse_expand_fold_to_cast( +// CHECK-SAME: %[[m:.*]]: memref, 3> +// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref, 3> to memref, 3>) + -> (memref) +{ + %0 = memref.expand_shape %m [[0, 1]] + : memref, 3> into memref<1x?xf32, 3> + %1 = memref.collapse_shape %0 [[0, 1]] + : memref<1x?xf32, 3> into memref + return %1 : memref +} diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1666,3 +1666,15 @@ %1 = tensor.dim %0, %c1 : tensor return %1 : index } + +// ----- + +// CHECK-LABEL: func @collapse_expand_fold_to_cast( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: return %[[t]] +func.func @collapse_expand_fold_to_cast(%t: tensor) -> (tensor) +{ + %0 = tensor.expand_shape %t [[0, 1]] : tensor into tensor<1x?xf32> + %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor + return %1 : tensor +}