diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -33,8 +33,11 @@ bool enableRegionSimplification = true; /// This specifies the maximum number of times the rewriter will iterate - /// between applying patterns and simplifying regions. - unsigned maxIterations = 10; + /// between applying patterns and simplifying regions. Use `kNoIterationLimit` + /// to disable this iteration limit. + int64_t maxIterations = 10; + + static constexpr int64_t kNoIterationLimit = -1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -378,7 +378,7 @@ Option<"enableRegionSimplification", "region-simplify", "bool", /*default=*/"true", "Seed the worklist in general top-down order">, - Option<"maxIterations", "max-iterations", "unsigned", + Option<"maxIterations", "max-iterations", "int64_t", /*default=*/"10", "Seed the worklist in general top-down order"> ] # RewritePassUtils.options; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -222,7 +222,9 @@ // is kept up to date. if (config.enableRegionSimplification) changed |= succeeded(simplifyRegions(*this, regions)); - } while (changed && ++iteration < config.maxIterations); + } while (changed && + (++iteration < config.maxIterations || + config.maxIterations == GreedyRewriteConfig::kNoIterationLimit)); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return !changed; @@ -345,7 +347,9 @@ changed |= succeeded(matcher.matchAndRewrite(op, *this)); if ((erased = opErasedViaPatternRewrites)) return success(); - } while (changed && ++iterations < maxIterations); + } while (changed && + (++iterations < maxIterations || + maxIterations == GreedyRewriteConfig::kNoIterationLimit)); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return failure(changed);