diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -24,20 +24,38 @@ /// Rewrite the regions of the specified operation, which must be isolated from /// above, by repeatedly applying the highest benefit patterns in a greedy -/// work-list driven manner. Return success if no more patterns can be matched -/// in the result operation regions. -/// Note: This does not apply patterns to the top-level operation itself. Note: +/// work-list driven manner. +/// This variant may stop after a predefined number of iterations, see the +/// alternative below to provide a specific number of iterations before stopping +/// in absence of convergence. +/// Return success if the iterative process converged and no more patterns can +/// be matched in the result operation regions. +/// Note: This does not apply patterns to the top-level operation itself. /// These methods also perform folding and simple dead-code elimination /// before attempting to match any of the provided patterns. -/// LogicalResult applyPatternsAndFoldGreedily(Operation *op, const FrozenRewritePatternList &patterns); + +/// Rewrite the regions of the specified operation, with a user-provided limit +/// on iterations to attempt before reaching convergence. +LogicalResult +applyPatternsAndFoldGreedily(Operation *op, + const FrozenRewritePatternList &patterns, + unsigned maxIterations); + /// Rewrite the given regions, which must be isolated from above. LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef regions, const FrozenRewritePatternList &patterns); +/// Rewrite the given regions, with a user-provided limit on iterations to +/// attempt before reaching convergence. +LogicalResult +applyPatternsAndFoldGreedily(MutableArrayRef regions, + const FrozenRewritePatternList &patterns, + unsigned maxIterations); + /// Applies the specified patterns on `op` alone while also trying to fold it, /// by selecting the highest benefits patterns in a greedy manner. Returns /// success if no more patterns can be matched. `erased` is set to true if `op` diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -220,12 +220,26 @@ LogicalResult mlir::applyPatternsAndFoldGreedily(Operation *op, const FrozenRewritePatternList &patterns) { - return applyPatternsAndFoldGreedily(op->getRegions(), patterns); + return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations); +} +LogicalResult +mlir::applyPatternsAndFoldGreedily(Operation *op, + const FrozenRewritePatternList &patterns, + unsigned maxIterations) { + return applyPatternsAndFoldGreedily(op->getRegions(), patterns, + maxIterations); } /// Rewrite the given regions, which must be isolated from above. LogicalResult mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, const FrozenRewritePatternList &patterns) { + return applyPatternsAndFoldGreedily(regions, patterns, + maxPatternMatchIterations); +} +LogicalResult +mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, + const FrozenRewritePatternList &patterns, + unsigned maxIterations) { if (regions.empty()) return success(); @@ -241,10 +255,10 @@ // Start the pattern driver. GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns); - bool converged = driver.simplify(regions, maxPatternMatchIterations); + bool converged = driver.simplify(regions, maxIterations); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " - << maxPatternMatchIterations << " times"; + << maxIterations << " times"; }); return success(converged); }