diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -38,34 +38,6 @@ class BufferizationState; struct PostAnalysisStep; -/// Callback functions that are used to allocate/deallocate/copy memory buffers. -/// Comprehensive Bufferize provides default implementations of these functions. -// TODO: Could be replaced with a "bufferization strategy" object with virtual -// functions in the future. -struct AllocationCallbacks { - using AllocationFn = std::function( - OpBuilder &, Location, MemRefType, ArrayRef)>; - using DeallocationFn = std::function; - using MemCpyFn = std::function; - - AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn, - MemCpyFn copyFn) - : allocationFn(std::move(allocFn)), deallocationFn(std::move(deallocFn)), - memCpyFn(std::move(copyFn)) {} - - /// A function that allocates memory. - AllocationFn allocationFn; - - /// A function that deallocated memory. Must be allocated by `allocationFn`. - DeallocationFn deallocationFn; - - /// A function that copies memory between two allocations. - MemCpyFn memCpyFn; -}; - -/// Return default allocation callbacks. -std::unique_ptr defaultAllocationCallbacks(); - /// PostAnalysisSteps can be registered with `BufferizationOptions` and are /// executed after the analysis, but before bufferization. They can be used to /// implement custom dialect-specific optimizations. @@ -84,6 +56,13 @@ /// Options for ComprehensiveBufferize. struct BufferizationOptions { + using AllocationFn = std::function( + OpBuilder &, Location, MemRefType, ArrayRef)>; + using DeallocationFn = + std::function; + using MemCpyFn = + std::function; + BufferizationOptions(); // BufferizationOptions cannot be copied. @@ -126,7 +105,9 @@ BufferizableOpInterface dynCastBufferizableOp(Value value) const; /// Helper functions for allocation, deallocation, memory copying. - std::unique_ptr allocationFns; + Optional allocationFn; + Optional deallocationFn; + Optional memCpyFn; /// Specifies whether returning newly allocated memrefs should be allowed. /// Otherwise, a pass failure is triggered. @@ -362,24 +343,6 @@ /// is returned regardless of whether it is a memory write or not. SetVector findLastPrecedingWrite(Value value) const; - /// Creates a memref allocation. - FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape) const; - - /// Creates a memref allocation for the given shaped value. This function may - /// perform additional optimizations such as buffer allocation hoisting. If - /// `createDealloc`, a deallocation op is inserted at the point where the - /// allocation goes out of scope. - FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue, - bool deallocMemref) const; - - /// Creates a memref deallocation. The given memref buffer must have been - /// allocated using `createAlloc`. - void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer) const; - - /// Creates a memcpy between two given buffers. - void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const; - /// Return `true` if the given OpResult has been decided to bufferize inplace. bool isInPlace(OpOperand &opOperand) const; @@ -458,6 +421,28 @@ MemRefType getDynamicMemRefType(RankedTensorType tensorType, unsigned addressSpace = 0); +/// Creates a memref allocation. +FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, + ArrayRef dynShape, + const BufferizationOptions &options); + +/// Creates a memref allocation for the given shaped value. This function may +/// perform additional optimizations such as buffer allocation hoisting. If +/// `createDealloc`, a deallocation op is inserted at the point where the +/// allocation goes out of scope. +FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue, + bool deallocMemref, + const BufferizationOptions &options); + +/// Creates a memref deallocation. The given memref buffer must have been +/// allocated using `createAlloc`. +LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, + const BufferizationOptions &options); + +/// Creates a memcpy between two given buffers. +LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to, + const BufferizationOptions &options); + } // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -39,40 +39,8 @@ // BufferizationOptions //===----------------------------------------------------------------------===// -/// Default allocation function that is used by the comprehensive bufferization -/// pass. The default currently creates a ranked memref using `memref.alloc`. -static FailureOr defaultAllocationFn(OpBuilder &b, Location loc, - MemRefType type, - ArrayRef dynShape) { - Value allocated = b.create( - loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); - return allocated; -} - -/// Default deallocation function that is used by the comprehensive -/// bufferization pass. It expects to recieve back the value called from the -/// `defaultAllocationFn`. -static void defaultDeallocationFn(OpBuilder &b, Location loc, - Value allocatedBuffer) { - b.create(loc, allocatedBuffer); -} - -/// Default memory copy function that is used by the comprehensive bufferization -/// pass. Creates a `memref.copy` op. -static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) { - b.create(loc, from, to); -} - -std::unique_ptr -mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() { - return std::make_unique( - defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn); -} - -// Default constructor for BufferizationOptions that sets all allocation -// callbacks to their default functions. -BufferizationOptions::BufferizationOptions() - : allocationFns(defaultAllocationCallbacks()) {} +// Default constructor for BufferizationOptions. +BufferizationOptions::BufferizationOptions() {} BufferizableOpInterface mlir::linalg::comprehensive_bufferize:: BufferizationOptions::dynCastBufferizableOp(Operation *op) const { @@ -393,8 +361,8 @@ // allocation should be inserted (in the absence of allocation hoisting). setInsertionPointAfter(rewriter, operandBuffer); // Allocate the result buffer. - FailureOr resultBuffer = - createAlloc(rewriter, loc, operandBuffer, options.createDeallocs); + FailureOr resultBuffer = createAlloc(rewriter, loc, operandBuffer, + options.createDeallocs, options); if (failed(resultBuffer)) return failure(); // Do not copy if the last preceding writes of `operand` are ops that do @@ -425,7 +393,9 @@ // The copy happens right before the op that is bufferized. rewriter.setInsertionPoint(op); } - createMemCpy(rewriter, loc, operandBuffer, *resultBuffer); + if (failed( + createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options))) + return failure(); return resultBuffer; } @@ -545,9 +515,9 @@ /// Create an AllocOp/DeallocOp pair, where the AllocOp is after /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. -FailureOr -mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( - OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) const { +FailureOr mlir::linalg::comprehensive_bufferize::createAlloc( + OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref, + const BufferizationOptions &options) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -558,7 +528,8 @@ // Note: getAllocationTypeAndShape also sets the insertion point. MemRefType allocMemRefType = getAllocationTypeAndShape(b, loc, shapedValue, dynShape); - FailureOr allocated = createAlloc(b, loc, allocMemRefType, dynShape); + FailureOr allocated = + createAlloc(b, loc, allocMemRefType, dynShape, options); if (failed(allocated)) return failure(); Value casted = allocated.getValue(); @@ -572,30 +543,47 @@ if (deallocMemref) { // 2. Create memory deallocation. b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); - createDealloc(b, loc, allocated.getValue()); + if (failed(createDealloc(b, loc, allocated.getValue(), options))) + return failure(); } return casted; } /// Create a memref allocation. -FailureOr -mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( - OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape) const { - return options.allocationFns->allocationFn(b, loc, type, dynShape); +FailureOr mlir::linalg::comprehensive_bufferize::createAlloc( + OpBuilder &b, Location loc, MemRefType type, ArrayRef dynShape, + const BufferizationOptions &options) { + if (options.allocationFn) + return (*options.allocationFn)(b, loc, type, dynShape); + + // Default bufferallocation via AllocOp. + Value allocated = b.create( + loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); + return allocated; } /// Create a memref deallocation. -void mlir::linalg::comprehensive_bufferize::BufferizationState::createDealloc( - OpBuilder &b, Location loc, Value allocatedBuffer) const { - return options.allocationFns->deallocationFn(b, loc, allocatedBuffer); +LogicalResult mlir::linalg::comprehensive_bufferize::createDealloc( + OpBuilder &b, Location loc, Value allocatedBuffer, + const BufferizationOptions &options) { + if (options.deallocationFn) + return (*options.deallocationFn)(b, loc, allocatedBuffer); + + // Default buffer deallocation via DeallocOp. + b.create(loc, allocatedBuffer); + return success(); } /// Create a memory copy between two memref buffers. -void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy( - OpBuilder &b, Location loc, Value from, Value to) const { - return options.allocationFns->memCpyFn(b, loc, from, to); +LogicalResult mlir::linalg::comprehensive_bufferize::createMemCpy( + OpBuilder &b, Location loc, Value from, Value to, + const BufferizationOptions &options) { + if (options.memCpyFn) + return (*options.memCpyFn)(b, loc, from, to); + + b.create(loc, from, to); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -221,9 +221,9 @@ if (initTensorOp->getUses().empty()) return success(); - FailureOr alloc = state.createAlloc( - rewriter, initTensorOp->getLoc(), initTensorOp.result(), - state.getOptions().createDeallocs); + FailureOr alloc = + createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(), + state.getOptions().createDeallocs, state.getOptions()); if (failed(alloc)) return failure(); replaceOpWithBufferizedValues(rewriter, op, *alloc); @@ -367,7 +367,9 @@ Value output = std::get<1>(it); Value toMemrefOp = rewriter.create( newTerminator.getLoc(), output.getType(), std::get<0>(it)); - state.createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp, output); + if (failed(createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp, + output, state.getOptions()))) + return failure(); } // Erase old terminator. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -158,8 +158,8 @@ Value alloc; if (!inplace) { FailureOr allocOrFailure = - state.createAlloc(rewriter, loc, extractSliceOp.result(), - state.getOptions().createDeallocs); + createAlloc(rewriter, loc, extractSliceOp.result(), + state.getOptions().createDeallocs, state.getOptions()); if (failed(allocOrFailure)) return failure(); alloc = *allocOrFailure; @@ -191,7 +191,9 @@ if (!inplace) { // Do not copy if the copied data is never read. if (state.isValueRead(extractSliceOp.result())) - state.createMemCpy(rewriter, extractSliceOp.getLoc(), subView, alloc); + if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, + alloc, state.getOptions()))) + return failure(); subView = alloc; } @@ -461,7 +463,9 @@ // tensor.extract_slice, the copy operation will eventually fold away. Value srcMemref = *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); - state.createMemCpy(rewriter, loc, srcMemref, subView); + if (failed(createMemCpy(rewriter, loc, srcMemref, subView, + state.getOptions()))) + return failure(); replaceOpWithBufferizedValues(rewriter, op, *dstMemref); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -77,9 +77,10 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() { auto options = std::make_unique(); if (useAlloca) { - options->allocationFns->allocationFn = allocationFnUsingAlloca; - options->allocationFns->deallocationFn = [](OpBuilder &b, Location loc, - Value v) {}; + options->allocationFn = allocationFnUsingAlloca; + options->deallocationFn = [](OpBuilder &b, Location loc, Value v) { + return success(); + }; } options->allowReturnMemref = allowReturnMemref;