diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/RegionKindInterface.h" using namespace mlir; @@ -272,18 +273,52 @@ "incorrect # of replacement values"); op->replaceAllUsesWith(newValues); - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationRemoved(op); - op->erase(); + // Erase op and notify listener. + eraseOp(op); +} + +#ifndef NDEBUG +/// Return true if the given op is in a region that may have SSA dominance. +static bool parentRegionHasSSADominance(Operation *op) { + Operation *parent = op->getParentOp(); + if (!parent) + return true; + auto regionKindOp = dyn_cast(parent); + if (!regionKindOp) + return true; + return regionKindOp.hasSSADominance(op->getParentRegion()->getRegionNumber()); } +#endif // NDEBUG /// This method erases an operation that is known to have no uses. The uses of /// the given operation *must* be known to be dead. void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationRemoved(op); - op->erase(); + auto *rewriteListener = dyn_cast_if_present(listener); + + // Fast path: If no listener is attached, the op be dropped in one go. + if (!rewriteListener) { + op->erase(); + return; + } + + // Otherwise, nested ops must be deleted in an one-by-one, so that listeners + // have a consistent view of the IR every time a notification is triggered. + // I.e., ops must be deleted in reverse order: delete uses before definitions. + op->walk([&](Operation *op) { + if (rewriteListener) + rewriteListener->notifyOperationRemoved(op); +#ifndef NDEBUG + if (!op->use_empty()) + assert(!parentRegionHasSSADominance(op) && + "expected that op has no uses"); +#endif // NDEBUG + // If the op is in a region without SSA dominance, erasing in reverse order + // is not enough. It may still have uses that must be dropped explicitly. + // The remaining users will be erased eventually. + op->dropAllUses(); + op->erase(); + }); } void RewriterBase::eraseBlock(Block *block) { diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -165,7 +165,8 @@ void OperationFolder::eraseOp(Operation *op) { notifyRemoval(op); if (listener) - listener->notifyOperationRemoved(op); + op->walk( + [&](Operation *op) { listener->notifyOperationRemoved(op); }); op->erase(); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -187,8 +187,7 @@ // If the operation is trivially dead - remove it. if (isOpTriviallyDead(op)) { - notifyOperationRemoved(op); - op->erase(); + eraseOp(op); changed = true; LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); @@ -341,10 +340,8 @@ config.listener->notifyOperationRemoved(op); addOperandsToWorklist(op->getOperands()); - op->walk([this](Operation *operation) { - removeFromWorklist(operation); - folder.notifyRemoval(operation); - }); + removeFromWorklist(op); + folder.notifyRemoval(op); if (config.strictMode != GreedyRewriteStrictness::AnyOp) strictModeFilteredOps.erase(op); diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -83,3 +83,17 @@ // in turn, replaces the successor with bb3. "test.implicit_change_op"() [^bb1] : () -> () } + +// ----- + +// CHECK-AN-LABEL: func @test_remove_nested_ops() +// CHECK-AN-NEXT: return +func.func @test_remove_nested_ops() { + "test.erase_op"() ({ + test.graph_region { + %0 = "test.foo"(%1) : (i1) -> (i1) + %1 = "test.foo"(%0) : (i1) -> (i1) + } + }) : () -> () + return +}