diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -122,8 +122,8 @@ /// Default allocation function that is used by the comprehensive bufferization /// pass. The default currently creates a ranked memref using `memref.alloc`. -Optional defaultAllocationFn(OpBuilder &b, Location loc, - Value shapedValue); +Optional defaultAllocationFn(OpBuilder &b, Location loc, MemRefType type, + const SmallVector &dynShape); /// Default deallocation function that is used by the comprehensive /// bufferization pass. It expects to recieve back the value called from the @@ -140,8 +140,8 @@ /// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned /// by the `allocationFn`. struct AllocationCallbacks { - using AllocationFn = - std::function(OpBuilder &, Location, Value)>; + using AllocationFn = std::function( + OpBuilder &, Location, MemRefType, const SmallVector &)>; using DeallocationFn = std::function; using MemCpyFn = std::function; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1208,6 +1208,61 @@ return nullptr; } +/// Compute the type of the `memref` to use for allocating the buffer for +/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the +/// dynamic dimensions in the returned `memref` type. The function also sets the +/// insertion point of the builder `b` to the position where the allocation is +/// to be inserted. +static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, + Value shapedValue, + SmallVectorImpl &dynShape) { + MemRefType allocMemRefType = + getContiguousMemRefType(shapedValue.getType().cast()); + if (auto bbArg = shapedValue.dyn_cast()) { + b.setInsertionPointToStart(bbArg.getOwner()); + loc = bbArg.getOwner()->getParentOp()->getLoc(); + } else { + b.setInsertionPoint(shapedValue.getDefiningOp()); + loc = shapedValue.getDefiningOp()->getLoc(); + } + + // Compute the dynamic part of the shape. + bool foundDynamicShapes = false; + if (auto rankedOp = dyn_cast_or_null( + shapedValue.getDefiningOp())) { + ReifiedRankedShapedTypeDims resultDims; + if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { + foundDynamicShapes = true; + OpResult resultValue = shapedValue.dyn_cast(); + auto &shape = resultDims[resultValue.getResultNumber()]; + for (auto dim : enumerate(allocMemRefType.getShape())) + if (dim.value() == ShapedType::kDynamicSize) + dynShape.push_back(shape[dim.index()]); + } + } + if (!foundDynamicShapes) { + for (auto dim : enumerate(allocMemRefType.getShape())) + if (dim.value() == ShapedType::kDynamicSize) + dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index())); + } + + // If the buffer is statically shaped, try to hoist it to the first enclosing + // parallel region. + // TODO: this concept of parallel region and threadlocal needs interfaces. + // TODO: also hoist in the dynamic case. For now this relies on subsequent + // calls to LICM and buffer hoisting which will most likely not succeed. + // TODO: when packing, allocate a static bounding box which will enable more + // hoisting. + if (dynShape.empty()) { + Operation *parent = + getFirstParentOfType(shapedValue); + if (parent) + b.setInsertionPointToStart(&(parent->getRegion(0).front())); + } + return allocMemRefType; +} + /// Create an Allocop/DeAllocOp pair, where the AllocOp is after /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. @@ -1217,20 +1272,26 @@ // Take a guard before anything else. OpBuilder::InsertionGuard g(b); + // 1. Create memory allocation. assert(shapedValue.getType().isa()); MemRefType memRefType = shapedValue.getType().dyn_cast(); - - Optional allocated = allocationFns.allocationFn(b, loc, shapedValue); + SmallVector dynShape; + // Note: getAllocationTypeAndShape also sets the insertion point. + MemRefType allocMemRefType = + getAllocationTypeAndShape(b, loc, shapedValue, dynShape); + Optional allocated = + allocationFns.allocationFn(b, loc, allocMemRefType, dynShape); // TODO: For now just assert the value is returned. Eventually need to // error-propagate. assert(allocated && "allocation failed"); Value casted = allocated.getValue(); - MemRefType allocMemRefType = allocated->getType().cast(); if (memRefType && memRefType != allocMemRefType) { casted = b.create(loc, memRefType, allocated.getValue()); aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue()); } + // 2. Create memory deallocation. + b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); allocationFns.deallocationFn(b, loc, allocated.getValue()); return casted; } @@ -1595,88 +1656,24 @@ // Bufferization entry-point for functions. //===----------------------------------------------------------------------===// -/// Compute the type of the `memref` to use for allocating the buffer for -/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the -/// dynamic dimensions in the returned `memref` type. The function also sets the -/// insertion point of the builder `b` to the position where the allocation is -/// to be inserted. -static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, - Value shapedValue, - SmallVectorImpl &dynShape) { - MemRefType allocMemRefType = - getContiguousMemRefType(shapedValue.getType().cast()); - if (auto bbArg = shapedValue.dyn_cast()) { - b.setInsertionPointToStart(bbArg.getOwner()); - loc = bbArg.getOwner()->getParentOp()->getLoc(); - } else { - b.setInsertionPoint(shapedValue.getDefiningOp()); - loc = shapedValue.getDefiningOp()->getLoc(); - } - - // Compute the dynamic part of the shape. - bool foundDynamicShapes = false; - if (auto rankedOp = dyn_cast_or_null( - shapedValue.getDefiningOp())) { - ReifiedRankedShapedTypeDims resultDims; - if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { - foundDynamicShapes = true; - OpResult resultValue = shapedValue.dyn_cast(); - auto &shape = resultDims[resultValue.getResultNumber()]; - for (auto dim : enumerate(allocMemRefType.getShape())) - if (dim.value() == ShapedType::kDynamicSize) - dynShape.push_back(shape[dim.index()]); - } - } - if (!foundDynamicShapes) { - for (auto dim : enumerate(allocMemRefType.getShape())) - if (dim.value() == ShapedType::kDynamicSize) - dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index())); - } - - // If the buffer is statically shaped, try to hoist it to the first enclosing - // parallel region. - // TODO: this concept of parallel region and threadlocal needs interfaces. - // TODO: also hoist in the dynamic case. For now this relies on subsequent - // calls to LICM and buffer hoisting which will most likely not succeed. - // TODO: when packing, allocate a static bounding box which will enable more - // hoisting. - if (dynShape.empty()) { - Operation *parent = - getFirstParentOfType(shapedValue); - if (parent) - b.setInsertionPointToStart(&(parent->getRegion(0).front())); - } - return allocMemRefType; -} - -Optional mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc, - Value shapedValue) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - SmallVector dynShape; - MemRefType allocMemRefType = - getAllocationTypeAndShape(b, loc, shapedValue, dynShape); +Optional +mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc, MemRefType type, + const SmallVector &dynShape) { Value allocated = b.create( - loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments)); + loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated; } -static Optional allocationFnUsingAlloca(OpBuilder &b, Location loc, - Value shapedValue) { - OpBuilder::InsertionGuard g(b); - SmallVector dynShape; - MemRefType allocMemRefType = - getAllocationTypeAndShape(b, loc, shapedValue, dynShape); +static Optional +allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type, + const SmallVector &dynShape) { Value allocated = b.create( - loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments)); + loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated; } void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer) { - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(allocatedBuffer.getParentBlock()->getTerminator()); b.create(loc, allocatedBuffer); }