diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -39,10 +39,11 @@ public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config); + const GreedyRewriteConfig &config, + const DenseSet &scope); /// Simplify the operations within the given regions. - bool simplify(MutableArrayRef regions); + bool simplify(MutableArrayRef regions) &&; /// Add the given operation and its ancestors to the worklist. void addToWorklist(Operation *op); @@ -100,12 +101,10 @@ protected: /// Configuration information for how to simplify. - GreedyRewriteConfig config; + const GreedyRewriteConfig config; - /// Only ops within this scope are simplified. This is set at the beginning - /// of `simplify()` and `simplifyLocally()` to the current scope the rewriter - /// operates on. - DenseSet scope; + /// Only ops within this scope are simplified. + const DenseSet scope; private: #ifndef NDEBUG @@ -117,19 +116,16 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config) - : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { + const GreedyRewriteConfig &config, const DenseSet &scope) + : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config), + scope(scope) { worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); } -bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { - scope.clear(); - for (Region &r : regions) - scope.insert(&r); - +bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) && { #ifndef NDEBUG const char *logLineComment = "//===-------------------------------------------===//\n"; @@ -449,9 +445,15 @@ assert(llvm::all_of(regions, regionIsIsolated) && "patterns can only be applied to operations IsolatedFromAbove"); + // Limit ops on the worklist to this scope. + DenseSet scope; + for (Region &r : regions) + scope.insert(&r); + // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config); - bool converged = driver.simplify(regions); + GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config, + scope); + bool converged = std::move(driver).simplify(regions); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " << config.maxIterations << " times\n"; @@ -472,11 +474,12 @@ /// ops are not considered. class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: - explicit MultiOpPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode) - : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), - strictMode(strictMode) {} + explicit MultiOpPatternRewriteDriver( + MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + const DenseSet &scope, GreedyRewriteStrictness strictMode, + llvm::SmallDenseSet *survivingOps = nullptr) + : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope), + strictMode(strictMode), survivingOps(survivingOps) {} /// Performs the specified rewrites on `ops` while also trying to fold these /// ops. `strictMode` controls which other ops are simplified. Only ops @@ -486,11 +489,9 @@ /// Note that ops in `ops` could be erased as a result of folding, becoming /// dead, or via pattern rewrites. The return value indicates convergence. /// - /// All `ops` that survived the rewrite are stored in `surviving`. - LogicalResult - simplifyLocally(ArrayRef ops, bool *changed = nullptr, - llvm::SmallDenseSet *surviving = nullptr, - Region *scope = nullptr); + /// All erased ops are stored in `erased`. + LogicalResult simplifyLocally(ArrayRef op, + bool *changed = nullptr) &&; protected: void addSingleOpToWorklist(Operation *op) override { @@ -516,7 +517,7 @@ /// `strictMode` control which ops are added to the worklist during /// simplification. - GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; + const GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; /// The list of ops we are restricting our rewrites to. These include the /// supplied set of ops as well as new ops created while rewriting those ops @@ -524,20 +525,18 @@ /// is GreedyRewriteStrictness::AnyOp. llvm::SmallDenseSet strictModeFilteredOps; - /// An optional set of ops that survived the rewrite. This set is populated + /// An optional set of ops that were erased. This set is populated /// at the beginning of `simplifyLocally` with the inititally provided list /// of ops. - llvm::SmallDenseSet *survivingOps = nullptr; + llvm::SmallDenseSet *const survivingOps = nullptr; }; } // namespace -LogicalResult MultiOpPatternRewriteDriver::simplifyLocally( - ArrayRef ops, bool *changed, - llvm::SmallDenseSet *surviving, Region *scope) { - auto cleanup = llvm::make_scope_exit([&]() { survivingOps = nullptr; }); - if (surviving) { - survivingOps = surviving; +LogicalResult +MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, + bool *changed) && { + if (survivingOps) { survivingOps->clear(); survivingOps->insert(ops.begin(), ops.end()); } @@ -547,10 +546,6 @@ strictModeFilteredOps.insert(ops.begin(), ops.end()); } - assert(scope && "scope is mandatory"); - this->scope.clear(); - this->scope.insert(scope); - if (changed) *changed = false; worklist.clear(); @@ -684,11 +679,13 @@ } // Start the pattern driver. - MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - strictMode); llvm::SmallDenseSet surviving; - LogicalResult converged = driver.simplifyLocally( - ops, changed, allErased ? &surviving : nullptr, /*scope=*/scope); + DenseSet scopeSet; + scopeSet.insert(scope); + MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, + scopeSet, strictMode, + allErased ? &surviving : nullptr); + LogicalResult converged = std::move(driver).simplifyLocally(ops, changed); if (allErased) *allErased = surviving.empty(); LLVM_DEBUG(if (failed(converged)) {