diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -38,9 +38,15 @@ /// Creates a pass that converts memref function results to out-params. std::unique_ptr createBufferResultsToOutParamsPass(); +/// Creates a pass that converts memref function results to out-params. +std::unique_ptr createBufferResultsToOutParamsPass( + const BufferResultsToOutParamsOptions &options); + /// Replace buffers that are returned from a function with an out parameter. /// Also update all call sites. -LogicalResult promoteBufferResultsToOutParams(ModuleOp module); +LogicalResult +promoteBufferResultsToOutParams(ModuleOp module, + const BufferResultsToOutParamsOptions &options); /// Creates a pass that drops memref function results that are equivalent to a /// function argument. diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -129,6 +129,12 @@ buffers for results need to be allocated in the caller. This currently only works for static shaped memrefs. }]; + let options = [ + Option<"restrictToMarkedFunctions", "restrict-to-marked-functions", "bool", + /*default=*/"false", + "Restrict transformation to functions with the " + "'bufferize.results_to_out_params' attribute">, + ]; let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()"; let dependentDialects = ["memref::MemRefDialect"]; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -22,6 +22,9 @@ using namespace mlir; +static constexpr std::string_view kResultsToOutParamsAttr = + "bufferize.results_to_out_params"; + /// Return `true` if the given MemRef type has a fully dynamic layout. static bool hasFullyDynamicLayoutMap(MemRefType type) { int64_t offset; @@ -41,6 +44,19 @@ return type.getLayout().isIdentity(); } +// Returns whether we should skip the transformation for 'func'. +static bool +shouldSkip(func::FuncOp func, + const bufferization::BufferResultsToOutParamsOptions &options) { + // If the option isn't enabled, do not skip anything. + if (!options.restrictToMarkedFunctions) + return false; + + // Otherwise, skip all functions except those marked with the attribute. + auto attr = func->getAttr(kResultsToOutParamsAttr); + return !attr || !attr.isa(); +} + // Updates the func op and entry block. // // Any args appended to the entry block are added to `appendedEntryArgs`. @@ -119,9 +135,22 @@ // Updates all CallOps in the scope of the given ModuleOp by allocating // temporary buffers for newly introduced out params. -static LogicalResult updateCalls(ModuleOp module) { +static LogicalResult +updateCalls(ModuleOp module, + const bufferization::BufferResultsToOutParamsOptions &options) { bool didFail = false; + SymbolTable symtab(module); module.walk([&](func::CallOp op) { + auto callee = symtab.lookup(op.getCallee()); + if (!callee) { + op.emitError() << "cannot find callee '" << op.getCallee() << "' in " + << "symbol table"; + didFail = true; + return; + } + if (shouldSkip(callee, options)) { + return; + } SmallVector replaceWithNewCallResults; SmallVector replaceWithOutParams; for (OpResult result : op.getResults()) { @@ -169,9 +198,13 @@ return failure(didFail); } -LogicalResult -mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) { +LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( + ModuleOp module, + const bufferization::BufferResultsToOutParamsOptions &options) { for (auto func : module.getOps()) { + if (shouldSkip(func, options)) { + continue; + } SmallVector appendedEntryArgs; if (failed(updateFuncOp(func, appendedEntryArgs))) return failure(); @@ -179,7 +212,7 @@ continue; updateReturnOps(func, appendedEntryArgs); } - if (failed(updateCalls(module))) + if (failed(updateCalls(module, options))) return failure(); return success(); } @@ -188,10 +221,27 @@ struct BufferResultsToOutParamsPass : bufferization::impl::BufferResultsToOutParamsBase< BufferResultsToOutParamsPass> { + BufferResultsToOutParamsPass() = default; + + explicit BufferResultsToOutParamsPass( + const bufferization::BufferResultsToOutParamsOptions &options) + : options(options) {} + void runOnOperation() override { - if (failed(bufferization::promoteBufferResultsToOutParams(getOperation()))) + bufferization::BufferResultsToOutParamsOptions opts; + if (options) { + opts = *options; + } else { + // If no options were provided, take the defaults from the base class. + opts.restrictToMarkedFunctions = restrictToMarkedFunctions; + } + if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), + opts))) return signalPassFailure(); } + +private: + llvm::Optional options; }; } // namespace @@ -199,3 +249,8 @@ mlir::bufferization::createBufferResultsToOutParamsPass() { return std::make_unique(); } + +std::unique_ptr mlir::bufferization::createBufferResultsToOutParamsPass( + const bufferization::BufferResultsToOutParamsOptions &options) { + return std::make_unique(options); +} diff --git a/mlir/test/Transforms/buffer-results-to-out-params.mlir b/mlir/test/Transforms/buffer-results-to-out-params.mlir --- a/mlir/test/Transforms/buffer-results-to-out-params.mlir +++ b/mlir/test/Transforms/buffer-results-to-out-params.mlir @@ -1,4 +1,8 @@ // RUN: mlir-opt -buffer-results-to-out-params -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt \ +// RUN: -buffer-results-to-out-params=restrict-to-marked-functions=true \ +// RUN: -split-input-file -verify-diagnostics %s | \ +// RUN: FileCheck %s --check-prefix=CHECK-RESTRICT // CHECK-LABEL: func @basic( // CHECK-SAME: %[[ARG:.*]]: memref) { @@ -6,6 +10,12 @@ // CHECK: memref.copy %[[RESULT]], %[[ARG]] : memref to memref // CHECK: return // CHECK: } + +// CHECK-RESTRICT-LABEL: func @basic( +// CHECK-RESTRICT-SAME: ) -> memref { +// CHECK-RESTRICT: %[[RESULT:.*]] = "test.source"() : () -> memref +// CHECK-RESTRICT: return %[[RESULT]] : memref +// CHECK-RESTRICT: } func.func @basic() -> (memref) { %0 = "test.source"() : () -> (memref) return %0 : memref @@ -48,15 +58,24 @@ } // CHECK: func private @external_function(memref) +// CHECK-RESTRICT: func private @external_function() -> memref func.func private @external_function() -> (memref) + +// CHECK: func private @external_function2(memref) +// CHECK-RESTRICT: func private @external_function2(memref) +func.func private @external_function2() -> + (memref) attributes { bufferize.results_to_out_params } + // CHECK: func private @result_attrs(memref {test.some_attr}) func.func private @result_attrs() -> (memref {test.some_attr}) + // CHECK: func private @mixed_result_attrs(memref<1xf32>, memref<2xf32> {test.some_attr}, memref<3xf32>) func.func private @mixed_result_attrs() -> (memref<1xf32>, memref<2xf32> {test.some_attr}, memref<3xf32>) // ----- // CHECK-LABEL: func private @callee(memref<1xf32>) +// CHECK-RESTRICT-LABEL: func private @callee() -> memref<1xf32> func.func private @callee() -> memref<1xf32> // CHECK-LABEL: func @call_basic() { @@ -65,6 +84,40 @@ // CHECK: "test.sink"(%[[OUTPARAM]]) : (memref<1xf32>) -> () // CHECK: return // CHECK: } + +// CHECK-RESTRICT-LABEL: func @call_basic() { +// CHECK-RESTRICT: %[[RESULT:.*]] = call @callee() : () -> memref<1xf32> +// CHECK-RESTRICT: "test.sink"(%[[RESULT]]) : (memref<1xf32>) -> () +// CHECK-RESTRICT: return +// CHECK-RESTRICT: } +func.func @call_basic() { + %0 = call @callee() : () -> memref<1xf32> + "test.sink"(%0) : (memref<1xf32>) -> () + return +} + +// ----- + +// Same as @call_basic above, except for @callee's attributes. + +// CHECK-LABEL: func private @callee(memref<1xf32>) +// CHECK-RESTRICT-LABEL: func private @callee(memref<1xf32>) +func.func private @callee() -> + memref<1xf32> attributes { bufferize.results_to_out_params } + +// CHECK-LABEL: func @call_basic() { +// CHECK: %[[OUTPARAM:.*]] = memref.alloc() : memref<1xf32> +// CHECK: call @callee(%[[OUTPARAM]]) : (memref<1xf32>) -> () +// CHECK: "test.sink"(%[[OUTPARAM]]) : (memref<1xf32>) -> () +// CHECK: return +// CHECK: } + +// CHECK-RESTRICT-LABEL: func @call_basic() { +// CHECK-RESTRICT: %[[OUTPARAM:.*]] = memref.alloc() : memref<1xf32> +// CHECK-RESTRICT: call @callee(%[[OUTPARAM]]) : (memref<1xf32>) -> () +// CHECK-RESTRICT: "test.sink"(%[[OUTPARAM]]) : (memref<1xf32>) -> () +// CHECK-RESTRICT: return +// CHECK-RESTRICT: } func.func @call_basic() { %0 = call @callee() : () -> memref<1xf32> "test.sink"(%0) : (memref<1xf32>) -> () @@ -104,7 +157,8 @@ // ----- -func.func private @callee() -> (memref) +func.func private @callee() -> + (memref) attributes { bufferize.results_to_out_params } func.func @call_non_memref_result() { // expected-error @+1 {{cannot create out param for dynamically shaped result}}