diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -36,8 +36,8 @@ /// Options for ComprehensiveBufferize. struct BufferizationOptions { - using AllocationFn = std::function( - OpBuilder &, Location, MemRefType, ArrayRef)>; + using AllocationFn = std::function(OpBuilder &, Location, + MemRefType, ValueRange)>; using DeallocationFn = std::function; using MemCpyFn = @@ -298,15 +298,23 @@ MemRefType getDynamicMemRefType(RankedTensorType tensorType, unsigned addressSpace = 0); -/// Creates a memref allocation. +/// Creates a memref allocation with the given type and dynamic extents. FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape, + ValueRange dynShape, + const BufferizationOptions &options); + +/// Creates a memref allocation with the given type and dynamic extents. If +/// `createDealloc`, a deallocation op is inserted at the point where the +/// allocation goes out of scope. +FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape, bool deallocMemref, const BufferizationOptions &options); /// Creates a memref allocation for the given shaped value. This function may /// perform additional optimizations such as buffer allocation hoisting. If /// `createDealloc`, a deallocation op is inserted at the point where the /// allocation goes out of scope. +// TODO: Allocation hoisting should be a cleanup pass. FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref, const BufferizationOptions &options); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -433,10 +433,10 @@ return casted; } -/// Create a memref allocation. +/// Create a memref allocation with the given type and dynamic extents. FailureOr bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape, + ValueRange dynShape, const BufferizationOptions &options) { if (options.allocationFn) return (*options.allocationFn)(b, loc, type, dynShape); @@ -447,6 +447,28 @@ return allocated; } +/// Create a memref allocation with the given type and dynamic extents. May also +/// deallocate the memref again. +FailureOr +bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape, bool deallocMemref, + const BufferizationOptions &options) { + OpBuilder::InsertionGuard g(b); + + FailureOr alloc = createAlloc(b, loc, type, dynShape, options); + if (failed(alloc)) + return failure(); + + if (deallocMemref) { + // Dealloc at the end of the block. + b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator()); + if (failed(createDealloc(b, loc, *alloc, options))) + return failure(); + } + + return alloc; +} + /// Create a memref deallocation. LogicalResult bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -73,7 +73,7 @@ static FailureOr allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape) { + ValueRange dynShape) { Value allocated = b.create( loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated;