diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -197,24 +197,47 @@ /// is returned regardless of whether it is a memory write or not. Value findLastPrecedingWrite(Value value); -/// Callback functions that are used by the comprehensive bufferization pass to -/// allocate/deallocate memory. The `deallocationFn` is gauranteed to recieve -/// the `Value` returned by the `allocationFn`. +/// Callback functions that are used to allocate/deallocate/copy memory buffers. +/// Comprehensive Bufferize provides default implementations of these functions. +// TODO: Could be replaced with a "bufferization strategy" object with virtual +// functions in the future. struct AllocationCallbacks { using AllocationFn = std::function( OpBuilder &, Location, MemRefType, const SmallVector &)>; using DeallocationFn = std::function; using MemCpyFn = std::function; + using CreateAllocDeallocFn = + std::function; AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn, - MemCpyFn copyFn) - : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {} + MemCpyFn copyFn, CreateAllocDeallocFn allocDeallocFn) + : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn), + createAllocDeallocFn(allocDeallocFn) {} + /// A function that allocates memory. AllocationFn allocationFn; + + /// A function that deallocated memory. Must be allocated by `allocationFn`. DeallocationFn deallocationFn; + + /// A function that copies memory between two allocations. MemCpyFn memCpyFn; + + /// A function that creates an alloc-dealloc pair. This function may perform + /// additional optimizations such as buffer allocation hoisting. This function + /// calls `allocationFn` and `deallocationFn` to create (de)allocations. + CreateAllocDeallocFn createAllocDeallocFn; }; +/// Return the result buffer (memref) for a given OpResult (tensor). Allocate +/// a new buffer and copy over data from the existing buffer if out-of-place +/// bufferization is necessary. +Value getResultBuffer(OpBuilder &b, OpResult result, + const BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks allocationFns); + } // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Operation.h" #include "llvm/Support/Debug.h" @@ -313,3 +314,70 @@ assert(result.size() == 1 && "expected exactly one result"); return result.front(); } + +/// Return the result buffer (memref) for a given OpResult (tensor). Allocate +/// a new buffer and copy over data from the existing buffer if out-of-place +/// bufferization is necessary. +Value mlir::linalg::comprehensive_bufferize::getResultBuffer( + OpBuilder &b, OpResult result, const BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, AllocationCallbacks allocationFns) { + OpBuilder::InsertionGuard guard(b); + Operation *op = result.getOwner(); + SmallVector aliasingOperands = getAliasingOpOperand(result); + assert(!aliasingOperands.empty() && "could not get aliasing OpOperand"); + OpOperand *opOperand = aliasingOperands.front(); + Value operand = opOperand->get(); + Value operandBuffer = bvm.lookupOrNull(operand); + assert(operandBuffer && "operand buffer not found"); + // Make sure that all OpOperands are the same buffer. If this is not the case, + // we would have to materialize a memref value. + // TODO: Should be looking for checking for "equivalent buffers" instead of + // operator== here, but equivalent buffers for scf.if yield values are not + // set up yet. + if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) { + return bvm.lookup(o->get()) == operandBuffer; + })) { + op->emitError("result buffer is ambiguous"); + return Value(); + } + + // If bufferizing out-of-place, allocate a new buffer. + if (!aliasInfo.isInPlace(result)) { + // Ops with multiple aliasing operands can currently not bufferize + // out-of-place. + assert( + aliasingOperands.size() == 1 && + "ops with multiple aliasing OpOperands cannot bufferize out-of-place"); + Location loc = op->getLoc(); + // Allocate the result buffer. + Value resultBuffer = allocationFns.createAllocDeallocFn( + b, loc, operand, aliasInfo, allocationFns); + bool skipCopy = false; + // Do not copy if the last preceding write of `operand` is an op that does + // not write (skipping ops that merely create aliases). E.g., InitTensorOp. + // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA + // use-def chain, it returns that value, regardless of whether it is a + // memory write or not. + Value lastWrite = findLastPrecedingWrite(operand); + if (auto bufferizableOp = + lastWrite.getDefiningOp()) + if (!bufferizableOp.isMemoryWrite(lastWrite.cast())) + skipCopy = true; + // Do not copy if the copied data is never read. + if (!isValueRead(result)) + skipCopy = true; + // Do not copy if this op does not read the data, but writes it. + if (bufferizesToMemoryWrite(*opOperand) && + !bufferizesToMemoryRead(*opOperand)) + skipCopy = true; + if (!skipCopy) { + // Set insertion point now that potential alloc/dealloc are introduced. + b.setInsertionPoint(op); + allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer); + } + return resultBuffer; + } + + // Bufferizing in-place. No need to allocate a new buffer. + return operandBuffer; +} diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -877,74 +877,6 @@ // Bufferization as simple BlockAndValueMapping rewrites. //===----------------------------------------------------------------------===// -/// Return the result buffer (memref) for a given OpResult (tensor). Allocate -/// a new buffer and copy over data from the existing buffer if out-of-place -/// bufferization is necessary. -static Value getResultBuffer(OpBuilder &b, OpResult result, - const BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks allocationFns) { - OpBuilder::InsertionGuard guard(b); - Operation *op = result.getOwner(); - SmallVector aliasingOperands = getAliasingOpOperand(result); - assert(!aliasingOperands.empty() && "could not get aliasing OpOperand"); - OpOperand *opOperand = aliasingOperands.front(); - Value operand = opOperand->get(); - Value operandBuffer = lookup(bvm, operand); - assert(operandBuffer && "operand buffer not found"); - // Make sure that all OpOperands are the same buffer. If this is not the case, - // we would have to materialize a memref value. - // TODO: Should be looking for checking for "equivalent buffers" instead of - // operator== here, but equivalent buffers for scf.if yield values are not - // set up yet. - if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) { - return lookup(bvm, o->get()) == operandBuffer; - })) { - op->emitError("result buffer is ambiguous"); - return Value(); - } - - // If bufferizing out-of-place, allocate a new buffer. - if (!aliasInfo.isInPlace(result)) { - // Ops with multiple aliasing operands can currently not bufferize - // out-of-place. - assert( - aliasingOperands.size() == 1 && - "ops with multiple aliasing OpOperands cannot bufferize out-of-place"); - Location loc = op->getLoc(); - // Allocate the result buffer. - Value resultBuffer = createNewAllocDeallocPairForShapedValue( - b, loc, operand, aliasInfo, allocationFns); - bool skipCopy = false; - // Do not copy if the last preceding write of `operand` is an op that does - // not write (skipping ops that merely create aliases). E.g., InitTensorOp. - // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA - // use-def chain, it returns that value, regardless of whether it is a - // memory write or not. - Value lastWrite = findLastPrecedingWrite(operand); - if (auto bufferizableOp = - lastWrite.getDefiningOp()) - if (!bufferizableOp.isMemoryWrite(lastWrite.cast())) - skipCopy = true; - // Do not copy if the copied data is never read. - if (!isValueRead(result)) - skipCopy = true; - // Do not copy if this op does not read the data, but writes it. - if (bufferizesToMemoryWrite(*opOperand) && - !bufferizesToMemoryRead(*opOperand)) - skipCopy = true; - if (!skipCopy) { - // Set insertion point now that potential alloc/dealloc are introduced. - b.setInsertionPoint(op); - allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer); - } - return resultBuffer; - } - - // Bufferizing in-place. No need to allocate a new buffer. - return operandBuffer; -} - /// In a first approximation, all the function arguments of a FuncOp are marked /// inplaceable. For now, it is the responsibility of the `callOp` bufferization /// to allow FuncOp that are inplaceable to write inPlace. @@ -1906,7 +1838,8 @@ // callbacks to their default functions. BufferizationOptions::BufferizationOptions() : allocationFns(std::make_unique( - defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn)) {} + defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn, + createNewAllocDeallocPairForShapedValue)) {} //===----------------------------------------------------------------------===// // BufferizableOpInterface Implementations