diff --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h --- a/mlir/include/mlir/IR/RegionKindInterface.h +++ b/mlir/include/mlir/IR/RegionKindInterface.h @@ -43,6 +43,12 @@ /// not implement the RegionKindInterface. bool mayHaveSSADominance(Region ®ion); +/// Return "true" if the given region may be a graph region without SSA +/// dominance. This function returns "true" in case the owner op is an +/// unregistered op. It returns "false" if it is a registered op that does not +/// implement the RegionKindInterface. +bool mayBeGraphRegion(Region ®ion); + } // namespace mlir #include "mlir/IR/RegionKindInterface.h.inc" diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -8,6 +8,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Iterators.h" +#include "mlir/IR/RegionKindInterface.h" using namespace mlir; @@ -275,18 +277,77 @@ for (auto it : llvm::zip(op->getResults(), newValues)) replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationRemoved(op); - op->erase(); + // Erase op and notify listener. + eraseOp(op); } /// This method erases an operation that is known to have no uses. The uses of /// the given operation *must* be known to be dead. void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); - if (auto *rewriteListener = dyn_cast_if_present(listener)) + auto *rewriteListener = dyn_cast_if_present(listener); + + // Fast path: If no listener is attached, the op can be dropped in one go. + if (!rewriteListener) { + op->erase(); + return; + } + + // Helper function that erases a single op. + auto eraseSingleOp = [&](Operation *op) { +#ifndef NDEBUG + // All nested ops should have been erased already. + assert( + llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) && + "expected empty regions"); + // All users should have been erased already if the op is in a region with + // SSA dominance. + if (!op->use_empty() && op->getParentOp()) + assert(mayBeGraphRegion(*op->getParentRegion()) && + "expected that op has no uses"); +#endif // NDEBUG rewriteListener->notifyOperationRemoved(op); - op->erase(); + + // Explicitly drop all uses in case the op is in a graph region. + op->dropAllUses(); + op->erase(); + }; + + // Nested ops must be erased one-by-one, so that listeners have a consistent + // view of the IR every time a notification is triggered. Users must be + // erased before definitions. I.e., post-order, reverse dominance. + std::function eraseTree = [&](Operation *op) { + // Erase nested ops. + for (Region &r : llvm::reverse(op->getRegions())) { + // Erase all blocks in the right order. Successors should be erased + // before predecessors because successor blocks may use values defined + // in predecessor blocks. A post-order traversal of blocks within a + // region visits successors before predecessors. Repeat the traversal + // until the region is empty. (The block graph could be disconnected.) + while (!r.empty()) { + SmallVector erasedBlocks; + for (Block *b : llvm::post_order(&r.front())) { + // Visit ops in reverse order. + for (Operation &op : + llvm::make_early_inc_range(ReverseIterator::makeIterable(*b))) + eraseTree(&op); + // Do not erase the block immediately. This is not supprted by the + // post_order iterator. + erasedBlocks.push_back(b); + } + for (Block *b : erasedBlocks) { + // Explicitly drop all uses in case there is a cycle in the block + // graph. + b->dropAllUses(); + b->erase(); + } + } + } + // Then erase the enclosing op. + eraseSingleOp(op); + }; + + eraseTree(op); } void RewriterBase::eraseBlock(Block *block) { diff --git a/mlir/lib/IR/RegionKindInterface.cpp b/mlir/lib/IR/RegionKindInterface.cpp --- a/mlir/lib/IR/RegionKindInterface.cpp +++ b/mlir/lib/IR/RegionKindInterface.cpp @@ -18,9 +18,17 @@ #include "mlir/IR/RegionKindInterface.cpp.inc" bool mlir::mayHaveSSADominance(Region ®ion) { - auto regionKindOp = - dyn_cast_if_present(region.getParentOp()); + auto regionKindOp = dyn_cast(region.getParentOp()); if (!regionKindOp) return true; return regionKindOp.hasSSADominance(region.getRegionNumber()); } + +bool mlir::mayBeGraphRegion(Region ®ion) { + if (!region.getParentOp()->isRegistered()) + return true; + auto regionKindOp = dyn_cast(region.getParentOp()); + if (!regionKindOp) + return false; + return !regionKindOp.hasSSADominance(region.getRegionNumber()); +} diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -187,8 +187,7 @@ // If the operation is trivially dead - remove it. if (isOpTriviallyDead(op)) { - notifyOperationRemoved(op); - op->erase(); + eraseOp(op); changed = true; LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); @@ -343,10 +342,8 @@ config.listener->notifyOperationRemoved(op); addOperandsToWorklist(op->getOperands()); - op->walk([this](Operation *operation) { - removeFromWorklist(operation); - folder.notifyRemoval(operation); - }); + removeFromWorklist(op); + folder.notifyRemoval(op); if (config.strictMode != GreedyRewriteStrictness::AnyOp) strictModeFilteredOps.erase(op); diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -12,9 +12,9 @@ // CHECK-EN-LABEL: func @test_erase // CHECK-EN-SAME: pattern_driver_all_erased = true, pattern_driver_changed = true} -// CHECK-EN: test.arg0 -// CHECK-EN: test.arg1 -// CHECK-EN-NOT: test.erase_op +// CHECK-EN: "test.arg0" +// CHECK-EN: "test.arg1" +// CHECK-EN-NOT: "test.erase_op" func.func @test_erase() { %0 = "test.arg0"() : () -> (i32) %1 = "test.arg1"() : () -> (i32) @@ -51,13 +51,13 @@ // CHECK-EN-LABEL: func @test_replace_with_erase_op // CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true} -// CHECK-EN-NOT: test.replace_with_new_op -// CHECK-EN-NOT: test.erase_op +// CHECK-EN-NOT: "test.replace_with_new_op" +// CHECK-EN-NOT: "test.erase_op" // CHECK-EX-LABEL: func @test_replace_with_erase_op // CHECK-EX-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true} -// CHECK-EX-NOT: test.replace_with_new_op -// CHECK-EX: test.erase_op +// CHECK-EX-NOT: "test.replace_with_new_op" +// CHECK-EX: "test.erase_op" func.func @test_replace_with_erase_op() { "test.replace_with_new_op"() {create_erase_op} : () -> () return @@ -83,3 +83,149 @@ // in turn, replaces the successor with bb3. "test.implicit_change_op"() [^bb1] : () -> () } + +// ----- + +// CHECK-AN: notifyOperationRemoved: test.foo_b +// CHECK-AN: notifyOperationRemoved: test.foo_a +// CHECK-AN: notifyOperationRemoved: test.graph_region +// CHECK-AN: notifyOperationRemoved: test.erase_op +// CHECK-AN-LABEL: func @test_remove_graph_region() +// CHECK-AN-NEXT: return +func.func @test_remove_graph_region() { + "test.erase_op"() ({ + test.graph_region { + %0 = "test.foo_a"(%1) : (i1) -> (i1) + %1 = "test.foo_b"(%0) : (i1) -> (i1) + } + }) : () -> () + return +} + +// ----- + +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.bar +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.foo +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.dummy_op +// CHECK-AN: notifyOperationRemoved: test.erase_op +// CHECK-AN-LABEL: func @test_remove_cyclic_blocks() +// CHECK-AN-NEXT: return +func.func @test_remove_cyclic_blocks() { + "test.erase_op"() ({ + %x = "test.dummy_op"() : () -> (i1) + cf.br ^bb1(%x: i1) + ^bb1(%arg0: i1): + "test.foo"(%x) : (i1) -> () + cf.br ^bb2(%arg0: i1) + ^bb2(%arg1: i1): + "test.bar"(%x) : (i1) -> () + cf.br ^bb1(%arg1: i1) + }) : () -> () + return +} + +// ----- + +// CHECK-AN: notifyOperationRemoved: test.dummy_op +// CHECK-AN: notifyOperationRemoved: test.bar +// CHECK-AN: notifyOperationRemoved: test.qux +// CHECK-AN: notifyOperationRemoved: test.qux_unreachable +// CHECK-AN: notifyOperationRemoved: test.nested_dummy +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.foo +// CHECK-AN: notifyOperationRemoved: test.erase_op +// CHECK-AN-LABEL: func @test_remove_dead_blocks() +// CHECK-AN-NEXT: return +func.func @test_remove_dead_blocks() { + "test.erase_op"() ({ + "test.dummy_op"() : () -> (i1) + // The following blocks are not reachable. Still, ^bb2 should be deleted + // befire ^bb1. + ^bb1(%arg0: i1): + "test.foo"() : () -> () + cf.br ^bb2(%arg0: i1) + ^bb2(%arg1: i1): + "test.nested_dummy"() ({ + "test.qux"() : () -> () + // The following block is unreachable. + ^bb3: + "test.qux_unreachable"() : () -> () + }) : () -> () + "test.bar"() : () -> () + }) : () -> () + return +} + +// ----- + +// test.nested_* must be deleted before test.foo. +// test.bar must be deleted before test.foo. + +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.bar +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.nested_b +// CHECK-AN: notifyOperationRemoved: test.nested_a +// CHECK-AN: notifyOperationRemoved: test.nested_d +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.nested_e +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.nested_c +// CHECK-AN: notifyOperationRemoved: test.foo +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.dummy_op +// CHECK-AN: notifyOperationRemoved: test.erase_op +// CHECK-AN-LABEL: func @test_remove_nested_ops() +// CHECK-AN-NEXT: return +func.func @test_remove_nested_ops() { + "test.erase_op"() ({ + %x = "test.dummy_op"() : () -> (i1) + cf.br ^bb1(%x: i1) + ^bb1(%arg0: i1): + "test.foo"() ({ + "test.nested_a"() : () -> () + "test.nested_b"() : () -> () + ^dead1: + "test.nested_c"() : () -> () + cf.br ^dead3 + ^dead2: + "test.nested_d"() : () -> () + ^dead3: + "test.nested_e"() : () -> () + cf.br ^dead2 + }) : () -> () + cf.br ^bb2(%arg0: i1) + ^bb2(%arg1: i1): + "test.bar"(%x) : (i1) -> () + cf.br ^bb1(%arg1: i1) + }) : () -> () + return +} + +// ----- + +// CHECK-AN: notifyOperationRemoved: test.qux +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.foo +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.bar +// CHECK-AN: notifyOperationRemoved: cf.cond_br +// CHECK-AN-LABEL: func @test_remove_diamond( +// CHECK-AN-NEXT: return +func.func @test_remove_diamond(%c: i1) { + "test.erase_op"() ({ + cf.cond_br %c, ^bb1, ^bb2 + ^bb1: + "test.foo"() : () -> () + cf.br ^bb3 + ^bb2: + "test.bar"() : () -> () + cf.br ^bb3 + ^bb3: + "test.qux"() : () -> () + }) : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -236,6 +236,12 @@ llvm::cl::init(GreedyRewriteConfig().maxIterations)}; }; +struct DumpNotifications : public RewriterBase::Listener { + void notifyOperationRemoved(Operation *op) override { + llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n"; + } +}; + struct TestStrictPatternDriver : public PassWrapper> { public: @@ -272,7 +278,9 @@ } }); + DumpNotifications dumpNotifications; GreedyRewriteConfig config; + config.listener = &dumpNotifications; if (strictMode == "AnyOp") { config.strictMode = GreedyRewriteStrictness::AnyOp; } else if (strictMode == "ExistingAndNewOps") {