diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -112,6 +112,12 @@ /// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are /// simplified. All other ops are excluded. /// +/// In addition to strictness, a region scope can be specified. Only ops within +/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`, +/// where only ops within the given regions are simplified. If no scope is +/// specified, it is assumed to be the first common enclosing region of the +/// given ops. +/// /// Note that ops in `ops` could be erased as result of folding, becoming dead, /// or via pattern rewrites. If more far reaching simplification is desired, /// applyPatternsAndFoldGreedily should be used. @@ -123,7 +129,8 @@ const FrozenRewritePatternSet &patterns, GreedyRewriteStrictness strictMode, bool *changed = nullptr, - bool *allErased = nullptr); + bool *allErased = nullptr, + Region *scope = nullptr); } // namespace mlir diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -16,6 +16,7 @@ #include "mlir/Rewrite/PatternApplicator.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/CommandLine.h" @@ -43,10 +44,8 @@ /// Simplify the operations within the given regions. bool simplify(MutableArrayRef regions); - /// Add the given operation to the worklist. Parent ops may or may not be - /// added to the worklist, depending on the type of rewrite driver. By - /// default, parent ops are added. - virtual void addToWorklist(Operation *op); + /// Add the given operation and its ancestors to the worklist. + void addToWorklist(Operation *op); /// Pop the next operation from the worklist. Operation *popFromWorklist(); @@ -60,7 +59,7 @@ protected: /// Add the given operation to the worklist. - void addSingleOpToWorklist(Operation *op); + virtual void addSingleOpToWorklist(Operation *op); // Implement the hook for inserting operations, and make sure that newly // inserted ops are added to the worklist for processing. @@ -103,11 +102,12 @@ /// Configuration information for how to simplify. GreedyRewriteConfig config; -private: /// Only ops within this scope are simplified. This is set at the beginning - /// of `simplify()` to the current scope the rewriter operates on. + /// of `simplify()` and `simplifyLocally()` to the current scope the rewriter + /// operates on. DenseSet scope; +private: #ifndef NDEBUG /// A logger used to emit information during the application process. llvm::ScopedPrinter logger{llvm::dbgs()}; @@ -126,6 +126,7 @@ } bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { + scope.clear(); for (Region &r : regions) scope.insert(&r); @@ -491,9 +492,6 @@ void notifyRootReplaced(Operation *op, ValueRange replacement) override {} private: - /// Op that is being processed. - Operation *op = nullptr; - /// The low-level pattern applicator. PatternApplicator matcher; @@ -584,7 +582,9 @@ strictMode(strictMode) {} /// Performs the specified rewrites on `ops` while also trying to fold these - /// ops. `strictMode` controls which other ops are simplified. + /// ops. `strictMode` controls which other ops are simplified. Only ops + /// within the given scope region are added to the worklist. If no scope is + /// specified, it assumed to be closest common region of all `ops`. /// /// Note that ops in `ops` could be erased as a result of folding, becoming /// dead, or via pattern rewrites. The return value indicates convergence. @@ -592,9 +592,11 @@ /// All `ops` that survived the rewrite are stored in `surviving`. LogicalResult simplifyLocally(ArrayRef ops, bool *changed = nullptr, - llvm::SmallDenseSet *surviving = nullptr); + llvm::SmallDenseSet *surviving = nullptr, + Region *scope = nullptr); - void addToWorklist(Operation *op) override { +protected: + void addSingleOpToWorklist(Operation *op) override { if (strictMode == GreedyRewriteStrictness::AnyOp || strictModeFilteredOps.contains(op)) GreedyPatternRewriteDriver::addSingleOpToWorklist(op); @@ -635,7 +637,7 @@ LogicalResult MultiOpPatternRewriteDriver::simplifyLocally( ArrayRef ops, bool *changed, - llvm::SmallDenseSet *surviving) { + llvm::SmallDenseSet *surviving, Region *scope) { auto cleanup = llvm::make_scope_exit([&]() { survivingOps = nullptr; }); if (surviving) { survivingOps = surviving; @@ -648,12 +650,16 @@ strictModeFilteredOps.insert(ops.begin(), ops.end()); } + assert(scope && "scope is mandatory"); + this->scope.clear(); + this->scope.insert(scope); + if (changed) *changed = false; worklist.clear(); worklistMap.clear(); for (Operation *op : ops) - addToWorklist(op); + addSingleOpToWorklist(op); // These are scratch vectors used in the folding loop below. SmallVector originalOperands, resultValues; @@ -745,9 +751,37 @@ return converged; } -LogicalResult mlir::applyOpPatternsAndFold( - ArrayRef ops, const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, bool *changed, bool *allErased) { +/// Find the region that is the closest common ancestor of all given ops. +static Region *findCommonAncestor(ArrayRef ops) { + assert(!ops.empty() && "expected at least one op"); + // Fast path in case there is only one op. + int sz = ops.size(); + if (sz == 1) + return ops.front()->getParentRegion(); + + Region *region = ops.front()->getParentRegion(); + llvm::BitVector remainingOps(sz, true); + remainingOps.reset(0); + do { + int pos = -1; + // Iterate over all remaining ops. + while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) { + // Is this op contained in `region`? + if (region->findAncestorOpInRegion(*ops[pos])) + remainingOps.reset(pos); + } + if (remainingOps.none()) + break; + } while ((region = region->getParentRegion())); + assert(region && "could not find common parent region"); + return region; +} + +LogicalResult +mlir::applyOpPatternsAndFold(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + GreedyRewriteStrictness strictMode, bool *changed, + bool *allErased, Region *scope) { if (ops.empty()) { if (changed) *changed = false; @@ -756,12 +790,25 @@ return success(); } + if (!scope) { + // Compute scope if none was provided. + scope = findCommonAncestor(ops); + } else { + // If a scope was provided, make sure that all ops are in scope. +#ifndef NDEBUG + bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) { + return static_cast(scope->findAncestorOpInRegion(*op)); + }); + assert(allOpsInScope && "ops must be within the specified scope"); +#endif // NDEBUG + } + // Start the pattern driver. MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, strictMode); llvm::SmallDenseSet surviving; - LogicalResult converged = - driver.simplifyLocally(ops, changed, allErased ? &surviving : nullptr); + LogicalResult converged = driver.simplifyLocally( + ops, changed, allErased ? &surviving : nullptr, /*scope=*/scope); if (allErased) *allErased = surviving.empty(); return converged;