diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -94,25 +94,36 @@ /// in absence of convergence. /// /// Return success if the iterative process converged and no more patterns can -/// be matched in the result operation regions. +/// be matched in the result operation regions. `changed` is set to true if the +/// IR was modified at all. /// /// Note: This does not apply patterns to the top-level operation itself. /// These methods also perform folding and simple dead-code elimination /// before attempting to match any of the provided patterns. /// /// You may configure several aspects of this with GreedyRewriteConfig. -LogicalResult applyPatternsAndFoldGreedily( - Region ®ion, const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config = GreedyRewriteConfig()); +LogicalResult +applyPatternsAndFoldGreedily(Region ®ion, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr); /// Rewrite ops in all regions of the given op, which must be isolated from /// above. -inline LogicalResult applyPatternsAndFoldGreedily( - Operation *op, const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config = GreedyRewriteConfig()) { +inline LogicalResult +applyPatternsAndFoldGreedily(Operation *op, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr) { bool failed = false; - for (Region ®ion : op->getRegions()) - failed |= applyPatternsAndFoldGreedily(region, patterns, config).failed(); + for (Region ®ion : op->getRegions()) { + bool regionChanged; + failed |= + applyPatternsAndFoldGreedily(region, patterns, config, ®ionChanged) + .failed(); + if (changed) + *changed |= regionChanged; + } return failure(failed); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -616,7 +616,7 @@ /// Simplify ops inside `region` and simplify the region itself. Return /// success if the transformation converged. - LogicalResult simplify() &&; + LogicalResult simplify(bool *changed) &&; private: /// The region that is simplified. @@ -652,7 +652,7 @@ }; } // namespace -LogicalResult RegionPatternRewriteDriver::simplify() && { +LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { auto insertKnownConstant = [&](Operation *op) { // Check for existing constants when populating the worklist. This avoids // accidentally reversing the constant order during processing. @@ -663,12 +663,12 @@ return false; }; - bool changed = false; + bool continueRewrites = false; int64_t iteration = 0; MLIRContext *ctx = getContext(); do { // Check if the iteration limit was reached. - if (iteration++ >= config.maxIterations && + if (++iteration > config.maxIterations && config.maxIterations != GreedyRewriteConfig::kNoLimit) break; @@ -696,24 +696,27 @@ ctx->executeAction( [&] { - changed = processWorklist(); + continueRewrites = processWorklist(); // After applying patterns, make sure that the CFG of each of the // regions is kept up to date. if (config.enableRegionSimplification) - changed |= succeeded(simplifyRegions(*this, region)); + continueRewrites |= succeeded(simplifyRegions(*this, region)); }, {®ion}, iteration); - } while (changed); + } while (continueRewrites); + + if (changed) + *changed = iteration > 1; // Whether the rewrite converges, i.e. wasn't changed in the last iteration. - return success(!changed); + return success(!continueRewrites); } LogicalResult mlir::applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config) { + GreedyRewriteConfig config, bool *changed) { // The top-level operation must be known to be isolated from above to // prevent performing canonicalizations on operations defined at or above // the region containing 'op'. @@ -727,7 +730,7 @@ // Start the pattern driver. RegionPatternRewriteDriver driver(region.getContext(), patterns, config, region); - LogicalResult converged = std::move(driver).simplify(); + LogicalResult converged = std::move(driver).simplify(changed); LLVM_DEBUG(if (failed(converged)) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " << config.maxIterations << " times\n";