diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -49,9 +49,41 @@ return success(); } }; + +/// Fold insert_slice(collapse_shape) ops that cancel itself out. +struct FoldInsertOfRankReducingInsert : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const override { + auto collapseShapeOp = + insertSliceOp.getSource().getDefiningOp(); + if (!collapseShapeOp) + return failure(); + RankedTensorType srcType = collapseShapeOp.getSrcType(); + + // Only cases where the CollapseShapeOp can be folded away entirely are + // supported. Moreover, only simple cases where the resulting InsertSliceOp + // has no rank-reduction anymore are supported at the moment. + RankedTensorType nonReducingInsertType = + RankedTensorType::get(insertSliceOp.getStaticSizes(), + insertSliceOp.getType().getElementType()); + if (nonReducingInsertType != srcType) + return failure(); + + SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); + SmallVector mixedSizes = insertSliceOp.getMixedSizes(); + SmallVector mixedStrides = insertSliceOp.getMixedStrides(); + rewriter.replaceOpWithNewOp( + insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(), + mixedOffsets, mixedSizes, mixedStrides); + return success(); + } +}; } // namespace void mlir::tensor::populateReassociativeReshapeFoldingPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); } diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir --- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir +++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir @@ -17,3 +17,19 @@ : tensor into tensor return %1, %2 : tensor, tensor } + +// ----- + +// CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor into tensor +// CHECK: return %[[insert]] +func.func @rank_reducing_insert_of_collapse_shape( + %t: tensor, %d: tensor, %sz: index) + -> tensor { + %0 = tensor.collapse_shape %t [[0, 1], [2], [3]] + : tensor into tensor + %1 = tensor.insert_slice %0 into %d[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1] + : tensor into tensor + return %1 : tensor +}