diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -96,10 +96,11 @@ /// Non-pattern based folder for operations. OperationFolder folder; -private: +protected: /// Configuration information for how to simplify. GreedyRewriteConfig config; +private: #ifndef NDEBUG /// A logger used to emit information during the application process. llvm::ScopedPrinter logger{llvm::dbgs()}; @@ -147,8 +148,13 @@ }; bool changed = false; - unsigned iteration = 0; + int64_t iteration = 0; do { + // Check if the iteration limit was reached. + if (iteration++ >= config.maxIterations && + config.maxIterations != GreedyRewriteConfig::kNoLimit) + break; + worklist.clear(); worklistMap.clear(); @@ -184,7 +190,9 @@ changed = false; int64_t numRewrites = 0; - while (!worklist.empty()) { + while (!worklist.empty() && + (numRewrites < config.maxNumRewrites || + config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { auto *op = popFromWorklist(); // Nulls get added to the worklist when operations are removed, ignore @@ -280,11 +288,10 @@ #else LogicalResult matchResult = matcher.matchAndRewrite(op, *this); #endif + if (succeeded(matchResult)) { changed = true; - if (numRewrites++ >= config.maxNumRewrites && - config.maxNumRewrites != GreedyRewriteConfig::kNoLimit) - break; + ++numRewrites; } } @@ -292,8 +299,7 @@ // is kept up to date. if (config.enableRegionSimplification) changed |= succeeded(simplifyRegions(*this, regions)); - } while (changed && (iteration++ < config.maxIterations || - config.maxIterations == GreedyRewriteConfig::kNoLimit)); + } while (changed); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return !changed; @@ -421,7 +427,7 @@ GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config); bool converged = driver.simplify(regions); LLVM_DEBUG(if (!converged) { - llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " + llvm::dbgs() << "The pattern rewrite did not converge after scanning " << config.maxIterations << " times\n"; }); return success(converged); @@ -443,7 +449,8 @@ matcher.applyDefaultCostModel(); } - LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased); + LogicalResult simplifyLocally(Operation *op, int64_t maxNumRewrites, + bool &erased); // These are hooks implemented for PatternRewriter. protected: @@ -473,18 +480,22 @@ /// Performs the rewrites and folding only on `op`. The simplification /// converges if the op is erased as a result of being folded, replaced, or /// becoming dead, or no more changes happen in an iteration. Returns success if -/// the rewrite converges in `maxIterations`. `erased` is set to true if `op` +/// the rewrite converges in `maxNumRewrites`. `erased` is set to true if `op` /// gets erased. LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, - int maxIterations, + int64_t maxNumRewrites, bool &erased) { bool changed = false; erased = false; opErasedViaPatternRewrites = false; - int iterations = 0; - // Iterate until convergence or until maxIterations. Deletion of the op as + int64_t numRewrites = 0; + // Iterate until convergence or until maxNumRewrites. Deletion of the op as // a result of being dead or folded is convergence. do { + if (numRewrites >= maxNumRewrites && + maxNumRewrites != GreedyRewriteConfig::kNoLimit) + break; + changed = false; // If the operation is trivially dead - remove it. @@ -508,11 +519,13 @@ // Try to match one of the patterns. The rewriter is automatically // notified of any necessary changes, so there is nothing else to do here. - changed |= succeeded(matcher.matchAndRewrite(op, *this)); + if (succeeded(matcher.matchAndRewrite(op, *this))) { + changed = true; + ++numRewrites; + } if ((erased = opErasedViaPatternRewrites)) return success(); - } while (changed && (++iterations < maxIterations || - maxIterations == GreedyRewriteConfig::kNoLimit)); + } while (changed); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return failure(changed); @@ -601,7 +614,10 @@ // These are scratch vectors used in the folding loop below. SmallVector originalOperands, resultValues; - while (!worklist.empty()) { + int64_t numRewrites = 0; + while (!worklist.empty() && + (numRewrites < config.maxNumRewrites || + config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { Operation *op = popFromWorklist(); // Nulls get added to the worklist when operations are removed, ignore @@ -656,7 +672,10 @@ // Try to match one of the patterns. The rewriter is automatically // notified of any necessary changes, so there is nothing else to do // here. - changed |= succeeded(matcher.matchAndRewrite(op, *this)); + if (succeeded(matcher.matchAndRewrite(op, *this))) { + changed = true; + ++numRewrites; + } } return changed; @@ -672,12 +691,12 @@ OpPatternRewriteDriver driver(op->getContext(), patterns); bool opErased; LogicalResult converged = - driver.simplifyLocally(op, config.maxIterations, opErased); + driver.simplifyLocally(op, config.maxNumRewrites, opErased); if (erased) *erased = opErased; LLVM_DEBUG(if (failed(converged)) { - llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " - << config.maxIterations << " times"; + llvm::dbgs() << "The pattern rewrite did not converge after " + << config.maxNumRewrites << " rewrites"; }); return converged; }