diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3718,92 +3718,6 @@ } namespace { -/// Fold transfer_reads of a tensor.extract_slice op. E.g.: -/// -/// ``` -/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1] -/// : tensor to tensor -/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]} -/// : tensor, vector<4x5xf32> -/// ``` -/// is rewritten to: -/// ``` -/// %p0 = arith.addi %a, %e : index -/// %p1 = arith.addi %b, %f : index -/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]} -/// : tensor, vector<4x5xf32> -/// ``` -// TODO: this is brittle and should be deprecated in favor of a more general -// pattern that applies on-demand. -struct FoldExtractSliceIntoTransferRead - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TransferReadOp xferOp, - PatternRewriter &rewriter) const override { - // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) - return failure(); - if (xferOp.hasOutOfBoundsDim()) - return failure(); - if (!xferOp.getPermutationMap().isMinorIdentity()) - return failure(); - if (xferOp.getMask()) - return failure(); - auto extractOp = xferOp.getSource().getDefiningOp(); - if (!extractOp) - return failure(); - if (!extractOp.hasUnitStride()) - return failure(); - - // Bail on illegal rank-reduction: we need to check that the rank-reduced - // dims are exactly the leading dims. I.e. the following is illegal: - // ``` - // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] : - // tensor<2x1x4xf32> to tensor<2x4xf32> - // %1 = vector.transfer_read %0[0,0], %cst : - // tensor<2x4xf32>, vector<2x4xf32> - // ``` - // - // Cannot fold into: - // ``` - // %0 = vector.transfer_read %t[0,0,0], %cst : - // tensor<2x1x4xf32>, vector<2x4xf32> - // ``` - // For this, check the trailing `vectorRank` dims of the extract_slice - // result tensor match the trailing dims of the inferred result tensor. - if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank())) - return failure(); - - int64_t rankReduced = - extractOp.getSourceType().getRank() - extractOp.getType().getRank(); - - SmallVector newIndices; - // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced - // indices first. - for (int64_t i = 0; i < rankReduced; ++i) { - OpFoldResult offset = extractOp.getMixedOffsets()[i]; - newIndices.push_back(getValueOrCreateConstantIndexOp( - rewriter, extractOp.getLoc(), offset)); - } - for (const auto &it : llvm::enumerate(xferOp.getIndices())) { - OpFoldResult offset = - extractOp.getMixedOffsets()[it.index() + rankReduced]; - newIndices.push_back(rewriter.create( - xferOp->getLoc(), it.value(), - getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), - offset))); - } - SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp( - xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices, - xferOp.getPadding(), ArrayRef{inBounds}); - - return success(); - } -}; - /// Store to load forwarding for transfer operations with permuation maps. /// Even if the permutation maps are different we can still propagate the store /// into the load if the size of the dimensions read and written match. Then we @@ -3885,13 +3799,7 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // clang-format off - results.add < - // TODO: this is brittle and should be deprecated in favor of a - // more general pattern that applies on-demand. - FoldExtractSliceIntoTransferRead, - TransferReadAfterWriteToBroadcast>(context); - // clang-format on + results.add(context); } //===----------------------------------------------------------------------===// @@ -4227,93 +4135,6 @@ } }; -/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write -/// could directly write to the insert_slice's destination. E.g.: -/// -/// ``` -/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} -/// : vector<4x5xf32>, tensor<4x5xf32> -/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] -/// : tensor<4x5xf32> into tensor -/// ``` -/// is rewritten to: -/// ``` -/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]} -/// : vector<4x5xf32>, tensor -/// ``` -// TODO: this is brittle and should be deprecated in favor of a more general -// pattern that applies on-demand. -struct FoldInsertSliceIntoTransferWrite - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, - PatternRewriter &rewriter) const override { - if (!insertOp.hasUnitStride()) - return failure(); - - auto xferOp = insertOp.getSource().getDefiningOp(); - if (!xferOp) - return failure(); - // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) - return failure(); - - if (xferOp.hasOutOfBoundsDim()) - return failure(); - if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) - return failure(); - if (xferOp.getMask()) - return failure(); - // Fold only if the TransferWriteOp completely overwrites the `source` with - // a vector. I.e., the result of the TransferWriteOp is a new tensor whose - // content is the data of the vector. - if (!llvm::equal(xferOp.getVectorType().getShape(), - xferOp.getShapedType().getShape())) - return failure(); - if (!xferOp.getPermutationMap().isIdentity()) - return failure(); - - // Bail on illegal rank-reduction: we need to check that the rank-reduced - // dims are exactly the leading dims. I.e. the following is illegal: - // ``` - // %0 = vector.transfer_write %v, %t[0,0], %cst : - // vector<2x4xf32>, tensor<2x4xf32> - // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : - // tensor<2x4xf32> into tensor<2x1x4xf32> - // ``` - // - // Cannot fold into: - // ``` - // %0 = vector.transfer_write %v, %t[0,0,0], %cst : - // vector<2x4xf32>, tensor<2x1x4xf32> - // ``` - // For this, check the trailing `vectorRank` dims of the insert_slice result - // tensor match the trailing dims of the inferred result tensor. - int64_t rankReduced = - insertOp.getType().getRank() - insertOp.getSourceType().getRank(); - int64_t vectorRank = xferOp.getVectorType().getRank(); - RankedTensorType inferredSourceTensorType = - tensor::ExtractSliceOp::inferResultType( - insertOp.getType(), insertOp.getMixedOffsets(), - insertOp.getMixedSizes(), insertOp.getMixedStrides()); - auto actualSourceTensorShape = insertOp.getSourceType().getShape(); - if (rankReduced > 0 && - actualSourceTensorShape.take_back(vectorRank) != - inferredSourceTensorType.getShape().take_back(vectorRank)) - return failure(); - - SmallVector indices = getValueOrCreateConstantIndexOp( - rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); - SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp(insertOp, xferOp.getVector(), - insertOp.getDest(), indices, - ArrayRef{inBounds}); - return success(); - } -}; - /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is /// overwritten and inserted into another tensor. After this rewrite, the @@ -4425,13 +4246,7 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // clang-format off - results.add(context); - // clang-format on + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir --- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir @@ -39,6 +39,7 @@ %0 = affine.min #map0()[%arg5] %1 = tensor.extract_slice %arg0[%arg3, %arg5] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> %2 = tensor.extract_slice %arg1[%arg5, %arg4] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + // CHECK: %[[sC:.*]] = tensor.extract_slice %[[C]] %3 = tensor.extract_slice %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> %4 = affine.apply #map1()[%0] // CHECK: %[[pA:.*]] = tensor.pad @@ -54,9 +55,9 @@ } : tensor to tensor<7x5xf32> // CHECK: %[[vA:.+]] = vector.transfer_read %[[pA]] // CHECK: %[[vB:.+]] = vector.transfer_read %[[pB]] - // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]] + // CHECK: %[[vC:.+]] = vector.transfer_read %[[sC]] // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]] - // CHECK: vector.transfer_write %[[vR]], %[[C]] + // CHECK: vector.transfer_write %[[vR]], %[[sC]] %8 = linalg.matmul ins(%5, %7 : tensor<4x7xf32>, tensor<7x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> %9 = tensor.insert_slice %8 into %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> return %9 : tensor<24x25xf32> @@ -87,6 +88,7 @@ %0 = affine.min #map0()[%arg5] // CHECK: %[[sA:.+]] = tensor.extract_slice %[[A]] // CHECK: %[[sB:.+]] = tensor.extract_slice %[[B]] + // CHECK: %[[sC:.+]] = tensor.extract_slice %[[C]] %1 = tensor.extract_slice %arg0[%arg3, %arg5] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> %2 = tensor.extract_slice %arg1[%arg5, %arg4] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor %3 = tensor.extract_slice %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> @@ -102,9 +104,9 @@ ^bb0(%arg6: index, %arg7: index): tensor.yield %cst : f32 } : tensor to tensor<7x5xf32> - // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]] + // CHECK: %[[vC:.+]] = vector.transfer_read %[[sC]] // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]] - // CHECK: vector.transfer_write %[[vR]], %[[C]] + // CHECK: vector.transfer_write %[[vR]], %[[sC]] %8 = linalg.matmul ins(%5, %7 : tensor<4x7xf32>, tensor<7x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> %9 = tensor.insert_slice %8 into %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> return %9 : tensor<24x25xf32> diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1188,116 +1188,6 @@ // ----- -// CHECK-LABEL: func @transfer_read_of_extract_slice( -// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index -// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index -// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]] -// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor, vector<5x6xf32> -// CHECK: return %[[r]] -func.func @transfer_read_of_extract_slice(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %cst = arith.constant 0.0 : f32 - %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor to tensor<10x?xf32> - %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32> - return %1 : vector<5x6xf32> -} - -// ----- - -// CHECK-LABEL: func @transfer_read_of_extract_slice( -// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index -// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index -// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]] -// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor, vector<6xf32> -// CHECK: return %[[r]] -func.func @transfer_read_of_extract_slice(%t : tensor, %s1 : index, %s2 : index) -> vector<6xf32> { - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %cst = arith.constant 0.0 : f32 - %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor to tensor<10x?xf32> - %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32> - return %1 : vector<6xf32> -} - -// ----- - -// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing( -// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index -// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index -// CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index -// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c3]] -// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor, vector<5x6xf32> -// CHECK: return %[[r]] -func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %cst = arith.constant 0.0 : f32 - %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor to tensor - %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor, vector<5x6xf32> - return %1 : vector<5x6xf32> -} - -// ----- - -// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing( -// CHECK: extract_slice -// CHECK: vector.transfer_read -func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %cst = arith.constant 0.0 : f32 - %0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor to tensor - %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor, vector<5x6xf32> - return %1 : vector<5x6xf32> -} - -// ----- - -// CHECK-LABEL: func @insert_slice_of_transfer_write( -// CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index -// CHECK: %[[c3:.*]] = arith.constant 3 : index -// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor -// CHECK: return %[[r]] -func.func @insert_slice_of_transfer_write(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { - %c0 = arith.constant 0 : index - %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> - %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor - return %1 : tensor -} - -// ----- - -// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending( -// CHECK: vector.transfer_write -// CHECK: insert_slice -func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { - %c0 = arith.constant 0 : index - %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> - %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor - return %1 : tensor -} - -// ----- - -// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending( -// CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index -// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor -// CHECK: return %[[r]] -func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { - %c0 = arith.constant 0 : index - %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> - %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor - return %1 : tensor -} - -// ----- - // CHECK: #[[$MAP:[0-9a-z]+]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-LABEL: func @swap_extract_slice_transfer_write diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter -cse -canonicalize \ +// RUN: --split-input-file | FileCheck %s // CHECK-LABEL: func @matmul_tensors func.func @matmul_tensors( @@ -6,7 +7,7 @@ -> tensor<8x32xf32> { // CHECK-NOT: linalg // CHECK: vector.extract {{.*}} : vector<8x4xf32> -// CHECK: vector.store {{.*}} : memref<8x32xf32>, vector<4xf32> +// CHECK: vector.store {{.*}} : memref<8x4xf32, strided<[32, 1], offset: ?>>, vector<4xf32> %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x16xf32>, tensor<16x32xf32>) outs(%arg2: tensor<8x32xf32>) -> tensor<8x32xf32>