diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -481,7 +481,8 @@ /// If an operation is about to be removed, mark it so that we can let clients /// know. void notifyOperationRemoved(Operation *op) override { - opErasedViaPatternRewrites = true; + if (this->op == op) + opErasedViaPatternRewrites = true; } // When a root is going to be replaced, its removal will be notified as well. @@ -495,6 +496,9 @@ /// Non-pattern based folder for operations. OperationFolder folder; + /// Op that is being processed. + Operation *op = nullptr; + /// Set to true if the operation has been erased via pattern rewrites. bool opErasedViaPatternRewrites = false; }; @@ -509,6 +513,7 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, int64_t maxNumRewrites, bool &erased) { + this->op = op; bool changed = false; erased = false; opErasedViaPatternRewrites = false;