diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -48,6 +48,11 @@ void populateMergeConsecutiveInsertExtractSlicePatterns( RewritePatternSet &patterns); +/// Populates `patterns` with patterns that drop redundant tensor.insert_slice +/// rank expansions. +void populateDropRedundantInsertSliceRankExpansionPatterns( + RewritePatternSet &patterns); + /// Populates `patterns` with patterns that fold `tensor.expand_shape` and /// `tensor.collapse_shape` into other ops. void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -42,6 +42,11 @@ computeTransposedType(RankedTensorType rankedTensorType, ArrayRef transposeVector); +/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the +/// source tensor or inserts the source tensor into a destination tensor with +/// the same shape. +bool isCastLikeInsertSliceOp(InsertSliceOp op); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt @@ -13,6 +13,6 @@ MLIRSCFDialect MLIRTensorDialect MLIRTensorTransforms + MLIRTensorUtils MLIRTransformDialect - MLIRValueBoundsOpInterface ) diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -12,9 +12,9 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -24,29 +24,6 @@ // TrackingListener //===----------------------------------------------------------------------===// -/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the -/// source tensor or inserts the source tensor into a destination tensor with -/// the same shape. -static bool isCastLikeInsertSliceOp(InsertSliceOp op) { - llvm::SmallBitVector droppedDims = op.getDroppedDims(); - int64_t srcDim = 0; - // Source dims and destination dims (apart from dropped dims) must have the - // same size. - for (int64_t resultDim = 0; resultDim < op.getDestType().getRank(); - ++resultDim) { - if (droppedDims.test(resultDim)) { - continue; - } - FailureOr equalDimSize = ValueBoundsConstraintSet::areEqual( - op.getSource(), op.getResult(), srcDim, resultDim); - if (failed(equalDimSize) || !*equalDimSize) - return false; - ++srcDim; - } - - return true; -} - Operation * tensor::TrackingListener::findReplacementOp(Operation *op, ValueRange newValues) const { diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -29,6 +29,7 @@ MLIRPass MLIRSCFDialect MLIRTensorDialect + MLIRTensorUtils MLIRTilingInterface MLIRTransforms MLIRVectorDialect diff --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -76,6 +77,63 @@ return success(); } }; + +/// Drop redundant rank expansion. I.e., rank expansions that are directly +/// followed by rank reductions. E.g.: +/// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32> +/// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1] +/// : tensor<1x1x5x10xf32> to tensor<2x2xf32> +struct DropRedundantInsertSliceRankExpansion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp, + PatternRewriter &rewriter) const override { + // Nothing to do if no dims are dropped. + llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims(); + if (droppedDims.empty()) + return failure(); + + // Look for tensor.insert_slice op that has an inverse rank expansion. + auto insertSliceOp = + extractSliceOp.getSource().getDefiningOp(); + if (!insertSliceOp) + return failure(); + llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims(); + + // TODO: This could be extended to support cases where the dropped dims are + // a subset of the expanded dims. + if (expandedDims != droppedDims) + return failure(); + + // The tensor.insert_slice may not be redundant if it has multiple users. + if (!insertSliceOp->hasOneUse()) + return failure(); + + // Only consider tensor.insert_slice ops that are pure rank-reductions. + // I.e., no elements are taken from the destination. + if (!isCastLikeInsertSliceOp(insertSliceOp)) + return failure(); + + // Extract directly from the source. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(extractSliceOp); + SmallVector newOffsets, newSizes, newStrides; + for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e; + ++i) { + if (droppedDims.test(i)) + continue; + newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]); + newSizes.push_back(extractSliceOp.getMixedSizes()[i]); + newStrides.push_back(extractSliceOp.getMixedStrides()[i]); + } + rewriter.replaceOpWithNewOp( + extractSliceOp, /*source=*/insertSliceOp.getSource(), newOffsets, + newSizes, newStrides); + rewriter.eraseOp(insertSliceOp); + return success(); + } +}; } // namespace void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns( @@ -85,3 +143,8 @@ MergeConsecutiveInsertSlice>( patterns.getContext()); } + +void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt @@ -10,4 +10,5 @@ MLIRArithUtils MLIRIR MLIRTensorDialect + MLIRValueBoundsOpInterface ) diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" using namespace mlir; using namespace mlir::tensor; @@ -102,3 +103,23 @@ RTTBuilder(rankedTensorType).setShape(transposedShape); return transposedTensorType; } + +bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) { + llvm::SmallBitVector droppedDims = op.getDroppedDims(); + int64_t srcDim = 0; + // Source dims and destination dims (apart from dropped dims) must have the + // same size. + for (int64_t resultDim = 0; resultDim < op.getDestType().getRank(); + ++resultDim) { + if (droppedDims.test(resultDim)) { + continue; + } + FailureOr equalDimSize = ValueBoundsConstraintSet::areEqual( + op.getSource(), op.getResult(), srcDim, resultDim); + if (failed(equalDimSize) || !*equalDimSize) + return false; + ++srcDim; + } + + return true; +} diff --git a/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-drop-redundant-insert-slice-rank-expansion %s | FileCheck %s + +// CHECK-LABEL: func @test_drop_rank_expansion( +// CHECK-SAME: %[[src:.*]]: tensor<128x480xf32>, +// CHECK: %[[extract:.*]] = tensor.extract_slice %[[src]][0, 0] [123, 456] [1, 1] : tensor<128x480xf32> to tensor<123x456xf32> +// CHECK: return %[[extract]] +func.func @test_drop_rank_expansion(%src: tensor<128x480xf32>, %dest: tensor<1x1x128x480xf32>) -> tensor<123x456xf32> { + %inserted_slice = tensor.insert_slice %src into %dest[0, 0, 0, 0] [1, 1, 128, 480] [1, 1, 1, 1] : tensor<128x480xf32> into tensor<1x1x128x480xf32> + %extracted_slice = tensor.extract_slice %inserted_slice[0, 0, 0, 0] [1, 1, 123, 456] [1, 1, 1, 1] : tensor<1x1x128x480xf32> to tensor<123x456xf32> + return %extracted_slice : tensor<123x456xf32> +} diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -62,6 +62,11 @@ "with loop nest"), llvm::cl::init(false)}; + Option testDropRedundantInsertSliceRankExpansion{ + *this, "test-drop-redundant-insert-slice-rank-expansion", + llvm::cl::desc("Test dropping redundant insert_slice rank expansions"), + llvm::cl::init(false)}; + Option testReassociativeReshapeFolding{ *this, "test-reassociative-reshape-folding", llvm::cl::desc("Test folding of expand_shape/collapse_shape"), @@ -135,6 +140,13 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void +applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + static void applySimplifyPackPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateSimplifyTensorPack(patterns); @@ -367,6 +379,8 @@ applyFoldConstantExtractSlicePatterns(rootOp); if (testFoldConsecutiveInsertExtractSlice) applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); + if (testDropRedundantInsertSliceRankExpansion) + applyDropRedundantInsertSliceRankExpansionPatterns(rootOp); if (testReassociativeReshapeFolding) applyReassociativeReshapeFoldingPatterns(rootOp); if (testEmptyOpFolding) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5935,6 +5935,7 @@ ":ArithUtils", ":DialectUtils", ":TensorDialect", + ":ValueBoundsOpInterface", "//llvm:Support", ], ) @@ -5988,6 +5989,7 @@ ":SCFDialect", ":TensorDialect", ":TensorPassIncGen", + ":TensorUtils", ":TilingInterface", ":Transforms", ":ValueBoundsOpInterface", @@ -6039,6 +6041,7 @@ ":TensorDialect", ":TensorTransformOpsIncGen", ":TensorTransforms", + ":TensorUtils", ":TransformDialect", ":ValueBoundsOpInterface", "//llvm:Support",