diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -340,6 +340,12 @@ const UnrollVectorOptions &options, PatternBenefit benefit = 1); +/// Swaps vector.transfer_write op that are used as tensor.insert_slice source +/// or destination tensors. This may help to push tensor insert/extract slice +/// ops to be near each other so that we can cancel them later. +void populateVectorReorderTransferExtractInsertSlicePatterns( + RewritePatternSet &patterns, PatternBenefit benefit = 1); + //===----------------------------------------------------------------------===// // Finer-grained patterns exposed for more control over individual lowerings. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRVectorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + ReorderTransferInsertExtractSlice.cpp VectorDistribute.cpp VectorDropLeadUnitDim.cpp VectorInsertExtractStridedSliceRewritePatterns.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/ReorderTransferInsertExtractSlice.cpp b/mlir/lib/Dialect/Vector/Transforms/ReorderTransferInsertExtractSlice.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/ReorderTransferInsertExtractSlice.cpp @@ -0,0 +1,231 @@ +//===- ReorderTransferInsertExtractSlice.cpp ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::vector; + +/// Returns true if all rank reduced in the given `extractOp` happen in leading +/// dimensions earlier than last `trailingRank` dimensions. +static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp, + unsigned trailingRank) { + if (extractOp.getSourceType().getRank() == extractOp.getType().getRank()) + return true; + + RankedTensorType inferredType = extractOp.inferResultType( + extractOp.getSourceType(), extractOp.getMixedOffsets(), + extractOp.getMixedSizes(), extractOp.getMixedStrides()); + return extractOp.getType().getShape().take_back(trailingRank) == + inferredType.getShape().take_back(trailingRank); +} + +/// Returns true if all rank reduced in the given `insertOp` happen in leading +/// dimensions earlier than last `trailingRank` dimensions. +static bool areAllRankReducedLeadingDim(tensor::InsertSliceOp insertOp, + unsigned trailingRank) { + // If no reduced ranks then simply return true. + if (insertOp.getSourceType().getRank() == insertOp.getDestType().getRank()) + return true; + + // Infer the small type by extracting from the large type. + RankedTensorType inferredType = tensor::ExtractSliceOp::inferResultType( + insertOp.getDestType(), insertOp.getMixedOffsets(), + insertOp.getMixedSizes(), insertOp.getMixedStrides()); + return insertOp.getSourceType().getShape().take_back(trailingRank) == + inferredType.getShape().take_back(trailingRank); +} + +namespace { + +/// Reorders vector.transfer_write that are tensor.insert_slice source to be +/// after the tensor.insert_slice op. +/// +/// In order to make sure the reordering is beneficial, this pattern +/// additionally requires the vector.transfer_write is writing to some +/// tensor.extract_slice that extracts from the tensor.insert_slice's +/// destination tensor. +/// +/// For example, given the following IR: +/// ``` +/// %extract = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] +/// : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> +/// %write0 = vector.transfer_write %val0, %extract[%c0, %c0, %c0] +/// {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> +/// %write1 = vector.transfer_write %val1, %write0[%c0, %c1, %c0] +/// {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> +/// %insert = tensor.insert_slice %write1 into %input[0, 0, 0, 0] [1, 1, 2, 4] +/// [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> +/// ``` +/// We can fold it into +/// ```mlir +/// %write0 = vector.transfer_write %val0, %input[%c0, %c0, %c0] +/// {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> +/// %write1 = vector.transfer_write %val1, %write0[%c0, %c1, %c0] +/// {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> +/// ``` +struct ReorderTransferWriteAsInsertSource final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, + PatternRewriter &rewriter) const override { + auto writeOp = insertOp.getSource().getDefiningOp(); + if (!writeOp) + return failure(); + + Value writeDest = writeOp.getSource(); + // Allow a chain of vector.transfer_write ops that build upon one another. + // It's common to see that after vector unrolling. + while (auto prevOp = writeDest.getDefiningOp()) + writeDest = prevOp.getSource(); + auto extractOp = writeDest.getDefiningOp(); + if (!extractOp) + return failure(); + + // To be beneficial, require that 1) extract source to be the same as insert + // destination; 2) the extract and insert slice op has matching offsets, + // sizes, and strides. This makes sure they can be folded away afterwards. + if (extractOp.getSource() != insertOp.getDest()) + return failure(); + const auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; + if (!extractOp.isSameAs(insertOp, isSame)) + return rewriter.notifyMatchFailure(insertOp, "mismatched parameters"); + + // Make sure the transfer_write op has minor identity and all reduced rank + // are in leading dimensions. This avoid complicated rank reducing issues + // when swap the transfer and slice op. + int64_t largeTensorRank = insertOp.getType().getRank(); + int64_t smallTensorRank = insertOp.getSourceType().getRank(); + int64_t vectorRank = writeOp.getVectorType().getRank(); + if (!writeOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(insertOp, "not minor identity map"); + if (!areAllRankReducedLeadingDim(extractOp, smallTensorRank)) + return rewriter.notifyMatchFailure(insertOp, "not leading rank reduced"); + + Location loc = insertOp.getLoc(); + auto newInsertOp = rewriter.create( + loc, writeOp.getSource(), insertOp.getDest(), + insertOp.getMixedOffsets(), insertOp.getMixedSizes(), + insertOp.getMixedStrides()); + + // Prepend zeros to the indices to match the large tensor, if the extract + // slice op is rank reducing. + SmallVector newIndices; + newIndices.reserve(largeTensorRank); + int64_t reducedRank = largeTensorRank - smallTensorRank; + for (int i = 0; i < reducedRank; ++i) { + OpFoldResult offset = insertOp.getMixedOffsets()[i]; + newIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, offset)); + } + AffineExpr dim0, dim1; + bindDims(getContext(), dim0, dim1); + for (int i = 0; i < smallTensorRank; ++i) { + OpFoldResult offset = insertOp.getMixedOffsets()[i + reducedRank]; + Value offsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset); + newIndices.push_back(makeComposedAffineApply( + rewriter, loc, dim0 + dim1, {writeOp.getIndices()[i], offsetVal})); + } + + auto newMap = AffineMap::getMinorIdentityMap(largeTensorRank, vectorRank, + writeOp.getContext()); + + rewriter.replaceOpWithNewOp( + insertOp, writeOp.getVector(), newInsertOp.getResult(), newIndices, + AffineMapAttr::get(newMap), writeOp.getMask(), + writeOp.getInBoundsAttr()); + return success(); + } +}; + +/// Reorders vector.transfer_write that are tensor.insert_slice destination to +/// be after the tensor.insert_slice op when the ranges are disjoint. +/// +/// E.g., the following IR: +/// ``mlir +/// %0 = vector.transfer_write %val, %src[0, 0, 1, 0] {in_bounds = [true]} +/// : vector<4xf32>, tensor<1x2x2x4xf32> +/// %1 = tensor.insert_slice %slice into %0[0, 1, 0, 0] [1, 1, 2, 4] +/// [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> +/// ``` +/// Can be converted into +/// ```mlir +/// %0 = tensor.insert_slice %slice into %src[0, 1, 0, 0] [1, 1, 2, 4] +/// [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> +/// %1 = vector.transfer_write %val, %0[0, 0, 1, 0] {in_bounds = [true]} +/// : vector<4xf32>, tensor<1x2x2x4xf32> +/// ``` +struct ReorderTransferWriteAsInsertDest final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, + PatternRewriter &rewriter) const override { + auto writeOp = insertOp.getDest().getDefiningOp(); + if (!writeOp) + return rewriter.notifyMatchFailure(insertOp, "not inserting into write"); + if (!writeOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(insertOp, "not minor identity map"); + + if (!insertOp.hasUnitStride()) + return rewriter.notifyMatchFailure(insertOp, "not unit stride"); + if (!areAllRankReducedLeadingDim(insertOp, + insertOp.getSourceType().getRank())) + return rewriter.notifyMatchFailure(insertOp, "not leading rank reduced"); + + unsigned writeTensorRank = writeOp.getSource().getType().getRank(); + unsigned writeReducedRank = writeOp.getLeadingShapedRank(); + + SmallVector writeOffsets; + writeOffsets.reserve(writeTensorRank); + llvm::append_range(writeOffsets, writeOp.getIndices()); + + SmallVector writeSizes; + writeSizes.reserve(writeTensorRank); + for (unsigned i = 0; i < writeReducedRank; ++i) + writeSizes.push_back(rewriter.getIndexAttr(1)); + for (unsigned i = writeReducedRank; i < writeTensorRank; ++i) + writeSizes.push_back(rewriter.getIndexAttr( + writeOp.getVectorType().getDimSize(i - writeReducedRank))); + + SmallVector insertOffsets = insertOp.getMixedOffsets(); + SmallVector insertSizes = insertOp.getMixedSizes(); + + if (!areDisjointRanges(writeOffsets, writeSizes, insertOffsets, + insertSizes)) + return rewriter.notifyMatchFailure(insertOp, "not disjoint ranges"); + + auto newInsertOp = rewriter.create( + insertOp.getLoc(), insertOp.getSource(), writeOp.getSource(), + insertOp.getMixedOffsets(), insertOp.getMixedSizes(), + insertOp.getMixedStrides()); + + rewriter.replaceOpWithNewOp( + insertOp, writeOp.getVector(), newInsertOp.getResult(), + writeOp.getIndices(), writeOp.getPermutationMapAttr(), + writeOp.getMask(), writeOp.getInBoundsAttr()); + + return success(); + } +}; + +} // namespace + +void vector::populateVectorReorderTransferExtractInsertSlicePatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), + benefit); +} diff --git a/mlir/test/Dialect/Vector/reorder-transfer-extract-insert-slice.mlir b/mlir/test/Dialect/Vector/reorder-transfer-extract-insert-slice.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/reorder-transfer-extract-insert-slice.mlir @@ -0,0 +1,186 @@ +// RUN: mlir-opt %s -split-input-file -test-vector-reorder-transfer | FileCheck %s + +func.func @write_as_insert_source(%input: tensor<8x8x8x4xf32>, %val0: vector<4xf32>) -> tensor<8x8x8x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %extract = tensor.extract_slice %input[4, 3, 1, 2] [1, 4, 4, 4] [1, 1, 1, 1] : tensor<8x8x8x4xf32> to tensor<4x4x4xf32> + %write0 = vector.transfer_write %val0, %extract[%c1, %c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4x4xf32> + %insert = tensor.insert_slice %write0 into %input[4, 3, 1, 2] [1, 4, 4, 4] [1, 1, 1, 1] : tensor<4x4x4xf32> into tensor<8x8x8x4xf32> + return %insert : tensor<8x8x8x4xf32> +} + +// CHECK-LABEL: func.func @write_as_insert_source +// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x8x8x4xf32>, %[[VAL:.+]]: vector<4xf32>) +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[VAL]], %[[INPUT]][%[[C4]], %[[C4]], %[[C3]], %[[C2]]] +// CHECK-SAME: {in_bounds = [true]} : vector<4xf32>, tensor<8x8x8x4xf32> +// CHECK: return %[[WRITE]] + +// ----- + +// CHECK-LABEL: func.func @write_as_insert_source +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>, %[[VAL0:.+]]: vector<4xf32>, %[[VAL1:.+]]: vector<4xf32>) +// CHECK: %[[W0:.+]] = vector.transfer_write %[[VAL0]], %[[INPUT]] +// CHECK: %[[W1:.+]] = vector.transfer_write %[[VAL1]], %[[W0]] +// CHECK: return %[[W1]] +func.func @write_as_insert_source(%input: tensor<1x2x2x4xf32>, %val0: vector<4xf32>, %val1: vector<4xf32>) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %extract = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %write0 = vector.transfer_write %val0, %extract[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> + %write1 = vector.transfer_write %val1, %write0[%c0, %c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> + %insert = tensor.insert_slice %write1 into %input[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + return %insert : tensor<1x2x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @rank_reduced_in_trailing_dimensions +// CHECK: tensor.extract_slice +// CHECK-COUNT-2: vector.transfer_write +// CHECK: tensor.insert_slice +func.func @rank_reduced_in_trailing_dimensions(%input: tensor<1x2x2x4xf32>, %val0: vector<4xf32>, %val1: vector<4xf32>) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %extract = tensor.extract_slice %input[0, 0, 0, 0] [1, 2, 1, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %write0 = vector.transfer_write %val0, %extract[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> + %write1 = vector.transfer_write %val1, %write0[%c0, %c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> + %insert = tensor.insert_slice %write1 into %input[%c0, 0, %c0, 0] [1, 2, 1, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + return %insert : tensor<1x2x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @not_minor_identity_map +// CHECK: tensor.extract_slice +// CHECK-COUNT-2: vector.transfer_write +// CHECK: tensor.insert_slice +func.func @not_minor_identity_map(%input: tensor<1x2x2x4xf32>, %val0: vector<2xf32>, %val1: vector<2xf32>) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %extract = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %write0 = vector.transfer_write %val0, %extract[%c0, %c0, %c0] {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2) -> (d1)>} : vector<2xf32>, tensor<1x2x4xf32> + %write1 = vector.transfer_write %val1, %write0[%c0, %c1, %c0] {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2) -> (d1)>} : vector<2xf32>, tensor<1x2x4xf32> + %insert = tensor.insert_slice %write1 into %input[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + return %insert : tensor<1x2x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @mismatched_slice_parameters +// CHECK: tensor.extract_slice +// CHECK-COUNT-2: vector.transfer_write +// CHECK: tensor.insert_slice +func.func @mismatched_slice_parameters(%input: tensor<1x2x2x4xf32>, %val0: vector<4xf32>, %val1: vector<4xf32>) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %extract = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %write0 = vector.transfer_write %val0, %extract[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> + %write1 = vector.transfer_write %val1, %write0[%c0, %c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> + %insert = tensor.insert_slice %write1 into %input[%c0, 1, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + return %insert : tensor<1x2x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @not_insert_back_to_original_tensor +// CHECK: tensor.extract_slice +// CHECK-COUNT-2: vector.transfer_write +// CHECK: tensor.insert_slice +func.func @not_insert_back_to_original_tensor(%input0: tensor<1x2x2x4xf32>, %input1: tensor<1x2x2x4xf32>, %val0: vector<4xf32>, %val1: vector<4xf32>) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %extract = tensor.extract_slice %input0[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %write0 = vector.transfer_write %val0, %extract[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> + %write1 = vector.transfer_write %val1, %write0[%c0, %c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> + %insert = tensor.insert_slice %write1 into %input1[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + return %insert : tensor<1x2x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @write_as_insert_dest +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>, %[[VAL0:.+]]: vector<4xf32>, %[[VAL1:.+]]: vector<4xf32>) +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[W0:.+]] = vector.transfer_write %[[VAL0]], %[[INPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x2x4xf32> +// CHECK: %[[W1:.+]] = vector.transfer_write %[[VAL1]], %[[W0]][%[[C0]], %[[C0]], %[[C1]], %[[C0]]] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x2x4xf32> +// CHECK: return %[[W1]] : tensor<1x2x2x4xf32> +func.func @write_as_insert_dest(%input: tensor<1x2x2x4xf32>, %val0: vector<4xf32>, %val1: vector<4xf32>) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tensor.extract_slice %input[0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + %3 = vector.transfer_write %val0, %input[%c0, %c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x2x4xf32> + %4 = vector.transfer_write %val1, %3[%c0, %c0, %c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x2x4xf32> + %5 = tensor.insert_slice %0 into %4[0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> +return %5 : tensor<1x2x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @not_disjoint_ranges +// CHECK: %[[W:.+]] = vector.transfer_write +// CHECK: tensor.insert_slice %{{.+}} into %[[W]] +func.func @not_disjoint_ranges(%input: tensor<1x2x2x4xf32>, %val: vector<4xf32>, %slice: tensor<1x2x4xf32>) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = vector.transfer_write %val, %input[%c0, %c1, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x2x4xf32> + %1 = tensor.insert_slice %slice into %0[0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> +return %1 : tensor<1x2x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @not_minor_identity_map +// CHECK: %[[W:.+]] = vector.transfer_write +// CHECK: tensor.insert_slice %{{.+}} into %[[W]] +func.func @not_minor_identity_map(%input: tensor<1x2x2x4xf32>, %val: vector<2xf32>, %slice: tensor<1x2x4xf32>) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = vector.transfer_write %val, %input[%c0, %c0, %c0, %c0] {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2)>} : vector<2xf32>, tensor<1x2x2x4xf32> + %1 = tensor.insert_slice %slice into %0[0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> +return %1 : tensor<1x2x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @rank_reduced_in_trailing_dimensions +// CHECK: %[[W:.+]] = vector.transfer_write +// CHECK: tensor.insert_slice %{{.+}} into %[[W]] +func.func @rank_reduced_in_trailing_dimensions(%input: tensor<1x2x2x4xf32>, %val: vector<4xf32>, %slice: tensor<1x2x4xf32>) -> tensor<1x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = vector.transfer_write %val, %input[%c0, %c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x2x4xf32> + %1 = tensor.insert_slice %slice into %0[0, 1, 0, 0] [1, 2, 1, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> +return %1 : tensor<1x2x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @double_sandwiched_transfer_write +// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x2x2x4xf32>, %[[VAL0:.+]]: vector<4xf32>, %[[VAL1:.+]]: vector<4xf32>, %[[VAL2:.+]]: vector<4xf32>, %[[VAL3:.+]]: vector<4xf32>) +// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[W0:.+]] = vector.transfer_write %[[VAL2]], %[[INPUT]][%[[C6]], %[[C1]], %[[C0]], %[[C0]]] +// CHECK: %[[W1:.+]] = vector.transfer_write %[[VAL3]], %[[W0]][%[[C6]], %[[C1]], %[[C1]], %[[C0]]] +// CHECK: %[[W2:.+]] = vector.transfer_write %[[VAL0]], %[[W1]][%[[C4]], %[[C0]], %[[C0]], %[[C0]]] +// CHECK: %[[W3:.+]] = vector.transfer_write %[[VAL1]], %[[W2]][%[[C4]], %[[C0]], %[[C1]], %c0] +// CHECK: return %[[W3]] +func.func @double_sandwiched_transfer_write(%input: tensor<8x2x2x4xf32>, %val0: vector<4xf32>, %val1: vector<4xf32>, %val2: vector<4xf32>, %val3: vector<4xf32>) -> tensor<8x2x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tensor.extract_slice %input[6, %c1, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<8x2x2x4xf32> to tensor<1x2x4xf32> + %1 = tensor.extract_slice %input[4, %c0, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<8x2x2x4xf32> to tensor<1x2x4xf32> + %2 = vector.transfer_write %val0, %1[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> + %3 = vector.transfer_write %val1, %2[%c0, %c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> + %4 = vector.transfer_write %val2, %0[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> + %5 = vector.transfer_write %val3, %4[%c0, %c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<1x2x4xf32> + %6 = tensor.insert_slice %3 into %input[4, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<8x2x2x4xf32> + %7 = tensor.insert_slice %5 into %6[6, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<8x2x2x4xf32> + return %7 : tensor<8x2x2x4xf32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -901,6 +902,29 @@ } }; +struct TestVectorReoderTransfer + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorReoderTransfer) + + StringRef getArgument() const final { return "test-vector-reorder-transfer"; } + StringRef getDescription() const final { + return "Test patterns reordering vector transfer ops tensor insert/extract " + "slice ops"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateVectorReorderTransferExtractInsertSlicePatterns(patterns); + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + } // namespace namespace mlir { @@ -939,6 +963,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir