diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -212,16 +212,13 @@ // functions in the future. struct AllocationCallbacks { using AllocationFn = std::function( - OpBuilder &, Location, MemRefType, const SmallVector &)>; + OpBuilder &, Location, MemRefType, ArrayRef)>; using DeallocationFn = std::function; using MemCpyFn = std::function; - using CreateAllocDeallocFn = - std::function; AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn, - MemCpyFn copyFn, CreateAllocDeallocFn allocDeallocFn) - : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn), - createAllocDeallocFn(allocDeallocFn) {} + MemCpyFn copyFn) + : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {} /// A function that allocates memory. AllocationFn allocationFn; @@ -231,11 +228,6 @@ /// A function that copies memory between two allocations. MemCpyFn memCpyFn; - - /// A function that creates an alloc-dealloc pair. This function may perform - /// additional optimizations such as buffer allocation hoisting. This function - /// calls `allocationFn` and `deallocationFn` to create (de)allocations. - CreateAllocDeallocFn createAllocDeallocFn; }; /// BufferizationState keeps track of bufferization state and provides access to @@ -247,6 +239,12 @@ // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; + /// A function that creates an alloc-dealloc pair. This function may perform + /// additional optimizations such as buffer allocation hoisting. This function + /// calls `allocationFn` and `deallocationFn` to create (de)allocations. + Value createAllocDeallocFn(OpBuilder &builder, Location loc, + Value shapedValue); + /// Map tensor values to memref buffers. void mapBuffer(ValueRange tensors, ValueRange buffers); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" + #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -359,8 +360,7 @@ b.setInsertionPointAfter(operandBuffer.getDefiningOp()); } // Allocate the result buffer. - Value resultBuffer = - state.allocationFns.createAllocDeallocFn(b, loc, operandBuffer, state); + Value resultBuffer = state.createAllocDeallocFn(b, loc, operandBuffer); bool skipCopy = false; // Do not copy if the last preceding write of `operand` is an op that does // not write (skipping ops that merely create aliases). E.g., InitTensorOp. @@ -416,6 +416,128 @@ return op->emitError() << "unsupported op with tensors"; } +//===----------------------------------------------------------------------===// +// Bufferization-specific scoped alloc/dealloc insertion support. +//===----------------------------------------------------------------------===// + +/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) +/// with the same shape as `shapedType` and specified `layout` and +/// `addressSpace`. +static MemRefType getContiguousMemRefType(ShapedType shapedType, + MemRefLayoutAttrInterface layout = {}, + Attribute memorySpace = {}) { + return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), + layout, memorySpace); +} + +/// Move the insertion point of the given builder to the beginning of a +/// surrounding block as much as possible, while not crossing any allocation +/// hoisting barriers. +static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) { + Operation *op = b.getInsertionBlock()->getParentOp(); + while (op) { + if (auto bufferizableOp = dyn_cast(op)) + if (bufferizableOp.isAllocationHoistingBarrier()) + break; + op = op->getParentOp(); + } + + // FuncOp is an allocation hoisting barrier, so the above loop should never + // run out of parents. + assert( + (op && cast(op).isAllocationHoistingBarrier()) && + "expected traversal to end at allocation hoisting barrier"); + + // TODO: Handle cases where allocation hoisting barrier has more than one + // region or block. + assert(op->getNumRegions() == 1 && + "allocation hoisting barriers with >1 regions not supported"); + assert(op->getRegion(0).getBlocks().size() == 1 && + "allocation hoisting barriers with >1 blocks not supported"); + b.setInsertionPointToStart(&(op->getRegion(0).front())); +} + +/// Compute the type of the `memref` to use for allocating the buffer for +/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the +/// dynamic dimensions in the returned `memref` type. The function may also set +/// the insertion point to an earlier location, where the allocation should +/// happen ("allocation hoisting"). +static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, + Value shapedValue, + SmallVectorImpl &dynShape) { + MemRefType allocMemRefType = + getContiguousMemRefType(shapedValue.getType().cast()); + + // Compute the dynamic part of the shape. + bool reifiedShapes = false; + if (auto rankedOp = dyn_cast_or_null( + shapedValue.getDefiningOp())) { + ReifiedRankedShapedTypeDims resultDims; + if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { + reifiedShapes = true; + OpResult resultValue = shapedValue.dyn_cast(); + auto &shape = resultDims[resultValue.getResultNumber()]; + for (auto dim : enumerate(allocMemRefType.getShape())) + if (ShapedType::isDynamic(dim.value())) + dynShape.push_back(shape[dim.index()]); + } + } + + if (!reifiedShapes) { + for (auto dim : enumerate(allocMemRefType.getShape())) + if (ShapedType::isDynamic(dim.value())) { + assert((shapedValue.getType().isa() || + shapedValue.getType().isa()) && + "expected MemRef type"); + dynShape.push_back( + b.create(loc, shapedValue, dim.index())); + } + } + + // If the buffer is statically shaped, try to hoist it to the first enclosing + // parallel region. + // TODO: also hoist in the dynamic case. For now this relies on subsequent + // calls to LICM and buffer hoisting which will most likely not succeed. + // TODO: when packing, allocate a static bounding box which will enable more + // hoisting. + if (dynShape.empty()) + moveInsertionPointToAllocationHoistingBarrier(b); + + return allocMemRefType; +} + +/// Create an Allocop/DeAllocOp pair, where the AllocOp is after +/// `shapedValue.getDefiningOp` (or at the top of the block in case of a +/// bbArg) and the DeallocOp is at the end of the block. +Value mlir::linalg::comprehensive_bufferize::BufferizationState:: + createAllocDeallocFn(OpBuilder &b, Location loc, Value shapedValue) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + // 1. Create memory allocation. + assert(shapedValue.getType().isa()); + MemRefType memRefType = shapedValue.getType().dyn_cast(); + SmallVector dynShape; + // Note: getAllocationTypeAndShape also sets the insertion point. + MemRefType allocMemRefType = + getAllocationTypeAndShape(b, loc, shapedValue, dynShape); + Optional allocated = + allocationFns.allocationFn(b, loc, allocMemRefType, dynShape); + // TODO: For now just assert the value is returned. Eventually need to + // error-propagate. + assert(allocated && "allocation failed"); + Value casted = allocated.getValue(); + if (memRefType && memRefType != allocMemRefType) { + casted = b.create(loc, memRefType, allocated.getValue()); + aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue()); + } + + // 2. Create memory deallocation. + b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); + allocationFns.deallocationFn(b, loc, allocated.getValue()); + return casted; +} + //===----------------------------------------------------------------------===// // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -667,118 +667,6 @@ return it2.first->second; } -//===----------------------------------------------------------------------===// -// Bufferization-specific scoped alloc/dealloc insertion support. -//===----------------------------------------------------------------------===// - -/// Move the insertion point of the given builder to the beginning of a -/// surrounding block as much as possible, while not crossing any allocation -/// hoisting barriers. -static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) { - Operation *op = b.getInsertionBlock()->getParentOp(); - while (op) { - if (auto bufferizableOp = dyn_cast(op)) - if (bufferizableOp.isAllocationHoistingBarrier()) - break; - op = op->getParentOp(); - } - - // FuncOp is an allocation hoisting barrier, so the above loop should never - // run out of parents. - assert( - (op && cast(op).isAllocationHoistingBarrier()) && - "expected traversal to end at allocation hoisting barrier"); - - // TODO: Handle cases where allocation hoisting barrier has more than one - // region or block. - assert(op->getNumRegions() == 1 && - "allocation hoisting barriers with >1 regions not supported"); - assert(op->getRegion(0).getBlocks().size() == 1 && - "allocation hoisting barriers with >1 blocks not supported"); - b.setInsertionPointToStart(&(op->getRegion(0).front())); -} - -/// Compute the type of the `memref` to use for allocating the buffer for -/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the -/// dynamic dimensions in the returned `memref` type. The function may also set -/// the insertion point to an earlier location, where the allocation should -/// happen ("allocation hoisting"). -static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, - Value shapedValue, - SmallVectorImpl &dynShape) { - MemRefType allocMemRefType = - getContiguousMemRefType(shapedValue.getType().cast()); - - // Compute the dynamic part of the shape. - bool reifiedShapes = false; - if (auto rankedOp = dyn_cast_or_null( - shapedValue.getDefiningOp())) { - ReifiedRankedShapedTypeDims resultDims; - if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { - reifiedShapes = true; - OpResult resultValue = shapedValue.dyn_cast(); - auto &shape = resultDims[resultValue.getResultNumber()]; - for (auto dim : enumerate(allocMemRefType.getShape())) - if (ShapedType::isDynamic(dim.value())) - dynShape.push_back(shape[dim.index()]); - } - } - - if (!reifiedShapes) { - for (auto dim : enumerate(allocMemRefType.getShape())) - if (ShapedType::isDynamic(dim.value())) { - assert((shapedValue.getType().isa() || - shapedValue.getType().isa()) && - "expected MemRef type"); - dynShape.push_back( - b.create(loc, shapedValue, dim.index())); - } - } - - // If the buffer is statically shaped, try to hoist it to the first enclosing - // parallel region. - // TODO: also hoist in the dynamic case. For now this relies on subsequent - // calls to LICM and buffer hoisting which will most likely not succeed. - // TODO: when packing, allocate a static bounding box which will enable more - // hoisting. - if (dynShape.empty()) - moveInsertionPointToAllocationHoistingBarrier(b); - - return allocMemRefType; -} - -/// Create an Allocop/DeAllocOp pair, where the AllocOp is after -/// `shapedValue.getDefiningOp` (or at the top of the block in case of a -/// bbArg) and the DeallocOp is at the end of the block. -static Value createNewAllocDeallocPairForShapedValue( - OpBuilder &b, Location loc, Value shapedValue, BufferizationState &state) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - // 1. Create memory allocation. - assert(shapedValue.getType().isa()); - MemRefType memRefType = shapedValue.getType().dyn_cast(); - SmallVector dynShape; - // Note: getAllocationTypeAndShape also sets the insertion point. - MemRefType allocMemRefType = - getAllocationTypeAndShape(b, loc, shapedValue, dynShape); - Optional allocated = - state.allocationFns.allocationFn(b, loc, allocMemRefType, dynShape); - // TODO: For now just assert the value is returned. Eventually need to - // error-propagate. - assert(allocated && "allocation failed"); - Value casted = allocated.getValue(); - if (memRefType && memRefType != allocMemRefType) { - casted = b.create(loc, memRefType, allocated.getValue()); - state.aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue()); - } - - // 2. Create memory deallocation. - b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); - state.allocationFns.deallocationFn(b, loc, allocated.getValue()); - return casted; -} - //===----------------------------------------------------------------------===// // Bufferization as simple BlockAndValueMapping rewrites. //===----------------------------------------------------------------------===// @@ -1426,7 +1314,7 @@ /// pass. The default currently creates a ranked memref using `memref.alloc`. static Optional defaultAllocationFn(OpBuilder &b, Location loc, MemRefType type, - const SmallVector &dynShape) { + ArrayRef dynShape) { Value allocated = b.create( loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated; @@ -1449,8 +1337,7 @@ std::unique_ptr mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() { return std::make_unique( - defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn, - createNewAllocDeallocPairForShapedValue); + defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn); } // Default constructor for BufferizationOptions that sets all allocation @@ -2154,8 +2041,7 @@ bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0)); Value alloc; if (!inplace) - alloc = createNewAllocDeallocPairForShapedValue( - b, loc, extractSliceOp.result(), state); + alloc = state.createAllocDeallocFn(b, loc, extractSliceOp.result()); // Bufferize to subview. auto subviewMemRefType = diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -167,8 +167,8 @@ OpBuilder::InsertionGuard g(b); b.setInsertionPoint(initTensorOp); - Value alloc = state.allocationFns.createAllocDeallocFn( - b, initTensorOp->getLoc(), initTensorOp.result(), state); + Value alloc = state.createAllocDeallocFn(b, initTensorOp->getLoc(), + initTensorOp.result()); state.mapBuffer(initTensorOp.result(), alloc); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -48,9 +48,9 @@ (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); } -static Optional -allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type, - const SmallVector &dynShape) { +static Optional allocationFnUsingAlloca(OpBuilder &b, Location loc, + MemRefType type, + ArrayRef dynShape) { Value allocated = b.create( loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated;