diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -297,7 +297,8 @@ /// the results of the analysis. struct BufferizationState { BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options) - : aliasInfo(moduleOp), options(options) {} + : aliasInfo(moduleOp), options(options), + builder(moduleOp->getContext()) {} // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; @@ -321,6 +322,11 @@ /// Return `true` if the given value is mapped. bool isMapped(Value value) const; + /// Return the result buffer (memref) for a given OpResult (tensor). Allocate + /// a new buffer and copy over data from the existing buffer if out-of-place + /// bufferization is necessary. + Value getResultBuffer(OpResult result); + /// Mark `op` as obsolete, so that it is deleted after bufferization. void markOpObsolete(Operation *op); @@ -349,12 +355,10 @@ /// A reference to current bufferization options. const BufferizationOptions &options; -}; -/// Return the result buffer (memref) for a given OpResult (tensor). Allocate -/// a new buffer and copy over data from the existing buffer if out-of-place -/// bufferization is necessary. -Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state); + /// The OpBuilder used during bufferization. + OpBuilder builder; +}; /// Bufferize all ops in the given region. LogicalResult bufferize(Region *region, BufferizationState &state); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -26,21 +26,14 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto constantOp = cast(op); - if (!constantOp.getResult().getType().isa()) - return success(); assert(constantOp.getType().dyn_cast() && "not a constant ranked tensor"); auto moduleOp = constantOp->getParentOfType(); - if (!moduleOp) { + if (!moduleOp) return constantOp.emitError( "cannot bufferize constants not within builtin.module op"); - } - GlobalCreator globalCreator(moduleOp); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(constantOp); + GlobalCreator globalCreator(moduleOp); auto globalMemref = globalCreator.getGlobalFor(constantOp); Value memref = b.create( constantOp.getLoc(), globalMemref.type(), globalMemref.getName()); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -372,15 +372,15 @@ /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. -Value mlir::linalg::comprehensive_bufferize::getResultBuffer( - OpBuilder &b, OpResult result, BufferizationState &state) { - OpBuilder::InsertionGuard guard(b); +Value mlir::linalg::comprehensive_bufferize::BufferizationState:: + getResultBuffer(OpResult result) { + OpBuilder::InsertionGuard guard(builder); Operation *op = result.getOwner(); SmallVector aliasingOperands = getAliasingOpOperand(result); assert(!aliasingOperands.empty() && "could not get aliasing OpOperand"); OpOperand *opOperand = aliasingOperands.front(); Value operand = opOperand->get(); - Value operandBuffer = state.lookupBuffer(operand); + Value operandBuffer = lookupBuffer(operand); // Make sure that all OpOperands are the same buffer. If this is not the case, // we would have to materialize a memref value. // TODO: Should be looking for checking for "equivalent buffers" instead of @@ -388,14 +388,14 @@ // set up yet. if (aliasingOperands.size() > 1 && !llvm::all_of(aliasingOperands, [&](OpOperand *o) { - return state.lookupBuffer(o->get()) == operandBuffer; + return lookupBuffer(o->get()) == operandBuffer; })) { op->emitError("result buffer is ambiguous"); return Value(); } // If bufferizing out-of-place, allocate a new buffer. - if (!state.aliasInfo.isInPlace(result)) { + if (!aliasInfo.isInPlace(result)) { // Ops with multiple aliasing operands can currently not bufferize // out-of-place. assert( @@ -404,9 +404,9 @@ Location loc = op->getLoc(); // Move insertion point right after `operandBuffer`. That is where the // allocation should be inserted (in the absence of allocation hoisting). - setInsertionPointAfter(b, operandBuffer); + setInsertionPointAfter(builder, operandBuffer); // Allocate the result buffer. - Value resultBuffer = state.createAllocDeallocFn(b, loc, operandBuffer); + Value resultBuffer = createAllocDeallocFn(builder, loc, operandBuffer); bool skipCopy = false; // Do not copy if the last preceding write of `operand` is an op that does // not write (skipping ops that merely create aliases). E.g., InitTensorOp. @@ -427,9 +427,9 @@ skipCopy = true; if (!skipCopy) { // The copy happens right before the op that is bufferized. - b.setInsertionPoint(op); - state.options.allocationFns->memCpyFn(b, loc, operandBuffer, - resultBuffer); + builder.setInsertionPoint(op); + options.allocationFns->memCpyFn(builder, loc, operandBuffer, + resultBuffer); } return resultBuffer; } @@ -459,7 +459,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(Operation *op, BufferizationState &state) { - OpBuilder b(op->getContext()); + OpBuilder &b = state.builder; // Check if op has tensor results or operands. auto isaTensor = [](Type t) { return t.isa(); }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -38,7 +38,7 @@ OpResult opResult = cast(op.getOperation()) .getAliasingOpResult(*opOperand); assert(opResult && "could not find correspond OpResult"); - Value resultBuffer = getResultBuffer(b, opResult, state); + Value resultBuffer = state.getResultBuffer(opResult); if (!resultBuffer) return failure(); resultBuffers.push_back(resultBuffer); @@ -158,10 +158,6 @@ if (initTensorOp->getUses().empty()) return success(); - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(initTensorOp); - Value alloc = state.createAllocDeallocFn(b, initTensorOp->getLoc(), initTensorOp.result()); state.mapBuffer(initTensorOp.result(), alloc); @@ -250,7 +246,7 @@ const OpResult &opResult = tiledLoopOp->getResult(resultIndex); OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); - Value resultBuffer = getResultBuffer(b, opResult, state); + Value resultBuffer = state.getResultBuffer(opResult); if (!resultBuffer) return failure(); @@ -350,11 +346,6 @@ BufferizationState &state) const { auto yieldOp = cast(op); - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - // Cannot create IR past a yieldOp. - b.setInsertionPoint(yieldOp); - // No tensors -> success. if (!llvm::any_of(yieldOp.getOperandTypes(), [](Type t) { return t.isa(); })) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -469,10 +469,6 @@ "expected Callop to a FuncOp"); ModuleBufferizationState &moduleState = getModuleBufferizationState(state); - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(callOp); - // 1. Filter return types: // - if the callee is bodiless / external, we cannot inspect it and we // cannot assume anything. We can just assert that it does not return a @@ -600,14 +596,9 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto returnOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - // Cannot insert after returnOp. - b.setInsertionPoint(returnOp); - assert(isa(returnOp->getParentOp()) && "only support FuncOp parent for ReturnOp"); + for (OpOperand &operand : returnOp->getOpOperands()) { auto tensorType = operand.get().getType().dyn_cast(); if (!tensorType) @@ -628,9 +619,6 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto funcOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&funcOp.body().front()); // Create BufferCastOps for function args. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -138,7 +138,7 @@ assert(opResult.getType().isa() && "unsupported unranked tensor"); - Value resultBuffer = getResultBuffer(b, opResult, state); + Value resultBuffer = state.getResultBuffer(opResult); if (!resultBuffer) return failure(); @@ -204,7 +204,7 @@ "unsupported unranked tensor"); // TODO: More general: Matching bbArg does not bufferize to a read. - Value resultBuffer = getResultBuffer(b, opResult, state); + Value resultBuffer = state.getResultBuffer(opResult); if (!resultBuffer) return failure(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -60,11 +60,7 @@ BufferizationState &state) const { auto castOp = cast(op); - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(castOp); - - Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state); + Value resultBuffer = state.getResultBuffer(castOp->getResult(0)); if (!resultBuffer) return failure(); Type sourceType = resultBuffer.getType(); @@ -107,11 +103,6 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto dimOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(dimOp); - if (dimOp.source().getType().isa()) { Value v = state.lookupBuffer(dimOp.source()); dimOp.result().replaceAllUsesWith( @@ -145,11 +136,6 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto extractSliceOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(extractSliceOp); - Location loc = extractSliceOp.getLoc(); Value srcMemref = state.lookupBuffer(extractSliceOp.source()); auto srcMemrefType = srcMemref.getType().cast(); @@ -207,11 +193,6 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto extractOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(extractOp); - Location loc = extractOp.getLoc(); Value srcMemref = state.lookupBuffer(extractOp.tensor()); Value l = b.create(loc, srcMemref, extractOp.indices()); @@ -245,13 +226,8 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto insertOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(insertOp); - Location loc = insertOp.getLoc(); - Value destMemref = getResultBuffer(b, insertOp->getOpResult(0), state); + Value destMemref = state.getResultBuffer(insertOp->getOpResult(0)); b.create(loc, insertOp.scalar(), destMemref, insertOp.indices()); state.mapBuffer(insertOp, destMemref); @@ -419,15 +395,11 @@ // catastrophically bad scheduling decision. // TODO: be very loud about it or even consider failing the pass. auto insertSliceOp = cast(op); - TensorBufferizationState &tensorState = getTensorBufferizationState(state); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(insertSliceOp); Location loc = insertSliceOp.getLoc(); + TensorBufferizationState &tensorState = getTensorBufferizationState(state); // When bufferizing out-of-place, `getResultBuffer` allocates. - Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state); + Value dstMemref = state.getResultBuffer(insertSliceOp->getResult(0)); if (!dstMemref) return failure(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -39,14 +39,10 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto transferReadOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - - // TransferReadOp always reads from the bufferized op.source(). assert(transferReadOp.getShapedType().isa() && "only tensor types expected"); + + // TransferReadOp always reads from the bufferized op.source(). Value v = state.lookupBuffer(transferReadOp.source()); transferReadOp.sourceMutable().assign(v); return success(); @@ -81,17 +77,13 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto writeOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); + assert(writeOp.getShapedType().isa() && + "only tensor types expected"); // Create a new transfer_write on buffer that doesn't have a return value. // Leave the previous transfer_write to dead code as it still has uses at // this point. - assert(writeOp.getShapedType().isa() && - "only tensor types expected"); - Value resultBuffer = getResultBuffer(b, op->getResult(0), state); + Value resultBuffer = state.getResultBuffer(op->getResult(0)); if (!resultBuffer) return failure(); b.create(