diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -35,12 +35,13 @@ /// applies the locally optimal patterns in a roughly "bottom up" way. class GreedyPatternRewriteDriver : public PatternRewriter { public: - explicit GreedyPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config); + explicit GreedyPatternRewriteDriver( + MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, + const std::optional> &scope); /// Simplify the operations within the given regions. - bool simplify(MutableArrayRef regions); + bool simplify(MutableArrayRef regions) &&; /// Add the given operation and its ancestors to the worklist. void addToWorklist(Operation *op); @@ -98,12 +99,11 @@ protected: /// Configuration information for how to simplify. - GreedyRewriteConfig config; + const GreedyRewriteConfig config; private: - /// Only ops within this scope are simplified. This is set at the beginning - /// of `simplify()` to the current scope the rewriter operates on. - std::optional> scope = {}; + /// Only ops within this scope are simplified. + const std::optional> scope = {}; #ifndef NDEBUG /// A logger used to emit information during the application process. @@ -114,19 +114,17 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config) - : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { + const GreedyRewriteConfig &config, + const std::optional> &scope) + : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config), + scope(scope) { worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); } -bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { - DenseSet &s = scope.emplace(); - for (Region &r : regions) - s.insert(&r); - +bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) && { #ifndef NDEBUG const char *logLineComment = "//===-------------------------------------------===//\n"; @@ -451,9 +449,15 @@ assert(llvm::all_of(regions, regionIsIsolated) && "patterns can only be applied to operations IsolatedFromAbove"); + // Limit ops on the worklist to this scope. + DenseSet scope; + for (Region &r : regions) + scope.insert(&r); + // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config); - bool converged = driver.simplify(regions); + GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config, + scope); + bool converged = std::move(driver).simplify(regions); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " << config.maxIterations << " times\n"; @@ -474,11 +478,13 @@ /// ops are not considered. class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: - explicit MultiOpPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode) - : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), - strictMode(strictMode) {} + explicit MultiOpPatternRewriteDriver( + MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + GreedyRewriteStrictness strictMode, + DenseSet *erasedOps = nullptr) + : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), + /*scope=*/{}), + strictMode(strictMode), erasedOps(erasedOps) {} /// Performs the specified rewrites on `ops` while also trying to fold these /// ops. `strictMode` controls which other ops are simplified. @@ -488,8 +494,7 @@ /// /// All erased ops are stored in `erased`. LogicalResult simplifyLocally(ArrayRef op, - bool *changed = nullptr, - DenseSet *erased = nullptr); + bool *changed = nullptr) &&; protected: void addSingleOpToWorklist(Operation *op) override { @@ -515,7 +520,7 @@ /// `strictMode` control which ops are added to the worklist during /// simplificiation. - GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; + const GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; /// The list of ops we are restricting our rewrites to. These include the /// supplied set of ops as well as new ops created while rewriting those ops @@ -524,15 +529,14 @@ llvm::SmallDenseSet strictModeFilteredOps; /// An optional set of ops that were erased. - DenseSet *erasedOps = nullptr; + DenseSet *const erasedOps = nullptr; }; } // namespace -LogicalResult MultiOpPatternRewriteDriver::simplifyLocally( - ArrayRef ops, bool *changed, DenseSet *erased) { - erasedOps = erased; - +LogicalResult +MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, + bool *changed) && { if (strictMode != GreedyRewriteStrictness::AnyOp) { strictModeFilteredOps.clear(); strictModeFilteredOps.insert(ops.begin(), ops.end()); @@ -615,7 +619,6 @@ } } - erased = nullptr; return success(worklist.empty()); } @@ -631,11 +634,11 @@ } // Start the pattern driver. - MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - strictMode); DenseSet erased; - LogicalResult converged = - driver.simplifyLocally(ops, changed, allOpsErased ? &erased : nullptr); + MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, + strictMode, + allOpsErased ? &erased : nullptr); + LogicalResult converged = std::move(driver).simplifyLocally(ops, changed); if (allOpsErased) *allOpsErased = llvm::all_of(ops, [&](Operation *op) { return erased.contains(op); });