diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -63,6 +63,18 @@ let assemblyFormat = "attr-dict"; } +def ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp : Op]> { + let description = [{ + Indicates that tensor.extract_slice -> vector.transfer_read and + vector.transfer_write -> tensor.insert_slice op chains should be folded into + vector tranfer read and write ops + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyMergeConsecutiveInsertExtractSlicePatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -34,10 +34,15 @@ // Populate functions. //===----------------------------------------------------------------------===// -/// Appends patterns for folding tensor aliasing ops into consumer load/store -/// ops into `patterns`. +/// Appends patterns for folding tensor subset ops into consumer load/store +/// ops into `patterns`. (This includes patterns for folding tensor subset ops +/// into vector transfer ops.) void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns); +/// Appends patterns for folding tensor subset ops into vector transfer ops. +void populateFoldTensorSubsetIntoVectorTransferPatterns( + RewritePatternSet &patterns); + /// Collects patterns to merge consecutive tensor.insert_slice/extract_slice /// into one. These patterns are in this separate entry point because the /// bufferization is sensitive to IR structure, particularly those diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -306,16 +306,4 @@ }]; } -def ApplyFoldTensorSliceIntoTransferPatternsOp : Op]> { - let description = [{ - Indicates that tensor.extract_slice -> vector.transfer_read and - vector.transfer_write -> tensor.insert_slice op chains should be folded into - vector tranfer read and write ops - }]; - - let assemblyFormat = "attr-dict"; -} - #endif // VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -214,17 +214,6 @@ void populateVectorInsertExtractStridedSliceTransforms( RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Collect patterns to fold tensor.extract_slice -> vector.transfer_read and -/// vector.transfer_write -> tensor.insert_slice op chains into vector tranfer -/// read and write ops. -/// -/// If `controlFn` is not nullptr, the pattern will only apply to ops where -/// `controlFn` returns true, given the vector transfer read/write op as input. -void populateVectorTransferTensorSliceTransforms( - RewritePatternSet &patterns, - std::function controlFn = nullptr, - PatternBenefit benefit = 1); - /// Collect a set of pattern to unroll vector operations to a smaller shapes. /// `options` structure controls which operations are unrolled and the target /// shape. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2971,7 +2971,7 @@ /*benefit=*/2); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); - vector::populateVectorTransferTensorSliceTransforms(patterns); + tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); patterns.add(ctx); diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -103,6 +103,11 @@ tensor::populateFoldTensorSubsetOpPatterns(patterns); } +void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); +} + void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp:: populatePatterns(RewritePatternSet &patterns) { tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp @@ -222,12 +222,18 @@ }; void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) { - patterns.add, + populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); + patterns.add, InsertSliceOfInsertSliceFolder>( patterns.getContext()); } + +void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + //===----------------------------------------------------------------------===// // Pass registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -138,11 +138,6 @@ populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); } -void transform::ApplyFoldTensorSliceIntoTransferPatternsOp::populatePatterns( - RewritePatternSet &patterns) { - populateVectorTransferTensorSliceTransforms(patterns); -} - //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp +++ /dev/null @@ -1,237 +0,0 @@ -//===- VectorTransferTensorSliceTransforms.cpp ----------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" -#include "mlir/IR/PatternMatch.h" - -using namespace mlir; - -/// Returns true if all rank reduced in the given `extractOp` happen in leading -/// dimensions earlier than last `trailingRank` dimensions. -static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp, - unsigned trailingRank) { - // If no ranks are reduced at all, it's a degenerated case; always true. - if (extractOp.getSourceType().getRank() == extractOp.getType().getRank()) - return true; - - RankedTensorType inferredType = extractOp.inferResultType( - extractOp.getSourceType(), extractOp.getMixedOffsets(), - extractOp.getMixedSizes(), extractOp.getMixedStrides()); - return extractOp.getType().getShape().take_back(trailingRank) == - inferredType.getShape().take_back(trailingRank); -} - -namespace { -/// Fold transfer_reads of a tensor.extract_slice op. E.g.: -/// -/// ``` -/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1] -/// : tensor to tensor -/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]} -/// : tensor, vector<4x5xf32> -/// ``` -/// is rewritten to: -/// ``` -/// %p0 = arith.addi %a, %e : index -/// %p1 = arith.addi %b, %f : index -/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]} -/// : tensor, vector<4x5xf32> -/// ``` -// TODO: this is brittle and should be deprecated in favor of a more general -// pattern that applies on-demand. -class FoldExtractSliceIntoTransferRead final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - FoldExtractSliceIntoTransferRead(MLIRContext *context, - std::function controlFn, - PatternBenefit benefit) - : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} - - LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, - PatternRewriter &rewriter) const override { - if (controlFn && !controlFn(xferOp)) - return failure(); - - // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) - return failure(); - if (xferOp.hasOutOfBoundsDim()) - return failure(); - if (!xferOp.getPermutationMap().isMinorIdentity()) - return failure(); - if (xferOp.getMask()) - return failure(); - auto extractOp = xferOp.getSource().getDefiningOp(); - if (!extractOp) - return failure(); - if (!extractOp.hasUnitStride()) - return failure(); - - // Bail on illegal rank-reduction: we need to check that the rank-reduced - // dims are exactly the leading dims. I.e. the following is illegal: - // ``` - // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] : - // tensor<2x1x4xf32> to tensor<2x4xf32> - // %1 = vector.transfer_read %0[0,0], %cst : - // tensor<2x4xf32>, vector<2x4xf32> - // ``` - // - // Cannot fold into: - // ``` - // %0 = vector.transfer_read %t[0,0,0], %cst : - // tensor<2x1x4xf32>, vector<2x4xf32> - // ``` - // For this, check the trailing `vectorRank` dims of the extract_slice - // result tensor match the trailing dims of the inferred result tensor. - if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank())) - return failure(); - - int64_t rankReduced = - extractOp.getSourceType().getRank() - extractOp.getType().getRank(); - - SmallVector newIndices; - // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced - // indices first. - for (int64_t i = 0; i < rankReduced; ++i) { - OpFoldResult offset = extractOp.getMixedOffsets()[i]; - newIndices.push_back(getValueOrCreateConstantIndexOp( - rewriter, extractOp.getLoc(), offset)); - } - for (const auto &it : llvm::enumerate(xferOp.getIndices())) { - OpFoldResult offset = - extractOp.getMixedOffsets()[it.index() + rankReduced]; - newIndices.push_back(rewriter.create( - xferOp->getLoc(), it.value(), - getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), - offset))); - } - SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp( - xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices, - xferOp.getPadding(), ArrayRef{inBounds}); - - return success(); - } - -private: - std::function controlFn; -}; - -/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write -/// could directly write to the insert_slice's destination. E.g.: -/// -/// ``` -/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} -/// : vector<4x5xf32>, tensor<4x5xf32> -/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] -/// : tensor<4x5xf32> into tensor -/// ``` -/// is rewritten to: -/// ``` -/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]} -/// : vector<4x5xf32>, tensor -/// ``` -// TODO: this is brittle and should be deprecated in favor of a more general -// pattern that applies on-demand. -class FoldInsertSliceIntoTransferWrite final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - FoldInsertSliceIntoTransferWrite(MLIRContext *context, - std::function controlFn, - PatternBenefit benefit) - : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} - - LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, - PatternRewriter &rewriter) const override { - if (!insertOp.hasUnitStride()) - return failure(); - - auto xferOp = insertOp.getSource().getDefiningOp(); - if (!xferOp) - return failure(); - if (controlFn && !controlFn(xferOp)) - return failure(); - - // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) - return failure(); - - if (xferOp.hasOutOfBoundsDim()) - return failure(); - if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) - return failure(); - if (xferOp.getMask()) - return failure(); - // Fold only if the TransferWriteOp completely overwrites the `source` with - // a vector. I.e., the result of the TransferWriteOp is a new tensor whose - // content is the data of the vector. - if (!llvm::equal(xferOp.getVectorType().getShape(), - xferOp.getShapedType().getShape())) - return failure(); - if (!xferOp.getPermutationMap().isIdentity()) - return failure(); - - // Bail on illegal rank-reduction: we need to check that the rank-reduced - // dims are exactly the leading dims. I.e. the following is illegal: - // ``` - // %0 = vector.transfer_write %v, %t[0,0], %cst : - // vector<2x4xf32>, tensor<2x4xf32> - // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : - // tensor<2x4xf32> into tensor<2x1x4xf32> - // ``` - // - // Cannot fold into: - // ``` - // %0 = vector.transfer_write %v, %t[0,0,0], %cst : - // vector<2x4xf32>, tensor<2x1x4xf32> - // ``` - // For this, check the trailing `vectorRank` dims of the insert_slice result - // tensor match the trailing dims of the inferred result tensor. - int64_t rankReduced = - insertOp.getType().getRank() - insertOp.getSourceType().getRank(); - int64_t vectorRank = xferOp.getVectorType().getRank(); - RankedTensorType inferredSourceTensorType = - tensor::ExtractSliceOp::inferResultType( - insertOp.getType(), insertOp.getMixedOffsets(), - insertOp.getMixedSizes(), insertOp.getMixedStrides()); - auto actualSourceTensorShape = insertOp.getSourceType().getShape(); - if (rankReduced > 0 && - actualSourceTensorShape.take_back(vectorRank) != - inferredSourceTensorType.getShape().take_back(vectorRank)) - return failure(); - - SmallVector indices = getValueOrCreateConstantIndexOp( - rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); - SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp( - insertOp, xferOp.getVector(), insertOp.getDest(), indices, - ArrayRef{inBounds}); - return success(); - } - -private: - std::function controlFn; -}; - -} // namespace - -void vector::populateVectorTransferTensorSliceTransforms( - RewritePatternSet &patterns, - std::function controlFn, - PatternBenefit benefit) { - patterns - .add( - patterns.getContext(), controlFn, benefit); -} diff --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir rename from mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir rename to mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir +++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir @@ -3,15 +3,18 @@ transform.sequence failures(propagate) { ^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { - transform.apply_patterns.vector.fold_tensor_slice_into_transfer + transform.apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers } : !transform.op<"func.func"> } +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 4)> +// CHECK: #[[$map1:.*]] = affine_map<()[s0] -> (s0 + 3)> +// CHECK: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> + // CHECK-LABEL: func @transfer_read_of_extract_slice( // CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index -// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index -// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]] +// CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[s1]]] // CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor, vector<5x6xf32> // CHECK: return %[[r]] func.func @transfer_read_of_extract_slice(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { @@ -25,9 +28,8 @@ // CHECK-LABEL: func @transfer_read_of_extract_slice_1d( // CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index -// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index -// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]] +// CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[s1]]] // CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor, vector<6xf32> // CHECK: return %[[r]] func.func @transfer_read_of_extract_slice_1d(%t : tensor, %s1 : index, %s2 : index) -> vector<6xf32> { @@ -41,10 +43,9 @@ // CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing( // CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index -// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index // CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index -// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c3]] +// CHECK: %[[add:.*]] = affine.apply #[[$map1]]()[%[[s1]]] // CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor, vector<5x6xf32> // CHECK: return %[[r]] func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { @@ -56,10 +57,13 @@ return %1 : vector<5x6xf32> } -// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing( -// CHECK: extract_slice -// CHECK: vector.transfer_read -func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { +// CHECK-LABEL: func @transfer_read_of_extract_slice_non_leading_rank_reduction( +// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index +// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[s1]], %[[c10]]], %{{.*}} {in_bounds = [true, true], permutation_map = #[[$map2]]} : tensor, vector<5x6xf32> +// CHECK: return %[[r]] +func.func @transfer_read_of_extract_slice_non_leading_rank_reduction(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { %c3 = arith.constant 3 : index %c4 = arith.constant 4 : index %cst = arith.constant 0.0 : f32 @@ -80,10 +84,12 @@ return %1 : tensor } -// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending( -// CHECK: vector.transfer_write -// CHECK: insert_slice -func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { +// CHECK-LABEL: func @insert_slice_of_transfer_write_non_leading_rank_reduction( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index +// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index +// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true], permutation_map = #[[$map2]]} : vector<5x6xf32>, tensor +func.func @insert_slice_of_transfer_write_non_leading_rank_reduction(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { %c0 = arith.constant 0 : index %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor