diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -42,10 +42,8 @@ /// Simplify the operations within the given regions. bool simplify(MutableArrayRef regions); - /// Add the given operation to the worklist. Parent ops may or may not be - /// added to the worklist, depending on the type of rewrite driver. By - /// default, parent ops are added. - virtual void addToWorklist(Operation *op); + /// Add the given operation and its ancestors to the worklist. + void addToWorklist(Operation *op); /// Pop the next operation from the worklist. Operation *popFromWorklist(); @@ -59,7 +57,7 @@ protected: /// Add the given operation to the worklist. - void addSingleOpToWorklist(Operation *op); + virtual void addSingleOpToWorklist(Operation *op); // Implement the hook for inserting operations, and make sure that newly // inserted ops are added to the worklist for processing. @@ -105,7 +103,7 @@ private: /// Only ops within this scope are simplified. This is set at the beginning /// of `simplify()` to the current scope the rewriter operates on. - DenseSet scope; + std::optional> scope = {}; #ifndef NDEBUG /// A logger used to emit information during the application process. @@ -125,8 +123,9 @@ } bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { + DenseSet &s = scope.emplace(); for (Region &r : regions) - scope.insert(&r); + s.insert(&r); #ifndef NDEBUG const char *logLineComment = @@ -316,13 +315,13 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { // Gather potential ancestors while looking for a "scope" parent region. + bool inScope = !scope.has_value(); SmallVector ancestors; ancestors.push_back(op); while (Region *region = op->getParentRegion()) { - if (scope.contains(region)) { + if (scope.has_value() && scope->contains(region)) { // All gathered ops are in fact ancestors. - for (Operation *op : ancestors) - addSingleOpToWorklist(op); + inScope = true; break; } op = region->getParentOp(); @@ -330,6 +329,11 @@ break; ancestors.push_back(op); } + + // Enqueue ops if "scope" parent region was reached or no scope was specified. + if (inScope) + for (Operation *op : ancestors) + addSingleOpToWorklist(op); } void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { @@ -590,7 +594,8 @@ bool *changed = nullptr, DenseSet *erased = nullptr); - void addToWorklist(Operation *op) override { +protected: + void addSingleOpToWorklist(Operation *op) override { if (strictMode == GreedyRewriteStrictness::AnyOp || strictModeFilteredOps.contains(op)) GreedyPatternRewriteDriver::addSingleOpToWorklist(op); @@ -641,7 +646,7 @@ worklist.clear(); worklistMap.clear(); for (Operation *op : ops) - addToWorklist(op); + addSingleOpToWorklist(op); // These are scratch vectors used in the folding loop below. SmallVector originalOperands, resultValues;