diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -226,6 +226,8 @@ setInsertionPoint(block, insertPoint); } + virtual ~OpBuilder(); + /// Create a builder and set the insertion point to before the first operation /// in the block but still inside the block. static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) { @@ -419,6 +421,16 @@ Block *createBlock(Block *insertBefore, TypeRange argTypes = std::nullopt, ArrayRef locs = std::nullopt); + /// Clone the blocks that belong to "region" before the given position in + /// another region "parent". The two regions must be different. The caller is + /// responsible for creating or updating the operation transferring flow of + /// control to the region and passing it the correct block arguments. + virtual void cloneRegionBefore(Region ®ion, Region &parent, + Region::iterator before, IRMapping &mapping); + void cloneRegionBefore(Region ®ion, Region &parent, + Region::iterator before); + void cloneRegionBefore(Region ®ion, Block *before); + //===--------------------------------------------------------------------===// // Operation Creation //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -406,16 +406,6 @@ Region::iterator before); void inlineRegionBefore(Region ®ion, Block *before); - /// Clone the blocks that belong to "region" before the given position in - /// another region "parent". The two regions must be different. The caller is - /// responsible for creating or updating the operation transferring flow of - /// control to the region and passing it the correct block arguments. - virtual void cloneRegionBefore(Region ®ion, Region &parent, - Region::iterator before, IRMapping &mapping); - void cloneRegionBefore(Region ®ion, Region &parent, - Region::iterator before); - void cloneRegionBefore(Region ®ion, Block *before); - /// This method replaces the uses of the results of `op` with the values in /// `newValues` when the provided `functor` returns true for a specific use. /// The number of values in `newValues` is required to match the number of diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -382,6 +382,8 @@ OpBuilder::Listener::~Listener() = default; +OpBuilder::~OpBuilder() = default; + /// Insert the given operation at the current insertion point and return it. Operation *OpBuilder::insert(Operation *op) { if (block) @@ -514,6 +516,7 @@ // about any ops that got inserted inside those regions as part of cloning. if (listener) { auto walkFn = [&](Operation *walkedOp) { + // TODO: Notify listener of block creation. listener->notifyOperationInserted(walkedOp); }; for (Region ®ion : newOp->getRegions()) @@ -526,3 +529,24 @@ IRMapping mapper; return clone(op, mapper); } + +void OpBuilder::cloneRegionBefore(Region ®ion, Region &parent, + Region::iterator before, IRMapping &mapping) { + region.cloneInto(&parent, before, mapping); + if (listener) { + Region::iterator it = mapping.lookup(®ion.front())->getIterator(); + while (it != before) { + // TODO: Notify listener of nested blocks. + it->walk([&](Operation *op) { listener->notifyOperationInserted(op); }); + it++; + } + } +} +void OpBuilder::cloneRegionBefore(Region ®ion, Region &parent, + Region::iterator before) { + IRMapping mapping; + cloneRegionBefore(region, parent, before, mapping); +} +void OpBuilder::cloneRegionBefore(Region ®ion, Block *before) { + cloneRegionBefore(region, *before->getParent(), before->getIterator()); +} diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -379,21 +379,3 @@ void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) { inlineRegionBefore(region, *before->getParent(), before->getIterator()); } - -/// Clone the blocks that belong to "region" before the given position in -/// another region "parent". The two regions must be different. The caller is -/// responsible for creating or updating the operation transferring flow of -/// control to the region and passing it the correct block arguments. -void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent, - Region::iterator before, - IRMapping &mapping) { - region.cloneInto(&parent, before, mapping); -} -void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent, - Region::iterator before) { - IRMapping mapping; - cloneRegionBefore(region, parent, before, mapping); -} -void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) { - cloneRegionBefore(region, *before->getParent(), before->getIterator()); -} diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1624,9 +1624,14 @@ Region &parent, Region::iterator before, IRMapping &mapping) { + // Note: Do not call OpBuilder::cloneRegionBefore. We cannot rely on its "op + // created" notifications because we need to populate the list of created ops + // in a certain order. This is done in `notifyRegionWasClonedBefore`. + if (region.empty()) return; - PatternRewriter::cloneRegionBefore(region, parent, before, mapping); + + region.cloneInto(&parent, before, mapping); // Collect the range of the cloned blocks. auto clonedBeginIt = mapping.lookup(®ion.front())->getIterator(); diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -14,7 +14,7 @@ func.func @test_erase() { %0 = "test.arg0"() : () -> (i32) %1 = "test.arg1"() : () -> (i32) - %erase = "test.erase_op"(%0, %1) : (i32, i32) -> (i32) + %erase = "test.erase_op"(%0, %1) {worklist} : (i32, i32) -> (i32) return } @@ -25,7 +25,7 @@ // CHECK-EN: "test.insert_same_op"() {skip = true} // CHECK-EN: "test.insert_same_op"() {skip = true} func.func @test_insert_same_op() { - %0 = "test.insert_same_op"() : () -> (i32) + %0 = "test.insert_same_op"() {worklist} : () -> (i32) return } @@ -37,7 +37,7 @@ // CHECK-EN: "test.dummy_user"(%[[n]]) // CHECK-EN: "test.dummy_user"(%[[n]]) func.func @test_replace_with_new_op() { - %0 = "test.replace_with_new_op"() : () -> (i32) + %0 = "test.replace_with_new_op"() {worklist} : () -> (i32) %1 = "test.dummy_user"(%0) : (i32) -> (i32) %2 = "test.dummy_user"(%0) : (i32) -> (i32) return @@ -55,6 +55,41 @@ // CHECK-EX-NOT: test.replace_with_new_op // CHECK-EX: test.erase_op func.func @test_replace_with_erase_op() { - "test.replace_with_new_op"() {create_erase_op} : () -> () + "test.replace_with_new_op"() {create_erase_op, worklist} : () -> () return } + +// ----- + +// Make sure that test.erase_op is deleted. + +// CHECK-EN-LABEL: func @test_add_cloned_ops_to_worklist +// CHECK-EN-NEXT: "test.dummy_op"() ({ +// CHECK-EN-NEXT: "test.another_op"() : () -> () +// CHECK-EN-NEXT: ^bb1: // no predecessors +// CHECK-EN-NEXT: "test.inner_op"() : () -> () +// CHECK-EN-NEXT: }) : () -> () +// CHECK-EN-NEXT: } + +// When strictness=ExistingOps, test.erase_op is not deleted because it was not +// on the initial worklist. + +// CHECK-EX-LABEL: func @test_add_cloned_ops_to_worklist +// CHECK-EX-NEXT: "test.dummy_op"() ({ +// CHECK-EX-NEXT: "test.another_op"() : () -> () +// CHECK-EX-NEXT: ^bb1: // no predecessors +// CHECK-EX-NEXT: "test.inner_op"() : () -> () +// CHECK-EX-NEXT: "test.erase_op"() : () -> () +// CHECK-EX-NEXT: }) : () -> () +// CHECK-EX-NEXT: } +func.func @test_add_cloned_ops_to_worklist() { + "test.dummy_op"() ({ + ^bb0(): + "test.clone_region_in_parent"() ({ + ^bb1(): + "test.inner_op"() : () -> () + "test.erase_op"() : () -> () + }) {worklist} : () -> () + "test.another_op"() : () -> () + }) : () -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -256,12 +256,13 @@ void runOnOperation() override { MLIRContext *ctx = &getContext(); mlir::RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns + .add( + ctx); SmallVector ops; getOperation()->walk([&](Operation *op) { - StringRef opName = op->getName().getStringRef(); - if (opName == "test.insert_same_op" || - opName == "test.replace_with_new_op" || opName == "test.erase_op") { + if (op->hasAttr("worklist")) { + op->removeAttr("worklist"); ops.push_back(op); } }); @@ -296,6 +297,23 @@ llvm::cl::init("AnyOp")}; private: + // Clone region into parent op. + class CloneRegionInParentOp : public RewritePattern { + public: + CloneRegionInParentOp(MLIRContext *context) + : RewritePattern("test.clone_region_in_parent", /*benefit=*/1, + context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto &parentRegion = *op->getParentRegion(); + auto &opRegion = op->getRegion(0); + rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); + rewriter.eraseOp(op); + return success(); + } + }; + // New inserted operation is valid for further transformation. class InsertSameOp : public RewritePattern { public: