diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -516,11 +516,16 @@ LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to, const BufferizationOptions &options); -/// Finalize all buffer allocations. -/// * Hoist buffer allocations as much as possible. -/// * Create alloc/dealloc ops as specified by the bufferization options. -LogicalResult finalizeBuffers(Operation *op, - const BufferizationOptions &options); +/// Try to hoist all new buffer allocations until the next hoisting barrier. +LogicalResult hoistBufferAllocations(Operation *op, + const BufferizationOptions &options); + +/// Create alloc/dealloc ops as specified in the bufferization options. If +/// `onlyLeakingAlloc`, only those buffer allocations are processed for which no +/// buffer deallocation can be created. +LogicalResult createAllocDeallocOps(Operation *op, + const BufferizationOptions &options, + bool onlyLeakingAllocs = false); } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -90,6 +90,12 @@ /// Note: This function overload is useful for extending the bufferization. LogicalResult bufferizeOp(Operation *op, BufferizationState &bufferizationState); + +/// Finalize all buffer allocations. +/// * Hoist buffer allocations as much as possible. +/// * Create alloc/dealloc ops as specified by the bufferization options. +LogicalResult finalizeBuffers(Operation *op, + const BufferizationOptions &options); } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -15,6 +15,9 @@ /// buffers. std::unique_ptr createBufferDeallocationPass(); +/// Run buffer deallocation. +LogicalResult deallocateBuffers(Operation *op); + /// Creates a pass that moves allocations upwards to reduce the number of /// required copies that are inserted during the BufferDeallocation pass. std::unique_ptr createBufferHoistingPass(); @@ -55,6 +58,9 @@ // Registration //===----------------------------------------------------------------------===// +/// Register external models for AllocationOpInterface. +void registerAllocationOpInterfaceExternalModels(DialectRegistry ®istry); + /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -170,11 +170,11 @@ example, `tensor.generate` is not in destination-passing style and always results in a new buffer allocation. - One-Shot Bufferize deallocates all buffers that it allocates. Yielding newly - allocated buffers from a block is not supported yet and such IR will be - rejected. For testing purposes and compatibility with partial bufferization, - One-Shot Bufferize can be run with `allow-return-allocs=1 create-dealloc=0` - to allow such IR. + One-Shot Bufferize deallocates all buffers that it allocates. Returning or + yielding newly allocated buffers from a block can lead to bad performance + because additional buffer copies would be inserted. By default, such IR is + rejected by One-Shot Bufferize. If performance is not important, such IR can + be allowed with `allow-return-allocs=1`. One-Shot Bufferize will by default reject IR that contains non-bufferizable op, i.e., ops that do not implemement BufferizableOpInterface. Such IR can @@ -204,7 +204,7 @@ let options = [ Option<"allowReturnAllocs", "allow-return-allocs", "bool", /*default=*/"false", - "Allows the return of new allocations (for testing purposes only)">, + "Allows returning/yielding new allocations from a block.">, Option<"allowUnknownOps", "allow-unknown-ops", "bool", /*default=*/"false", "Allows unknown (not bufferizable) ops in the input IR.">, diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -45,7 +45,7 @@ "Annotates IR with RaW conflicts. Requires test-analysis-only.">, Option<"allowReturnAllocs", "allow-return-allocs", "bool", /*default=*/"false", - "Allows the return of new allocations (for testing purposes only)">, + "Allows returning/yielding new allocations from a block.">, Option<"allowUnknownOps", "allow-unknown-ops", "bool", /*default=*/"false", "Allows unknown (not bufferizable) ops in the input IR.">, diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -508,8 +508,10 @@ return success(); } -static LogicalResult -createAllocDeallocOps(Operation *op, const BufferizationOptions &options) { +LogicalResult +bufferization::createAllocDeallocOps(Operation *op, + const BufferizationOptions &options, + bool onlyLeakingAllocs) { IRRewriter rewriter(op->getContext()); // Bufferization creates memref.alloca ops. After bufferization, these must be @@ -518,7 +520,11 @@ // Ignore memref.alloca ops that were not created by the bufferization. if (!allocaOp->hasAttr(kBufferAllocationAttr)) return WalkResult::skip(); + // If `onlyLeakingAllocs`, process only ops that are marked as + // "skip dealloc". bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr); + if (onlyLeakingAllocs && !skipDealloc) + return WalkResult::skip(); // Create alloc. Block *block = allocaOp->getBlock(); @@ -547,8 +553,9 @@ /// Try to hoist all new buffer allocations until the next hoisting barrier. // TODO: Consolidate this function with the existing buffer hoisting pass. -static LogicalResult -hoistBufferAllocations(Operation *op, const BufferizationOptions &options) { +LogicalResult +bufferization::hoistBufferAllocations(Operation *op, + const BufferizationOptions &options) { // Nothing to do if allocation hoisting is deactivated. if (!options.hoistAllocations) return success(); @@ -601,17 +608,6 @@ return success(); } -LogicalResult -bufferization::finalizeBuffers(Operation *op, - const BufferizationOptions &options) { - if (failed(hoistBufferAllocations(op, options))) - return failure(); - if (failed(createAllocDeallocOps(op, options))) - return failure(); - - return success(); -} - //===----------------------------------------------------------------------===// // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -77,13 +77,15 @@ return success(); } -/// Checks if all operations in a given region that have at least one attached -/// region implement the RegionBranchOpInterface. This is not required in edge -/// cases, where we have a single attached region and the parent operation has -/// no results. -static bool validateSupportedControlFlow(Region ®ion) { - bool success = true; - region.walk([&success](Operation *operation) { +/// Checks if all operations that have at least one attached region implement +/// the RegionBranchOpInterface. This is not required in edge cases, where we +/// have a single attached region and the parent operation has no results. +static bool validateSupportedControlFlow(Operation *op) { + WalkResult result = op->walk([&](Operation *operation) { + // Only check ops that are inside a function. + if (!operation->getParentOfType()) + return WalkResult::advance(); + auto regions = operation->getRegions(); // Walk over all operations in a region and check if the operation has at // least one region and implements the RegionBranchOpInterface. If there @@ -96,10 +98,11 @@ !dyn_cast(operation)) { operation->emitError("All operations with attached regions need to " "implement the RegionBranchOpInterface."); - success = false; } + + return WalkResult::advance(); }); - return success; + return !result.wasSkipped(); } namespace { @@ -639,7 +642,7 @@ void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); - registry.addOpInterface(); + registerAllocationOpInterfaceExternalModels(registry); } void runOnOperation() override { @@ -647,32 +650,53 @@ if (func.isExternal()) return; - // Ensure that there are supported loops only. - Backedges backedges(func); - if (backedges.size()) { - func.emitError("Only structured control-flow loops are supported."); - return signalPassFailure(); - } - - // Check that the control flow structures are supported. - if (!validateSupportedControlFlow(func.getRegion())) - return signalPassFailure(); + if (failed(deallocateBuffers(func))) + signalPassFailure(); + } +}; - // Gather all required allocation nodes and prepare the deallocation phase. - BufferDeallocation deallocation(func); +} // namespace - // Check for supported AllocationOpInterface implementations and prepare the - // internal deallocation pass. - if (failed(deallocation.prepare())) - return signalPassFailure(); +LogicalResult bufferization::deallocateBuffers(Operation *op) { + if (isa(op)) { + WalkResult result = op->walk([&](FuncOp funcOp) { + if (failed(deallocateBuffers(funcOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return success(!result.wasInterrupted()); + } - // Place all required temporary clone and dealloc nodes. - if (failed(deallocation.deallocate())) - return signalPassFailure(); + // Ensure that there are supported loops only. + Backedges backedges(op); + if (backedges.size()) { + op->emitError("Only structured control-flow loops are supported."); + return failure(); } -}; -} // namespace + // Check that the control flow structures are supported. + if (!validateSupportedControlFlow(op)) + return failure(); + + // Gather all required allocation nodes and prepare the deallocation phase. + BufferDeallocation deallocation(op); + + // Check for supported AllocationOpInterface implementations and prepare the + // internal deallocation pass. + if (failed(deallocation.prepare())) + return failure(); + + // Place all required temporary clone and dealloc nodes. + if (failed(deallocation.deallocate())) + return failure(); + + return success(); +} + +void bufferization::registerAllocationOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addOpInterface(); +} //===----------------------------------------------------------------------===// // BufferDeallocationPass construction diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -157,7 +157,9 @@ : options(options) {} void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); + registerAllocationOpInterfaceExternalModels(registry); } void runOnOperation() override { @@ -299,6 +301,21 @@ return success(); } +LogicalResult +bufferization::finalizeBuffers(Operation *op, + const BufferizationOptions &options) { + if (failed(hoistBufferAllocations(op, options))) + return failure(); + if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true))) + return failure(); + if (options.createDeallocs && failed(deallocateBuffers(op))) + return failure(); + if (failed(createAllocDeallocOps(op, options))) + return failure(); + + return success(); +} + LogicalResult bufferization::bufferizeOp(Operation *op, const AnalysisState &analysisState) { BufferizationState bufferizationState(analysisState); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -51,6 +52,7 @@ vector::VectorDialect, scf::SCFDialect, arith::ArithmeticDialect, func::FuncDialect, AffineDialect>(); arith::registerBufferizableOpInterfaceExternalModels(registry); + bufferization::registerAllocationOpInterfaceExternalModels(registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerModuleBufferizationExternalModels(registry); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir @@ -21,12 +21,13 @@ } else { // CHECK: } else { // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] - // CHECK: scf.yield %[[m]] + // CHECK: %[[cloned:.*]] = bufferization.clone %[[m]] + // CHECK: scf.yield %[[cloned]] scf.yield %t : tensor } // CHECK: } - // CHECK-NOT: dealloc // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] + // CHECK: memref.dealloc %[[r]] // CHECK: return %[[r_tensor]] return %r : tensor } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -81,7 +81,8 @@ // CHECK: linalg.fill ins(%[[F0]] : f32) outs(%[[ALLOC]] : memref) %r = linalg.fill ins(%f0 : f32) outs(%A : tensor) -> tensor - // CHECK: return %[[ALLOC]] : memref + // CHECK-NOT: dealloc + // CHECK: return %[[ALLOC]] : memref return %r: tensor } @@ -110,6 +111,7 @@ outs(%A: tensor) -> tensor + // CHECK: memref.dealloc %[[ALLOC]] // CHECK: return // CHECK-NOT: tensor return %r: tensor @@ -124,6 +126,7 @@ %r = linalg.matmul ins(%A, %A: tensor, tensor) outs(%A: tensor) -> tensor + // CHECK-NOT: dealloc return %r: tensor } // ----- @@ -328,6 +331,7 @@ } // CHECK: return %[[CASTED]] : memref + // CHECK-NOT: dealloc return %r0, %r1: tensor, tensor } @@ -419,6 +423,7 @@ // CHECK: call @some_external_func(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> () call @some_external_func(%A) : (tensor<4xi32>) -> () +// CHECK: memref.dealloc %[[alloc]] return } @@ -444,6 +449,7 @@ scf.yield } +// CHECK: memref.dealloc %[[alloc]] return } @@ -1192,3 +1198,23 @@ } return %5: tensor } + +// ----- + +// Note: This bufferization is inefficient, but it bufferizes correctly. + +// CHECK-LABEL: func @scf_execute_region_yield_non_equivalent( +// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}) +// CHECK: %[[clone:.*]] = bufferization.clone %[[alloc]] +// CHECK: memref.dealloc %[[alloc]] +// CHECK: %[[r:.*]] = memref.load %[[clone]][%{{.*}}] +// CHECK: memref.dealloc %[[clone]] +// CHECK: return %[[r]] +func @scf_execute_region_yield_non_equivalent(%i: index, %j: index) -> f32 { + %r = scf.execute_region -> (tensor) { + %t2 = linalg.init_tensor [%i] : tensor + scf.yield %t2 : tensor + } + %f = tensor.extract %r[%j] : tensor + return %f : f32 +}