diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -218,6 +218,17 @@ void populateVectorInsertExtractStridedSliceTransforms( RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Collect patterns to fold tensor.extract_slice -> vector.transfer_read and +/// vector.transfer_write -> tensor.insert_slice op chains into vector tranfer +/// read and write ops. +/// +/// If `controlFn` is not nullptr, the pattern will only apply to ops where +/// `controlFn` returns true, given the vector transfer read/write op as input. +void populateVectorTransferTensorSliceTransforms( + RewritePatternSet &patterns, + std::function controlFn = nullptr, + PatternBenefit benefit = 1); + /// Collect a set of pattern to unroll vector operations to a smaller shapes. /// `options` structure controls which operations are unrolled and the target /// shape. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -30,6 +30,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/TilingInterface.h" @@ -2866,6 +2867,7 @@ /*benefit=*/2); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); + vector::populateVectorTransferTensorSliceTransforms(patterns); patterns.add(ctx); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3687,108 +3687,7 @@ SideEffects::DefaultResource::get()); } -/// Returns true if all rank reduced in the given `extractOp` happen in leading -/// dimensions earlier than last `trailingRank` dimensions. -static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp, - unsigned trailingRank) { - // If no ranks are reduced at all, it's a degenerated case; always true. - if (extractOp.getSourceType().getRank() == extractOp.getType().getRank()) - return true; - - RankedTensorType inferredType = extractOp.inferResultType( - extractOp.getSourceType(), extractOp.getMixedOffsets(), - extractOp.getMixedSizes(), extractOp.getMixedStrides()); - return extractOp.getType().getShape().take_back(trailingRank) == - inferredType.getShape().take_back(trailingRank); -} - namespace { -/// Fold transfer_reads of a tensor.extract_slice op. E.g.: -/// -/// ``` -/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1] -/// : tensor to tensor -/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]} -/// : tensor, vector<4x5xf32> -/// ``` -/// is rewritten to: -/// ``` -/// %p0 = arith.addi %a, %e : index -/// %p1 = arith.addi %b, %f : index -/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]} -/// : tensor, vector<4x5xf32> -/// ``` -// TODO: this is brittle and should be deprecated in favor of a more general -// pattern that applies on-demand. -struct FoldExtractSliceIntoTransferRead - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TransferReadOp xferOp, - PatternRewriter &rewriter) const override { - // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) - return failure(); - if (xferOp.hasOutOfBoundsDim()) - return failure(); - if (!xferOp.getPermutationMap().isMinorIdentity()) - return failure(); - if (xferOp.getMask()) - return failure(); - auto extractOp = xferOp.getSource().getDefiningOp(); - if (!extractOp) - return failure(); - if (!extractOp.hasUnitStride()) - return failure(); - - // Bail on illegal rank-reduction: we need to check that the rank-reduced - // dims are exactly the leading dims. I.e. the following is illegal: - // ``` - // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] : - // tensor<2x1x4xf32> to tensor<2x4xf32> - // %1 = vector.transfer_read %0[0,0], %cst : - // tensor<2x4xf32>, vector<2x4xf32> - // ``` - // - // Cannot fold into: - // ``` - // %0 = vector.transfer_read %t[0,0,0], %cst : - // tensor<2x1x4xf32>, vector<2x4xf32> - // ``` - // For this, check the trailing `vectorRank` dims of the extract_slice - // result tensor match the trailing dims of the inferred result tensor. - if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank())) - return failure(); - - int64_t rankReduced = - extractOp.getSourceType().getRank() - extractOp.getType().getRank(); - - SmallVector newIndices; - // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced - // indices first. - for (int64_t i = 0; i < rankReduced; ++i) { - OpFoldResult offset = extractOp.getMixedOffsets()[i]; - newIndices.push_back(getValueOrCreateConstantIndexOp( - rewriter, extractOp.getLoc(), offset)); - } - for (const auto &it : llvm::enumerate(xferOp.getIndices())) { - OpFoldResult offset = - extractOp.getMixedOffsets()[it.index() + rankReduced]; - newIndices.push_back(rewriter.create( - xferOp->getLoc(), it.value(), - getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), - offset))); - } - SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp( - xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices, - xferOp.getPadding(), ArrayRef{inBounds}); - - return success(); - } -}; - /// Store to load forwarding for transfer operations with permuation maps. /// Even if the permutation maps are different we can still propagate the store /// into the load if the size of the dimensions read and written match. Then we @@ -3870,13 +3769,7 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // clang-format off - results.add < - // TODO: this is brittle and should be deprecated in favor of a - // more general pattern that applies on-demand. - FoldExtractSliceIntoTransferRead, - TransferReadAfterWriteToBroadcast>(context); - // clang-format on + results.add(context); } //===----------------------------------------------------------------------===// @@ -4212,93 +4105,6 @@ } }; -/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write -/// could directly write to the insert_slice's destination. E.g.: -/// -/// ``` -/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} -/// : vector<4x5xf32>, tensor<4x5xf32> -/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] -/// : tensor<4x5xf32> into tensor -/// ``` -/// is rewritten to: -/// ``` -/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]} -/// : vector<4x5xf32>, tensor -/// ``` -// TODO: this is brittle and should be deprecated in favor of a more general -// pattern that applies on-demand. -struct FoldInsertSliceIntoTransferWrite - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, - PatternRewriter &rewriter) const override { - if (!insertOp.hasUnitStride()) - return failure(); - - auto xferOp = insertOp.getSource().getDefiningOp(); - if (!xferOp) - return failure(); - // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) - return failure(); - - if (xferOp.hasOutOfBoundsDim()) - return failure(); - if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) - return failure(); - if (xferOp.getMask()) - return failure(); - // Fold only if the TransferWriteOp completely overwrites the `source` with - // a vector. I.e., the result of the TransferWriteOp is a new tensor whose - // content is the data of the vector. - if (!llvm::equal(xferOp.getVectorType().getShape(), - xferOp.getShapedType().getShape())) - return failure(); - if (!xferOp.getPermutationMap().isIdentity()) - return failure(); - - // Bail on illegal rank-reduction: we need to check that the rank-reduced - // dims are exactly the leading dims. I.e. the following is illegal: - // ``` - // %0 = vector.transfer_write %v, %t[0,0], %cst : - // vector<2x4xf32>, tensor<2x4xf32> - // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : - // tensor<2x4xf32> into tensor<2x1x4xf32> - // ``` - // - // Cannot fold into: - // ``` - // %0 = vector.transfer_write %v, %t[0,0,0], %cst : - // vector<2x4xf32>, tensor<2x1x4xf32> - // ``` - // For this, check the trailing `vectorRank` dims of the insert_slice result - // tensor match the trailing dims of the inferred result tensor. - int64_t rankReduced = - insertOp.getType().getRank() - insertOp.getSourceType().getRank(); - int64_t vectorRank = xferOp.getVectorType().getRank(); - RankedTensorType inferredSourceTensorType = - tensor::ExtractSliceOp::inferResultType( - insertOp.getType(), insertOp.getMixedOffsets(), - insertOp.getMixedSizes(), insertOp.getMixedStrides()); - auto actualSourceTensorShape = insertOp.getSourceType().getShape(); - if (rankReduced > 0 && - actualSourceTensorShape.take_back(vectorRank) != - inferredSourceTensorType.getShape().take_back(vectorRank)) - return failure(); - - SmallVector indices = getValueOrCreateConstantIndexOp( - rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); - SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp(insertOp, xferOp.getVector(), - insertOp.getDest(), indices, - ArrayRef{inBounds}); - return success(); - } -}; - /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is /// overwritten and inserted into another tensor. After this rewrite, the @@ -4410,13 +4216,7 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // clang-format off - results.add(context); - // clang-format on + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ VectorDropLeadUnitDim.cpp VectorInsertExtractStridedSliceRewritePatterns.cpp VectorTransferOpTransforms.cpp + VectorTransferTensorSliceTransforms.cpp VectorTransferSplitRewritePatterns.cpp VectorTransforms.cpp VectorUnroll.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp @@ -0,0 +1,237 @@ +//===- VectorTransferTensorSliceTransforms.cpp ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; + +/// Returns true if all rank reduced in the given `extractOp` happen in leading +/// dimensions earlier than last `trailingRank` dimensions. +static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp, + unsigned trailingRank) { + // If no ranks are reduced at all, it's a degenerated case; always true. + if (extractOp.getSourceType().getRank() == extractOp.getType().getRank()) + return true; + + RankedTensorType inferredType = extractOp.inferResultType( + extractOp.getSourceType(), extractOp.getMixedOffsets(), + extractOp.getMixedSizes(), extractOp.getMixedStrides()); + return extractOp.getType().getShape().take_back(trailingRank) == + inferredType.getShape().take_back(trailingRank); +} + +namespace { +/// Fold transfer_reads of a tensor.extract_slice op. E.g.: +/// +/// ``` +/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1] +/// : tensor to tensor +/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]} +/// : tensor, vector<4x5xf32> +/// ``` +/// is rewritten to: +/// ``` +/// %p0 = arith.addi %a, %e : index +/// %p1 = arith.addi %b, %f : index +/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]} +/// : tensor, vector<4x5xf32> +/// ``` +// TODO: this is brittle and should be deprecated in favor of a more general +// pattern that applies on-demand. +class FoldExtractSliceIntoTransferRead final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + FoldExtractSliceIntoTransferRead(MLIRContext *context, + std::function controlFn, + PatternBenefit benefit) + : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, + PatternRewriter &rewriter) const override { + if (controlFn && !controlFn(xferOp)) + return failure(); + + // TODO: support 0-d corner case. + if (xferOp.getTransferRank() == 0) + return failure(); + if (xferOp.hasOutOfBoundsDim()) + return failure(); + if (!xferOp.getPermutationMap().isMinorIdentity()) + return failure(); + if (xferOp.getMask()) + return failure(); + auto extractOp = xferOp.getSource().getDefiningOp(); + if (!extractOp) + return failure(); + if (!extractOp.hasUnitStride()) + return failure(); + + // Bail on illegal rank-reduction: we need to check that the rank-reduced + // dims are exactly the leading dims. I.e. the following is illegal: + // ``` + // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] : + // tensor<2x1x4xf32> to tensor<2x4xf32> + // %1 = vector.transfer_read %0[0,0], %cst : + // tensor<2x4xf32>, vector<2x4xf32> + // ``` + // + // Cannot fold into: + // ``` + // %0 = vector.transfer_read %t[0,0,0], %cst : + // tensor<2x1x4xf32>, vector<2x4xf32> + // ``` + // For this, check the trailing `vectorRank` dims of the extract_slice + // result tensor match the trailing dims of the inferred result tensor. + if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank())) + return failure(); + + int64_t rankReduced = + extractOp.getSourceType().getRank() - extractOp.getType().getRank(); + + SmallVector newIndices; + // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced + // indices first. + for (int64_t i = 0; i < rankReduced; ++i) { + OpFoldResult offset = extractOp.getMixedOffsets()[i]; + newIndices.push_back(getValueOrCreateConstantIndexOp( + rewriter, extractOp.getLoc(), offset)); + } + for (const auto &it : llvm::enumerate(xferOp.getIndices())) { + OpFoldResult offset = + extractOp.getMixedOffsets()[it.index() + rankReduced]; + newIndices.push_back(rewriter.create( + xferOp->getLoc(), it.value(), + getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), + offset))); + } + SmallVector inBounds(xferOp.getTransferRank(), true); + rewriter.replaceOpWithNewOp( + xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices, + xferOp.getPadding(), ArrayRef{inBounds}); + + return success(); + } + +private: + std::function controlFn; +}; + +/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write +/// could directly write to the insert_slice's destination. E.g.: +/// +/// ``` +/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} +/// : vector<4x5xf32>, tensor<4x5xf32> +/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] +/// : tensor<4x5xf32> into tensor +/// ``` +/// is rewritten to: +/// ``` +/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]} +/// : vector<4x5xf32>, tensor +/// ``` +// TODO: this is brittle and should be deprecated in favor of a more general +// pattern that applies on-demand. +class FoldInsertSliceIntoTransferWrite final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + FoldInsertSliceIntoTransferWrite(MLIRContext *context, + std::function controlFn, + PatternBenefit benefit) + : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, + PatternRewriter &rewriter) const override { + if (!insertOp.hasUnitStride()) + return failure(); + + auto xferOp = insertOp.getSource().getDefiningOp(); + if (!xferOp) + return failure(); + if (controlFn && !controlFn(xferOp)) + return failure(); + + // TODO: support 0-d corner case. + if (xferOp.getTransferRank() == 0) + return failure(); + + if (xferOp.hasOutOfBoundsDim()) + return failure(); + if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) + return failure(); + if (xferOp.getMask()) + return failure(); + // Fold only if the TransferWriteOp completely overwrites the `source` with + // a vector. I.e., the result of the TransferWriteOp is a new tensor whose + // content is the data of the vector. + if (!llvm::equal(xferOp.getVectorType().getShape(), + xferOp.getShapedType().getShape())) + return failure(); + if (!xferOp.getPermutationMap().isIdentity()) + return failure(); + + // Bail on illegal rank-reduction: we need to check that the rank-reduced + // dims are exactly the leading dims. I.e. the following is illegal: + // ``` + // %0 = vector.transfer_write %v, %t[0,0], %cst : + // vector<2x4xf32>, tensor<2x4xf32> + // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : + // tensor<2x4xf32> into tensor<2x1x4xf32> + // ``` + // + // Cannot fold into: + // ``` + // %0 = vector.transfer_write %v, %t[0,0,0], %cst : + // vector<2x4xf32>, tensor<2x1x4xf32> + // ``` + // For this, check the trailing `vectorRank` dims of the insert_slice result + // tensor match the trailing dims of the inferred result tensor. + int64_t rankReduced = + insertOp.getType().getRank() - insertOp.getSourceType().getRank(); + int64_t vectorRank = xferOp.getVectorType().getRank(); + RankedTensorType inferredSourceTensorType = + tensor::ExtractSliceOp::inferResultType( + insertOp.getType(), insertOp.getMixedOffsets(), + insertOp.getMixedSizes(), insertOp.getMixedStrides()); + auto actualSourceTensorShape = insertOp.getSourceType().getShape(); + if (rankReduced > 0 && + actualSourceTensorShape.take_back(vectorRank) != + inferredSourceTensorType.getShape().take_back(vectorRank)) + return failure(); + + SmallVector indices = getValueOrCreateConstantIndexOp( + rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); + SmallVector inBounds(xferOp.getTransferRank(), true); + rewriter.replaceOpWithNewOp( + insertOp, xferOp.getVector(), insertOp.getDest(), indices, + ArrayRef{inBounds}); + return success(); + } + +private: + std::function controlFn; +}; + +} // namespace + +void vector::populateVectorTransferTensorSliceTransforms( + RewritePatternSet &patterns, + std::function controlFn, + PatternBenefit benefit) { + patterns + .add( + patterns.getContext(), controlFn, benefit); +} diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1201,116 +1201,6 @@ // ----- -// CHECK-LABEL: func @transfer_read_of_extract_slice( -// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index -// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index -// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]] -// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor, vector<5x6xf32> -// CHECK: return %[[r]] -func.func @transfer_read_of_extract_slice(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %cst = arith.constant 0.0 : f32 - %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor to tensor<10x?xf32> - %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32> - return %1 : vector<5x6xf32> -} - -// ----- - -// CHECK-LABEL: func @transfer_read_of_extract_slice( -// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index -// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index -// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]] -// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor, vector<6xf32> -// CHECK: return %[[r]] -func.func @transfer_read_of_extract_slice(%t : tensor, %s1 : index, %s2 : index) -> vector<6xf32> { - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %cst = arith.constant 0.0 : f32 - %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor to tensor<10x?xf32> - %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32> - return %1 : vector<6xf32> -} - -// ----- - -// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing( -// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index -// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index -// CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index -// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c3]] -// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor, vector<5x6xf32> -// CHECK: return %[[r]] -func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %cst = arith.constant 0.0 : f32 - %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor to tensor - %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor, vector<5x6xf32> - return %1 : vector<5x6xf32> -} - -// ----- - -// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing( -// CHECK: extract_slice -// CHECK: vector.transfer_read -func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %cst = arith.constant 0.0 : f32 - %0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor to tensor - %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor, vector<5x6xf32> - return %1 : vector<5x6xf32> -} - -// ----- - -// CHECK-LABEL: func @insert_slice_of_transfer_write( -// CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index -// CHECK: %[[c3:.*]] = arith.constant 3 : index -// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor -// CHECK: return %[[r]] -func.func @insert_slice_of_transfer_write(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { - %c0 = arith.constant 0 : index - %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> - %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor - return %1 : tensor -} - -// ----- - -// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending( -// CHECK: vector.transfer_write -// CHECK: insert_slice -func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { - %c0 = arith.constant 0 : index - %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> - %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor - return %1 : tensor -} - -// ----- - -// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending( -// CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index -// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor -// CHECK: return %[[r]] -func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { - %c0 = arith.constant 0 : index - %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> - %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor - return %1 : tensor -} - -// ----- - // CHECK: #[[$MAP:[0-9a-z]+]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-LABEL: func @swap_extract_slice_transfer_write diff --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir @@ -0,0 +1,109 @@ +// RUN: mlir-opt -split-input-file -test-vector-transfer-tensor-slice-patterns %s | FileCheck %s + +// CHECK-LABEL: func @transfer_read_of_extract_slice( +// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index +// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]] +// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor, vector<5x6xf32> +// CHECK: return %[[r]] +func.func @transfer_read_of_extract_slice(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.0 : f32 + %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor to tensor<10x?xf32> + %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32> + return %1 : vector<5x6xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_read_of_extract_slice( +// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index +// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]] +// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor, vector<6xf32> +// CHECK: return %[[r]] +func.func @transfer_read_of_extract_slice(%t : tensor, %s1 : index, %s2 : index) -> vector<6xf32> { + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.0 : f32 + %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor to tensor<10x?xf32> + %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32> + return %1 : vector<6xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing( +// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index +// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c3]] +// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor, vector<5x6xf32> +// CHECK: return %[[r]] +func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.0 : f32 + %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor to tensor + %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor, vector<5x6xf32> + return %1 : vector<5x6xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing( +// CHECK: extract_slice +// CHECK: vector.transfer_read +func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.0 : f32 + %0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor to tensor + %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor, vector<5x6xf32> + return %1 : vector<5x6xf32> +} + +// ----- + +// CHECK-LABEL: func @insert_slice_of_transfer_write( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor +// CHECK: return %[[r]] +func.func @insert_slice_of_transfer_write(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { + %c0 = arith.constant 0 : index + %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> + %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending( +// CHECK: vector.transfer_write +// CHECK: insert_slice +func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { + %c0 = arith.constant 0 : index + %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> + %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index +// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index +// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor +// CHECK: return %[[r]] +func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { + %c0 = arith.constant 0 : index + %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> + %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor + return %1 : tensor +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -679,6 +679,26 @@ } }; +struct TestVectorTransferTensorSlicePatterns + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestVectorTransferTensorSlicePatterns) + + StringRef getArgument() const final { + return "test-vector-transfer-tensor-slice-patterns"; + } + StringRef getDescription() const final { + return "Test patterns that fold vector transfer and tensor slice ops"; + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorTransferTensorSliceTransforms(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + } // namespace namespace mlir { @@ -713,6 +733,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir