diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -272,9 +272,8 @@ "incorrect # of replacement values"); op->replaceAllUsesWith(newValues); - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationRemoved(op); - op->erase(); + // Erase op and notify listener. + eraseOp(op); } /// This method erases an operation that is known to have no uses. The uses of @@ -282,7 +281,8 @@ void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationRemoved(op); + op->walk( + [&](Operation *op) { rewriteListener->notifyOperationRemoved(op); }); op->erase(); } diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -165,7 +165,8 @@ void OperationFolder::eraseOp(Operation *op) { notifyRemoval(op); if (listener) - listener->notifyOperationRemoved(op); + op->walk( + [&](Operation *op) { listener->notifyOperationRemoved(op); }); op->erase(); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -181,8 +181,7 @@ // If the operation is trivially dead - remove it. if (isOpTriviallyDead(op)) { - notifyOperationRemoved(op); - op->erase(); + eraseOp(op); changed = true; LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); @@ -320,10 +319,8 @@ if (config.listener) config.listener->notifyOperationRemoved(op); - op->walk([this](Operation *operation) { - removeFromWorklist(operation); - folder.notifyRemoval(operation); - }); + removeFromWorklist(op); + folder.notifyRemoval(op); if (config.strictMode != GreedyRewriteStrictness::AnyOp) strictModeFilteredOps.erase(op); diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -83,3 +83,17 @@ // in turn, replaces the successor with bb3. "test.implicit_change_op"() [^bb1] : () -> () } + +// ----- + +// CHECK-AN-LABEL: func @test_remove_nested_ops() +// CHECK-AN-NEXT: return +func.func @test_remove_nested_ops() { + "test.erase_op"() ({ + test.graph_region { + %0 = "test.foo"(%1) : (i1) -> (i1) + %1 = "test.foo"(%0) : (i1) -> (i1) + } + }) : () -> () + return +}