diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -415,13 +415,25 @@ BufferizationState(const AnalysisState &analysisState) : analysisState(analysisState) {} + /// Creates a memref allocation with the given type and dynamic extents. + FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape); + + /// Creates a memref allocation for the given shaped value. This function may + /// perform additional optimizations such as buffer allocation hoisting. + // TODO: Allocation hoisting should be a cleanup pass. + FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue); + + /// Deallocate all buffers. + LogicalResult deallocateAllBuffers(OpBuilder &b); + /// Return the buffer (memref) for a given OpOperand (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization was decided. FailureOr getBuffer(RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace = false, - Optional customCopyInsertionPoint = None) const; + Optional customCopyInsertionPoint = None); /// Return a reference to the BufferizationOptions. const BufferizationOptions &getOptions() const { @@ -436,6 +448,9 @@ private: const AnalysisState &analysisState; + + /// A list of all buffer allocations. + SmallVector allocations; }; /// Replace an op with replacement values. The op is deleted. Tensor OpResults @@ -482,22 +497,6 @@ ValueRange dynShape, const BufferizationOptions &options); -/// Creates a memref allocation with the given type and dynamic extents. If -/// `createDealloc`, a deallocation op is inserted at the point where the -/// allocation goes out of scope. -FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, - ValueRange dynShape, bool deallocMemref, - const BufferizationOptions &options); - -/// Creates a memref allocation for the given shaped value. This function may -/// perform additional optimizations such as buffer allocation hoisting. If -/// `createDealloc`, a deallocation op is inserted at the point where the -/// allocation goes out of scope. -// TODO: Allocation hoisting should be a cleanup pass. -FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue, - bool deallocMemref, - const BufferizationOptions &options); - /// Creates a memref deallocation. The given memref buffer must have been /// allocated using `createAlloc`. LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -243,9 +243,10 @@ /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. -FailureOr BufferizationState::getBuffer( - RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace, - Optional customCopyInsertionPoint) const { +FailureOr +BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, + bool forceInPlace, + Optional customCopyInsertionPoint) { const BufferizationOptions &options = analysisState.getOptions(); OpBuilder::InsertionGuard guard(rewriter); Operation *op = opOperand.getOwner(); @@ -261,8 +262,7 @@ // allocation should be inserted (in the absence of allocation hoisting). setInsertionPointAfter(rewriter, operandBuffer); // Allocate the result buffer. - FailureOr resultBuffer = createAlloc(rewriter, loc, operandBuffer, - options.createDeallocs, options); + FailureOr resultBuffer = createAlloc(rewriter, loc, operandBuffer); if (failed(resultBuffer)) return failure(); // Do not copy if the last preceding writes of `operand` are ops that do @@ -436,27 +436,23 @@ return allocMemRefType; } -/// Create an AllocOp/DeallocOp pair, where the AllocOp is after -/// `shapedValue.getDefiningOp` (or at the top of the block in case of a -/// bbArg) and the DeallocOp is at the end of the block. -FailureOr -bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue, - bool deallocMemref, - const BufferizationOptions &options) { +/// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the +/// block in case of a bbArg). +FailureOr BufferizationState::createAlloc(OpBuilder &b, Location loc, + Value shapedValue) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); - - // 1. Create memory allocation. assert(shapedValue.getType().isa()); MemRefType memRefType = shapedValue.getType().dyn_cast(); SmallVector dynShape; // Note: getAllocationTypeAndShape also sets the insertion point. MemRefType allocMemRefType = getAllocationTypeAndShape(b, loc, shapedValue, dynShape); - FailureOr allocated = - createAlloc(b, loc, allocMemRefType, dynShape, options); + FailureOr allocated = bufferization::createAlloc( + b, loc, allocMemRefType, dynShape, getOptions()); if (failed(allocated)) return failure(); + allocations.push_back(*allocated); Value casted = allocated.getValue(); if (memRefType && memRefType != allocMemRefType) { assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(), @@ -464,15 +460,34 @@ "createAlloc: cast incompatible"); casted = b.create(loc, memRefType, allocated.getValue()); } + return casted; +} - if (deallocMemref) { - // 2. Create memory deallocation. - b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); - if (failed(createDealloc(b, loc, allocated.getValue(), options))) - return failure(); - } +/// Create a memref allocation with the given type and dynamic extents. +FailureOr BufferizationState::createAlloc(OpBuilder &b, Location loc, + MemRefType type, + ValueRange dynShape) { + OpBuilder::InsertionGuard g(b); - return casted; + FailureOr alloc = + bufferization::createAlloc(b, loc, type, dynShape, getOptions()); + if (failed(alloc)) + return failure(); + allocations.push_back(*alloc); + return alloc; +} + +LogicalResult BufferizationState::deallocateAllBuffers(OpBuilder &builder) { + if (getOptions().createDeallocs) { + for (Value value : allocations) { + // Dealloc at the end of the block. + builder.setInsertionPoint(value.getParentBlock()->getTerminator()); + if (failed(createDealloc(builder, value.getLoc(), value, getOptions()))) + return failure(); + } + } + allocations.clear(); + return success(); } /// Create a memref allocation with the given type and dynamic extents. @@ -490,28 +505,6 @@ return allocated; } -/// Create a memref allocation with the given type and dynamic extents. May also -/// deallocate the memref again. -FailureOr -bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, - ValueRange dynShape, bool deallocMemref, - const BufferizationOptions &options) { - OpBuilder::InsertionGuard g(b); - - FailureOr alloc = createAlloc(b, loc, type, dynShape, options); - if (failed(alloc)) - return failure(); - - if (deallocMemref) { - // Dealloc at the end of the block. - b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator()); - if (failed(createDealloc(b, loc, *alloc, options))) - return failure(); - } - - return alloc; -} - /// Create a memref deallocation. LogicalResult bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -325,7 +325,15 @@ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) return failure(); - return checkBufferizationResult(op, bufferizationState.getOptions()); + if (failed(checkBufferizationResult(op, bufferizationState.getOptions()))) + return failure(); + + // Deallocate all buffers. + OpBuilder builder(op->getContext()); + if (failed(bufferizationState.deallocateAllBuffers(builder))) + return failure(); + + return success(); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -235,9 +235,8 @@ if (initTensorOp->getUses().empty()) return success(); - FailureOr alloc = - createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(), - state.getOptions().createDeallocs, state.getOptions()); + FailureOr alloc = state.createAlloc(rewriter, initTensorOp->getLoc(), + initTensorOp.result()); if (failed(alloc)) return failure(); replaceOpWithBufferizedValues(rewriter, op, *alloc); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -228,8 +228,7 @@ Value alloc; if (!inplace) { FailureOr allocOrFailure = - createAlloc(rewriter, loc, extractSliceOp.result(), - state.getOptions().createDeallocs, state.getOptions()); + state.createAlloc(rewriter, loc, extractSliceOp.result()); if (failed(allocOrFailure)) return failure(); alloc = *allocOrFailure; @@ -338,9 +337,7 @@ auto shape = tensorType.getShape(); MemRefType resultType = getContiguousMemRefType(tensorType); FailureOr maybeBuffer = - createAlloc(rewriter, loc, resultType, {}, - /*deallocMemref=*/state.getOptions().createDeallocs, - state.getOptions()); + state.createAlloc(rewriter, loc, resultType, {}); if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; @@ -389,10 +386,8 @@ Location loc = op->getLoc(); MemRefType memrefType = getContiguousMemRefType(generateOp.getType().cast()); - FailureOr maybeResult = - createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(), - /*deallocMemref=*/state.getOptions().createDeallocs, - state.getOptions()); + FailureOr maybeResult = state.createAlloc( + rewriter, loc, memrefType, generateOp.dynamicExtents()); if (failed(maybeResult)) return failure(); Value result = *maybeResult;