diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -295,7 +295,7 @@ /// the results of the analysis. struct BufferizationState { BufferizationState(Operation *op, const BufferizationOptions &options) - : aliasInfo(op), options(options) {} + : aliasInfo(op), options(options), builder(op->getContext()) {} // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; @@ -326,6 +326,11 @@ /// Return `true` if the given value is mapped. bool isMapped(Value value) const; + /// Return the result buffer (memref) for a given OpResult (tensor). Allocate + /// a new buffer and copy over data from the existing buffer if out-of-place + /// bufferization is necessary. + Value getResultBuffer(OpResult result); + /// Mark `op` as obsolete, so that it is deleted after bufferization. void markOpObsolete(Operation *op); @@ -355,12 +360,10 @@ /// A reference to current bufferization options. const BufferizationOptions &options; -}; -/// Return the result buffer (memref) for a given OpResult (tensor). Allocate -/// a new buffer and copy over data from the existing buffer if out-of-place -/// bufferization is necessary. -Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state); + /// The OpBuilder used during bufferization. + OpBuilder builder; +}; /// Bufferize all ops in the given region. LogicalResult bufferize(Region *region, BufferizationState &state); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -31,21 +31,14 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto constantOp = cast(op); - if (!constantOp.getResult().getType().isa()) - return success(); assert(constantOp.getType().dyn_cast() && "not a constant ranked tensor"); auto moduleOp = constantOp->getParentOfType(); - if (!moduleOp) { + if (!moduleOp) return constantOp.emitError( "cannot bufferize constants not within builtin.module op"); - } - GlobalCreator globalCreator(moduleOp); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(constantOp); + GlobalCreator globalCreator(moduleOp); auto globalMemref = globalCreator.getGlobalFor(constantOp); Value memref = b.create( constantOp.getLoc(), globalMemref.type(), globalMemref.getName()); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -342,15 +342,15 @@ /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. -Value mlir::linalg::comprehensive_bufferize::getResultBuffer( - OpBuilder &b, OpResult result, BufferizationState &state) { - OpBuilder::InsertionGuard guard(b); +Value mlir::linalg::comprehensive_bufferize::BufferizationState:: + getResultBuffer(OpResult result) { + OpBuilder::InsertionGuard guard(builder); Operation *op = result.getOwner(); SmallVector aliasingOperands = getAliasingOpOperand(result); assert(!aliasingOperands.empty() && "could not get aliasing OpOperand"); OpOperand *opOperand = aliasingOperands.front(); Value operand = opOperand->get(); - Value operandBuffer = state.lookupBuffer(operand); + Value operandBuffer = lookupBuffer(operand); // Make sure that all OpOperands are the same buffer. If this is not the case, // we would have to materialize a memref value. // TODO: Should be looking for checking for "equivalent buffers" instead of @@ -358,14 +358,14 @@ // set up yet. if (aliasingOperands.size() > 1 && !llvm::all_of(aliasingOperands, [&](OpOperand *o) { - return state.lookupBuffer(o->get()) == operandBuffer; + return lookupBuffer(o->get()) == operandBuffer; })) { op->emitError("result buffer is ambiguous"); return Value(); } // If bufferizing out-of-place, allocate a new buffer. - if (!state.aliasInfo.isInPlace(result)) { + if (!aliasInfo.isInPlace(result)) { // Ops with multiple aliasing operands can currently not bufferize // out-of-place. assert( @@ -374,9 +374,9 @@ Location loc = op->getLoc(); // Move insertion point right after `operandBuffer`. That is where the // allocation should be inserted (in the absence of allocation hoisting). - setInsertionPointAfter(b, operandBuffer); + setInsertionPointAfter(builder, operandBuffer); // Allocate the result buffer. - Value resultBuffer = state.createAllocDeallocFn(b, loc, operandBuffer); + Value resultBuffer = createAllocDeallocFn(builder, loc, operandBuffer); bool skipCopy = false; // Do not copy if the last preceding write of `operand` is an op that does // not write (skipping ops that merely create aliases). E.g., InitTensorOp. @@ -397,9 +397,9 @@ skipCopy = true; if (!skipCopy) { // The copy happens right before the op that is bufferized. - b.setInsertionPoint(op); - state.options.allocationFns->memCpyFn(b, loc, operandBuffer, - resultBuffer); + builder.setInsertionPoint(op); + options.allocationFns->memCpyFn(builder, loc, operandBuffer, + resultBuffer); } return resultBuffer; } @@ -435,7 +435,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(Operation *op, BufferizationState &state) { - OpBuilder b(op->getContext()); + OpBuilder &b = state.builder; // Check if op has tensor results or operands. auto isaTensor = [](Type t) { return t.isa(); }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -38,7 +38,7 @@ OpResult opResult = cast(op.getOperation()) .getAliasingOpResult(*opOperand); assert(opResult && "could not find correspond OpResult"); - Value resultBuffer = getResultBuffer(b, opResult, state); + Value resultBuffer = state.getResultBuffer(opResult); if (!resultBuffer) return failure(); resultBuffers.push_back(resultBuffer); @@ -164,10 +164,6 @@ if (initTensorOp->getUses().empty()) return success(); - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(initTensorOp); - Value alloc = state.createAllocDeallocFn(b, initTensorOp->getLoc(), initTensorOp.result()); state.mapBuffer(initTensorOp.result(), alloc); @@ -266,7 +262,7 @@ const OpResult &opResult = tiledLoopOp->getResult(resultIndex); OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); - Value resultBuffer = getResultBuffer(b, opResult, state); + Value resultBuffer = state.getResultBuffer(opResult); if (!resultBuffer) return failure(); @@ -358,11 +354,6 @@ BufferizationState &state) const { auto yieldOp = cast(op); - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - // Cannot create IR past a yieldOp. - b.setInsertionPoint(yieldOp); - // No tensors -> success. if (!llvm::any_of(yieldOp.getOperandTypes(), [](Type t) { return t.isa(); })) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -477,10 +477,6 @@ "expected Callop to a FuncOp"); ModuleBufferizationState &moduleState = getModuleBufferizationState(state); - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(callOp); - // 1. Filter return types: // - if the callee is bodiless / external, we cannot inspect it and we // cannot assume anything. We can just assert that it does not return a @@ -604,14 +600,9 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto returnOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - // Cannot insert after returnOp. - b.setInsertionPoint(returnOp); - assert(isa(returnOp->getParentOp()) && "only support FuncOp parent for ReturnOp"); + for (OpOperand &operand : returnOp->getOpOperands()) { auto tensorType = operand.get().getType().dyn_cast(); if (!tensorType) @@ -631,9 +622,6 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto funcOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&funcOp.body().front()); // Create BufferCastOps for function args. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -145,7 +145,7 @@ assert(opResult.getType().isa() && "unsupported unranked tensor"); - Value resultBuffer = getResultBuffer(b, opResult, state); + Value resultBuffer = state.getResultBuffer(opResult); if (!resultBuffer) return failure(); @@ -255,7 +255,7 @@ // Construct a new scf.for op with memref instead of tensor values. SmallVector initArgs = convert(forOp.initArgs(), [&](Value val, int64_t index) { - return getResultBuffer(rewriter, forOp->getOpResult(index), state); + return state.getResultBuffer(forOp->getOpResult(index)); // return state.lookupBuffer(val); }); auto newForOp = diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -66,11 +66,7 @@ BufferizationState &state) const { auto castOp = cast(op); - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(castOp); - - Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state); + Value resultBuffer = state.getResultBuffer(castOp->getResult(0)); if (!resultBuffer) return failure(); Type sourceType = resultBuffer.getType(); @@ -112,11 +108,6 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto dimOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(dimOp); - if (dimOp.source().getType().isa()) { Value v = state.lookupBuffer(dimOp.source()); dimOp.result().replaceAllUsesWith( @@ -156,11 +147,6 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto extractSliceOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(extractSliceOp); - Location loc = extractSliceOp.getLoc(); Value srcMemref = state.lookupBuffer(extractSliceOp.source()); auto srcMemrefType = srcMemref.getType().cast(); @@ -218,11 +204,6 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto extractOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(extractOp); - Location loc = extractOp.getLoc(); Value srcMemref = state.lookupBuffer(extractOp.tensor()); Value l = b.create(loc, srcMemref, extractOp.indices()); @@ -256,13 +237,8 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto insertOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(insertOp); - Location loc = insertOp.getLoc(); - Value destMemref = getResultBuffer(b, insertOp->getOpResult(0), state); + Value destMemref = state.getResultBuffer(insertOp->getOpResult(0)); b.create(loc, insertOp.scalar(), destMemref, insertOp.indices()); state.mapBuffer(insertOp, destMemref); @@ -436,15 +412,11 @@ // catastrophically bad scheduling decision. // TODO: be very loud about it or even consider failing the pass. auto insertSliceOp = cast(op); - TensorBufferizationState &tensorState = getTensorBufferizationState(state); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(insertSliceOp); Location loc = insertSliceOp.getLoc(); + TensorBufferizationState &tensorState = getTensorBufferizationState(state); // When bufferizing out-of-place, `getResultBuffer` allocates. - Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state); + Value dstMemref = state.getResultBuffer(insertSliceOp->getResult(0)); if (!dstMemref) return failure(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -39,14 +39,10 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto transferReadOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - - // TransferReadOp always reads from the bufferized op.source(). assert(transferReadOp.getShapedType().isa() && "only tensor types expected"); + + // TransferReadOp always reads from the bufferized op.source(). Value v = state.lookupBuffer(transferReadOp.source()); transferReadOp.sourceMutable().assign(v); return success(); @@ -87,17 +83,13 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto writeOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); + assert(writeOp.getShapedType().isa() && + "only tensor types expected"); // Create a new transfer_write on buffer that doesn't have a return value. // Leave the previous transfer_write to dead code as it still has uses at // this point. - assert(writeOp.getShapedType().isa() && - "only tensor types expected"); - Value resultBuffer = getResultBuffer(b, op->getResult(0), state); + Value resultBuffer = state.getResultBuffer(op->getResult(0)); if (!resultBuffer) return failure(); b.create(