diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3534,11 +3534,114 @@ return success(); } }; + +/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to +/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is +/// overwritten and inserted into another tensor. After this rewrite, the +/// operations bufferize in-place since all of them work on the same slice. +/// +/// For example: +/// ```mlir +/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0] +/// : vector<8x16xf32>, tensor<8x16xf32> +/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1] +/// : tensor<8x16xf32> to tensor +/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1] +/// : tensor into tensor<27x37xf32> +/// ``` +/// folds to +/// ```mlir +/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1] +/// : tensor<27x37xf32> to tensor +/// %1 = vector.transfer_write %vec, %0[%c0, %c0] +/// : vector<8x16xf32>, tensor +/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1] +/// : tensor into tensor<27x37xf32> +/// ``` +struct SwapExtractSliceOfTransferWrite + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, + PatternRewriter &rewriter) const override { + if (!insertOp.hasUnitStride()) + return failure(); + auto extractOp = insertOp.source().getDefiningOp(); + if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse()) + return failure(); + auto transferOp = extractOp.source().getDefiningOp(); + if (!transferOp || !transferOp->hasOneUse()) + return failure(); + + // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is + // rank-reducing. + if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) { + return rewriter.notifyMatchFailure(insertOp, + "use-def chain is rank-reducing"); + } + + // Fail if tensor::ExtractSliceOp has non-zero offset. + if (!extractOp.hasZeroOffset()) { + return rewriter.notifyMatchFailure(insertOp, + "ExtractSliceOp has non-zero offset"); + } + + // Fail if tensor::TransferWriteOp has non-zero offset. + if (!llvm::all_of(transferOp.getIndices(), [](Value value) { + return getConstantIntValue(value) == static_cast(0); + })) { + return rewriter.notifyMatchFailure(insertOp, + "TranferWriteOp has non-zero offset"); + } + + // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ. + for (const auto &it : + llvm::zip(insertOp.getMixedSizes(), extractOp.getMixedSizes())) { + if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it))) { + return rewriter.notifyMatchFailure( + insertOp, "InsertSliceOp and ExtractSliceOp sizes differ"); + } + } + + // Fail if the vector::TransferWriteOp may not overwrite the full tensor. + assert(transferOp.getVectorType().hasStaticShape() && + "expected vector to have a static shape"); + ArrayRef vectorShape = transferOp.getVectorType().getShape(); + SmallVector resultShape = applyPermutationMap( + transferOp.getPermutationMap(), transferOp.getShapedType().getShape()); + if (transferOp.getMask() || !vectorShape.equals(resultShape)) { + return rewriter.notifyMatchFailure( + insertOp, "TransferWriteOp may not write the full tensor."); + } + + // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp. + SmallVector newResultShape = applyPermutationMap( + transferOp.getPermutationMap(), insertOp.getSourceType().getShape()); + SmallVector newInBounds; + for (const auto &en : enumerate(newResultShape)) + newInBounds.push_back(en.value() == vectorShape[en.index()]); + auto newExtractOp = rewriter.create( + extractOp.getLoc(), insertOp.getSourceType(), insertOp.dest(), + insertOp.getMixedOffsets(), insertOp.getMixedSizes(), + insertOp.getMixedStrides()); + auto newTransferWriteOp = rewriter.create( + transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(), + transferOp.getIndices(), transferOp.getPermutationMapAttr(), + rewriter.getBoolArrayAttr(newInBounds)); + rewriter.updateRootInPlace(insertOp, [&]() { + insertOp.sourceMutable().assign(newTransferWriteOp.getResult()); + }); + return success(); + } +}; + } // namespace void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1149,6 +1149,82 @@ // ----- +// CHECK: #[[$MAP:[0-9a-z]+]] = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK-LABEL: func @swap_extract_slice_transfer_write +// CHECK-SAME: %[[VEC:.*]]: vector<8x4xf32> +// CHECK-SAME: %[[INIT_TENSOR:.*]]: tensor<4x8xf32>, +// CHECK-SAME: %[[ITER_ARG:.*]]: tensor<64x64xf32>, +// CHECK-SAME: %[[IV:.*]]: index, %[[SZ:.*]]: index) +func.func @swap_extract_slice_transfer_write(%arg0 : vector<8x4xf32>, + %arg1 : tensor<4x8xf32>, + %arg2 : tensor<64x64xf32>, + %iv : index, %sz : index) -> tensor<64x64xf32> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ITER_ARG]] + // CHECK-SAME: [%[[IV]], 16] [%[[SZ]], 8] + // CHECK: %[[T1:.*]] = vector.transfer_write %[[VEC]] + // CHECK-SAME: %[[T0]][%[[C0]], %[[C0]]] + // CHECK-SAME: in_bounds = [true, false] + // CHECK-SAME: permutation_map = #[[$MAP]] + // CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1]] into %[[ITER_ARG]] + // CHECK-SAME: [%[[IV]], 16] [%[[SZ]], 8] + %0 = vector.transfer_write %arg0, %arg1[%c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<8x4xf32>, tensor<4x8xf32> + %1 = tensor.extract_slice %0[0, 0] [%sz, 8] [1, 1] : tensor<4x8xf32> to tensor + %2 = tensor.insert_slice %1 into %arg2[%iv, 16] [%sz, 8] [1, 1] : tensor into tensor<64x64xf32> + + // CHECK: return %[[T2]] + func.return %2 : tensor<64x64xf32> +} + +// ----- + +// CHECK-LABEL: func @do_not_swap_extract_slice_transfer_write +// CHECK-SAME: %[[VEC:.*]]: vector<8xf32>, +// CHECK-SAME: %[[VEC_SMALL:.*]]: vector<4xf32>, +// CHECK-SAME: %[[INIT_TENSOR:.*]]: tensor<8xf32>, +// CHECK-SAME: %[[ITER_ARG:.*]]: tensor<64xf32>, +// CHECK-SAME: %[[IV:.*]]: index, %[[SZ:.*]]: index) +func.func @do_not_swap_extract_slice_transfer_write(%arg0 : vector<8xf32>, + %arg1 : vector<4xf32>, + %arg2 : tensor<8xf32>, + %arg3 : tensor<64xf32>, + %iv : index, %sz : index) -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + + // Don't swap if the extracted and inserted slices do not match. + // CHECK: %[[T0:.*]] = vector.transfer_write %[[VEC]] + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[T0]] + // CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1]] + %0 = vector.transfer_write %arg0, %arg2[%c0] {in_bounds = [true]} : vector<8xf32>, tensor<8xf32> + %1 = tensor.extract_slice %0[0] [%iv] [1] : tensor<8xf32> to tensor + %2 = tensor.insert_slice %1 into %arg3[%iv] [%sz] [1] : tensor into tensor<64xf32> + + // Don't swap if the TransferWriteOp takes a small vector. + // CHECK: %[[T3:.*]] = vector.transfer_write %[[VEC_SMALL]] + // CHECK: %[[T4:.*]] = tensor.extract_slice %[[T3]] + // CHECK: %[[T5:.*]] = tensor.insert_slice %[[T4]] + %3 = vector.transfer_write %arg1, %arg2[%c0] {in_bounds = [true]} : vector<4xf32>, tensor<8xf32> + %4 = tensor.extract_slice %3[0] [%sz] [1] : tensor<8xf32> to tensor + %5 = tensor.insert_slice %4 into %arg3[%iv] [%sz] [1] : tensor into tensor<64xf32> + + // Don't swap if the one of the operations is rank-reducing. + // CHECK: %[[T6:.*]] = vector.transfer_write %[[VEC]] + // CHECK: %[[T7:.*]] = tensor.extract_slice %[[T6]] + // CHECK: %[[T8:.*]] = tensor.insert_slice %[[T7]] + %6 = vector.transfer_write %arg0, %arg2[%c0] {in_bounds = [true]} : vector<8xf32>, tensor<8xf32> + %7 = tensor.extract_slice %6[0] [1] [1] : tensor<8xf32> to tensor + %8 = tensor.insert_slice %7 into %arg3[%iv] [1] [1] : tensor into tensor<64xf32> + + // CHECK: return %[[T2]], %[[T5]], %[[T8]] + func.return %2, %5, %8 : tensor<64xf32>, tensor<64xf32>, tensor<64xf32> +} + +// ----- + // CHECK-LABEL: func @vector_multi_reduction_single_parallel( // CHECK-SAME: %[[v:.*]]: vector<2xf32> func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> {