diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -16,6 +16,8 @@ #include "mlir/Rewrite/FrozenRewritePatternList.h" +#define MAX_PATTERN_MATCH_ITERATIONS 10 + namespace mlir { //===----------------------------------------------------------------------===// @@ -32,11 +34,13 @@ /// LogicalResult applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternList &patterns); + const FrozenRewritePatternList &patterns, + unsigned max_iterations = MAX_PATTERN_MATCH_ITERATIONS); /// Rewrite the given regions, which must be isolated from above. LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternList &patterns); + const FrozenRewritePatternList &patterns, + unsigned max_iterations = MAX_PATTERN_MATCH_ITERATIONS); /// Applies the specified patterns on `op` alone while also trying to fold it, /// by selecting the highest benefits patterns in a greedy manner. Returns diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -25,7 +25,7 @@ #define DEBUG_TYPE "pattern-matcher" /// The max number of iterations scanning for pattern match. -static unsigned maxPatternMatchIterations = 10; +static unsigned maxPatternMatchIterations = MAX_PATTERN_MATCH_ITERATIONS; //===----------------------------------------------------------------------===// // GreedyPatternRewriteDriver @@ -219,13 +219,15 @@ /// LogicalResult mlir::applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternList &patterns) { + const FrozenRewritePatternList &patterns, + unsigned max_iterations /*= MAXPATTERNMATCHITERATIONS*/) { return applyPatternsAndFoldGreedily(op->getRegions(), patterns); } /// Rewrite the given regions, which must be isolated from above. LogicalResult mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternList &patterns) { + const FrozenRewritePatternList &patterns, + unsigned max_iterations /*= MAXPATTERNMATCHITERATIONS*/) { if (regions.empty()) return success(); @@ -241,10 +243,10 @@ // Start the pattern driver. GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns); - bool converged = driver.simplify(regions, maxPatternMatchIterations); + bool converged = driver.simplify(regions, max_iterations); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " - << maxPatternMatchIterations << " times"; + << max_iterations << " times"; }); return success(converged); }