diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -133,6 +133,10 @@ /// the boundaries. bool allowUnknownOps = false; + /// Specifies whether dealloc ops should be generated along with alloc ops. If + /// not, new memory allocations will leak. + bool createDeallocs = true; + /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated. /// Should be used only with `testAnalysisOnly = true`. unsigned analysisFuzzerSeed = 0; @@ -369,10 +373,12 @@ Optional createAlloc(OpBuilder &b, Location loc, MemRefType type, ArrayRef dynShape); - /// Creates an alloc-dealloc pair. This function may perform additional - /// optimizations such as buffer allocation hoisting. - Value createAllocDeallocPair(OpBuilder &builder, Location loc, - Value shapedValue); + /// Creates a memref allocation for the given shaped value. This function may + /// perform additional optimizations such as buffer allocation hoisting. If + /// `createDealloc`, a deallocation op is inserted at the point where the + /// allocation goes out of scope. + Value createAlloc(OpBuilder &b, Location loc, Value shapedValue, + bool deallocMemref = true); /// Creates a memref deallocation. The given memref buffer must have been /// allocated using `createAlloc`. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -393,7 +393,8 @@ // allocation should be inserted (in the absence of allocation hoisting). setInsertionPointAfter(rewriter, operandBuffer); // Allocate the result buffer. - Value resultBuffer = createAllocDeallocPair(rewriter, loc, operandBuffer); + Value resultBuffer = + createAlloc(rewriter, loc, operandBuffer, options.createDeallocs); bool skipCopy = false; // Do not copy if the last preceding write of `operand` is an op that does // not write (skipping ops that merely create aliases). E.g., InitTensorOp. @@ -536,11 +537,11 @@ return allocMemRefType; } -/// Create an Allocop/DeAllocOp pair, where the AllocOp is after +/// Create an AllocOp/DeallocOp pair, where the AllocOp is after /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. -Value mlir::linalg::comprehensive_bufferize::BufferizationState:: - createAllocDeallocPair(OpBuilder &b, Location loc, Value shapedValue) { +Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( + OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -560,9 +561,12 @@ casted = b.create(loc, memRefType, allocated.getValue()); } - // 2. Create memory deallocation. - b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); - createDealloc(b, loc, allocated.getValue()); + if (deallocMemref) { + // 2. Create memory deallocation. + b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); + createDealloc(b, loc, allocated.getValue()); + } + return casted; } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -199,8 +199,9 @@ if (initTensorOp->getUses().empty()) return success(); - Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(), - initTensorOp.result()); + Value alloc = state.createAlloc(rewriter, initTensorOp->getLoc(), + initTensorOp.result(), + state.getOptions().createDeallocs); state.replaceOpWithBufferizedValues(rewriter, op, alloc); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -143,8 +143,8 @@ bool inplace = state.isInPlace(extractSliceOp->getResult(0)); Value alloc; if (!inplace) - alloc = - state.createAllocDeallocPair(rewriter, loc, extractSliceOp.result()); + alloc = state.createAlloc(rewriter, loc, extractSliceOp.result(), + state.getOptions().createDeallocs); // Bufferize to subview. auto subviewMemRefType =