diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -10,5 +10,6 @@ add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(Pass) +add_subdirectory(Rewrite) add_subdirectory(SDBM) add_subdirectory(TableGen) diff --git a/mlir/unittests/Rewrite/CMakeLists.txt b/mlir/unittests/Rewrite/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Rewrite/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRRewriteTests + PatternBenefit.cpp +) +target_link_libraries(MLIRRewriteTests + PRIVATE + MLIRRewrite + MLIRTransformUtils) diff --git a/mlir/unittests/Rewrite/PatternBenefit.cpp b/mlir/unittests/Rewrite/PatternBenefit.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Rewrite/PatternBenefit.cpp @@ -0,0 +1,78 @@ +//===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/PatternApplicator.h" +#include "gtest/gtest.h" + +using namespace mlir; + +namespace { +TEST(PatternBenefitTest, BenefitOrder) { + // There was a bug which caused low-benefit op-specific patterns to never be + // called in presence of high-benefit op-agnostic pattern + + MLIRContext context; + + OpBuilder builder(&context); + auto module = ModuleOp::create(builder.getUnknownLoc()); + + struct Pattern1 : public OpRewritePattern { + Pattern1(mlir::MLIRContext *context, bool *called) + : OpRewritePattern(context, /*benefit*/ 1), called(called) {} + + mlir::LogicalResult + matchAndRewrite(ModuleOp /*op*/, + mlir::PatternRewriter & /*rewriter*/) const override { + *called = true; + return failure(); + } + + private: + bool *called; + }; + + struct Pattern2 : public RewritePattern { + Pattern2(bool *called) + : RewritePattern(/*benefit*/ 2, MatchAnyOpTypeTag{}), called(called) {} + + mlir::LogicalResult + matchAndRewrite(Operation * /*op*/, + mlir::PatternRewriter & /*rewriter*/) const override { + *called = true; + return failure(); + } + + private: + bool *called; + }; + + OwningRewritePatternList patterns; + + bool called1 = false; + bool called2 = false; + + patterns.insert(&context, &called1); + patterns.insert(&called2); + + FrozenRewritePatternList frozenPatterns(std::move(patterns)); + PatternApplicator pa(frozenPatterns); + pa.applyDefaultCostModel(); + + class MyPatternRewriter : public PatternRewriter { + public: + MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {} + }; + + MyPatternRewriter rewriter(&context); + (void)pa.matchAndRewrite(module, rewriter); + + EXPECT_TRUE(called1); + EXPECT_TRUE(called2); +} +} // namespace