diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -539,12 +539,27 @@ } } + void notifyOperationInserted(Operation *op) override { + GreedyPatternRewriteDriver::notifyOperationInserted(op); + if (strictMode) + strictModeFilteredOps.insert(op); + } + void notifyOperationRemoved(Operation *op) override { GreedyPatternRewriteDriver::notifyOperationRemoved(op); if (strictMode) strictModeFilteredOps.erase(op); } + void notifyRootReplaced(Operation *op) override { + for (auto result : op->getResults()) { + for (auto *user : result.getUsers()) { + if (!strictMode || strictModeFilteredOps.contains(user)) + addToWorklist(user); + } + } + } + /// If `strictMode` is true, any pre-existing ops outside of /// `strictModeFilteredOps` remain completely untouched by the rewrite driver. /// If `strictMode` is false, operations that use results of (or supply @@ -592,6 +607,8 @@ SmallVector originalOperands, resultValues; while (!worklist.empty()) { Operation *op = popFromWorklist(); + assert((!strictMode || strictModeFilteredOps.contains(op)) && + "unexpected op was inserted under strict mode"); // Nulls get added to the worklist when operations are removed, ignore // them.