diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -380,22 +380,6 @@ /// Creates a memcpy between two given buffers. void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const; - /// Replace an op with replacement values. The op is deleted. Tensor OpResults - /// must be replaced with memref values. - void replaceOp(RewriterBase &rewriter, Operation *op, - ValueRange values) const; - - /// Replace an op with a new op. Tensor OpResults must be replaced with memref - /// values. - template - OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op, - Args &&...args) const { - Operation *newOp = - rewriter.create(op->getLoc(), std::forward(args)...); - replaceOp(rewriter, op, newOp->getResults()); - return cast(newOp); - } - /// Lookup the memref buffer that is associated to the given tensor value. /// Asserts if no buffer is associated. Value lookupBuffer(RewriterBase &rewriter, Value tensor) const; @@ -444,6 +428,21 @@ const BufferizationOptions &options; }; +/// Replace an op with replacement values. The op is deleted. Tensor OpResults +/// must be replaced with memref values. +void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, + ValueRange values); + +/// Replace an op with a new op. Tensor OpResults must be replaced with memref +/// values. +template +OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op, + Args &&...args) { + auto newOp = rewriter.create(op->getLoc(), std::forward(args)...); + replaceOpWithBufferizedValues(rewriter, op, newOp->getResults()); + return newOp; +} + /// Return a contiguous MemRefType (i.e. with canonical/empty layout map) /// with the same shape as `shapedType` and specified `layout` and /// `addressSpace`. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -35,7 +35,7 @@ GlobalCreator globalCreator(moduleOp); auto globalMemref = globalCreator.getGlobalFor(constantOp); - state.replaceOpWithNewOp( + replaceOpWithNewBufferizedOp( rewriter, op, globalMemref.type(), globalMemref.getName()); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -422,8 +422,8 @@ return operandBuffer; } -void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp( - RewriterBase &rewriter, Operation *op, ValueRange values) const { +void mlir::linalg::comprehensive_bufferize::replaceOpWithBufferizedValues( + RewriterBase &rewriter, Operation *op, ValueRange values) { OpBuilder::InsertionGuard g(rewriter); // Replace all OpResults with the given values. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -67,7 +67,7 @@ op.clone(rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands); // Replace the results of the old op with the new output buffers. - state.replaceOp(rewriter, op, newOutputBuffers); + replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); return success(); } @@ -201,7 +201,7 @@ Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(), initTensorOp.result()); - state.replaceOp(rewriter, op, alloc); + replaceOpWithBufferizedValues(rewriter, op, alloc); return success(); } }; @@ -342,7 +342,7 @@ rewriter.eraseOp(oldTerminator); // Replace results and delete old op. - state.replaceOp(rewriter, op, newResults); + replaceOpWithBufferizedValues(rewriter, op, newResults); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -634,7 +634,7 @@ } // 5. Replace the old op with the new op. - state.replaceOp(rewriter, callOp, replacementValues); + replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -192,7 +192,7 @@ } // Replace op results. - state.replaceOp(rewriter, op, newIfOp->getResults()); + replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); return success(); } @@ -326,7 +326,7 @@ yieldOp.getResultsMutable().assign(yieldValues); // Replace loop results. - state.replaceOp(rewriter, op, newForOp->getResults()); + replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -68,8 +68,8 @@ : MemRefLayoutAttrInterface(); Type memRefType = getContiguousOrUnrankedMemRefType( castOp.getResult().getType(), layout, memorySpace); - state.replaceOpWithNewOp(rewriter, op, memRefType, - resultBuffer); + replaceOpWithNewBufferizedOp(rewriter, op, memRefType, + resultBuffer); return success(); } }; @@ -98,7 +98,7 @@ if (!dimOp.source().getType().isa()) return dimOp.emitError("unranked tensor not supported"); Value v = state.lookupBuffer(rewriter, dimOp.source()); - state.replaceOpWithNewOp(rewriter, op, v, dimOp.index()); + replaceOpWithNewBufferizedOp(rewriter, op, v, dimOp.index()); return success(); } }; @@ -164,7 +164,7 @@ subView = alloc; } - state.replaceOp(rewriter, op, subView); + replaceOpWithBufferizedValues(rewriter, op, subView); return success(); } }; @@ -191,8 +191,8 @@ const BufferizationState &state) const { auto extractOp = cast(op); Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor()); - state.replaceOpWithNewOp(rewriter, op, srcMemref, - extractOp.indices()); + replaceOpWithNewBufferizedOp(rewriter, op, srcMemref, + extractOp.indices()); return success(); } }; @@ -231,7 +231,7 @@ state.getResultBuffer(rewriter, insertOp->getOpResult(0)); rewriter.create(loc, insertOp.scalar(), destMemref, insertOp.indices()); - state.replaceOp(rewriter, op, destMemref); + replaceOpWithBufferizedValues(rewriter, op, destMemref); return success(); } @@ -413,7 +413,7 @@ Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source()); state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView); - state.replaceOp(rewriter, op, dstMemref); + replaceOpWithBufferizedValues(rewriter, op, dstMemref); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -47,11 +47,10 @@ // TransferReadOp always reads from the bufferized op.source(). Value buffer = state.lookupBuffer(rewriter, readOp.source()); - Value read = rewriter.create( - readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(), + replaceOpWithNewBufferizedOp( + rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(), readOp.permutation_map(), readOp.padding(), readOp.mask(), readOp.in_boundsAttr()); - state.replaceOp(rewriter, op, read); return success(); } }; @@ -101,7 +100,7 @@ rewriter.create( writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); - state.replaceOp(rewriter, op, resultBuffer); + replaceOpWithBufferizedValues(rewriter, op, resultBuffer); return success(); }