diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -112,6 +112,12 @@ /// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are /// simplified. All other ops are excluded. /// +/// In addition to strictness, a region scope can be specified. Only ops within +/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`, +/// where only ops within the given regions are simplified. If no scope is +/// specified, it is assumed to be the first common enclosing region of the +/// given ops. +/// /// Note that ops in `ops` could be erased as result of folding, becoming dead, /// or via pattern rewrites. If more far reaching simplification is desired, /// applyPatternsAndFoldGreedily should be used. @@ -123,7 +129,8 @@ const FrozenRewritePatternSet &patterns, GreedyRewriteStrictness strictMode, bool *changed = nullptr, - bool *allErased = nullptr); + bool *allErased = nullptr, + Region *scope = nullptr); } // namespace mlir diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -42,10 +42,8 @@ /// Simplify the operations within the given regions. bool simplify(MutableArrayRef regions); - /// Add the given operation to the worklist. Parent ops may or may not be - /// added to the worklist, depending on the type of rewrite driver. By - /// default, parent ops are added. - virtual void addToWorklist(Operation *op); + /// Add the given operation and its ancestors to the worklist. + void addToWorklist(Operation *op); /// Pop the next operation from the worklist. Operation *popFromWorklist(); @@ -59,7 +57,7 @@ protected: /// Add the given operation to the worklist. - void addSingleOpToWorklist(Operation *op); + virtual void addSingleOpToWorklist(Operation *op); // Implement the hook for inserting operations, and make sure that newly // inserted ops are added to the worklist for processing. @@ -102,11 +100,12 @@ /// Configuration information for how to simplify. GreedyRewriteConfig config; -private: /// Only ops within this scope are simplified. This is set at the beginning - /// of `simplify()` to the current scope the rewriter operates on. + /// of `simplify()` and `simplifyLocally()` to the current scope the rewriter + /// operates on. DenseSet scope; +private: #ifndef NDEBUG /// A logger used to emit information during the application process. llvm::ScopedPrinter logger{llvm::dbgs()}; @@ -125,6 +124,7 @@ } bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { + scope.clear(); for (Region &r : regions) scope.insert(&r); @@ -490,9 +490,6 @@ void notifyRootReplaced(Operation *op, ValueRange replacement) override {} private: - /// Op that is being processed. - Operation *op = nullptr; - /// The low-level pattern applicator. PatternApplicator matcher; @@ -583,7 +580,9 @@ strictMode(strictMode) {} /// Performs the specified rewrites on `ops` while also trying to fold these - /// ops. `strictMode` controls which other ops are simplified. + /// ops. `strictMode` controls which other ops are simplified. Only ops + /// within the given scope region are added to the worklist. If no scope is + /// specified, it assumed to be closest common region of all `ops`. /// /// Note that ops in `ops` could be erased as a result of folding, becoming /// dead, or via pattern rewrites. The return value indicates convergence. @@ -591,9 +590,11 @@ /// All `ops` that survived the rewrite are stored in `surviving`. LogicalResult simplifyLocally(ArrayRef ops, bool *changed = nullptr, - llvm::SmallDenseSet *surviving = nullptr); + llvm::SmallDenseSet *surviving = nullptr, + Region *scope = nullptr); - void addToWorklist(Operation *op) override { +protected: + void addSingleOpToWorklist(Operation *op) override { if (strictMode == GreedyRewriteStrictness::AnyOp || strictModeFilteredOps.contains(op)) GreedyPatternRewriteDriver::addSingleOpToWorklist(op); @@ -632,7 +633,7 @@ LogicalResult MultiOpPatternRewriteDriver::simplifyLocally( ArrayRef ops, bool *changed, - llvm::SmallDenseSet *surviving) { + llvm::SmallDenseSet *surviving, Region *scope) { if (surviving) { survivingOps = surviving; survivingOps->clear(); @@ -644,12 +645,16 @@ strictModeFilteredOps.insert(ops.begin(), ops.end()); } + assert(scope && "scope is mandatory"); + this->scope.clear(); + this->scope.insert(scope); + if (changed) *changed = false; worklist.clear(); worklistMap.clear(); for (Operation *op : ops) - addToWorklist(op); + addSingleOpToWorklist(op); // These are scratch vectors used in the folding loop below. SmallVector originalOperands, resultValues; @@ -742,9 +747,33 @@ return converged; } -LogicalResult mlir::applyOpPatternsAndFold( - ArrayRef ops, const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, bool *changed, bool *allErased) { +/// Find the region that is the closest common ancestor of all given ops. +static Region *findCommonAncestor(ArrayRef ops) { + assert(!ops.empty() && "expected at least one op"); + Region *region = ops.front()->getParentRegion(); + SmallVector remainingOps(ops.drop_front().begin(), + ops.drop_front().end()); + int64_t numRemainingOps = remainingOps.size(); + while (numRemainingOps > 0) { + for (int64_t i = 0; i < numRemainingOps; ++i) { + if (!remainingOps[i]) + continue; + if (region->findAncestorOpInRegion(*remainingOps[i])) { + remainingOps[i] = nullptr; + --numRemainingOps; + } + } + region = region->getParentRegion(); + assert(region && "could not find common parent region"); + } + return region; +} + +LogicalResult +mlir::applyOpPatternsAndFold(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + GreedyRewriteStrictness strictMode, bool *changed, + bool *allErased, Region *scope) { if (ops.empty()) { if (changed) *changed = false; @@ -753,12 +782,25 @@ return success(); } + if (!scope) { + // Compute scope if none was provided. + scope = findCommonAncestor(ops); + } else { + // If a scope was provided, make sure that all ops are in scope. +#ifndef NDEBUG + bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) { + return static_cast(scope->findAncestorOpInRegion(*op)); + }); + assert(allOpsInScope && "ops must be within the specified scope"); +#endif // NDEBUG + } + // Start the pattern driver. MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, strictMode); llvm::SmallDenseSet surviving; - LogicalResult converged = - driver.simplifyLocally(ops, changed, allErased ? &surviving : nullptr); + LogicalResult converged = driver.simplifyLocally( + ops, changed, allErased ? &surviving : nullptr, /*scope=*/scope); if (allErased) *allErased = surviving.empty(); return converged;