diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -33,11 +33,15 @@ bool enableRegionSimplification = true; /// This specifies the maximum number of times the rewriter will iterate - /// between applying patterns and simplifying regions. Use `kNoIterationLimit` - /// to disable this iteration limit. + /// between applying patterns and simplifying regions. Use `kNoLimit` to + /// disable this iteration limit. int64_t maxIterations = 10; - static constexpr int64_t kNoIterationLimit = -1; + /// This specifies the maximum number of rewrites within an iteration. Use + /// `kNoLimit` to disable this limit. + int64_t maxNumRewrites = kNoLimit; + + static constexpr int64_t kNoLimit = -1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -30,10 +30,12 @@ "Seed the worklist in general top-down order">, Option<"enableRegionSimplification", "region-simplify", "bool", /*default=*/"true", - "Seed the worklist in general top-down order">, + "Perform control flow optimizations to the region tree">, Option<"maxIterations", "max-iterations", "int64_t", /*default=*/"10", - "Seed the worklist in general top-down order"> + "Max. iterations between applying patterns / simplifying regions">, + Option<"maxNumRewrites", "max-num-rewrites", "int64_t", /*default=*/"-1", + "Max. number of pattern rewrites within an iteration"> ] # RewritePassUtils.options; } diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -33,6 +33,7 @@ this->topDownProcessingEnabled = config.useTopDownTraversal; this->enableRegionSimplification = config.enableRegionSimplification; this->maxIterations = config.maxIterations; + this->maxNumRewrites = config.maxNumRewrites; this->disabledPatterns = disabledPatterns; this->enabledPatterns = enabledPatterns; } @@ -55,6 +56,7 @@ config.useTopDownTraversal = topDownProcessingEnabled; config.enableRegionSimplification = enableRegionSimplification; config.maxIterations = maxIterations; + config.maxNumRewrites = maxNumRewrites; (void)applyPatternsAndFoldGreedily(getOperation(), patterns, config); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -183,6 +183,7 @@ SmallVector originalOperands, resultValues; changed = false; + int64_t numRewrites = 0; while (!worklist.empty()) { auto *op = popFromWorklist(); @@ -279,16 +280,20 @@ #else LogicalResult matchResult = matcher.matchAndRewrite(op, *this); #endif - changed |= succeeded(matchResult); + if (succeeded(matchResult)) { + changed = true; + if (numRewrites++ >= config.maxNumRewrites && + config.maxNumRewrites != GreedyRewriteConfig::kNoLimit) + break; + } } // After applying patterns, make sure that the CFG of each of the regions // is kept up to date. if (config.enableRegionSimplification) changed |= succeeded(simplifyRegions(*this, regions)); - } while (changed && - (iteration++ < config.maxIterations || - config.maxIterations == GreedyRewriteConfig::kNoIterationLimit)); + } while (changed && (iteration++ < config.maxIterations || + config.maxIterations == GreedyRewriteConfig::kNoLimit)); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return !changed; @@ -506,9 +511,8 @@ changed |= succeeded(matcher.matchAndRewrite(op, *this)); if ((erased = opErasedViaPatternRewrites)) return success(); - } while (changed && - (++iterations < maxIterations || - maxIterations == GreedyRewriteConfig::kNoIterationLimit)); + } while (changed && (++iterations < maxIterations || + maxIterations == GreedyRewriteConfig::kNoLimit)); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return failure(changed);