diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -211,16 +211,14 @@ // functions in the future. struct AllocationCallbacks { using AllocationFn = std::function( - OpBuilder &, Location, MemRefType, const SmallVector &)>; + OpBuilder &, Location, MemRefType, ArrayRef)>; using DeallocationFn = std::function; using MemCpyFn = std::function; - using CreateAllocDeallocFn = - std::function; + std::function; AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn, - MemCpyFn copyFn, CreateAllocDeallocFn allocDeallocFn) - : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn), - createAllocDeallocFn(allocDeallocFn) {} + MemCpyFn copyFn) + : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {} /// A function that allocates memory. AllocationFn allocationFn; @@ -230,11 +228,6 @@ /// A function that copies memory between two allocations. MemCpyFn memCpyFn; - - /// A function that creates an alloc-dealloc pair. This function may perform - /// additional optimizations such as buffer allocation hoisting. This function - /// calls `allocationFn` and `deallocationFn` to create (de)allocations. - CreateAllocDeallocFn createAllocDeallocFn; }; /// BufferizationState keeps track of bufferization state and provides access to @@ -247,6 +240,11 @@ // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; + /// A function that creates an alloc-dealloc pair. This function may perform + /// additional optimizations such as buffer allocation hoisting. This function + /// calls `allocationFn` and `deallocationFn` to create (de)allocations. + Value createAllocDeallocFn(OpBuilder &builder, Location loc, Value v); + /// Map tensor values to memref buffers. void mapBuffer(ValueRange tensors, ValueRange buffers); diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -39,7 +39,7 @@ /// Default allocation function that is used by the comprehensive bufferization /// pass. The default currently creates a ranked memref using `memref.alloc`. Optional defaultAllocationFn(OpBuilder &b, Location loc, MemRefType type, - const SmallVector &dynShape); + ArrayRef dynShape); /// Default deallocation function that is used by the comprehensive /// bufferization pass. It expects to recieve back the value called from the diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -358,8 +358,7 @@ b.setInsertionPointAfter(operandBuffer.getDefiningOp()); } // Allocate the result buffer. - Value resultBuffer = - state.allocationFns.createAllocDeallocFn(b, loc, operandBuffer, state); + Value resultBuffer = state.createAllocDeallocFn(b, loc, operandBuffer); bool skipCopy = false; // Do not copy if the last preceding write of `operand` is an op that does // not write (skipping ops that merely create aliases). E.g., InitTensorOp. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -835,8 +835,8 @@ /// Create an Allocop/DeAllocOp pair, where the AllocOp is after /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. -static Value createNewAllocDeallocPairForShapedValue( - OpBuilder &b, Location loc, Value shapedValue, BufferizationState &state) { +Value BufferizationState::createAllocDeallocFn(OpBuilder &b, Location loc, + Value shapedValue) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -848,19 +848,19 @@ MemRefType allocMemRefType = getAllocationTypeAndShape(b, loc, shapedValue, dynShape); Optional allocated = - state.allocationFns.allocationFn(b, loc, allocMemRefType, dynShape); + allocationFns.allocationFn(b, loc, allocMemRefType, dynShape); // TODO: For now just assert the value is returned. Eventually need to // error-propagate. assert(allocated && "allocation failed"); Value casted = allocated.getValue(); if (memRefType && memRefType != allocMemRefType) { casted = b.create(loc, memRefType, allocated.getValue()); - state.aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue()); + aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue()); } // 2. Create memory deallocation. b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); - state.allocationFns.deallocationFn(b, loc, allocated.getValue()); + allocationFns.deallocationFn(b, loc, allocated.getValue()); return casted; } @@ -1162,8 +1162,7 @@ //===----------------------------------------------------------------------===// Optional mlir::linalg::comprehensive_bufferize::defaultAllocationFn( - OpBuilder &b, Location loc, MemRefType type, - const SmallVector &dynShape) { + OpBuilder &b, Location loc, MemRefType type, ArrayRef dynShape) { Value allocated = b.create( loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated; @@ -1727,8 +1726,7 @@ std::unique_ptr mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() { return std::make_unique( - defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn, - createNewAllocDeallocPairForShapedValue); + defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn); } // Default constructor for BufferizationOptions that sets all allocation @@ -2260,8 +2258,7 @@ bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0)); Value alloc; if (!inplace) - alloc = createNewAllocDeallocPairForShapedValue( - b, loc, extractSliceOp.result(), state); + alloc = state.createAllocDeallocFn(b, loc, extractSliceOp.result()); // Bufferize to subview. auto subviewMemRefType = diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -167,8 +167,8 @@ OpBuilder::InsertionGuard g(b); b.setInsertionPoint(initTensorOp); - Value alloc = state.allocationFns.createAllocDeallocFn( - b, initTensorOp->getLoc(), initTensorOp.result(), state); + Value alloc = state.createAllocDeallocFn(b, initTensorOp->getLoc(), + initTensorOp.result()); state.mapBuffer(initTensorOp.result(), alloc); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -48,9 +48,9 @@ (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); } -static Optional -allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type, - const SmallVector &dynShape) { +static Optional allocationFnUsingAlloca(OpBuilder &b, Location loc, + MemRefType type, + ArrayRef dynShape) { Value allocated = b.create( loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated;