diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -522,6 +522,7 @@ // about any ops that got inserted inside those regions as part of cloning. if (listener) { auto walkFn = [&](Operation *walkedOp) { + // TODO: Notify listener of block creation. listener->notifyOperationInserted(walkedOp); }; for (Region ®ion : newOp->getRegions()) diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -380,6 +380,14 @@ Region::iterator before, IRMapping &mapping) { region.cloneInto(&parent, before, mapping); + if (Listener *listener = getListener()) { + Region::iterator it = mapping.lookup(®ion.front())->getIterator(); + while (it != before) { + // TODO: Notify listener of nested blocks. + it->walk([&](Operation *op) { listener->notifyOperationInserted(op); }); + it++; + } + } } void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before) { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1625,9 +1625,14 @@ Region &parent, Region::iterator before, IRMapping &mapping) { + // Note: Do not call PatternRewriter::cloneRegionBefore. We cannot rely on its + // "op created" notifications because we need to populate the list of created + // ops in a certain order. This is done in `notifyRegionWasClonedBefore`. + if (region.empty()) return; - PatternRewriter::cloneRegionBefore(region, parent, before, mapping); + + region.cloneInto(&parent, before, mapping); // Collect the range of the cloned blocks. auto clonedBeginIt = mapping.lookup(®ion.front())->getIterator(); diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -18,7 +18,7 @@ func.func @test_erase() { %0 = "test.arg0"() : () -> (i32) %1 = "test.arg1"() : () -> (i32) - %erase = "test.erase_op"(%0, %1) : (i32, i32) -> (i32) + %erase = "test.erase_op"(%0, %1) {worklist} : (i32, i32) -> (i32) return } @@ -29,7 +29,7 @@ // CHECK-EN: "test.insert_same_op"() {skip = true} // CHECK-EN: "test.insert_same_op"() {skip = true} func.func @test_insert_same_op() { - %0 = "test.insert_same_op"() : () -> (i32) + %0 = "test.insert_same_op"() {worklist} : () -> (i32) return } @@ -41,7 +41,7 @@ // CHECK-EN: "test.dummy_user"(%[[n]]) // CHECK-EN: "test.dummy_user"(%[[n]]) func.func @test_replace_with_new_op() { - %0 = "test.replace_with_new_op"() : () -> (i32) + %0 = "test.replace_with_new_op"() {worklist} : () -> (i32) %1 = "test.dummy_user"(%0) : (i32) -> (i32) %2 = "test.dummy_user"(%0) : (i32) -> (i32) return @@ -59,7 +59,7 @@ // CHECK-EX-NOT: test.replace_with_new_op // CHECK-EX: test.erase_op func.func @test_replace_with_erase_op() { - "test.replace_with_new_op"() {create_erase_op} : () -> () + "test.replace_with_new_op"() {create_erase_op, worklist} : () -> () return } @@ -74,7 +74,7 @@ return ^bb1: // Uses bb1. ChangeBlockOp replaces that and all other usages of bb1 with bb2. - "test.change_block_op"() [^bb1, ^bb2] : () -> () + "test.change_block_op"() [^bb1, ^bb2] { worklist }: () -> () ^bb2: return ^bb3: @@ -83,3 +83,38 @@ // in turn, replaces the successor with bb3. "test.implicit_change_op"() [^bb1] : () -> () } + +// ----- + +// Make sure that test.erase_op is deleted. + +// CHECK-EN-LABEL: func @test_add_cloned_ops_to_worklist +// CHECK-EN-NEXT: "test.dummy_op"() ({ +// CHECK-EN-NEXT: "test.another_op"() : () -> () +// CHECK-EN-NEXT: ^bb1: // no predecessors +// CHECK-EN-NEXT: "test.inner_op"() : () -> () +// CHECK-EN-NEXT: }) : () -> () +// CHECK-EN-NEXT: } + +// When strictness=ExistingOps, test.erase_op is not deleted because it was not +// on the initial worklist. + +// CHECK-EX-LABEL: func @test_add_cloned_ops_to_worklist +// CHECK-EX-NEXT: "test.dummy_op"() ({ +// CHECK-EX-NEXT: "test.another_op"() : () -> () +// CHECK-EX-NEXT: ^bb1: // no predecessors +// CHECK-EX-NEXT: "test.inner_op"() : () -> () +// CHECK-EX-NEXT: "test.erase_op"() : () -> () +// CHECK-EX-NEXT: }) : () -> () +// CHECK-EX-NEXT: } +func.func @test_add_cloned_ops_to_worklist() { + "test.dummy_op"() ({ + ^bb0(): + "test.clone_region_in_parent"() ({ + ^bb1(): + "test.inner_op"() : () -> () + "test.erase_op"() : () -> () + }) {worklist} : () -> () + "test.another_op"() : () -> () + }) : () -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -258,18 +258,18 @@ mlir::RewritePatternSet patterns(ctx); patterns.add< // clang-format off - InsertSameOp, - ReplaceWithNewOp, EraseOp, ChangeBlockOp, - ImplicitChangeOp + CloneRegionInParentOp, + ImplicitChangeOp, + InsertSameOp, + ReplaceWithNewOp // clang-format on >(ctx); SmallVector ops; getOperation()->walk([&](Operation *op) { - StringRef opName = op->getName().getStringRef(); - if (opName == "test.insert_same_op" || opName == "test.change_block_op" || - opName == "test.replace_with_new_op" || opName == "test.erase_op") { + if (op->hasAttr("worklist")) { + op->removeAttr("worklist"); ops.push_back(op); } }); @@ -304,6 +304,23 @@ llvm::cl::init("AnyOp")}; private: + // Clone region into parent op. + class CloneRegionInParentOp : public RewritePattern { + public: + CloneRegionInParentOp(MLIRContext *context) + : RewritePattern("test.clone_region_in_parent", /*benefit=*/1, + context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto &parentRegion = *op->getParentRegion(); + auto &opRegion = op->getRegion(0); + rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); + rewriter.eraseOp(op); + return success(); + } + }; + // New inserted operation is valid for further transformation. class InsertSameOp : public RewritePattern { public: