diff --git a/mlir/include/mlir/ExecutionEngine/JitRunner.h b/mlir/include/mlir/ExecutionEngine/JitRunner.h --- a/mlir/include/mlir/ExecutionEngine/JitRunner.h +++ b/mlir/include/mlir/ExecutionEngine/JitRunner.h @@ -36,11 +36,21 @@ class Operation; struct LogicalResult; +/// JitRunner command line options used by JitRunnerConfig methods +struct JitRunnerOptions { + /// The name of the main function + llvm::StringRef mainFuncName; + /// The type of the main function (as string, from cmd-line) + llvm::StringRef mainFuncType; +}; + +/// Configuration to override functionality of the JitRunner struct JitRunnerConfig { /// MLIR transformer applied after parsing the input into MLIR IR and before /// passing the MLIR IR to the ExecutionEngine. - llvm::function_ref mlirTransformer = - nullptr; + llvm::function_ref + mlirTransformer = nullptr; /// A custom function that is passed to ExecutionEngine. It processes MLIR and /// creates an LLVM IR module. diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -364,8 +364,9 @@ return 1; } + JitRunnerOptions runnerOptions{options.mainFuncName, options.mainFuncType}; if (config.mlirTransformer) - if (failed(config.mlirTransformer(m.get()))) + if (failed(config.mlirTransformer(m.get(), runnerOptions))) return EXIT_FAILURE; auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); diff --git a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp --- a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp +++ b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp @@ -74,7 +74,8 @@ return mainModule; } -static LogicalResult runMLIRPasses(Operation *module) { +static LogicalResult runMLIRPasses(Operation *module, + JitRunnerOptions &options) { PassManager passManager(module->getContext(), module->getName().getStringRef()); applyPassManagerCLOptions(passManager); diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -41,7 +41,7 @@ using namespace mlir; -static LogicalResult runMLIRPasses(Operation *op) { +static LogicalResult runMLIRPasses(Operation *op, JitRunnerOptions &options) { auto module = dyn_cast(op); if (!module) return op->emitOpError("expected a 'builtin.module' op");