diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3718,93 +3718,7 @@ } namespace { -/// Fold transfer_reads of a tensor.extract_slice op. E.g.: -/// -/// ``` -/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1] -/// : tensor to tensor -/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]} -/// : tensor, vector<4x5xf32> -/// ``` -/// is rewritten to: -/// ``` -/// %p0 = arith.addi %a, %e : index -/// %p1 = arith.addi %b, %f : index -/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]} -/// : tensor, vector<4x5xf32> -/// ``` -// TODO: this is brittle and should be deprecated in favor of a more general -// pattern that applies on-demand. -struct FoldExtractSliceIntoTransferRead - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TransferReadOp xferOp, - PatternRewriter &rewriter) const override { - // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) - return failure(); - if (xferOp.hasOutOfBoundsDim()) - return failure(); - if (!xferOp.getPermutationMap().isMinorIdentity()) - return failure(); - if (xferOp.getMask()) - return failure(); - auto extractOp = xferOp.getSource().getDefiningOp(); - if (!extractOp) - return failure(); - if (!extractOp.hasUnitStride()) - return failure(); - - // Bail on illegal rank-reduction: we need to check that the rank-reduced - // dims are exactly the leading dims. I.e. the following is illegal: - // ``` - // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] : - // tensor<2x1x4xf32> to tensor<2x4xf32> - // %1 = vector.transfer_read %0[0,0], %cst : - // tensor<2x4xf32>, vector<2x4xf32> - // ``` - // - // Cannot fold into: - // ``` - // %0 = vector.transfer_read %t[0,0,0], %cst : - // tensor<2x1x4xf32>, vector<2x4xf32> - // ``` - // For this, check the trailing `vectorRank` dims of the extract_slice - // result tensor match the trailing dims of the inferred result tensor. - if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank())) - return failure(); - - int64_t rankReduced = - extractOp.getSourceType().getRank() - extractOp.getType().getRank(); - - SmallVector newIndices; - // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced - // indices first. - for (int64_t i = 0; i < rankReduced; ++i) { - OpFoldResult offset = extractOp.getMixedOffsets()[i]; - newIndices.push_back(getValueOrCreateConstantIndexOp( - rewriter, extractOp.getLoc(), offset)); - } - for (const auto &it : llvm::enumerate(xferOp.getIndices())) { - OpFoldResult offset = - extractOp.getMixedOffsets()[it.index() + rankReduced]; - newIndices.push_back(rewriter.create( - xferOp->getLoc(), it.value(), - getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), - offset))); - } - SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp( - xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices, - xferOp.getPadding(), ArrayRef{inBounds}); - - return success(); - } -}; - -/// Store to load forwarding for transfer operations with permuation maps. +/// Store to load forwarding for transfer operations with permutation maps. /// Even if the permutation maps are different we can still propagate the store /// into the load if the size of the dimensions read and written match. Then we /// can replace the transfer_read + transfer_write by vector.broadcast and @@ -3885,13 +3799,7 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // clang-format off - results.add < - // TODO: this is brittle and should be deprecated in favor of a - // more general pattern that applies on-demand. - FoldExtractSliceIntoTransferRead, - TransferReadAfterWriteToBroadcast>(context); - // clang-format on + results.add(context); } //===----------------------------------------------------------------------===// @@ -4227,93 +4135,6 @@ } }; -/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write -/// could directly write to the insert_slice's destination. E.g.: -/// -/// ``` -/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} -/// : vector<4x5xf32>, tensor<4x5xf32> -/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] -/// : tensor<4x5xf32> into tensor -/// ``` -/// is rewritten to: -/// ``` -/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]} -/// : vector<4x5xf32>, tensor -/// ``` -// TODO: this is brittle and should be deprecated in favor of a more general -// pattern that applies on-demand. -struct FoldInsertSliceIntoTransferWrite - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, - PatternRewriter &rewriter) const override { - if (!insertOp.hasUnitStride()) - return failure(); - - auto xferOp = insertOp.getSource().getDefiningOp(); - if (!xferOp) - return failure(); - // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) - return failure(); - - if (xferOp.hasOutOfBoundsDim()) - return failure(); - if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) - return failure(); - if (xferOp.getMask()) - return failure(); - // Fold only if the TransferWriteOp completely overwrites the `source` with - // a vector. I.e., the result of the TransferWriteOp is a new tensor whose - // content is the data of the vector. - if (!llvm::equal(xferOp.getVectorType().getShape(), - xferOp.getShapedType().getShape())) - return failure(); - if (!xferOp.getPermutationMap().isIdentity()) - return failure(); - - // Bail on illegal rank-reduction: we need to check that the rank-reduced - // dims are exactly the leading dims. I.e. the following is illegal: - // ``` - // %0 = vector.transfer_write %v, %t[0,0], %cst : - // vector<2x4xf32>, tensor<2x4xf32> - // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : - // tensor<2x4xf32> into tensor<2x1x4xf32> - // ``` - // - // Cannot fold into: - // ``` - // %0 = vector.transfer_write %v, %t[0,0,0], %cst : - // vector<2x4xf32>, tensor<2x1x4xf32> - // ``` - // For this, check the trailing `vectorRank` dims of the insert_slice result - // tensor match the trailing dims of the inferred result tensor. - int64_t rankReduced = - insertOp.getType().getRank() - insertOp.getSourceType().getRank(); - int64_t vectorRank = xferOp.getVectorType().getRank(); - RankedTensorType inferredSourceTensorType = - tensor::ExtractSliceOp::inferResultType( - insertOp.getType(), insertOp.getMixedOffsets(), - insertOp.getMixedSizes(), insertOp.getMixedStrides()); - auto actualSourceTensorShape = insertOp.getSourceType().getShape(); - if (rankReduced > 0 && - actualSourceTensorShape.take_back(vectorRank) != - inferredSourceTensorType.getShape().take_back(vectorRank)) - return failure(); - - SmallVector indices = getValueOrCreateConstantIndexOp( - rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); - SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp(insertOp, xferOp.getVector(), - insertOp.getDest(), indices, - ArrayRef{inBounds}); - return success(); - } -}; - /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is /// overwritten and inserted into another tensor. After this rewrite, the @@ -4425,13 +4246,7 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // clang-format off - results.add(context); - // clang-format on + results.add(context); } //===----------------------------------------------------------------------===//