diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -555,9 +555,8 @@ /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. - void - replaceUsesWithIf(Value from, Value to, - llvm::unique_function functor); + void replaceUsesWithIf(Value from, Value to, + function_ref functor); /// Find uses of `from` and replace them with `to` except if the user is /// `exceptedUser`. It also marks every modified uses and notifies the diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -235,14 +235,14 @@ assert(op->getNumResults() == newValues.size() && "incorrect number of values to replace operation"); - // Notify the rewriter subclass that we're about to replace this root. + // Notify the listener that we're about to replace this op. if (auto *rewriteListener = dyn_cast_if_present(listener)) rewriteListener->notifyOperationReplaced(op, newValues); // Replace each use of the results when the functor is true. bool replacedAllUses = true; for (auto it : llvm::zip(op->getResults(), newValues)) { - std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor); + replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor); replacedAllUses &= std::get<0>(it).use_empty(); } if (allUsesReplaced) @@ -264,17 +264,19 @@ /// values. The number of provided values must match the number of results of /// the operation. void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { - // Notify the rewriter subclass that we're about to replace this root. - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationReplaced(op, newValues); - assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); - op->replaceAllUsesWith(newValues); + // Notify the listener that we're about to remove this op. if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationRemoved(op); - op->erase(); + rewriteListener->notifyOperationReplaced(op, newValues); + + // Replace results one-by-one. Also notifies the listener of modifications. + for (auto it : llvm::zip(op->getResults(), newValues)) + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); + + // Erase op and notify listener. + eraseOp(op); } /// This method erases an operation that is known to have no uses. The uses of @@ -307,14 +309,14 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest, ValueRange argValues) { assert(llvm::all_of(source->getPredecessors(), - [dest](Block *succ) { return succ == dest; }) && + [dest](Block *b) { return b == dest; }) && "expected 'source' to have no predecessors or only 'dest'"); assert(argValues.size() == source->getNumArguments() && "incorrect # of argument replacement values"); // Replace all of the successor arguments with the provided values. for (auto it : llvm::zip(source->getArguments(), argValues)) - std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); // Splice the operations of the 'source' block into the 'dest' block and erase // it. @@ -326,9 +328,8 @@ /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. -void RewriterBase::replaceUsesWithIf( - Value from, Value to, - llvm::unique_function functor) { +void RewriterBase::replaceUsesWithIf(Value from, Value to, + function_ref functor) { for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { if (functor(operand)) updateRootInPlace(operand.getOwner(), [&]() { operand.set(to); }); diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -18,7 +18,7 @@ func.func @test_erase() { %0 = "test.arg0"() : () -> (i32) %1 = "test.arg1"() : () -> (i32) - %erase = "test.erase_op"(%0, %1) : (i32, i32) -> (i32) + %erase = "test.erase_op"(%0, %1) {worklist} : (i32, i32) -> (i32) return } @@ -29,7 +29,7 @@ // CHECK-EN: "test.insert_same_op"() {skip = true} // CHECK-EN: "test.insert_same_op"() {skip = true} func.func @test_insert_same_op() { - %0 = "test.insert_same_op"() : () -> (i32) + %0 = "test.insert_same_op"() {worklist} : () -> (i32) return } @@ -41,7 +41,7 @@ // CHECK-EN: "test.dummy_user"(%[[n]]) // CHECK-EN: "test.dummy_user"(%[[n]]) func.func @test_replace_with_new_op() { - %0 = "test.replace_with_new_op"() : () -> (i32) + %0 = "test.replace_with_new_op"() {worklist} : () -> (i32) %1 = "test.dummy_user"(%0) : (i32) -> (i32) %2 = "test.dummy_user"(%0) : (i32) -> (i32) return @@ -59,7 +59,7 @@ // CHECK-EX-NOT: test.replace_with_new_op // CHECK-EX: test.erase_op func.func @test_replace_with_erase_op() { - "test.replace_with_new_op"() {create_erase_op} : () -> () + "test.replace_with_new_op"() {create_erase_op, worklist} : () -> () return } @@ -74,7 +74,7 @@ return ^bb1: // Uses bb1. ChangeBlockOp replaces that and all other usages of bb1 with bb2. - "test.change_block_op"() [^bb1, ^bb2] : () -> () + "test.change_block_op"() [^bb1, ^bb2] {worklist} : () -> () ^bb2: return ^bb3: @@ -83,3 +83,26 @@ // in turn, replaces the successor with bb3. "test.implicit_change_op"() [^bb1] : () -> () } + +// ----- + +// Make sure that "test.erase_op" is put on the worklist during mergeBlocks and +// subsequently deleted. + +// CHECK-EN-LABEL: func @test_merge_blocks( +// CHECK-EX-LABEL: func @test_merge_blocks( +// CHECK-AN-LABEL: func @test_merge_blocks( +// CHECK-AN: "test.merge_blocks"() ({ +// CHECK-AN-NEXT: "test.return" +// CHECK-AN-NEXT: }) : () -> i32 +// CHECK-AN-NEXT: "test.return" +func.func @test_merge_blocks(%arg0: i32) -> () { + %0 = "test.merge_blocks"() ({ + ^bb0: + cf.br ^bb1 (%arg0: i32) + ^bb1(%arg3 : i32): + "test.erase_op"(%arg3) : (i32) -> () + "test.return"(%arg3) : (i32) -> () + }) {worklist} : () -> (i32) + "test.return"(%0) : (i32) -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -261,14 +261,14 @@ ReplaceWithNewOp, EraseOp, ChangeBlockOp, - ImplicitChangeOp + ImplicitChangeOp, + MergeBlocks // clang-format on >(ctx); SmallVector ops; getOperation()->walk([&](Operation *op) { - StringRef opName = op->getName().getStringRef(); - if (opName == "test.insert_same_op" || opName == "test.change_block_op" || - opName == "test.replace_with_new_op" || opName == "test.erase_op") { + if (op->hasAttr("worklist")) { + op->removeAttr("worklist"); ops.push_back(op); } }); @@ -361,6 +361,26 @@ } }; + // Users of replaced block arguments are revisited after merging two blocks. + class MergeBlocks : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TestMergeBlocksOp op, + PatternRewriter &rewriter) const override { + if (op.getBody().getBlocks().size() != 2) + return failure(); + Block &destBlock = op.getBody().front(); + Operation *branchOp = destBlock.getTerminator(); + Block *sourceBlock = &*(std::next(op.getBody().begin())); + auto succOperands = branchOp->getOperands(); + SmallVector replacements(succOperands); + rewriter.eraseOp(branchOp); + rewriter.mergeBlocks(sourceBlock, &destBlock, replacements); + return success(); + } + }; + // The following two patterns test RewriterBase::replaceAllUsesWith. // // That function replaces all usages of a Block (or a Value) with another one