diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -406,13 +406,24 @@ virtual void notifyOperationModified(Operation *op) {} /// Notify the listener that the specified operation is about to be replaced - /// with the set of values potentially produced by new operations. This is - /// called before the uses of the operation have been changed. + /// with another operation. This is called before the uses of the old + /// operation have been changed. + /// + /// By default, this function calls the "operation replaced with values" + /// notification. + virtual void notifyOperationReplaced(Operation *op, + Operation *replacement) { + notifyOperationReplaced(op, replacement->getResults()); + } + + /// Notify the listener that the specified operation is about to be replaced + /// with the a range of values, potentially produced by other operations. + /// This is called before the uses of the operation have been changed. virtual void notifyOperationReplaced(Operation *op, ValueRange replacement) {} - /// This is called on an operation that a rewrite is removing, right before - /// the operation is deleted. At this point, the operation has zero uses. + /// Notify the listener that the specified operation is about to be erased. + /// At this point, the operation has zero uses. virtual void notifyOperationRemoved(Operation *op) {} /// Notify the listener that the pattern failed to match the given @@ -444,6 +455,9 @@ void notifyOperationModified(Operation *op) override { listener->notifyOperationModified(op); } + void notifyOperationReplaced(Operation *op, Operation *newOp) override { + listener->notifyOperationReplaced(op, newOp); + } void notifyOperationReplaced(Operation *op, ValueRange replacement) override { listener->notifyOperationReplaced(op, replacement); @@ -505,15 +519,20 @@ /// This method replaces the results of the operation with the specified list /// of values. The number of provided values must match the number of results - /// of the operation. + /// of the operation. The replaced op is erased. virtual void replaceOp(Operation *op, ValueRange newValues); + /// This method replaces the results of the operation with the specified + /// new op (replacement). The number of results of the two operations must + /// match. The replaced op is erased. + virtual void replaceOp(Operation *op, Operation *newOp); + /// Replaces the result op with a new op that is created without verification. /// The result values of the two ops must be the same types. template OpTy replaceOpWithNewOp(Operation *op, Args &&...args) { auto newOp = create(op->getLoc(), std::forward(args)...); - replaceOpWithResultsOfAnotherOp(op, newOp.getOperation()); + replaceOp(op, newOp.getOperation()); return newOp; } @@ -666,10 +685,6 @@ private: void operator=(const RewriterBase &) = delete; RewriterBase(const RewriterBase &) = delete; - - /// 'op' and 'newOp' are known to have the same number of results, replace the - /// uses of op with uses of newOp. - void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp); }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -695,15 +695,17 @@ /// patterns even if a failure is encountered during the rewrite step. bool canRecoverFromRewriteFailure() const override { return true; } - /// PatternRewriter hook for replacing the results of an operation when the - /// given functor returns true. + /// PatternRewriter hook for replacing an operation when the given functor + /// returns "true". void replaceOpWithIf( Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function functor) override; - /// PatternRewriter hook for replacing the results of an operation. + /// PatternRewriter hook for replacing an operation. void replaceOp(Operation *op, ValueRange newValues) override; - using PatternRewriter::replaceOp; + + /// PatternRewriter hook for replacing an operation. + void replaceOp(Operation *op, Operation *newOp) override; /// PatternRewriter hook for erasing a dead operation. The uses of this /// operation *must* be made dead by the end of the conversion process, diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp --- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp @@ -139,7 +139,7 @@ loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare); rewriter.create(loc, canLeave, afterAtomic, ValueRange{}, loopBlock, atomicRes); - rewriter.replaceOp(atomicOp, {}); + rewriter.eraseOp(atomicOp); return success(); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -331,10 +331,7 @@ alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(), alloc.getAlignmentAttr()); // Insert a cast so we have the same type as the old alloc. - auto resultCast = - rewriter.create(alloc.getLoc(), alloc.getType(), newAlloc); - - rewriter.replaceOp(alloc, {resultCast}); + rewriter.replaceOpWithNewOp(alloc, alloc.getType(), newAlloc); return success(); } }; diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -262,12 +262,12 @@ /// This method replaces the results of the operation with the specified list of /// values. The number of provided values must match the number of results of -/// the operation. +/// the operation. The replaced op is erased. void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); - // Notify the listener that we're about to remove this op. + // Notify the listener that we're about to replace this op. if (auto *rewriteListener = dyn_cast_if_present(listener)) rewriteListener->notifyOperationReplaced(op, newValues); @@ -275,9 +275,28 @@ for (auto it : llvm::zip(op->getResults(), newValues)) replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); + // Erase the op. + eraseOp(op); +} + +/// This method replaces the results of the operation with the specified new op +/// (replacement). The number of results of the two operations must match. The +/// replaced op is erased. +void RewriterBase::replaceOp(Operation *op, Operation *newOp) { + assert(op && newOp && "expected non-null op"); + assert(op->getNumResults() == newOp->getNumResults() && + "ops have different number of results"); + + // Notify the listener that we're about to replace this op. if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationRemoved(op); - op->erase(); + rewriteListener->notifyOperationReplaced(op, newOp); + + // Replace results one-by-one. Also notifies the listener of modifications. + for (auto it : llvm::zip(op->getResults(), newOp->getResults())) + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); + + // Erase the old op. + eraseOp(op); } /// This method erases an operation that is known to have no uses. The uses of @@ -364,17 +383,6 @@ return block->splitBlock(before); } -/// 'op' and 'newOp' are known to have the same number of results, replace the -/// uses of op with uses of newOp -void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op, - Operation *newOp) { - assert(op->getNumResults() == newOp->getNumResults() && - "replacement op doesn't match results of original op"); - if (op->getNumResults() == 1) - return replaceOp(op, newOp->getResult(0)); - return replaceOp(op, newOp->getResults()); -} - /// Move the blocks that belong to "region" before the given position in /// another region. The two regions must be different. The caller is in /// charge to update create the operation transferring the control flow to the diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1452,7 +1452,14 @@ "replaceOpWithIf is currently not supported by DialectConversion"); } +void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { + assert(op && newOp && "expected non-null op"); + replaceOp(op, newOp->getResults()); +} + void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { + assert(op->getNumResults() == newValues.size() && + "incorrect # of replacement values"); LLVM_DEBUG({ impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -601,7 +601,7 @@ Location loc = op->getLoc(); rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); rewriter.create(loc); - rewriter.replaceOp(op, {}); + rewriter.eraseOp(op); return success(); } }; @@ -621,7 +621,7 @@ // Create an illegal op to ensure the conversion fails. rewriter.create(loc, i32Type); rewriter.create(loc); - rewriter.replaceOp(op, {}); + rewriter.eraseOp(op); return success(); } }; @@ -793,8 +793,8 @@ auto illegalOp = rewriter.create(op->getLoc(), resultType); auto legalOp = rewriter.create(op->getLoc(), resultType); - rewriter.replaceOp(illegalOp, {legalOp}); - rewriter.replaceOp(op, {illegalOp}); + rewriter.replaceOp(illegalOp, legalOp); + rewriter.replaceOp(op, illegalOp); return success(); } };