diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -37,10 +37,11 @@ public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config); + const GreedyRewriteConfig &config, + const DenseSet &scope); /// Simplify the operations within the given regions. - bool simplify(MutableArrayRef regions); + bool simplify(MutableArrayRef regions) &&; /// Add the given operation and its ancestors to the worklist. void addToWorklist(Operation *op); @@ -98,7 +99,7 @@ protected: /// Configuration information for how to simplify. - GreedyRewriteConfig config; + const GreedyRewriteConfig config; /// Only ops within this scope are simplified. This is set at the beginning /// of `simplify()` to the current scope the rewriter operates on. @@ -114,19 +115,16 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config) - : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { + const GreedyRewriteConfig &config, const DenseSet &scope) + : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config), + scope(scope) { worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); } -bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { - scope.clear(); - for (Region &r : regions) - scope.insert(&r); - +bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) && { #ifndef NDEBUG const char *logLineComment = "//===-------------------------------------------===//\n"; @@ -446,9 +444,15 @@ assert(llvm::all_of(regions, regionIsIsolated) && "patterns can only be applied to operations IsolatedFromAbove"); + // Limit ops on the worklist to this scope. + DenseSet scope; + for (Region &r : regions) + scope.insert(&r); + // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config); - bool converged = driver.simplify(regions); + GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config, + scope); + bool converged = std::move(driver).simplify(regions); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " << config.maxIterations << " times\n"; @@ -469,11 +473,12 @@ /// ops are not considered. class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: - explicit MultiOpPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode) - : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), - strictMode(strictMode) {} + explicit MultiOpPatternRewriteDriver( + MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + const DenseSet &scope, GreedyRewriteStrictness strictMode, + llvm::SmallDenseSet *survivingOps = nullptr) + : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope), + strictMode(strictMode), survivingOps(survivingOps) {} /// Performs the specified rewrites on `ops` while also trying to fold these /// ops. `strictMode` controls which other ops are simplified. Only ops @@ -483,11 +488,9 @@ /// Note that ops in `ops` could be erased as a result of folding, becoming /// dead, or via pattern rewrites. The return value indicates convergence. /// - /// All `ops` that survived the rewrite are stored in `surviving`. - LogicalResult - simplifyLocally(ArrayRef ops, bool *changed = nullptr, - llvm::SmallDenseSet *surviving = nullptr, - Region *scope = nullptr); + /// All erased ops are stored in `erased`. + LogicalResult simplifyLocally(ArrayRef op, + bool *changed = nullptr) &&; protected: void addSingleOpToWorklist(Operation *op) override { @@ -513,7 +516,7 @@ /// `strictMode` control which ops are added to the worklist during /// simplificiation. - GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; + const GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; /// The list of ops we are restricting our rewrites to. These include the /// supplied set of ops as well as new ops created while rewriting those ops @@ -521,17 +524,16 @@ /// is GreedyRewriteStrictness::AnyOp. llvm::SmallDenseSet strictModeFilteredOps; - /// An optional set of ops that survived the rewrite. - llvm::SmallDenseSet *survivingOps = nullptr; + /// An optional set of ops that were erased. + llvm::SmallDenseSet *const survivingOps = nullptr; }; } // namespace -LogicalResult MultiOpPatternRewriteDriver::simplifyLocally( - ArrayRef ops, bool *changed, - llvm::SmallDenseSet *surviving, Region *scope) { - if (surviving) { - survivingOps = surviving; +LogicalResult +MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, + bool *changed) && { + if (survivingOps) { survivingOps->clear(); survivingOps->insert(ops.begin(), ops.end()); } @@ -541,9 +543,6 @@ strictModeFilteredOps.insert(ops.begin(), ops.end()); } - this->scope.clear(); - this->scope.insert(scope); - if (changed) *changed = false; worklist.clear(); @@ -621,7 +620,6 @@ } } - surviving = nullptr; return success(worklist.empty()); } @@ -665,11 +663,13 @@ } // Start the pattern driver. - MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - strictMode); llvm::SmallDenseSet surviving; - LogicalResult converged = driver.simplifyLocally( - ops, changed, allErased ? &surviving : nullptr, /*scope=*/scope); + DenseSet scopeSet; + scopeSet.insert(findCommonAncestor(ops)); + MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, + scopeSet, strictMode, + allErased ? &surviving : nullptr); + LogicalResult converged = std::move(driver).simplifyLocally(ops, changed); if (allErased) *allErased = surviving.empty(); LLVM_DEBUG(if (failed(converged)) {