diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -138,20 +138,27 @@ const PDLByteCode::MatchResult *pdlMatch = nullptr; /// Operation specific patterns. if (opIt != opE) - bestPattern = *(opIt++); + bestPattern = *opIt; /// Operation agnostic patterns. if (anyIt != anyE && (!bestPattern || bestPattern->getBenefit() < (*anyIt)->getBenefit())) - bestPattern = *(anyIt++); + bestPattern = *anyIt; /// PDL patterns. if (pdlIt != pdlE && (!bestPattern || bestPattern->getBenefit() < pdlIt->benefit)) { pdlMatch = pdlIt; - bestPattern = (pdlIt++)->pattern; + bestPattern = pdlIt->pattern; } if (!bestPattern) break; + if (opIt != opE && bestPattern == *opIt) + opIt++; + else if (anyIt != anyE && bestPattern == *anyIt) + anyIt++; + else if (pdlIt != pdlE && bestPattern == pdlIt->pattern) + pdlIt++; + // Check that the pattern can be applied. if (canApply && !canApply(*bestPattern)) continue; diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -10,5 +10,6 @@ add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(Pass) +add_subdirectory(Rewrite) add_subdirectory(SDBM) add_subdirectory(TableGen) diff --git a/mlir/unittests/Rewrite/CMakeLists.txt b/mlir/unittests/Rewrite/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Rewrite/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRRewriteTests + PatternBenefit.cpp +) +target_link_libraries(MLIRRewriteTests + PRIVATE + MLIRRewrite + MLIRTransformUtils) diff --git a/mlir/unittests/Rewrite/PatternBenefit.cpp b/mlir/unittests/Rewrite/PatternBenefit.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Rewrite/PatternBenefit.cpp @@ -0,0 +1,69 @@ +//===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "gtest/gtest.h" + +using namespace mlir; + +namespace { +TEST(PatternBenefitTest, BenefitOrder) { + MLIRContext context; + + OpBuilder builder(&context); + auto module = ModuleOp::create(builder.getUnknownLoc()); + builder.setInsertionPointToStart(module.getBody()); + auto type = FunctionType::get(&context, llvm::None, llvm::None); + builder.create(builder.getUnknownLoc(), "test", type); + + struct Pattern1 : public OpRewritePattern { + Pattern1(mlir::MLIRContext *context, bool *called) + : OpRewritePattern(context, /*benefit*/ 1), called(called) {} + + mlir::LogicalResult + matchAndRewrite(FuncOp /*op*/, + mlir::PatternRewriter & /*rewriter*/) const override { + *called = true; + return failure(); + } + + private: + bool *called; + }; + + struct Pattern2 : public RewritePattern { + Pattern2(bool *called) + : RewritePattern(/*benefit*/ 2, MatchAnyOpTypeTag{}), called(called) {} + + mlir::LogicalResult + matchAndRewrite(Operation * /*op*/, + mlir::PatternRewriter & /*rewriter*/) const override { + *called = true; + return failure(); + } + + private: + bool *called; + }; + + OwningRewritePatternList patterns; + + bool called1 = false; + bool called2 = false; + + patterns.insert(&context, &called1); + patterns.insert(&called2); + + auto res = applyPatternsAndFoldGreedily(module, std::move(patterns)); + + EXPECT_TRUE(succeeded(res)); + EXPECT_TRUE(called1); + EXPECT_TRUE(called2); +} +} // namespace