diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -297,8 +297,7 @@ /// * `replaceOp` replaces an op with new values. class BufferizationState { public: - BufferizationState(Operation *op, const BufferizationOptions &options, - RewriterBase &rewriter); + BufferizationState(Operation *op, const BufferizationOptions &options); // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; @@ -384,7 +383,7 @@ /// Replace an op with replacement values. The op is deleted. Tensor OpResults /// must be replaced with memref values. - void replaceOp(Operation *op, ValueRange values); + void replaceOp(RewriterBase &rewriter, Operation *op, ValueRange values); /// Replace an op with a new op. Tensor OpResults must be replaced with memref /// values. @@ -393,13 +392,13 @@ Args &&...args) { Operation *newOp = rewriter.create(op->getLoc(), std::forward(args)...); - replaceOp(op, newOp->getResults()); + replaceOp(rewriter, op, newOp->getResults()); return cast(newOp); } /// Lookup the memref buffer that is associated to the given tensor value. /// Asserts if no buffer is associated. - Value lookupBuffer(Value tensor); + Value lookupBuffer(RewriterBase &rewriter, Value tensor); /// Return `true` if the given OpResult has been decided to bufferize inplace. bool isInPlace(OpResult opResult) const; @@ -407,7 +406,7 @@ /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. - Value getResultBuffer(OpResult result); + Value getResultBuffer(RewriterBase &rewriter, OpResult result); /// Return dialect-specific bufferization state. template StateT &getDialectState(StringRef name) { @@ -420,9 +419,6 @@ /// Return a reference to the BufferizationOptions. const BufferizationOptions &getOptions() const { return options; } - /// Return a reference to the rewriter. - RewriterBase &getRewriter() { return rewriter; } - private: friend LogicalResult runComprehensiveBufferize(Operation *op, const BufferizationOptions &options, @@ -441,21 +437,21 @@ /// A reference to current bufferization options. const BufferizationOptions &options; - - /// The OpBuilder used during bufferization. - RewriterBase &rewriter; }; /// Bufferize all ops in the given region. -LogicalResult bufferize(Region *region, BufferizationState &state); +LogicalResult bufferize(RewriterBase &rewriter, Region *region, + BufferizationState &state); /// Bufferize all ops in the given block. -LogicalResult bufferize(Block *block, BufferizationState &state); +LogicalResult bufferize(RewriterBase &rewriter, Block *block, + BufferizationState &state); /// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this /// function returns immediately. Otherwise, it calls the `bufferize` interface /// method of `BufferizableOpInterface`. -LogicalResult bufferize(Operation *op, BufferizationState &state); +LogicalResult bufferize(RewriterBase &rewriter, Operation *op, + BufferizationState &state); /// Return a contiguous MemRefType (i.e. with canonical/empty layout map) /// with the same shape as `shapedType` and specified `layout` and @@ -535,7 +531,7 @@ return op->emitError() << "unsupported op with tensors"; for (Region ®ion : op->getRegions()) - if (failed(comprehensive_bufferize::bufferize(®ion, state))) + if (failed(comprehensive_bufferize::bufferize(rewriter, ®ion, state))) return failure(); return success(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -333,8 +333,8 @@ } mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState( - Operation *op, const BufferizationOptions &options, RewriterBase &rewriter) - : aliasInfo(op), options(options), rewriter(rewriter) { + Operation *op, const BufferizationOptions &options) + : aliasInfo(op), options(options) { // Set up alias sets for OpResults that must bufferize in-place. This should // be done before making any other bufferization decisions. op->walk([&](BufferizableOpInterface bufferizableOp) { @@ -360,14 +360,14 @@ /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. Value mlir::linalg::comprehensive_bufferize::BufferizationState:: - getResultBuffer(OpResult result) { + getResultBuffer(RewriterBase &rewriter, OpResult result) { OpBuilder::InsertionGuard guard(rewriter); Operation *op = result.getOwner(); SmallVector aliasingOperands = getAliasingOpOperand(result); assert(!aliasingOperands.empty() && "could not get aliasing OpOperand"); OpOperand *opOperand = aliasingOperands.front(); Value operand = opOperand->get(); - Value operandBuffer = lookupBuffer(operand); + Value operandBuffer = lookupBuffer(rewriter, operand); // Make sure that all OpOperands are the same buffer. If this is not the case, // we would have to materialize a memref value. // TODO: Should be looking for checking for "equivalent buffers" instead of @@ -375,7 +375,7 @@ // set up yet. if (aliasingOperands.size() > 1 && !llvm::all_of(aliasingOperands, [&](OpOperand *o) { - return lookupBuffer(o->get()) == operandBuffer; + return lookupBuffer(rewriter, o->get()) == operandBuffer; })) { op->emitError("result buffer is ambiguous"); return Value(); @@ -424,7 +424,7 @@ } void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp( - Operation *op, ValueRange values) { + RewriterBase &rewriter, Operation *op, ValueRange values) { OpBuilder::InsertionGuard g(rewriter); // Replace all OpResults with the given values. @@ -453,18 +453,16 @@ rewriter.eraseOp(op); } -LogicalResult -mlir::linalg::comprehensive_bufferize::bufferize(Region *region, - BufferizationState &state) { +LogicalResult mlir::linalg::comprehensive_bufferize::bufferize( + RewriterBase &rewriter, Region *region, BufferizationState &state) { for (Block &block : *region) - if (failed(bufferize(&block, state))) + if (failed(bufferize(rewriter, &block, state))) return failure(); return success(); } -LogicalResult -mlir::linalg::comprehensive_bufferize::bufferize(Block *block, - BufferizationState &state) { +LogicalResult mlir::linalg::comprehensive_bufferize::bufferize( + RewriterBase &rewriter, Block *block, BufferizationState &state) { // Ops may get deleted during the traversal, so do not iterate over `block` // directly. SmallVector ops; @@ -472,16 +470,13 @@ for (Operation &op : *block) ops.push_back(&op); for (Operation *op : ops) - if (failed(bufferize(op, state))) + if (failed(bufferize(rewriter, op, state))) return failure(); return success(); } -LogicalResult -mlir::linalg::comprehensive_bufferize::bufferize(Operation *op, - BufferizationState &state) { - RewriterBase &rewriter = state.getRewriter(); - +LogicalResult mlir::linalg::comprehensive_bufferize::bufferize( + RewriterBase &rewriter, Operation *op, BufferizationState &state) { // Check if op has tensor results or operands. auto isaTensor = [](Type t) { return t.isa(); }; bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); @@ -505,7 +500,7 @@ // Bufferize all regions. for (Region ®ion : op->getRegions()) - if (failed(bufferize(®ion, state))) + if (failed(bufferize(rewriter, ®ion, state))) return failure(); return success(); @@ -654,7 +649,7 @@ } Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer( - Value tensor) { + RewriterBase &rewriter, Value tensor) { assert(tensor.getType().isa() && "unexpected non-tensor type"); // Replace "%t = to_tensor %m" with %m. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -654,8 +654,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( Operation *op, std::unique_ptr options) { - IRRewriter rewriter(op->getContext()); - BufferizationState state(op, *options, rewriter); + BufferizationState state(op, *options); return runComprehensiveBufferize(op, *options, state); } @@ -663,6 +662,7 @@ Operation *op, const BufferizationOptions &options, BufferizationState &state) { + IRRewriter rewriter(op->getContext()); DominanceInfo domInfo(op); BufferizationAliasInfo &aliasInfo = state.aliasInfo; @@ -693,7 +693,7 @@ } // Bufferize the op and its nested ops. - if (failed(bufferize(op, state))) + if (failed(bufferize(rewriter, op, state))) return failure(); return success(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -45,14 +45,14 @@ newInputBuffers.push_back(opOperand->get()); continue; } - newInputBuffers.push_back(state.lookupBuffer(opOperand->get())); + newInputBuffers.push_back(state.lookupBuffer(rewriter, opOperand->get())); } SmallVector newOutputBuffers; for (OpOperand *opOperand : op.getOutputOperands()) { OpResult opResult = op.getTiedOpResult(opOperand); assert(opResult && "could not find correspond OpResult"); - Value resultBuffer = state.getResultBuffer(opResult); + Value resultBuffer = state.getResultBuffer(rewriter, opResult); if (!resultBuffer) return failure(); newOutputBuffers.push_back(resultBuffer); @@ -68,9 +68,10 @@ rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); // Replace the results of the old op with the new output buffers. - state.replaceOp(op, newOutputBuffers); + state.replaceOp(rewriter, op, newOutputBuffers); - return comprehensive_bufferize::bufferize(bufferizedOp.getBlock(), state); + return comprehensive_bufferize::bufferize(rewriter, bufferizedOp.getBlock(), + state); } /// Linalg OpResults usually bufferize inplace with their tied (output @@ -202,7 +203,7 @@ Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(), initTensorOp.result()); - state.replaceOp(op, alloc); + state.replaceOp(rewriter, op, alloc); return success(); } }; @@ -259,7 +260,7 @@ SmallVector newInputs, newOutputs, newResults; for (Value value : tiledLoopOp.inputs()) { if (value.getType().isa()) { - newInputs.push_back(state.lookupBuffer(value)); + newInputs.push_back(state.lookupBuffer(rewriter, value)); } else { newInputs.push_back(value); } @@ -267,8 +268,8 @@ int nextResultNum = 0; for (Value value : tiledLoopOp.outputs()) { if (value.getType().isa()) { - Value buffer = - state.getResultBuffer(tiledLoopOp->getResult(nextResultNum++)); + Value buffer = state.getResultBuffer( + rewriter, tiledLoopOp->getResult(nextResultNum++)); newOutputs.push_back(buffer); newResults.push_back(buffer); } else { @@ -328,10 +329,11 @@ rewriter.eraseOp(oldTerminator); // Replace results and delete old op. - state.replaceOp(op, newResults); + state.replaceOp(rewriter, op, newResults); // Bufferize loop body. - return comprehensive_bufferize::bufferize(newTiledLoopOp.getBody(), state); + return comprehensive_bufferize::bufferize(rewriter, + newTiledLoopOp.getBody(), state); } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -219,6 +219,7 @@ /// originate from an op with an Alloc effect, they could be hoisted in the /// future. static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, + RewriterBase &rewriter, BufferizationState &state) { ModuleBufferizationState &moduleState = getModuleBufferizationState(state); @@ -277,7 +278,8 @@ continue; // Cast values at the call site if necessary. - returnValues.push_back(getNonCastedValue(state.lookupBuffer(returnVal))); + returnValues.push_back( + getNonCastedValue(state.lookupBuffer(rewriter, returnVal))); } // 2. Rewrite the terminator without the inPlace bufferizable values. @@ -510,7 +512,7 @@ /// In a first approximation, all the function arguments of a FuncOp are /// marked inplaceable. For now, it is the responsibility of the `callOp` /// bufferization to allow FuncOp that are inplaceable to write inPlace. - LogicalResult bufferize(Operation *op, OpBuilder &b, + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); @@ -552,13 +554,13 @@ moduleState .equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()]; Value oldRes = callOp->getResult(returnOperand.getOperandNumber()); - Value buffer = state.lookupBuffer(callOp->getOperand(idx)); + Value buffer = state.lookupBuffer(rewriter, callOp->getOperand(idx)); // Add a ToTensorOp to kill all uses of the CallOp return. // Replace all uses of the CallOp results so we can erase the CallOp. // This ToTensorOp must fold/DCE away or bufferization should be // considered failed. - Value toTensorOp = - b.create(callOp.getLoc(), buffer); + Value toTensorOp = rewriter.create( + callOp.getLoc(), buffer); oldRes.replaceAllUsesWith(toTensorOp); continue; } @@ -588,7 +590,7 @@ // Tensor operands are guaranteed to have been buferized. int64_t idx = opOperand.getOperandNumber(); - Value buffer = state.lookupBuffer(tensorOperand); + Value buffer = state.lookupBuffer(rewriter, tensorOperand); // Caller / callee type mistmatch is handled with a CastOp. auto memRefType = bufferizedFuncType.getInput(idx); @@ -598,16 +600,16 @@ // that will either canonicalize away or fail compilation until we can do // something better. if (buffer.getType() != memRefType) { - Value castBuffer = - b.create(callOp.getLoc(), memRefType, buffer); + Value castBuffer = rewriter.create(callOp.getLoc(), + memRefType, buffer); buffer = castBuffer; } newOperands.push_back(buffer); } // 4. Create the new CallOp. - Operation *newCallOp = b.create(callOp.getLoc(), funcOp.sym_name(), - resultTypes, newOperands); + Operation *newCallOp = rewriter.create( + callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands); newCallOp->setAttrs(callOp->getAttrs()); // 5. Delete the op at the end of bufferization. @@ -635,7 +637,7 @@ return OpResult(); } - LogicalResult bufferize(Operation *op, OpBuilder &b, + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto returnOp = cast(op); assert(isa(returnOp->getParentOp()) && @@ -645,9 +647,9 @@ auto tensorType = operand.get().getType().dyn_cast(); if (!tensorType) continue; - Value v = state.lookupBuffer(operand.get()); - Value returnTensor = b.create( - returnOp.getLoc(), v); + Value v = state.lookupBuffer(rewriter, operand.get()); + Value returnTensor = + rewriter.create(returnOp.getLoc(), v); operand.set(returnTensor); } return success(); @@ -656,12 +658,12 @@ struct FuncOpInterface : public BufferizableOpInterface::ExternalModel { - LogicalResult bufferize(Operation *op, OpBuilder &b, + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto funcOp = cast(op); // Bufferize function body. - return comprehensive_bufferize::bufferize(&funcOp.body(), state); + return comprehensive_bufferize::bufferize(rewriter, &funcOp.body(), state); } /// Return `true` if the given function argument is writable. @@ -726,7 +728,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( ModuleOp moduleOp, std::unique_ptr options) { IRRewriter rewriter(moduleOp.getContext()); - BufferizationState state(moduleOp, *options, rewriter); + BufferizationState state(moduleOp, *options); ModuleBufferizationState &moduleState = getModuleBufferizationState(state); BufferizationAliasInfo &aliasInfo = state.aliasInfo; @@ -766,7 +768,7 @@ for (FuncOp funcOp : moduleState.orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. - if (failed(bufferizeFuncOpBoundary(funcOp, state))) + if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, state))) return failure(); if (!options->allowReturnMemref && diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -70,8 +70,8 @@ if (hasTensorReturnType) return op->emitError( "scf.execute_region with tensor result not supported"); - return comprehensive_bufferize::bufferize(&executeRegionOp.getRegion(), - state); + return comprehensive_bufferize::bufferize( + rewriter, &executeRegionOp.getRegion(), state); } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -194,12 +194,14 @@ } // Replace op results. - state.replaceOp(op, newIfOp->getResults()); + state.replaceOp(rewriter, op, newIfOp->getResults()); // Bufferize then/else blocks. - if (failed(comprehensive_bufferize::bufferize(newIfOp.thenBlock(), state))) + if (failed(comprehensive_bufferize::bufferize(rewriter, newIfOp.thenBlock(), + state))) return failure(); - if (failed(comprehensive_bufferize::bufferize(newIfOp.elseBlock(), state))) + if (failed(comprehensive_bufferize::bufferize(rewriter, newIfOp.elseBlock(), + state))) return failure(); return success(); @@ -299,7 +301,7 @@ // Construct a new scf.for op with memref instead of tensor values. SmallVector initArgs = convert(forOp.getInitArgs(), [&](Value val, int64_t index) { - return state.getResultBuffer(forOp->getOpResult(index)); + return state.getResultBuffer(rewriter, forOp->getOpResult(index)); }); auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), @@ -333,10 +335,10 @@ yieldOp.getResultsMutable().assign(yieldValues); // Replace loop results. - state.replaceOp(op, newForOp->getResults()); + state.replaceOp(rewriter, op, newForOp->getResults()); // Bufferize loop body. - if (failed(comprehensive_bufferize::bufferize(loopBody, state))) + if (failed(comprehensive_bufferize::bufferize(rewriter, loopBody, state))) return failure(); return success(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -65,7 +65,7 @@ BufferizationState &state) const { auto castOp = cast(op); - Value resultBuffer = state.getResultBuffer(castOp->getResult(0)); + Value resultBuffer = state.getResultBuffer(rewriter, castOp->getResult(0)); if (!resultBuffer) return failure(); Type sourceType = resultBuffer.getType(); @@ -111,7 +111,7 @@ auto dimOp = cast(op); if (!dimOp.source().getType().isa()) return dimOp.emitError("unranked tensor not supported"); - Value v = state.lookupBuffer(dimOp.source()); + Value v = state.lookupBuffer(rewriter, dimOp.source()); state.replaceOpWithNewOp(rewriter, op, v, dimOp.index()); return success(); } @@ -147,7 +147,7 @@ BufferizationState &state) const { auto extractSliceOp = cast(op); Location loc = extractSliceOp.getLoc(); - Value srcMemref = state.lookupBuffer(extractSliceOp.source()); + Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source()); auto srcMemrefType = srcMemref.getType().cast(); auto dstTensorType = extractSliceOp.result().getType().cast(); @@ -178,7 +178,7 @@ subView = alloc; } - state.replaceOp(op, subView); + state.replaceOp(rewriter, op, subView); return success(); } }; @@ -204,7 +204,7 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto extractOp = cast(op); - Value srcMemref = state.lookupBuffer(extractOp.tensor()); + Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor()); state.replaceOpWithNewOp(rewriter, op, srcMemref, extractOp.indices()); return success(); @@ -241,10 +241,11 @@ BufferizationState &state) const { auto insertOp = cast(op); Location loc = insertOp.getLoc(); - Value destMemref = state.getResultBuffer(insertOp->getOpResult(0)); + Value destMemref = + state.getResultBuffer(rewriter, insertOp->getOpResult(0)); rewriter.create(loc, insertOp.scalar(), destMemref, insertOp.indices()); - state.replaceOp(op, destMemref); + state.replaceOp(rewriter, op, destMemref); return success(); } @@ -421,7 +422,8 @@ TensorBufferizationState &tensorState = getTensorBufferizationState(state); // When bufferizing out-of-place, `getResultBuffer` allocates. - Value dstMemref = state.getResultBuffer(insertSliceOp->getResult(0)); + Value dstMemref = + state.getResultBuffer(rewriter, insertSliceOp->getResult(0)); if (!dstMemref) return failure(); @@ -440,11 +442,11 @@ loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); // Copy tensor. - Value srcMemref = state.lookupBuffer(insertSliceOp.source()); + Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source()); state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView); } - state.replaceOp(op, dstMemref); + state.replaceOp(rewriter, op, dstMemref); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -46,12 +46,12 @@ "only tensor types expected"); // TransferReadOp always reads from the bufferized op.source(). - Value buffer = state.lookupBuffer(readOp.source()); + Value buffer = state.lookupBuffer(rewriter, readOp.source()); Value read = rewriter.create( readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(), readOp.permutation_map(), readOp.padding(), readOp.mask(), readOp.in_boundsAttr()); - state.replaceOp(op, read); + state.replaceOp(rewriter, op, read); return success(); } }; @@ -95,13 +95,13 @@ // Create a new transfer_write on buffer that doesn't have a return value. // Leave the previous transfer_write to dead code as it still has uses at // this point. - Value resultBuffer = state.getResultBuffer(op->getResult(0)); + Value resultBuffer = state.getResultBuffer(rewriter, op->getResult(0)); if (!resultBuffer) return failure(); rewriter.create( writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); - state.replaceOp(op, resultBuffer); + state.replaceOp(rewriter, op, resultBuffer); return success(); }