diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -505,7 +505,16 @@ /// Find uses of `from` and replace them with `to`. It also marks every /// modified uses and notifies the rewriter that an in-place operation /// modification is about to happen. - void replaceAllUsesWith(Value from, Value to); + void replaceAllUsesWith(Value from, Value to) { + return replaceAllUsesWith(from.getImpl(), to); + } + template + void replaceAllUsesWith(IRObjectWithUseList *from, ValueT &&to) { + for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) { + Operation *op = operand.getOwner(); + updateRootInPlace(op, [&]() { operand.set(to); }); + } + } /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -309,14 +309,6 @@ source->erase(); } -/// Find uses of `from` and replace it with `to` -void RewriterBase::replaceAllUsesWith(Value from, Value to) { - for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { - Operation *op = operand.getOwner(); - updateRootInPlace(op, [&]() { operand.set(to); }); - } -} - /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -1,3 +1,7 @@ +// RUN: mlir-opt \ +// RUN: -test-strict-pattern-driver="strictness=AnyOp" \ +// RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-AN + // RUN: mlir-opt \ // RUN: -test-strict-pattern-driver="strictness=ExistingAndNewOps" \ // RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-EN @@ -58,3 +62,24 @@ "test.replace_with_new_op"() {create_erase_op} : () -> () return } + +// ----- + +// CHECK-AN-LABEL: func @test_trigger_rewrite_through_block +// CHECK-AN: "test.change_block_op"()[^[[BB0:.*]], ^[[BB0]]] +// CHECK-AN: return +// CHECK-AN: ^[[BB1:[^:]*]]: +// CHECK-AN: "test.implicit_change_op"()[^[[BB1]]] +func.func @test_trigger_rewrite_through_block() { + return +^bb1: + // Uses bb1. ChangeBlockOp replaces that and all other usages of bb1 with bb2. + "test.change_block_op"() [^bb1, ^bb2] : () -> () +^bb2: + return +^bb3: + // Also uses bb1. ChangeBlockOp replaces that usage with bb2. This triggers + // this op being put on the worklist, which triggers ImplicitChangeOp, which, + // in turn, replaces the successor with bb3. + "test.implicit_change_op"() [^bb1] : () -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -256,11 +256,19 @@ void runOnOperation() override { MLIRContext *ctx = &getContext(); mlir::RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add< + // clang-format off + InsertSameOp, + ReplaceWithNewOp, + EraseOp, + ChangeBlockOp, + ImplicitChangeOp + // clang-format on + >(ctx); SmallVector ops; getOperation()->walk([&](Operation *op) { StringRef opName = op->getName().getStringRef(); - if (opName == "test.insert_same_op" || + if (opName == "test.insert_same_op" || opName == "test.change_block_op" || opName == "test.replace_with_new_op" || opName == "test.erase_op") { ops.push_back(op); } @@ -342,7 +350,7 @@ } }; - // Remove an operation may introduce the re-visiting of its opreands. + // Remove an operation may introduce the re-visiting of its operands. class EraseOp : public RewritePattern { public: EraseOp(MLIRContext *context) @@ -353,6 +361,55 @@ return success(); } }; + + // The following two patterns test RewriterBase::replaceAllUsesWith. + // + // That function replaces all usages of a Block (or a Value) with another one + // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver + // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its + // worklist: when an op is modified, it is added to the worklist. The two + // patterns below make the tracking observable: ChangeBlockOp replaces all + // usages of a block and that pattern is applied because the corresponding ops + // are put on the initial worklist (see above). ImplicitChangeOp does an + // unrelated change but ops of the corresponding type are *not* on the initial + // worklist, so the effect of the second pattern is only visible if the + // tracking and subsequent adding to the worklist actually works. + + // Replace all usages of the first successor with the second successor. + class ChangeBlockOp : public RewritePattern { + public: + ChangeBlockOp(MLIRContext *context) + : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->getNumSuccessors() < 2) + return failure(); + Block *firstSuccessor = op->getSuccessor(0); + Block *secondSuccessor = op->getSuccessor(1); + if (firstSuccessor == secondSuccessor) + return failure(); + // This is the function being tested: + rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor); + // Using the following line instead would make the test fail: + // firstSuccessor->replaceAllUsesWith(secondSuccessor); + return success(); + } + }; + + // Changes the successor to the parent block. + class ImplicitChangeOp : public RewritePattern { + public: + ImplicitChangeOp(MLIRContext *context) + : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) + return failure(); + rewriter.updateRootInPlace( + op, [&]() { op->setSuccessor(op->getBlock(), 0); }); + return success(); + } + }; }; } // namespace