diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp @@ -48,7 +48,7 @@ return matchFailure(); // Use the rewriter to perform the replacement. - rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + rewriter.replaceOp(op, {transposeInputOp.getOperand()}); return matchSuccess(); } }; diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -53,7 +53,7 @@ return matchFailure(); // Use the rewriter to perform the replacement. - rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + rewriter.replaceOp(op, {transposeInputOp.getOperand()}); return matchSuccess(); } }; diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -53,7 +53,7 @@ return matchFailure(); // Use the rewriter to perform the replacement. - rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + rewriter.replaceOp(op, {transposeInputOp.getOperand()}); return matchSuccess(); } }; diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp @@ -53,7 +53,7 @@ return matchFailure(); // Use the rewriter to perform the replacement. - rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + rewriter.replaceOp(op, {transposeInputOp.getOperand()}); return matchSuccess(); } }; diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp @@ -71,7 +71,7 @@ return matchFailure(); // Use the rewriter to perform the replacement. - rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + rewriter.replaceOp(op, {transposeInputOp.getOperand()}); return matchSuccess(); } }; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -318,33 +318,15 @@ /// This method performs the final replacement for a pattern, where the /// results of the operation are updated to use the specified list of SSA - /// values. In addition to replacing and removing the specified operation, - /// clients can specify a list of other nodes that this replacement may make - /// (perhaps transitively) dead. If any of those values are dead, this will - /// remove them as well. - virtual void replaceOp(Operation *op, ValueRange newValues, - ValueRange valuesToRemoveIfDead); - void replaceOp(Operation *op, ValueRange newValues) { - replaceOp(op, newValues, llvm::None); - } + /// values. + virtual void replaceOp(Operation *op, ValueRange newValues); /// Replaces the result op with a new op that is created without verification. /// The result values of the two ops must be the same types. template void replaceOpWithNewOp(Operation *op, Args &&... args) { auto newOp = create(op->getLoc(), std::forward(args)...); - replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {}); - } - - /// Replaces the result op with a new op that is created without verification. - /// The result values of the two ops must be the same types. This allows - /// specifying a list of ops that may be removed if dead. - template - void replaceOpWithNewOp(ValueRange valuesToRemoveIfDead, Operation *op, - Args &&... args) { - auto newOp = create(op->getLoc(), std::forward(args)...); - replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), - valuesToRemoveIfDead); + replaceOpWithResultsOfAnotherOp(op, newOp.getOperation()); } /// This method erases an operation that is known to have no uses. @@ -405,10 +387,9 @@ virtual void notifyOperationRemoved(Operation *op) {} private: - /// op and newOp are known to have the same number of results, replace the - /// uses of op with uses of newOp - void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp, - ValueRange valuesToRemoveIfDead); + /// 'op' and 'newOp' are known to have the same number of results, replace the + /// uses of op with uses of newOp. + void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp); }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -332,8 +332,7 @@ //===--------------------------------------------------------------------===// /// PatternRewriter hook for replacing the results of an operation. - void replaceOp(Operation *op, ValueRange newValues, - ValueRange valuesToRemoveIfDead) override; + void replaceOp(Operation *op, ValueRange newValues) override; using PatternRewriter::replaceOp; /// PatternRewriter hook for erasing a dead operation. The uses of this diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -90,8 +90,8 @@ rewriter.getContext()); auto newConstOp = rewriter.create(fusedLoc, newConstValueType, newConstValue); - rewriter.replaceOpWithNewOp({qbarrier.arg()}, qbarrier, - qbarrier.getType(), newConstOp); + rewriter.replaceOpWithNewOp(qbarrier, qbarrier.getType(), + newConstOp); return matchSuccess(); } diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -328,7 +328,6 @@ SmallVector newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); SmallVector newOperands; - SmallVector droppedOperands; unsigned dynamicDimPos = 0; for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { @@ -342,8 +341,6 @@ if (auto constantIndexOp = dyn_cast_or_null(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); - // Record to check for zero uses later below. - droppedOperands.push_back(constantIndexOp); } else { // Dynamic shape dimension not folded; copy operand from old memref. newShapeConstants.push_back(-1); @@ -366,7 +363,7 @@ auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, alloc.getType()); - rewriter.replaceOp(alloc, {resultCast}, droppedOperands); + rewriter.replaceOp(alloc, {resultCast}); return matchSuccess(); } }; @@ -2447,7 +2444,6 @@ return matchFailure(); SmallVector newOperands; - SmallVector droppedOperands; // Fold dynamic offset operand if it is produced by a constant. auto dynamicOffset = viewOp.getDynamicOffset(); @@ -2458,7 +2454,6 @@ if (auto constantIndexOp = dyn_cast_or_null(defOp)) { // Dynamic offset will be folded into the map. newOffset = constantIndexOp.getValue(); - droppedOperands.push_back(dynamicOffset); } else { // Unable to fold dynamic offset. Add it to 'newOperands' list. newOperands.push_back(dynamicOffset); @@ -2483,8 +2478,6 @@ if (auto constantIndexOp = dyn_cast_or_null(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); - // Record to check for zero uses later below. - droppedOperands.push_back(constantIndexOp); } else { // Dynamic shape dimension not folded; copy operand from old memref. newShapeConstants.push_back(dimSize); @@ -2522,8 +2515,8 @@ auto newViewOp = rewriter.create(viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), newOperands); // Insert a cast so we have the same type as the old memref type. - rewriter.replaceOpWithNewOp(droppedOperands, viewOp, - newViewOp, viewOp.getType()); + rewriter.replaceOpWithNewOp(viewOp, newViewOp, + viewOp.getType()); return matchSuccess(); } }; @@ -2542,8 +2535,8 @@ AllocOp allocOp = dyn_cast_or_null(allocOperand.getDefiningOp()); if (!allocOp) return matchFailure(); - rewriter.replaceOpWithNewOp(memrefOperand, viewOp, viewOp.getType(), - allocOperand, viewOp.operands()); + rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), allocOperand, + viewOp.operands()); return matchSuccess(); } }; @@ -2839,8 +2832,8 @@ subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), ArrayRef(), subViewOp.strides(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp( - subViewOp.sizes(), subViewOp, newSubViewOp, subViewOp.getType()); + rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; @@ -2889,8 +2882,8 @@ subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), subViewOp.sizes(), ArrayRef(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp( - subViewOp.strides(), subViewOp, newSubViewOp, subViewOp.getType()); + rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; @@ -2941,8 +2934,8 @@ subViewOp.getLoc(), subViewOp.source(), ArrayRef(), subViewOp.sizes(), subViewOp.strides(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp( - subViewOp.offsets(), subViewOp, newSubViewOp, subViewOp.getType()); + rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" + using namespace mlir; PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { @@ -72,12 +73,8 @@ /// This method performs the final replacement for a pattern, where the /// results of the operation are updated to use the specified list of SSA -/// values. In addition to replacing and removing the specified operation, -/// clients can specify a list of other nodes that this replacement may make -/// (perhaps transitively) dead. If any of those ops are dead, this will -/// remove them as well. -void PatternRewriter::replaceOp(Operation *op, ValueRange newValues, - ValueRange valuesToRemoveIfDead) { +/// values. +void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) { // Notify the rewriter subclass that we're about to replace this root. notifyRootReplaced(op); @@ -87,9 +84,6 @@ notifyOperationRemoved(op); op->erase(); - - // TODO: Process the valuesToRemoveIfDead list, removing things and calling - // the notifyOperationRemoved hook in the process. } /// This method erases an operation that is known to have no uses. The uses of @@ -129,15 +123,15 @@ return block->splitBlock(before); } -/// op and newOp are known to have the same number of results, replace the +/// 'op' and 'newOp' are known to have the same number of results, replace the /// uses of op with uses of newOp -void PatternRewriter::replaceOpWithResultsOfAnotherOp( - Operation *op, Operation *newOp, ValueRange valuesToRemoveIfDead) { +void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op, + Operation *newOp) { assert(op->getNumResults() == newOp->getNumResults() && "replacement op doesn't match results of original op"); if (op->getNumResults() == 1) - return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead); - return replaceOp(op, newOp->getResults(), valuesToRemoveIfDead); + return replaceOp(op, newOp->getResult(0)); + return replaceOp(op, newOp->getResults()); } /// Move the blocks that belong to "region" before the given position in diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -554,8 +554,7 @@ TypeConverter::SignatureConversion &conversion); /// PatternRewriter hook for replacing the results of an operation. - void replaceOp(Operation *op, ValueRange newValues, - ValueRange valuesToRemoveIfDead); + void replaceOp(Operation *op, ValueRange newValues); /// Notifies that a block was split. void notifySplitBlock(Block *block, Block *continuation); @@ -757,8 +756,7 @@ } void ConversionPatternRewriterImpl::replaceOp(Operation *op, - ValueRange newValues, - ValueRange valuesToRemoveIfDead) { + ValueRange newValues) { assert(newValues.size() == op->getNumResults()); // Create mappings for each of the new result values. @@ -838,11 +836,11 @@ ConversionPatternRewriter::~ConversionPatternRewriter() {} /// PatternRewriter hook for replacing the results of an operation. -void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues, - ValueRange valuesToRemoveIfDead) { +void ConversionPatternRewriter::replaceOp(Operation *op, + ValueRange newValues) { LLVM_DEBUG(llvm::dbgs() << "** Replacing operation : " << op->getName() << "\n"); - impl->replaceOp(op, newValues, valuesToRemoveIfDead); + impl->replaceOp(op, newValues); } /// PatternRewriter hook for erasing a dead operation. The uses of this @@ -852,7 +850,7 @@ LLVM_DEBUG(llvm::dbgs() << "** Erasing operation : " << op->getName() << "\n"); SmallVector nullRepls(op->getNumResults(), nullptr); - impl->replaceOp(op, nullRepls, /*valuesToRemoveIfDead=*/llvm::None); + impl->replaceOp(op, nullRepls); } /// Apply a signature conversion to the entry block of the given region.