diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -29,6 +29,13 @@ FailureOr replaceExtractSliceWithTiledProducer( OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp); +/// Collects patterns to merge consecutive tensor.insert_slice/extract_slice +/// into one. These patterns are in in this separate entry point because the +/// bufferization is sensitive over IR structure, particularly those +/// tensor.extract_slice and tensor.insert_slice ops for creating the slices. +void populateMergeConsecutiveInsertExtractSlicePatterns( + RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp ExtractSliceFromReshape.cpp + MergeConsecutiveInsertExtractSlicePatterns.cpp SplitPadding.cpp SwapExtractSliceWithProducer.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp @@ -0,0 +1,117 @@ +//===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::tensor; + +/// Adds each corresponding pair of offsets in `offsets1` and `offsets2` and +/// returns the results. +static SmallVector mergeOffsets(Location loc, + ArrayRef offsets1, + ArrayRef offsets2, + OpBuilder &builder) { + SmallVector foldedOffsets; + assert(offsets1.size() == offsets2.size()); + foldedOffsets.reserve(offsets1.size()); + + AffineExpr dim1, dim2; + bindDims(builder.getContext(), dim1, dim2); + + for (const auto &pair : llvm::zip(offsets1, offsets2)) { + auto offset0 = + getValueOrCreateConstantIndexOp(builder, loc, std::get<0>(pair)); + auto offset1 = + getValueOrCreateConstantIndexOp(builder, loc, std::get<1>(pair)); + auto foldedOffset = + makeComposedAffineApply(builder, loc, dim1 + dim2, {offset0, offset1}); + foldedOffsets.push_back(foldedOffset.getResult()); + } + return foldedOffsets; +} + +namespace { +/// Merges consecutive tensor.extract_slice ops into one. +struct MergeConsecutiveExtractSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractSliceOp nextOp, + PatternRewriter &rewriter) const override { + auto prevOp = nextOp.getSource().getDefiningOp(); + if (!prevOp) + return failure(); + + if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride()) + return failure(); + + auto prevResultType = prevOp.getType().cast(); + if (prevOp.getSourceType().getRank() != prevResultType.getRank()) + return rewriter.notifyMatchFailure( + prevOp, "rank-reducing producder case unimplemented"); + + Location loc = nextOp.getLoc(); + + SmallVector prevOffsets = prevOp.getMixedOffsets(); + SmallVector nextOffsets = nextOp.getMixedOffsets(); + SmallVector foldedOffsets = + mergeOffsets(loc, prevOffsets, nextOffsets, rewriter); + + rewriter.replaceOpWithNewOp( + nextOp, nextOp.getType(), prevOp.getSource(), foldedOffsets, + nextOp.getMixedSizes(), nextOp.getMixedStrides()); + return success(); + } +}; + +/// Merges consecutive tensor.insert_slice ops into one. +struct MergeConsecutiveInsertSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertSliceOp nextOp, + PatternRewriter &rewriter) const override { + auto prevOp = nextOp.getSource().getDefiningOp(); + if (!prevOp) + return failure(); + + if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride()) + return failure(); + + // The first insert_slice op should be rank reducing to make sure we cover + // the full source tensor to be inserted in the second insert_slice op. + SliceVerificationResult result = + isRankReducedType(prevOp.getDestType(), prevOp.getSourceType()); + if (result != SliceVerificationResult::Success) + return failure(); + + // Dynamic dimensions can pass rank reducing check in the above, e.g, + // inserting into <1x?x1xf32>. For such cases we cannot be certain + // the dynamic size covers the full tensor. + if (!prevOp.getSourceType().hasStaticShape() || + !prevOp.getDestType().hasStaticShape()) + return failure(); + + rewriter.replaceOpWithNewOp( + nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(), + nextOp.getMixedSizes(), nextOp.getMixedStrides()); + return success(); + } +}; +} // namespace + +void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} diff --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-consecutive-insert-extract-slice -canonicalize -mlir-print-local-scope %s | FileCheck %s + +func.func @extract_slice_same_rank( + %src: tensor, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x16x32x?xf32> { + %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor to tensor<128x128x128x?xf32> + %1 = tensor.extract_slice %0[7, 8, 9, %offset1] [8, 16, 32, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<8x16x32x?xf32> + return %1: tensor<8x16x32x?xf32> +} + +// CHECK-LABEL: func.func @extract_slice_same_rank +// CHECK-SAME: (%[[SOURCE:.+]]: tensor, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index) +// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET0]], %[[OFFSET1]]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1] +// CHECK: return %[[EXTRACT]] : tensor<8x16x32x?xf32> + +func.func @extract_slice_rank_reducing_consumer( + %src: tensor, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> { + %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor to tensor<128x128x128x?xf32> + %1 = tensor.extract_slice %0[7, 8, 9, %offset1] [1, 16, 1, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<16x?xf32> + return %1: tensor<16x?xf32> +} + +// CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer +// CHECK: tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor to tensor<16x?xf32> + +func.func @extract_slice_rank_reducing_producer( + %src: tensor, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> { + %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor to tensor<128x?xf32> + %1 = tensor.extract_slice %0[7, %offset1] [8, %size1] [1, 1] : tensor<128x?xf32> to tensor<8x?xf32> + return %1: tensor<8x?xf32> +} + +// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer +// CHECK-COUNT-2: tensor.extract_slice + +// ----- + +func.func @insert_slice_rank_reducing( + %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x16x1xf32>, %src: tensor<16xf32>, %offset: index) -> tensor<128x128x128x128xf32> { + %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, 16, 1] [1, 1, 1] : tensor<16xf32> into tensor<1x16x1xf32> + %1 = tensor.insert_slice %0 into %dst[6, 7, 8, %offset] [1, 1, 16, 1] [1, 1, 1, 1] : tensor<1x16x1xf32> into tensor<128x128x128x128xf32> + return %1: tensor<128x128x128x128xf32> +} + +// CHECK-LABEL: func.func @insert_slice_rank_reducing +// CHECK-SAME: (%[[DST:.+]]: tensor<128x128x128x128xf32>, %{{.+}}: tensor<1x16x1xf32>, %[[SRC:.+]]: tensor<16xf32>, %[[IDX:.+]]: index) +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DST]][6, 7, 8, %[[IDX]]] [1, 1, 16, 1] [1, 1, 1, 1] +// CHECK: return %[[INSERT]] + +func.func @insert_slice_rank_reducing_dynamic_shape( + %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x?x1xf32>, %src: tensor, %offset: index, %size: index) -> tensor<128x128x128x128xf32> { + %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, %size, 1] [1, 1, 1] : tensor into tensor<1x?x1xf32> + %1 = tensor.insert_slice %0 into %dst[6, 7, 8, %offset] [1, 1, %size, 1] [1, 1, 1, 1] : tensor<1x?x1xf32> into tensor<128x128x128x128xf32> + return %1: tensor<128x128x128x128xf32> +} + +// CHECK-LABEL: func.func @insert_slice_rank_reducing_dynamic_shape +// CHECK-COUNT-2: tensor.insert_slice diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -53,6 +53,12 @@ llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"), llvm::cl::init(false)}; + Option testFoldConsecutiveInsertExtractSlice{ + *this, "test-fold-consecutive-insert-extract-slice", + llvm::cl::desc( + "Test folding consecutive tensor.insert_slice/tensor.extract_slice"), + llvm::cl::init(false)}; + Option testRewriteExtractSliceWithTiledCollapseShape{ *this, "test-rewrite-extract-slice-from-collapse-shape", llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " @@ -90,6 +96,12 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + namespace { /// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`. /// The `tensor.extract_slice` is replaced by a loop or gather operation that @@ -233,6 +245,8 @@ applySplitPaddingPatterns(rootOp); if (testFoldConstantExtractSlice) applyFoldConstantExtractSlicePatterns(rootOp); + if (testFoldConsecutiveInsertExtractSlice) + applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); if (testRewriteExtractSliceWithTiledCollapseShape) { if (failed( applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))