diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -36,6 +36,10 @@ void populateMergeConsecutiveInsertExtractSlicePatterns( RewritePatternSet &patterns); +/// Populates `patterns` with patterns that fold `tensor.expand_shape` and +/// `tensor.collapse_shape` into other ops. +void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ Bufferize.cpp ExtractSliceFromReshapeUtils.cpp MergeConsecutiveInsertExtractSlicePatterns.cpp + ReshapePatterns.cpp SplitPaddingPatterns.cpp SwapExtractSliceWithProducerPatterns.cpp @@ -26,4 +27,4 @@ MLIRTensorDialect MLIRTilingInterface MLIRTransforms - ) +) diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -0,0 +1,57 @@ +//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-tensor-split-padding" + +using namespace mlir; +using namespace mlir::tensor; + +namespace { +/// Fold expand_shape(extract_slice) ops that cancel itself out. +struct FoldExpandOfRankReducingExtract + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp, + PatternRewriter &rewriter) const override { + RankedTensorType resultType = expandShapeOp.getResultType(); + auto extractSliceOp = + expandShapeOp.getSrc().getDefiningOp(); + if (!extractSliceOp) + return failure(); + RankedTensorType srcType = extractSliceOp.getSourceType(); + + // Only cases where the ExpandShapeOp can be folded away entirely are + // supported. Moreover, only simple cases where the resulting ExtractSliceOp + // has no rank-reduction anymore are supported at the moment. + RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( + srcType, extractSliceOp.getStaticOffsets(), + extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); + if (nonReducingExtractType != resultType) + return failure(); + + SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); + SmallVector mixedSizes = extractSliceOp.getMixedSizes(); + SmallVector mixedStrides = extractSliceOp.getMixedStrides(); + rewriter.replaceOpWithNewOp( + expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes, + mixedStrides); + return success(); + } +}; +} // namespace + +void mlir::tensor::populateReassociativeReshapeFoldingPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-reassociative-reshape-folding %s | FileCheck %s + +// CHECK-LABEL: func @expand_shape_of_rank_reducing_extract( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK-DAG: %[[extract1:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor to tensor +// CHECK-DAG: %[[extract2:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor to tensor +// CHECK: return %[[extract1]], %[[extract2]] +func.func @expand_shape_of_rank_reducing_extract( + %t: tensor, %idx: index) + -> (tensor, tensor) +{ + %0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1] + : tensor to tensor + %1 = tensor.expand_shape %0 [[0], [1, 2], [3]] + : tensor into tensor + %2 = tensor.expand_shape %0 [[0, 1], [2], [3]] + : tensor into tensor + return %1, %2 : tensor, tensor +} diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -65,6 +65,11 @@ "with loop nest"), llvm::cl::init(false)}; + Option testReassociativeReshapeFolding{ + *this, "test-reassociative-reshape-folding", + llvm::cl::desc("Test folding of expand_shape/collapse_shape"), + llvm::cl::init(false)}; + Option useForeach{ *this, "use-foreach", llvm::cl::desc( @@ -74,6 +79,12 @@ }; } // namespace +static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateReassociativeReshapeFoldingPatterns(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + static void applySplitPaddingPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateSplitPaddingPatterns(patterns); @@ -262,6 +273,8 @@ applyFoldConstantExtractSlicePatterns(rootOp); if (testFoldConsecutiveInsertExtractSlice) applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); + if (testReassociativeReshapeFolding) + applyReassociativeReshapeFoldingPatterns(rootOp); if (testRewriteExtractSliceWithTiledCollapseShape) { if (failed( applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))