diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -415,13 +415,22 @@ BufferizationState(const AnalysisState &analysisState) : analysisState(analysisState) {} + /// Creates a memref allocation with the given type and dynamic extents. + FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape); + + /// Creates a memref allocation for the given shaped value. This function may + /// perform additional optimizations such as buffer allocation hoisting. + // TODO: Allocation hoisting should be a cleanup pass. + FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue); + /// Return the buffer (memref) for a given OpOperand (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization was decided. FailureOr getBuffer(RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace = false, - Optional customCopyInsertionPoint = None) const; + Optional customCopyInsertionPoint = None); /// Return a reference to the BufferizationOptions. const BufferizationOptions &getOptions() const { @@ -477,27 +486,6 @@ MemRefLayoutAttrInterface layout = {}, Attribute memorySpace = {}); -/// Creates a memref allocation with the given type and dynamic extents. -FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, - ValueRange dynShape, - const BufferizationOptions &options); - -/// Creates a memref allocation with the given type and dynamic extents. If -/// `createDealloc`, a deallocation op is inserted at the point where the -/// allocation goes out of scope. -FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, - ValueRange dynShape, bool deallocMemref, - const BufferizationOptions &options); - -/// Creates a memref allocation for the given shaped value. This function may -/// perform additional optimizations such as buffer allocation hoisting. If -/// `createDealloc`, a deallocation op is inserted at the point where the -/// allocation goes out of scope. -// TODO: Allocation hoisting should be a cleanup pass. -FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue, - bool deallocMemref, - const BufferizationOptions &options); - /// Creates a memref deallocation. The given memref buffer must have been /// allocated using `createAlloc`. LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, @@ -507,6 +495,10 @@ LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to, const BufferizationOptions &options); +/// Finalize all buffer allocations, i.e., create alloc ops as specified in the +/// bufferization options and deallocate all buffers. +LogicalResult finalizeBuffers(Operation *op, + const BufferizationOptions &options); } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -70,13 +70,6 @@ // TODO: Extract `options` from `state` and pass as separate argument. LogicalResult bufferizeOp(Operation *op, const AnalysisState &analysisState); -/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. -/// Reuse an existing `BufferizationState`. -/// -/// Note: This function overload is useful for extending the bufferization. -LogicalResult bufferizeOp(Operation *op, - BufferizationState &bufferizationState); - /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. /// Buffers are duplicated and copied before any tensor use that bufferizes to /// a memory write. @@ -87,6 +80,16 @@ BufferizationOptions getPartialBufferizationOptions(); +//===----------------------------------------------------------------------===// +// Helper functions for extending Bufferization +//===----------------------------------------------------------------------===// + +/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. +/// Reuse an existing `BufferizationState`. +/// +/// Note: This function overload is useful for extending the bufferization. +LogicalResult bufferizeOp(Operation *op, + BufferizationState &bufferizationState); } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -42,6 +42,8 @@ constexpr const ::llvm::StringLiteral bufferization::BufferizableOpInterface::kInplaceableAttrName; +static const char *kBufferAllocationAttr = "bufferization.allocation"; + //===----------------------------------------------------------------------===// // BufferizationOptions //===----------------------------------------------------------------------===// @@ -243,9 +245,10 @@ /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. -FailureOr BufferizationState::getBuffer( - RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace, - Optional customCopyInsertionPoint) const { +FailureOr +BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, + bool forceInPlace, + Optional customCopyInsertionPoint) { const BufferizationOptions &options = analysisState.getOptions(); OpBuilder::InsertionGuard guard(rewriter); Operation *op = opOperand.getOwner(); @@ -261,8 +264,7 @@ // allocation should be inserted (in the absence of allocation hoisting). setInsertionPointAfter(rewriter, operandBuffer); // Allocate the result buffer. - FailureOr resultBuffer = createAlloc(rewriter, loc, operandBuffer, - options.createDeallocs, options); + FailureOr resultBuffer = createAlloc(rewriter, loc, operandBuffer); if (failed(resultBuffer)) return failure(); // Do not copy if the last preceding writes of `operand` are ops that do @@ -358,6 +360,33 @@ // Bufferization-specific scoped alloc/dealloc insertion support. //===----------------------------------------------------------------------===// +/// Create a memref allocation with the given type and dynamic extents. +static FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape, + const BufferizationOptions &options) { + if (options.allocationFn) + return (*options.allocationFn)(b, loc, type, dynShape, + options.bufferAlignment); + + // Default bufferallocation via AllocOp. + Value allocated = b.create( + loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment)); + return allocated; +} + +/// Creates a memref deallocation. The given memref buffer must have been +/// allocated using `createAlloc`. +LogicalResult +bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, + const BufferizationOptions &options) { + if (options.deallocationFn) + return (*options.deallocationFn)(b, loc, allocatedBuffer); + + // Default buffer deallocation via DeallocOp. + b.create(loc, allocatedBuffer); + return success(); +} + /// Move the insertion point of the given builder to the beginning of a /// surrounding block as much as possible, while not crossing any allocation /// hoisting barriers. @@ -436,92 +465,39 @@ return allocMemRefType; } -/// Create an AllocOp/DeallocOp pair, where the AllocOp is after -/// `shapedValue.getDefiningOp` (or at the top of the block in case of a -/// bbArg) and the DeallocOp is at the end of the block. -FailureOr -bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue, - bool deallocMemref, - const BufferizationOptions &options) { +static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape) { + auto allocaOp = b.create(loc, type, dynShape); + allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr()); + return allocaOp.getResult(); +} + +/// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the +/// block in case of a bbArg). +FailureOr BufferizationState::createAlloc(OpBuilder &b, Location loc, + Value shapedValue) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); - - // 1. Create memory allocation. assert(shapedValue.getType().isa()); MemRefType memRefType = shapedValue.getType().dyn_cast(); SmallVector dynShape; // Note: getAllocationTypeAndShape also sets the insertion point. MemRefType allocMemRefType = getAllocationTypeAndShape(b, loc, shapedValue, dynShape); - FailureOr allocated = - createAlloc(b, loc, allocMemRefType, dynShape, options); - if (failed(allocated)) - return failure(); - Value casted = allocated.getValue(); + Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape); if (memRefType && memRefType != allocMemRefType) { - assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(), - memRefType) && + assert(memref::CastOp::areCastCompatible(alloc.getType(), memRefType) && "createAlloc: cast incompatible"); - casted = b.create(loc, memRefType, allocated.getValue()); - } - - if (deallocMemref) { - // 2. Create memory deallocation. - b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); - if (failed(createDealloc(b, loc, allocated.getValue(), options))) - return failure(); - } - - return casted; -} - -/// Create a memref allocation with the given type and dynamic extents. -FailureOr -bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, - ValueRange dynShape, - const BufferizationOptions &options) { - if (options.allocationFn) - return (*options.allocationFn)(b, loc, type, dynShape, - options.bufferAlignment); - - // Default bufferallocation via AllocOp. - Value allocated = b.create( - loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment)); - return allocated; -} - -/// Create a memref allocation with the given type and dynamic extents. May also -/// deallocate the memref again. -FailureOr -bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, - ValueRange dynShape, bool deallocMemref, - const BufferizationOptions &options) { - OpBuilder::InsertionGuard g(b); - - FailureOr alloc = createAlloc(b, loc, type, dynShape, options); - if (failed(alloc)) - return failure(); - - if (deallocMemref) { - // Dealloc at the end of the block. - b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator()); - if (failed(createDealloc(b, loc, *alloc, options))) - return failure(); + alloc = b.create(loc, memRefType, alloc); } - return alloc; } -/// Create a memref deallocation. -LogicalResult -bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, - const BufferizationOptions &options) { - if (options.deallocationFn) - return (*options.deallocationFn)(b, loc, allocatedBuffer); - - // Default buffer deallocation via DeallocOp. - b.create(loc, allocatedBuffer); - return success(); +/// Create a memref allocation with the given type and dynamic extents. +FailureOr BufferizationState::createAlloc(OpBuilder &b, Location loc, + MemRefType type, + ValueRange dynShape) { + return createBufferAllocation(b, loc, type, dynShape); } /// Create a memory copy between two memref buffers. @@ -535,6 +511,41 @@ return success(); } +LogicalResult +bufferization::finalizeBuffers(Operation *op, + const BufferizationOptions &options) { + IRRewriter rewriter(op->getContext()); + + // Bufferization creates memref.alloca ops. After bufferization, these must be + // rewritten to alloc/dealloc ops as specified in the bufferization options. + WalkResult status = op->walk([&](memref::AllocaOp allocaOp) { + // Ignore memref.alloca ops that were not created by the bufferization. + if (!allocaOp->hasAttr(kBufferAllocationAttr)) + return WalkResult::skip(); + + Block *block = allocaOp->getBlock(); + rewriter.setInsertionPoint(allocaOp); + FailureOr alloc = + createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(), + allocaOp.dynamicSizes(), options); + if (failed(alloc)) + return WalkResult::interrupt(); + rewriter.replaceOp(allocaOp, *alloc); + + // Stop here if deallocations are deactivated. + if (!options.createDeallocs) + return WalkResult::advance(); + + rewriter.setInsertionPoint(block->getTerminator()); + if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + return success(!status.wasInterrupted()); +} + //===----------------------------------------------------------------------===// // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -302,7 +302,11 @@ LogicalResult bufferization::bufferizeOp(Operation *op, const AnalysisState &analysisState) { BufferizationState bufferizationState(analysisState); - return bufferizeOp(op, bufferizationState); + if (failed(bufferizeOp(op, bufferizationState))) + return failure(); + if (failed(finalizeBuffers(op, analysisState.getOptions()))) + return failure(); + return success(); } LogicalResult @@ -332,7 +336,10 @@ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) return failure(); - return checkBufferizationResult(op, bufferizationState.getOptions()); + if (failed(checkBufferizationResult(op, bufferizationState.getOptions()))) + return failure(); + + return success(); } namespace { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -1054,6 +1054,10 @@ } } + // Finalize all buffers. + if (failed(finalizeBuffers(moduleOp, options))) + return failure(); + // Perform a post-processing pass of layout modification at function boundary // according to the kBufferLayoutAttrName. layoutPostProcessing(moduleOp); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -235,9 +235,8 @@ if (initTensorOp->getUses().empty()) return success(); - FailureOr alloc = - createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(), - state.getOptions().createDeallocs, state.getOptions()); + FailureOr alloc = state.createAlloc(rewriter, initTensorOp->getLoc(), + initTensorOp.result()); if (failed(alloc)) return failure(); replaceOpWithBufferizedValues(rewriter, op, *alloc); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -228,8 +228,7 @@ Value alloc; if (!inplace) { FailureOr allocOrFailure = - createAlloc(rewriter, loc, extractSliceOp.result(), - state.getOptions().createDeallocs, state.getOptions()); + state.createAlloc(rewriter, loc, extractSliceOp.result()); if (failed(allocOrFailure)) return failure(); alloc = *allocOrFailure; @@ -338,9 +337,7 @@ auto shape = tensorType.getShape(); MemRefType resultType = getContiguousMemRefType(tensorType); FailureOr maybeBuffer = - createAlloc(rewriter, loc, resultType, {}, - /*deallocMemref=*/state.getOptions().createDeallocs, - state.getOptions()); + state.createAlloc(rewriter, loc, resultType, {}); if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; @@ -389,10 +386,8 @@ Location loc = op->getLoc(); MemRefType memrefType = getContiguousMemRefType(generateOp.getType().cast()); - FailureOr maybeResult = - createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(), - /*deallocMemref=*/state.getOptions().createDeallocs, - state.getOptions()); + FailureOr maybeResult = state.createAlloc( + rewriter, loc, memrefType, generateOp.dynamicExtents()); if (failed(maybeResult)) return failure(); Value result = *maybeResult;