diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -177,10 +177,6 @@ Optional deallocationFn; Optional memCpyFn; - /// Specifies whether returning newly allocated memrefs should be allowed. - /// Otherwise, a pass failure is triggered. - bool allowReturnMemref = false; - /// Specifies whether not bufferizable ops are allowed in the input. If so, /// bufferization.to_memref and bufferization.to_tensor ops are inserted at /// the boundaries. @@ -348,7 +344,14 @@ /// Return true if `v1` and `v2` bufferize to equivalent buffers. virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0; - /// Return dialect-specific analysis state. + /// Return true if the given tensor (or an aliasing tensor) is yielded from + /// the containing block. Also include all aliasing tensors in the same block. + /// + /// Note: In the absence of an analysis, an implementation may return true for + /// any given tensor. + virtual bool isTensorYielded(Value tensor) const = 0; + + /// Return dialect-specific bufferization state. template Optional getDialectState(StringRef name) const { auto it = dialectState.find(name); @@ -407,6 +410,10 @@ /// Return true if `v1` and `v2` bufferize to equivalent buffers. bool areEquivalentBufferizedValues(Value v1, Value v2) const override; + + /// Return true if the given tensor (or an aliasing tensor) is yielded from + /// the containing block. Also include all aliasing tensors in the same block. + bool isTensorYielded(Value tensor) const override; }; /// BufferizationState provides helper functions for performing bufferization @@ -415,14 +422,20 @@ BufferizationState(const AnalysisState &analysisState) : analysisState(analysisState) {} - /// Creates a memref allocation with the given type and dynamic extents. - FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, - ValueRange dynShape); - - /// Creates a memref allocation for the given shaped value. This function may - /// perform additional optimizations such as buffer allocation hoisting. - // TODO: Allocation hoisting should be a cleanup pass. - FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue); + /// Creates a memref allocation for the given shaped value. `dealloc` + /// indicates whether the buffer should be deallocated or not. When `dealloc` + /// is `false`, this would create a memory leak, unless the buffer is + /// deallocated through some other mechanism. + /// + /// `dealloc` is optional. By default, this function will figure out by itself + /// if it is safe to deallocate the buffer. In essence, when returning the + /// buffer from a block, it is not safe to deallocate the buffer. This + /// information is queried via `AnalysisState::isTensorYielded`. + /// + /// Note: `shapedValue` is typically a tensor value. However, if it is a + /// memref value, `dealloc` is no longer optional and must be specified. + FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue, + Optional dealloc = None); /// Return the buffer (memref) for a given OpOperand (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -43,6 +43,10 @@ /// Registered post analysis steps. PostAnalysisStepList postAnalysisSteps; + + /// Specifies whether returning newly allocated memrefs should be allowed. + /// Otherwise, a pass failure is triggered. + bool allowReturnMemref = false; }; /// The BufferizationAliasInfo class maintains a list of buffer aliases and @@ -153,10 +157,22 @@ /// Return true if `v1` and `v2` bufferize to equivalent buffers. bool areEquivalentBufferizedValues(Value v1, Value v2) const override; + /// Return true if the given tensor (or an aliasing tensor) is yielded from + /// the containing block. Also include all aliasing tensors in the same block. + bool isTensorYielded(Value tensor) const override; + + /// Find all tensors that are yielded/returned from a block and store them in + /// `yieldedTensors`. Also include all aliasing tensors in the same block. + void gatherYieldedTensors(Operation *op); + private: /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal /// functions and `runOneShotBufferize` may access this object. BufferizationAliasInfo aliasInfo; + + /// A set of all tensors (and maybe aliasing tensors) that yielded from a + /// block. + DenseSet yieldedTensors; }; /// Analyze `op` and its nested ops. Bufferization decisions are stored in diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -42,8 +42,12 @@ constexpr const ::llvm::StringLiteral bufferization::BufferizableOpInterface::kInplaceableAttrName; +/// Attribute name used to mark allocs that are created by the bufferization. static const char *kBufferAllocationAttr = "bufferization.allocation"; +/// Attribute name used to mark allocs that should not be deallocated. +static const char *kSkipDeallocAttr = "bufferization.skip_dealloc"; + //===----------------------------------------------------------------------===// // BufferizationOptions //===----------------------------------------------------------------------===// @@ -253,6 +257,8 @@ OpBuilder::InsertionGuard guard(rewriter); Operation *op = opOperand.getOwner(); Location loc = op->getLoc(); + SmallVector aliasingOpResults = + analysisState.getAliasingOpResult(opOperand); Value operand = opOperand.get(); Value operandBuffer = lookupBuffer(rewriter, operand, options); @@ -263,8 +269,13 @@ // Move insertion point right after `operandBuffer`. That is where the // allocation should be inserted (in the absence of allocation hoisting). setInsertionPointAfter(rewriter, operandBuffer); - // Allocate the result buffer. - FailureOr resultBuffer = createAlloc(rewriter, loc, operandBuffer); + // Allocate the result buffer. The buffer should be deallocated if the tensor + // is not yielded and deallocs are enabled in general. + bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) { + return getAnalysisState().isTensorYielded(v); + }); + FailureOr resultBuffer = createAlloc( + rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs); if (failed(resultBuffer)) return failure(); // Do not copy if the last preceding writes of `operand` are ops that do @@ -281,8 +292,6 @@ })) return resultBuffer; // Do not copy if the copied data is never read. - SmallVector aliasingOpResults = - analysisState.getAliasingOpResult(opOperand); if (!aliasingOpResults.empty() && !analysisState.bufferizesToMemoryRead(opOperand) && llvm::none_of(aliasingOpResults, [&](OpResult opResult) { @@ -339,7 +348,12 @@ AlwaysCopyAnalysisState::AlwaysCopyAnalysisState( const BufferizationOptions &options) - : AnalysisState(options) {} + : AnalysisState(options) { + // Note: Allocations must be deallocated with a subsequent run of the buffer + // deallocation pass. + assert(!options.createDeallocs && + "cannot create deallocs with AlwaysCopyBufferizationState"); +} /// Return `true` if the given OpResult has been decided to bufferize inplace. bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const { @@ -356,6 +370,13 @@ return false; } +/// Return true if the given tensor (or an aliasing tensor) is yielded from +/// the containing block. Also include all aliasing tensors in the same block. +bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const { + // There is no analysis, so conservatively answer "true". + return true; +} + //===----------------------------------------------------------------------===// // Bufferization-specific scoped alloc/dealloc insertion support. //===----------------------------------------------------------------------===// @@ -426,37 +447,54 @@ } static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type, - ValueRange dynShape) { + ValueRange dynShape, bool skipDealloc) { auto allocaOp = b.create(loc, type, dynShape); allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr()); + if (skipDealloc) + allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr()); return allocaOp.getResult(); } /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the /// block in case of a bbArg). FailureOr BufferizationState::createAlloc(OpBuilder &b, Location loc, - Value shapedValue) { + Value shapedValue, + Optional dealloc) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); + + // Compute allocation memref type. assert(shapedValue.getType().isa()); MemRefType memRefType = shapedValue.getType().dyn_cast(); SmallVector dynShape; MemRefType allocMemRefType = getAllocationTypeAndShape(b, loc, shapedValue, dynShape); - Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape); + + // Should be the buffer be deallocated again or should we let it leak? + bool skipDealloc; + if (dealloc) { + skipDealloc = !dealloc.getValue(); + } else { + assert(shapedValue.getType().isa() && + "must specify `dealloc` if non-tensor value is passed"); + // Buffer should be not be deallocated if deallocs are generally deactivated + // or if the tensor is yielded from a block. + skipDealloc = !getOptions().createDeallocs || + getAnalysisState().isTensorYielded(shapedValue); + } + + // Create the buffer allocation. + Value alloc = + createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc); + + // Insert a cast if a different type was requested. if (memRefType && memRefType != allocMemRefType) { - assert(memref::CastOp::areCastCompatible(alloc.getType(), memRefType) && + assert(memref::CastOp::areCastCompatible(allocMemRefType, memRefType) && "createAlloc: cast incompatible"); alloc = b.create(loc, memRefType, alloc); } - return alloc; -} -/// Create a memref allocation with the given type and dynamic extents. -FailureOr BufferizationState::createAlloc(OpBuilder &b, Location loc, - MemRefType type, - ValueRange dynShape) { - return createBufferAllocation(b, loc, type, dynShape); + return alloc; } /// Create a memory copy between two memref buffers. @@ -480,7 +518,9 @@ // Ignore memref.alloca ops that were not created by the bufferization. if (!allocaOp->hasAttr(kBufferAllocationAttr)) return WalkResult::skip(); + bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr); + // Create alloc. Block *block = allocaOp->getBlock(); rewriter.setInsertionPoint(allocaOp); FailureOr alloc = @@ -490,10 +530,11 @@ return WalkResult::interrupt(); rewriter.replaceOp(allocaOp, *alloc); - // Stop here if deallocations are deactivated. - if (!options.createDeallocs) + // Stop here if the buffer should not be deallocated. + if (skipDealloc) return WalkResult::advance(); + // Create dealloc. rewriter.setInsertionPoint(block->getTerminator()); if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options))) return WalkResult::interrupt(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -379,7 +379,6 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() { BufferizationOptions options; - options.allowReturnMemref = true; options.allowUnknownOps = true; options.createDeallocs = false; options.fullyDynamicLayoutMaps = false; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -215,6 +215,43 @@ return aliasInfo.areEquivalentBufferizedValues(v1, v2); } +// Gather yielded tensors in `yieldedTensors` by querying all aliases. This is +// to ensure that such information is available during bufferization time. +// Alias information can no longer be queried through BufferizationAliasInfo +// once we have started modifying the IR. +void OneShotAnalysisState::gatherYieldedTensors(Operation *op) { + op->walk([&](Operation *returnOp) { + if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp)) + return WalkResult::advance(); + + for (OpOperand &returnValOperand : returnOp->getOpOperands()) { + Value returnVal = returnValOperand.get(); + // Skip non-tensor values. + if (!returnVal.getType().isa()) + continue; + + // Add all aliases of the returned value. But only the ones that are in + // the same block. + aliasInfo.applyOnAliases(returnVal, [&](Value v) { + if (auto bbArg = v.dyn_cast()) { + if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp()) + yieldedTensors.insert(bbArg); + return; + } + Operation *definingOp = v.getDefiningOp(); + if (definingOp->getParentOp() == returnOp->getParentOp()) + yieldedTensors.insert(v); + }); + } + + return WalkResult::advance(); + }); +} + +bool OneShotAnalysisState::isTensorYielded(Value tensor) const { + return yieldedTensors.contains(tensor); +} + //===----------------------------------------------------------------------===// // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// @@ -780,6 +817,9 @@ failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps)); } + // Gather all yielded tensors. + state.gatherYieldedTensors(op); + // Analysis verification: After setting up alias/equivalence sets, each op // can check for expected invariants/limitations and fail the analysis if // necessary. diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -335,9 +335,8 @@ Location loc = op->getLoc(); auto tensorType = fromElementsOp.getType().cast(); auto shape = tensorType.getShape(); - MemRefType resultType = getContiguousMemRefType(tensorType); FailureOr maybeBuffer = - state.createAlloc(rewriter, loc, resultType, {}); + state.createAlloc(rewriter, loc, fromElementsOp.result()); if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; @@ -386,8 +385,8 @@ Location loc = op->getLoc(); MemRefType memrefType = getContiguousMemRefType(generateOp.getType().cast()); - FailureOr maybeResult = state.createAlloc( - rewriter, loc, memrefType, generateOp.dynamicExtents()); + FailureOr maybeResult = + state.createAlloc(rewriter, loc, generateOp.result()); if (failed(maybeResult)) return failure(); Value result = *maybeResult; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -68,4 +68,67 @@ return } +// ----- + +// CHECK-LABEL: func @read_after_write_conflict( +func @read_after_write_conflict(%cst : f32, %idx : index, %idx2 : index) + -> (f32, f32) { + // CHECK-DAG: %[[alloc:.*]] = memref.alloc + // CHECK-DAG: %[[dummy:.*]] = "test.dummy_op" + // CHECK-DAG: %[[dummy_m:.*]] = bufferization.to_memref %[[dummy]] + %t = "test.dummy_op"() : () -> (tensor<10xf32>) + + // CHECK: memref.copy %[[dummy_m]], %[[alloc]] + // CHECK: memref.store %{{.*}}, %[[alloc]] + %write = tensor.insert %cst into %t[%idx2] : tensor<10xf32> + + // CHECK: %[[read:.*]] = "test.some_use"(%[[dummy]]) + %read = "test.some_use"(%t) : (tensor<10xf32>) -> (f32) + // CHECK: %[[read2:.*]] = memref.load %[[alloc]] + %read2 = tensor.extract %write[%idx] : tensor<10xf32> + + // CHECK: memref.dealloc %[[alloc]] + // CHECK: return %[[read]], %[[read2]] + return %read, %read2 : f32, f32 +} + +// ----- + +// CHECK-LABEL: func @copy_deallocated( +func @copy_deallocated() -> tensor<10xf32> { + // CHECK: %[[alloc:.*]] = memref.alloc() + %0 = linalg.init_tensor[10] : tensor<10xf32> + // CHECK: %[[alloc_tensor:.*]] = bufferization.to_tensor %[[alloc]] + // CHECK: memref.dealloc %[[alloc]] + // CHECK: return %[[alloc_tensor]] + return %0 : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: func @buffer_not_deallocated( +// CHECK-SAME: %[[t:.*]]: tensor +func @buffer_not_deallocated(%t : tensor, %c : i1) -> tensor { + // CHECK: %[[r:.*]] = scf.if %{{.*}} { + %r = scf.if %c -> tensor { + // CHECK: %[[some_op:.*]] = "test.some_op" + // CHECK: %[[alloc:.*]] = memref.alloc(%[[some_op]]) + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK-NOT: dealloc + // CHECK: scf.yield %[[casted]] + %sz = "test.some_op"() : () -> (index) + %0 = linalg.init_tensor[%sz] : tensor + scf.yield %0 : tensor + } else { + // CHECK: } else { + // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] + // CHECK: scf.yield %[[m]] + scf.yield %t : tensor + } + // CHECK: } + // CHECK-NOT: dealloc + // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] + // CHECK: return %[[r_tensor]] + return %r : tensor +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -81,7 +81,6 @@ // CHECK: linalg.fill(%[[F0]], %[[ALLOC]]) : f32, memref %r = linalg.fill(%f0, %A) : f32, tensor -> tensor - // CHECK: dealloc %[[ALLOC]] : memref // CHECK: return %[[ALLOC]] : memref return %r: tensor } @@ -292,7 +291,6 @@ // CHECK: memref.copy %[[A]], %[[ALLOC]] : memref // CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref to memref<4xf32> // CHECK: memref.copy %[[t]], %[[SV]] : memref<4xf32, #map> to memref<4xf32> - // CHECK: memref.dealloc %[[ALLOC]] : memref %r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor // CHECK: return %{{.*}} : memref @@ -329,7 +327,6 @@ scf.yield %t : tensor } - // CHECK: memref.dealloc %[[ALLOC_FOR_A]] : memref // CHECK: return %[[CASTED]] : memref return %r0, %r1: tensor, tensor } @@ -395,7 +392,6 @@ scf.yield %ttA, %ttB : tensor, tensor } - // CHECK: memref.dealloc %[[ALLOC_FOR_A]] : memref // CHECK: return %[[CASTED]] : memref return %r0#0, %r0#1: tensor, tensor }