diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -481,7 +481,8 @@ /// If an operation is about to be removed, mark it so that we can let clients /// know. void notifyOperationRemoved(Operation *op) override { - opErasedViaPatternRewrites = true; + if (this->op == op) + opErasedViaPatternRewrites = true; } // When a root is going to be replaced, its removal will be notified as well. @@ -489,6 +490,9 @@ void notifyRootReplaced(Operation *op, ValueRange replacement) override {} private: + /// Op that is being processed. + Operation *op = nullptr; + /// The low-level pattern applicator. PatternApplicator matcher; @@ -509,6 +513,7 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, int64_t maxNumRewrites, bool &erased) { + this->op = op; bool changed = false; erased = false; opErasedViaPatternRewrites = false;