diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -383,17 +383,18 @@ /// Replace an op with replacement values. The op is deleted. Tensor OpResults /// must be replaced with memref values. - void replaceOp(RewriterBase &rewriter, Operation *op, ValueRange values); + void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, + ValueRange values); /// Replace an op with a new op. Tensor OpResults must be replaced with memref /// values. template - OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op, - Args &&...args) { - Operation *newOp = + OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op, + Args &&...args) { + auto newOp = rewriter.create(op->getLoc(), std::forward(args)...); - replaceOp(rewriter, op, newOp->getResults()); - return cast(newOp); + replaceOpWithBufferizedValues(rewriter, op, newOp->getResults()); + return newOp; } /// Lookup the memref buffer that is associated to the given tensor value. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -35,7 +35,7 @@ GlobalCreator globalCreator(moduleOp); auto globalMemref = globalCreator.getGlobalFor(constantOp); - state.replaceOpWithNewOp( + state.replaceOpWithNewBufferizedOp( rewriter, op, globalMemref.type(), globalMemref.getName()); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -423,8 +423,9 @@ return operandBuffer; } -void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp( - RewriterBase &rewriter, Operation *op, ValueRange values) { +void mlir::linalg::comprehensive_bufferize::BufferizationState:: + replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, + ValueRange values) { OpBuilder::InsertionGuard g(rewriter); // Replace all OpResults with the given values. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -67,7 +67,7 @@ op.clone(rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands); // Replace the results of the old op with the new output buffers. - state.replaceOp(rewriter, op, newOutputBuffers); + state.replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); return success(); } @@ -201,7 +201,7 @@ Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(), initTensorOp.result()); - state.replaceOp(rewriter, op, alloc); + state.replaceOpWithBufferizedValues(rewriter, op, alloc); return success(); } }; @@ -333,7 +333,7 @@ rewriter.eraseOp(oldTerminator); // Replace results and delete old op. - state.replaceOp(rewriter, op, newResults); + state.replaceOpWithBufferizedValues(rewriter, op, newResults); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -639,7 +639,7 @@ } // 5. Replace the old op with the new op. - state.replaceOp(rewriter, callOp, replacementValues); + state.replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -193,7 +193,7 @@ } // Replace op results. - state.replaceOp(rewriter, op, newIfOp->getResults()); + state.replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); return success(); } @@ -326,7 +326,7 @@ yieldOp.getResultsMutable().assign(yieldValues); // Replace loop results. - state.replaceOp(rewriter, op, newForOp->getResults()); + state.replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -83,8 +83,8 @@ : MemRefLayoutAttrInterface(); Type memRefType = getContiguousOrUnrankedMemRefType( castOp.getResult().getType(), layout, memorySpace); - state.replaceOpWithNewOp(rewriter, op, memRefType, - resultBuffer); + state.replaceOpWithNewBufferizedOp(rewriter, op, memRefType, + resultBuffer); return success(); } }; @@ -113,7 +113,8 @@ if (!dimOp.source().getType().isa()) return dimOp.emitError("unranked tensor not supported"); Value v = state.lookupBuffer(rewriter, dimOp.source()); - state.replaceOpWithNewOp(rewriter, op, v, dimOp.index()); + state.replaceOpWithNewBufferizedOp(rewriter, op, v, + dimOp.index()); return success(); } }; @@ -179,7 +180,7 @@ subView = alloc; } - state.replaceOp(rewriter, op, subView); + state.replaceOpWithBufferizedValues(rewriter, op, subView); return success(); } }; @@ -206,8 +207,8 @@ BufferizationState &state) const { auto extractOp = cast(op); Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor()); - state.replaceOpWithNewOp(rewriter, op, srcMemref, - extractOp.indices()); + state.replaceOpWithNewBufferizedOp(rewriter, op, srcMemref, + extractOp.indices()); return success(); } }; @@ -246,7 +247,7 @@ state.getResultBuffer(rewriter, insertOp->getOpResult(0)); rewriter.create(loc, insertOp.scalar(), destMemref, insertOp.indices()); - state.replaceOp(rewriter, op, destMemref); + state.replaceOpWithBufferizedValues(rewriter, op, destMemref); return success(); } @@ -451,7 +452,7 @@ loc, insertSliceOp.source()); } - state.replaceOp(rewriter, op, dstMemref); + state.replaceOpWithBufferizedValues(rewriter, op, dstMemref); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -47,11 +47,10 @@ // TransferReadOp always reads from the bufferized op.source(). Value buffer = state.lookupBuffer(rewriter, readOp.source()); - Value read = rewriter.create( - readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(), + state.replaceOpWithNewBufferizedOp( + rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(), readOp.permutation_map(), readOp.padding(), readOp.mask(), readOp.in_boundsAttr()); - state.replaceOp(rewriter, op, read); return success(); } }; @@ -101,7 +100,7 @@ rewriter.create( writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); - state.replaceOp(rewriter, op, resultBuffer); + state.replaceOpWithBufferizedValues(rewriter, op, resultBuffer); return success(); }