diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -35,12 +35,25 @@ /// reallocations inside of loops. std::unique_ptr createBufferLoopHoistingPass(); +// Options struct for BufferResultsToOutParams pass. +// Note: defined only here, not in tablegen. +struct BufferResultsToOutParamsOptions { + // Filter function; returns true if the function should be converted. + // Defaults to true, i.e. all functions are converted. + llvm::function_ref filterFn = [](func::FuncOp *func) { + return true; + }; +}; + /// Creates a pass that converts memref function results to out-params. -std::unique_ptr createBufferResultsToOutParamsPass(); +std::unique_ptr createBufferResultsToOutParamsPass( + const BufferResultsToOutParamsOptions &options = {}); /// Replace buffers that are returned from a function with an out parameter. /// Also update all call sites. -LogicalResult promoteBufferResultsToOutParams(ModuleOp module); +LogicalResult +promoteBufferResultsToOutParams(ModuleOp module, + const BufferResultsToOutParamsOptions &options); /// Creates a pass that drops memref function results that are equivalent to a /// function argument. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -119,9 +119,22 @@ // Updates all CallOps in the scope of the given ModuleOp by allocating // temporary buffers for newly introduced out params. -static LogicalResult updateCalls(ModuleOp module) { +static LogicalResult +updateCalls(ModuleOp module, + const bufferization::BufferResultsToOutParamsOptions &options) { bool didFail = false; + SymbolTable symtab(module); module.walk([&](func::CallOp op) { + auto callee = symtab.lookup(op.getCallee()); + if (!callee) { + op.emitError() << "cannot find callee '" << op.getCallee() << "' in " + << "symbol table"; + didFail = true; + return; + } + if (!options.filterFn(&callee)) { + return; + } SmallVector replaceWithNewCallResults; SmallVector replaceWithOutParams; for (OpResult result : op.getResults()) { @@ -169,9 +182,13 @@ return failure(didFail); } -LogicalResult -mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) { +LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( + ModuleOp module, + const bufferization::BufferResultsToOutParamsOptions &options) { for (auto func : module.getOps()) { + if (!options.filterFn(&func)) { + continue; + } SmallVector appendedEntryArgs; if (failed(updateFuncOp(func, appendedEntryArgs))) return failure(); @@ -179,7 +196,7 @@ continue; updateReturnOps(func, appendedEntryArgs); } - if (failed(updateCalls(module))) + if (failed(updateCalls(module, options))) return failure(); return success(); } @@ -188,14 +205,22 @@ struct BufferResultsToOutParamsPass : bufferization::impl::BufferResultsToOutParamsBase< BufferResultsToOutParamsPass> { + explicit BufferResultsToOutParamsPass( + const bufferization::BufferResultsToOutParamsOptions &options) + : options(options) {} + void runOnOperation() override { - if (failed(bufferization::promoteBufferResultsToOutParams(getOperation()))) + if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), + options))) return signalPassFailure(); } + +private: + bufferization::BufferResultsToOutParamsOptions options; }; } // namespace -std::unique_ptr -mlir::bufferization::createBufferResultsToOutParamsPass() { - return std::make_unique(); +std::unique_ptr mlir::bufferization::createBufferResultsToOutParamsPass( + const bufferization::BufferResultsToOutParamsOptions &options) { + return std::make_unique(options); }