diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1603,6 +1603,37 @@ return success(); } }; + +/// Fold insert_slice(collapse_shape) ops that cancel itself out. +struct FoldInsertOfRankReducingInsert : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const override { + auto collapseShapeOp = + insertSliceOp.getSource().getDefiningOp(); + if (!collapseShapeOp) + return failure(); + RankedTensorType srcType = collapseShapeOp.getSrcType(); + + // Only cases where the CollapseShapeOp can be folded away entirely are + // supported. Moreover, only simple cases where the resulting InsertSliceOp + // has no rank-reduction anymore are supported at the moment. + RankedTensorType nonReducingInsertType = + RankedTensorType::get(insertSliceOp.getStaticSizes(), + insertSliceOp.getType().getElementType()); + if (nonReducingInsertType != srcType) + return failure(); + + SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); + SmallVector mixedSizes = insertSliceOp.getMixedSizes(); + SmallVector mixedStrides = insertSliceOp.getMixedStrides(); + rewriter.replaceOpWithNewOp( + insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(), + mixedOffsets, mixedSizes, mixedStrides); + return success(); + } +}; } // namespace void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -1616,12 +1647,11 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add, - ComposeCollapseOfExpandOp, - FoldReshapeWithConstant, - FoldReshapeWithFromElements, FoldCollapseOfCastOp>( - context); + results.add, + ComposeCollapseOfExpandOp, + FoldReshapeWithConstant, + FoldReshapeWithFromElements, + FoldCollapseOfCastOp, FoldInsertOfRankReducingInsert>(context); } OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1698,3 +1698,19 @@ : tensor into tensor return %1, %2 : tensor, tensor } + +// ----- + +// CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor into tensor +// CHECK: return %[[insert]] +func.func @rank_reducing_insert_of_collapse_shape( + %t: tensor, %d: tensor, %sz: index) + -> tensor { + %0 = tensor.collapse_shape %t [[0, 1], [2], [3]] + : tensor into tensor + %1 = tensor.insert_slice %0 into %d[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1] + : tensor into tensor + return %1 : tensor +}