diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -203,6 +203,10 @@ /// For debugging only. Should be used together with `testAnalysisOnly`. bool printConflicts = false; + /// If set to `true`, buffers that are returned from functions are replaced + /// with buffer "out" parameters. At the call site, new buffers are allocated. + bool promoteBufferResultsToOutParams = false; + /// If set to `true`, an `getAliasingOpResult` will return the corresponding /// "out"/"dest" OpOperand for every op that has the notion of an "out"/"dest" /// operand. I.e., the aliasing OpOperand of the i-th tensor OpResult is diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -4,6 +4,8 @@ #include "mlir/Pass/Pass.h" namespace mlir { +class ModuleOp; + namespace func { class FuncOp; } // namespace func @@ -33,6 +35,10 @@ /// Creates a pass that converts memref function results to out-params. std::unique_ptr createBufferResultsToOutParamsPass(); +/// Replace buffers that are returned from a function with an out parameter. +/// Also update all call sites. +LogicalResult promoteBufferResultsToOutParams(ModuleOp module); + /// Creates a pass that finalizes a partial bufferization by removing remaining /// bufferization.to_tensor and bufferization.to_memref operations. std::unique_ptr> createFinalizingBufferizePass(); diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -264,6 +264,9 @@ /*default=*/"false", "Test only: Annotate IR with RaW conflicts. Requires " "test-analysis-only.">, + Option<"promoteBufferResultsToOutParams", + "promote-buffer-results-to-out-params", "bool", /*default=*/"false", + "Replace returned buffers (that were not dropped) with out params.">, ]; let constructor = "mlir::bufferization::createOneShotBufferizePass()"; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -122,20 +122,25 @@ return failure(didFail); } +LogicalResult +mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) { + for (auto func : module.getOps()) { + SmallVector appendedEntryArgs; + updateFuncOp(func, appendedEntryArgs); + if (func.isExternal()) + continue; + updateReturnOps(func, appendedEntryArgs); + } + if (failed(updateCalls(module))) + return failure(); + return success(); +} + namespace { struct BufferResultsToOutParamsPass : BufferResultsToOutParamsBase { void runOnOperation() override { - ModuleOp module = getOperation(); - - for (auto func : module.getOps()) { - SmallVector appendedEntryArgs; - updateFuncOp(func, appendedEntryArgs); - if (func.isExternal()) - continue; - updateReturnOps(func, appendedEntryArgs); - } - if (failed(updateCalls(module))) + if (failed(bufferization::promoteBufferResultsToOutParams(getOperation()))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -179,6 +179,7 @@ opt.printConflicts = printConflicts; opt.testAnalysisOnly = testAnalysisOnly; opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; + opt.promoteBufferResultsToOutParams = promoteBufferResultsToOutParams; BufferizationOptions::OpFilterEntry::FilterFn filterFn = [&](Operation *op) { @@ -263,17 +264,29 @@ if (failed(hoistBufferAllocations(op, options))) return failure(); - // Deallocate buffers that escape block boundaries ("leaking buffers") with - // the buffer deallocation pass. - bool hasLeakingAlloc = false; + // Create allocation ops for "leaking buffers", i.e., buffer allocations that + // escape block boundaries. If there are no leaking allocs, `hasLeakingAllocs` + // is set to `false`. + bool hasLeakingAllocs = false; if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true, - &hasLeakingAlloc))) + &hasLeakingAllocs))) return failure(); - if (options.createDeallocs && hasLeakingAlloc && + + // Promote returned buffers to "out" parameters. + // TODO: Pass options to support custom dealloc ops. + if (options.promoteBufferResultsToOutParams && hasLeakingAllocs && + isa(op) && + failed(promoteBufferResultsToOutParams(cast(op)))) + return failure(); + + // Create deallocation ops for all "leaking buffers" and all buffer + // allocations that were added during the above promotion process. + // TODO: Pass options to support custom dealloc ops. + if (options.createDeallocs && hasLeakingAllocs && failed(deallocateBuffers(op))) return failure(); - // Deallocate all remaining buffers at the end of the block. + // Deallocate all remaining buffers at the end of their parent blocks. if (failed(createAllocDeallocOps(op, options))) return failure(); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs promote-buffer-results-to-out-params" -split-input-file | FileCheck %s + +// Note: This bufferization is not very efficient yet, but it works. + +// CHECK: #[[$map1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// CHECK-LABEL: func @callee( +// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, #[[$map1]]>, +// CHECK-SAME: %[[arg1:.*]]: memref<5xf32>) { +// CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32> +// CHECK: memref.copy %[[arg0]], %[[alloc]] +// CHECK: memref.store %{{.*}}, %[[alloc]] +// CHECK: memref.copy %[[alloc]], %[[arg1]] +// CHECK: memref.dealloc %[[alloc]] +// CHECK: return +// CHECK: } +func.func @callee(%t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 8.0 : f32 + %1 = tensor.insert %cst into %t[%c0] : tensor<5xf32> + return %t, %1 : tensor<5xf32>, tensor<5xf32> +} + +// CHECK: func @main(%[[arg0:.*]]: memref<5xf32, #[[$map1]]>) -> (f32, f32) { +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<5xf32> +// CHECK: call @callee(%[[arg0]], %[[alloc]]) +// CHECK: %[[l1:.*]] = memref.load %[[arg0]] +// CHECK: %[[l2:.*]] = memref.load %[[alloc]] +// CHECK: memref.dealloc %[[alloc]] +// CHECK: return %[[l1]], %[[l2]] +// CHECK: } +func.func @main(%t: tensor<5xf32>) -> (f32, f32) { + %c0 = arith.constant 0 : index + %0, %1 = func.call @callee(%t) + : (tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) + %2 = tensor.extract %0[%c0] : tensor<5xf32> + %3 = tensor.extract %1[%c0] : tensor<5xf32> + return %2, %3 : f32, f32 +} +