diff --git a/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h b/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h --- a/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h +++ b/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h @@ -45,6 +45,22 @@ SmallVector &combinedSizes, SmallVector &combinedStrides); +/// Returns true if the given two n-D ranges can be proven as disjoint. +/// Returns false otherwise. +/// +/// This function assumes all input arrays to have the same size. +bool areDisjointRanges(ArrayRef aOffsets, + ArrayRef aSizes, + ArrayRef bOffsets, + ArrayRef bSizes); + +/// Returns true if the given two slices can be proven as disjoint. Returns +/// false otherwise. +/// +/// This function assumes the two slices have the same rank. +bool areDisjointSlices(OffsetSizeAndStrideOpInterface aSlice, + OffsetSizeAndStrideOpInterface bSlice); + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -36,6 +36,10 @@ void populateMergeConsecutiveInsertExtractSlicePatterns( RewritePatternSet &patterns); +/// Collects patterns to update tensor.extract_slice to extract from the +/// destination tensor of its producer tensor.insert_slice op. +void populateExtractFromInsertSliceDestOpPatterns(RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp @@ -74,3 +74,52 @@ droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides, combinedOffsets, combinedSizes, combinedStrides); } + +bool mlir::areDisjointRanges(ArrayRef aOffsets, + ArrayRef aSizes, + ArrayRef bOffsets, + ArrayRef bSizes) { + assert(llvm::all_equal( + {aOffsets.size(), aSizes.size(), bOffsets.size(), bSizes.size()})); + + for (const auto &t : llvm::zip(aOffsets, aSizes, bOffsets, bSizes)) { + auto [aBeginVal, aSizeVal, bBeginVal, bSizeVal] = t; + Optional aBegin = getConstantIntValue(aBeginVal); + Optional aSize = getConstantIntValue(aSizeVal); + Optional bBegin = getConstantIntValue(bBeginVal); + Optional bSize = getConstantIntValue(bSizeVal); + + // If there are dynamic offsets/sizes, we cannot prove this dimension is + // disjoint. Look at other dimensions. + if (!aBegin || !aSize || !bBegin || !bSize) + continue; + + int aEnd = *aBegin + *aSize; + int bEnd = *bBegin + *bSize; + // As long as one dimension is disjoint, the whole slices are disjoint. + if (aEnd <= *bBegin || bEnd <= *aBegin) + return true; + } + return false; +} + +bool mlir::areDisjointSlices(OffsetSizeAndStrideOpInterface aSlice, + OffsetSizeAndStrideOpInterface bSlice) { + SmallVector aOffsets = aSlice.getMixedOffsets(); + SmallVector bOffsets = bSlice.getMixedOffsets(); + SmallVector aSizes = aSlice.getMixedSizes(); + SmallVector bSizes = bSlice.getMixedSizes(); + SmallVector aStrides = aSlice.getMixedStrides(); + SmallVector bStrides = bSlice.getMixedStrides(); + + // For simplicity only look at stride 1 cases for now. + auto hasAllOnes = [](ArrayRef strides) { + return llvm::all_of(strides, [](::mlir::OpFoldResult ofr) { + return getConstantIntValue(ofr) == static_cast(1); + }); + }; + if (!hasAllOnes(aStrides) || !hasAllOnes(bStrides)) + return false; + + return areDisjointRanges(aOffsets, aSizes, bOffsets, bSizes); +} diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + ExtractFromInsertSliceDestPatterns.cpp ExtractSliceFromReshapeUtils.cpp MergeConsecutiveInsertExtractSlicePatterns.cpp SplitPaddingPatterns.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractFromInsertSliceDestPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractFromInsertSliceDestPatterns.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractFromInsertSliceDestPatterns.cpp @@ -0,0 +1,63 @@ +//===- ExtractFromInsertSliceDestPatterns.cpp -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::tensor; + +namespace { +/// Updates extract_slice to extrace from insert_slice op's destination tensor +/// when the extract_slice and insert_slice are covering disjoint slices. +/// +/// Example: +/// ```mlir +/// %i = tensor.insert_slice %src into %dst[0, 0, 0, 0][1, 1, 2, 4][1, 1, 1, 1] +/// : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> +/// %e = tensor.extract_slice %i[0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] +/// : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> +/// ``` +/// Can be converted into +/// ```mlir +/// %i = tensor.insert_slice %src into %dst[0, 0, 0, 0][1, 1, 2, 4][1, 1, 1, 1] +/// : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> +/// %e = tensor.extract_slice %dest[0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] +/// : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> +/// ``` +/// This helps to break the chain of insert_slice and extract_slices, which +/// might enable further optimizations. +struct ExtractFromInsertDest final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractSliceOp extractOp, + PatternRewriter &rewriter) const override { + auto insertOp = extractOp.getSource().getDefiningOp(); + if (!insertOp) + return failure(); + + if (!areDisjointSlices(insertOp, extractOp)) + return rewriter.notifyMatchFailure(extractOp, "not disjoint"); + + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getType(), insertOp.getDest(), + extractOp.getMixedOffsets(), extractOp.getMixedSizes(), + extractOp.getMixedStrides()); + + return success(); + } +}; +} // namespace + +void mlir::tensor::populateExtractFromInsertSliceDestOpPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; diff --git a/mlir/test/Dialect/Tensor/extract-from-insert-slice-dest.mlir b/mlir/test/Dialect/Tensor/extract-from-insert-slice-dest.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/extract-from-insert-slice-dest.mlir @@ -0,0 +1,79 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-extract-from-insert-slice-dest -canonicalize %s | FileCheck %s + +func.func @disjoint_insert_extract_slice_static_shape(%src: tensor<1x2x4xf32>, %dst: tensor<1x2x2x4xf32>) -> tensor<1x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %insert = tensor.insert_slice %src into %dst[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + %extract = tensor.extract_slice %insert[%c0, %c1, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + return %extract : tensor<1x2x4xf32> +} + +// CHECK-LABEL: func.func @disjoint_insert_extract_slice_static_shape +// CHECK-SAME: (%{{.+}}: tensor<1x2x4xf32>, %[[DST:.+]]: tensor<1x2x2x4xf32>) +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[DST]][0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] +// CHECK: return %[[EXTRACT]] + +// ----- + +func.func @disjoint_insert_extract_slice_static_shape(%src: tensor<1x2x4xf32>, %dst: tensor<1x2x2x4xf32>) -> tensor<1x2x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %insert = tensor.insert_slice %src into %dst[0, 1, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32> + %extract = tensor.extract_slice %insert[%c0, %c0, %c0, %c0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32> + return %extract : tensor<1x2x4xf32> +} + +// CHECK-LABEL: func.func @disjoint_insert_extract_slice_static_shape +// CHECK-SAME: (%{{.+}}: tensor<1x2x4xf32>, %[[DST:.+]]: tensor<1x2x2x4xf32>) +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[DST]][0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] +// CHECK: return %[[EXTRACT]] + +// ----- + +func.func @disjoint_insert_extract_slice_dynamic_shape(%src: tensor<2x?xf32>, %dst: tensor<8x?xf32>, %size: index) -> tensor<3x?xf32> { + %insert = tensor.insert_slice %src into %dst[2, 0] [2, %size] [1, 1] : tensor<2x?xf32> into tensor<8x?xf32> + %extract = tensor.extract_slice %insert[5, 0] [3, %size] [1, 1] : tensor<8x?xf32> to tensor<3x?xf32> + return %extract : tensor<3x?xf32> +} + +// CHECK-LABEL: func.func @disjoint_insert_extract_slice_dynamic_shape +// CHECK-SAME: (%{{.+}}: tensor<2x?xf32>, %[[DST:.+]]: tensor<8x?xf32>, %[[SIZE:.+]]: index) +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[DST]][5, 0] [3, %[[SIZE]]] [1, 1] +// CHECK: return %[[EXTRACT]] + +// ----- + +func.func @joint_insert_extract_slice_dynamic_shape(%src: tensor<2x?xf32>, %dst: tensor<8x?xf32>, %size: index) -> tensor<3x?xf32> { + %insert = tensor.insert_slice %src into %dst[2, 0] [2, %size] [1, 1] : tensor<2x?xf32> into tensor<8x?xf32> + %extract = tensor.extract_slice %insert[3, 0] [3, %size] [1, 1] : tensor<8x?xf32> to tensor<3x?xf32> + return %extract : tensor<3x?xf32> +} + +// CHECK-LABEL: func.func @joint_insert_extract_slice_dynamic_shape +// CHECK: tensor.insert_slice +// CHECK: tensor.extract_slice + +// ----- + +func.func @joint_insert_extract_slice_dynamic_shape(%src: tensor<2x?xf32>, %dst: tensor<8x?xf32>, %size: index) -> tensor<3x?xf32> { + %insert = tensor.insert_slice %src into %dst[2, 0] [2, %size] [1, 1] : tensor<2x?xf32> into tensor<8x?xf32> + %extract = tensor.extract_slice %insert[1, 0] [3, %size] [1, 1] : tensor<8x?xf32> to tensor<3x?xf32> + return %extract : tensor<3x?xf32> +} + +// CHECK-LABEL: func.func @joint_insert_extract_slice_dynamic_shape +// CHECK: tensor.insert_slice +// CHECK: tensor.extract_slice + + +// ----- + +func.func @joint_insert_extract_slice_dynamic_shape(%src: tensor<2x?xf32>, %dst: tensor<8x?xf32>, %offset: index, %size: index) -> tensor<3x?xf32> { + %insert = tensor.insert_slice %src into %dst[2, 0] [2, %size] [1, 1] : tensor<2x?xf32> into tensor<8x?xf32> + %extract = tensor.extract_slice %insert[%offset, 0] [3, %size] [1, 1] : tensor<8x?xf32> to tensor<3x?xf32> + return %extract : tensor<3x?xf32> +} + +// CHECK-LABEL: func.func @joint_insert_extract_slice_dynamic_shape +// CHECK: tensor.insert_slice +// CHECK: tensor.extract_slice diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -59,6 +59,12 @@ "Test folding consecutive tensor.insert_slice/tensor.extract_slice"), llvm::cl::init(false)}; + Option testExtractFrominsertSliceDest{ + *this, "test-extract-from-insert-slice-dest", + llvm::cl::desc("Test tensor.extract_slice from tensor.insert_slice " + "destination tensor"), + llvm::cl::init(false)}; + Option testRewriteExtractSliceWithTiledCollapseShape{ *this, "test-rewrite-extract-slice-from-collapse-shape", llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " @@ -102,6 +108,12 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyExtractFromInsertSliceDestPatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateExtractFromInsertSliceDestOpPatterns(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + namespace { /// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`. /// The `tensor.extract_slice` is replaced by a loop or gather operation that @@ -247,6 +259,8 @@ applyFoldConstantExtractSlicePatterns(rootOp); if (testFoldConsecutiveInsertExtractSlice) applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); + if (testExtractFrominsertSliceDest) + applyExtractFromInsertSliceDestPatterns(rootOp); if (testRewriteExtractSliceWithTiledCollapseShape) { if (failed( applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))