diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -26,7 +26,6 @@ namespace mlir { namespace transform { -class ApplyPatternsOp; enum class FailurePropagationMode : uint32_t; class FailurePropagationModeAttr; @@ -166,9 +165,6 @@ /// be prefixed with the dialect to which the patterns belong. void registerPatterns(StringRef identifier, PopulatePatternsFn &&fn); -protected: - friend class ApplyPatternsOp; - /// Returns "true" if patterns are registered with the specified identifier. bool hasPatterns(StringAttr identifier) const; diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -163,6 +163,8 @@ registry.registerPatterns( "tensor.fold_into_pack_and_unpack", tensor::populateFoldIntoPackAndUnpackPatterns); + registry.registerPatterns("tensor.simplify_tensor_pack", + tensor::populateSimplifyTensorPack); }); } }; diff --git a/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir --- a/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir +++ b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-drop-redundant-insert-slice-rank-expansion %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-registered-patterns="patterns=tensor.drop_redundant_insert_slice_rank_expansion" %s | FileCheck %s // CHECK-LABEL: func @test_drop_rank_expansion( // CHECK-SAME: %[[src:.*]]: tensor<128x480xf32>, diff --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir --- a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir +++ b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-consecutive-insert-extract-slice -canonicalize -mlir-print-local-scope %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-registered-patterns="patterns=tensor.merge_consecutive_insert_extract_slice" -canonicalize -mlir-print-local-scope %s | FileCheck %s func.func @extract_slice_same_rank( %src: tensor, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x16x32x?xf32> { diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir --- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-empty-op-folding %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-registered-patterns="patterns=tensor.fold_tensor_empty" %s | FileCheck %s func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { %0 = tensor.empty(%arg0) : tensor<6x5x?xf32> diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir --- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir +++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-registered-patterns="patterns=tensor.fold_into_pack_and_unpack" %s | FileCheck %s func.func @fold_unpack_slice(%arg0 : tensor, %arg1 : tensor, %arg2 : index, %arg3 : index) -> tensor { diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir --- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir +++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-reassociative-reshape-folding %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-registered-patterns="patterns=tensor.reassociative_reshape_folding" %s | FileCheck %s // CHECK-LABEL: func @expand_shape_of_rank_reducing_extract( // CHECK-SAME: %[[t:.*]]: tensor diff --git a/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir b/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir --- a/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir +++ b/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-patterns" %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-registered-patterns="patterns=tensor.simplify_tensor_pack" %s | FileCheck %s // CHECK: func.func @single_dim_packing( // CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>) diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -50,37 +50,12 @@ llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"), llvm::cl::init(false)}; - Option testFoldConsecutiveInsertExtractSlice{ - *this, "test-fold-consecutive-insert-extract-slice", - llvm::cl::desc( - "Test folding consecutive tensor.insert_slice/tensor.extract_slice"), - llvm::cl::init(false)}; - Option testRewriteExtractSliceWithTiledCollapseShape{ *this, "test-rewrite-extract-slice-from-collapse-shape", llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " "with loop nest"), llvm::cl::init(false)}; - Option testDropRedundantInsertSliceRankExpansion{ - *this, "test-drop-redundant-insert-slice-rank-expansion", - llvm::cl::desc("Test dropping redundant insert_slice rank expansions"), - llvm::cl::init(false)}; - - Option testReassociativeReshapeFolding{ - *this, "test-reassociative-reshape-folding", - llvm::cl::desc("Test folding of expand_shape/collapse_shape"), - llvm::cl::init(false)}; - - Option testEmptyOpFolding{ - *this, "test-empty-op-folding", - llvm::cl::desc("Test folding of tensor.empty"), llvm::cl::init(false)}; - - Option testFoldIntoPackAndUnpack{ - *this, "test-fold-into-pack-and-unpack", - llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"), - llvm::cl::init(false)}; - Option useForeach{ *this, "use-foreach", llvm::cl::desc( @@ -88,11 +63,6 @@ "the extract_slice of collapse_shape pattern"), llvm::cl::init(false)}; - Option testSimplifyPackPatterns{ - *this, "test-simplify-pack-patterns", - llvm::cl::desc("Test patterns to simplify tensor.pack"), - llvm::cl::init(false)}; - Option testTrackingListener{ *this, "test-tracking-listener", llvm::cl::desc("Test tensor TrackingListener for the transform dialect"), @@ -100,24 +70,6 @@ }; } // namespace -static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) { - RewritePatternSet patterns(rootOp->getContext()); - tensor::populateReassociativeReshapeFoldingPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); -} - -static void applyEmptyOpFoldingPatterns(Operation *rootOp) { - RewritePatternSet patterns(rootOp->getContext()); - tensor::populateFoldTensorEmptyPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); -} - -static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) { - RewritePatternSet patterns(rootOp->getContext()); - tensor::populateFoldIntoPackAndUnpackPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); -} - static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::ControlConstantExtractSliceFusionFn controlFn = @@ -134,25 +86,6 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } -static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) { - RewritePatternSet patterns(rootOp->getContext()); - tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); -} - -static void -applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) { - RewritePatternSet patterns(rootOp->getContext()); - tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); -} - -static void applySimplifyPackPatterns(Operation *rootOp) { - RewritePatternSet patterns(rootOp->getContext()); - tensor::populateSimplifyTensorPack(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); -} - namespace { /// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`. /// The `tensor.extract_slice` is replaced by a loop or gather operation that @@ -373,20 +306,8 @@ void TestTensorTransforms::runOnOperation() { Operation *rootOp = getOperation(); - if (testSimplifyPackPatterns) - applySimplifyPackPatterns(rootOp); if (testFoldConstantExtractSlice) applyFoldConstantExtractSlicePatterns(rootOp); - if (testFoldConsecutiveInsertExtractSlice) - applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); - if (testDropRedundantInsertSliceRankExpansion) - applyDropRedundantInsertSliceRankExpansionPatterns(rootOp); - if (testReassociativeReshapeFolding) - applyReassociativeReshapeFoldingPatterns(rootOp); - if (testEmptyOpFolding) - applyEmptyOpFoldingPatterns(rootOp); - if (testFoldIntoPackAndUnpack) - applyFoldIntoPackAndUnpackPatterns(rootOp); if (testRewriteExtractSliceWithTiledCollapseShape) { if (failed( applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach))) diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt --- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_library(MLIRTestTransformDialect TestTransformDialectExtension.cpp TestTransformDialectInterpreter.cpp + TestRegisteredPatterns.cpp TestTransformStateExtension.cpp EXCLUDE_FROM_LIBMLIR @@ -22,4 +23,5 @@ MLIRTransformDialect MLIRTransformDialectTransforms MLIRTransformPDLExtension + MLIRTransforms ) diff --git a/mlir/test/lib/Dialect/Transform/TestRegisteredPatterns.cpp b/mlir/test/lib/Dialect/Transform/TestRegisteredPatterns.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Transform/TestRegisteredPatterns.cpp @@ -0,0 +1,64 @@ +//===- TestRegisteredPatterns.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestRegisteredPatternsPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRegisteredPatternsPass) + + TestRegisteredPatternsPass() = default; + TestRegisteredPatternsPass(const TestRegisteredPatternsPass &pass) + : PassWrapper(pass) {} + + StringRef getArgument() const final { return "test-registered-patterns"; } + StringRef getDescription() const final { + return "Test greedy application of registered patterns."; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; + + ListOption patterns{ + *this, "patterns", llvm::cl::desc("Pattens that should be applied")}; +}; +} // namespace + +void TestRegisteredPatternsPass::runOnOperation() { + MLIRContext *ctx = getOperation()->getContext(); + Builder builder(ctx); + const auto ®istry = ctx->getLoadedDialect() + ->getExtraData(); + RewritePatternSet set(ctx); + for (const std::string &identifier : patterns) { + registry.populatePatterns(builder.getStringAttr(identifier), set); + } + + LogicalResult result = + applyPatternsAndFoldGreedily(getOperation(), std::move(set)); + if (failed(result)) { + getOperation()->emitOpError("greedy pattern application failed"); + return signalPassFailure(); + } +} + +namespace mlir { +namespace test { +void registerTestRegisteredPatternsPass() { + PassRegistration reg; +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -119,6 +119,7 @@ void registerTestPDLLPasses(); void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); +void registerTestRegisteredPatternsPass(); void registerTestSCFUtilsPass(); void registerTestSCFWhileOpBuilderPass(); void registerTestShapeMappingPass(); @@ -235,6 +236,7 @@ mlir::test::registerTestPDLByteCodePass(); mlir::test::registerTestPDLLPasses(); mlir::test::registerTestRecursiveTypesPass(); + mlir::test::registerTestRegisteredPatternsPass(); mlir::test::registerTestSCFUtilsPass(); mlir::test::registerTestSCFWhileOpBuilderPass(); mlir::test::registerTestShapeMappingPass(); diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -336,6 +336,7 @@ "//mlir:TransformDialect", "//mlir:TransformDialectTransforms", "//mlir:TransformPDLExtension", + "//mlir:Transforms", ], )