diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -33,10 +33,21 @@ LogicalResult applyPatternsAndFoldGreedily(Operation *op, const FrozenRewritePatternList &patterns); +/// Rewrite with max_iterations +LogicalResult +applyPatternsAndFoldGreedily(Operation *op, + const FrozenRewritePatternList &patterns, + unsigned maxIterations); + /// Rewrite the given regions, which must be isolated from above. LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef regions, const FrozenRewritePatternList &patterns); +/// Rewrite with max_iterations +LogicalResult +applyPatternsAndFoldGreedily(MutableArrayRef regions, + const FrozenRewritePatternList &patterns, + unsigned maxIterations); /// Applies the specified patterns on `op` alone while also trying to fold it, /// by selecting the highest benefits patterns in a greedy manner. Returns diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -220,12 +220,24 @@ LogicalResult mlir::applyPatternsAndFoldGreedily(Operation *op, const FrozenRewritePatternList &patterns) { - return applyPatternsAndFoldGreedily(op->getRegions(), patterns); + return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations); +} +LogicalResult +mlir::applyPatternsAndFoldGreedily(Operation *op, + const FrozenRewritePatternList &patterns, + unsigned maxIterations) { + return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations); } /// Rewrite the given regions, which must be isolated from above. LogicalResult mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, const FrozenRewritePatternList &patterns) { + return applyPatternsAndFoldGreedily(regions, patterns, maxPatternMatchIterations); +} +LogicalResult +mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, + const FrozenRewritePatternList &patterns, + unsigned maxIterations) { if (regions.empty()) return success(); @@ -241,10 +253,10 @@ // Start the pattern driver. GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns); - bool converged = driver.simplify(regions, maxPatternMatchIterations); + bool converged = driver.simplify(regions, maxIterations); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " - << maxPatternMatchIterations << " times"; + << maxIterations << " times"; }); return success(converged); }