diff --git a/mlir/include/mlir/ExecutionEngine/JitRunner.h b/mlir/include/mlir/ExecutionEngine/JitRunner.h --- a/mlir/include/mlir/ExecutionEngine/JitRunner.h +++ b/mlir/include/mlir/ExecutionEngine/JitRunner.h @@ -18,29 +18,42 @@ #ifndef MLIR_SUPPORT_JITRUNNER_H_ #define MLIR_SUPPORT_JITRUNNER_H_ -#include "mlir/IR/Module.h" - #include "llvm/ADT/STLExtras.h" -#include "llvm/IR/Module.h" +#include "llvm/ExecutionEngine/Orc/Core.h" -namespace mlir { +namespace llvm { +class Module; +class LLVMContext; -using TranslationCallback = llvm::function_ref( - ModuleOp, llvm::LLVMContext &)>; +namespace orc { +class MangleAndInterner; +} // namespace orc +} // namespace llvm + +namespace mlir { class ModuleOp; struct LogicalResult; +struct JitRunnerConfig { + /// MLIR transformer applied after parsing the input into MLIR IR and before + /// passing the MLIR module to the ExecutionEngine. + llvm::function_ref mlirTransformer = nullptr; + + /// A custom function that is passed to ExecutionEngine. It processes MLIR + /// module and creates LLVM IR module. + llvm::function_ref(ModuleOp, + llvm::LLVMContext &)> + llvmModuleBuilder = nullptr; + + /// A callback to register symbols with ExecutionEngine at runtime. + llvm::function_ref + runtimesymbolMap = nullptr; +}; + // Entry point for all CPU runners. Expects the common argc/argv arguments for -// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`. -/// `mlirTransformer` is applied after parsing the input into MLIR IR and before -/// passing the MLIR module to the ExecutionEngine. -/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine. -/// It processes MLIR module and creates LLVM IR module. -int JitRunnerMain( - int argc, char **argv, - llvm::function_ref mlirTransformer, - TranslationCallback llvmModuleBuilder = nullptr); +// standard C++ main functions. +int JitRunnerMain(int argc, char **argv, JitRunnerConfig config = {}); } // namespace mlir diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -92,6 +92,23 @@ "object-filename", llvm::cl::desc("Dump JITted-compiled object to file .o")}; }; + +struct CompileAndExecuteConfig { + /// LLVM module transformer that is passed to ExecutionEngine. + llvm::function_ref transformer; + + /// A custom function that is passed to ExecutionEngine. It processes MLIR + /// module and creates LLVM IR module. + llvm::function_ref(ModuleOp, + llvm::LLVMContext &)> + llvmModuleBuilder; + + /// A custom function that is passed to ExecutinEngine to register symbols at + /// runtime. + llvm::function_ref + runtimeSymbolMap; +}; + } // end anonymous namespace static OwningModuleRef parseMLIRInput(StringRef inputFilename, @@ -131,11 +148,9 @@ } // JIT-compile the given module and run "entryPoint" with "args" as arguments. -static Error -compileAndExecute(Options &options, ModuleOp module, - TranslationCallback llvmModuleBuilder, StringRef entryPoint, - std::function transformer, - void **args) { +static Error compileAndExecute(Options &options, ModuleOp module, + StringRef entryPoint, + CompileAndExecuteConfig config, void **args) { Optional jitCodeGenOptLevel; if (auto clOptLevel = getCommandLineOptLevel(options)) jitCodeGenOptLevel = @@ -143,11 +158,15 @@ SmallVector libs(options.clSharedLibs.begin(), options.clSharedLibs.end()); auto expectedEngine = mlir::ExecutionEngine::create( - module, llvmModuleBuilder, transformer, jitCodeGenOptLevel, libs); + module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel, + libs); if (!expectedEngine) return expectedEngine.takeError(); auto engine = std::move(*expectedEngine); + if (config.runtimeSymbolMap) + engine->registerSymbols(config.runtimeSymbolMap); + auto expectedFPtr = engine->lookup(entryPoint); if (!expectedFPtr) return expectedFPtr.takeError(); @@ -163,16 +182,14 @@ return Error::success(); } -static Error compileAndExecuteVoidFunction( - Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder, - StringRef entryPoint, - std::function transformer) { +static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module, + StringRef entryPoint, + CompileAndExecuteConfig config) { auto mainFunction = module.lookupSymbol(entryPoint); if (!mainFunction || mainFunction.empty()) return make_string_error("entry point not found"); void *empty = nullptr; - return compileAndExecute(options, module, llvmModuleBuilder, entryPoint, - transformer, &empty); + return compileAndExecute(options, module, entryPoint, config, &empty); } template @@ -196,10 +213,9 @@ return Error::success(); } template -Error compileAndExecuteSingleReturnFunction( - Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder, - StringRef entryPoint, - std::function transformer) { +Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module, + StringRef entryPoint, + CompileAndExecuteConfig config) { auto mainFunction = module.lookupSymbol(entryPoint); if (!mainFunction || mainFunction.isExternal()) return make_string_error("entry point not found"); @@ -215,8 +231,8 @@ void *data; } data; data.data = &res; - if (auto error = compileAndExecute(options, module, llvmModuleBuilder, - entryPoint, transformer, (void **)&data)) + if (auto error = compileAndExecute(options, module, entryPoint, config, + (void **)&data)) return error; // Intentional printing of the output so we can test. @@ -226,15 +242,8 @@ } /// Entry point for all CPU runners. Expects the common argc/argv arguments for -/// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`. -/// `mlirTransformer` is applied after parsing the input into MLIR IR and before -/// passing the MLIR module to the ExecutionEngine. -/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine. -/// It processes MLIR module and creates LLVM IR module. -int mlir::JitRunnerMain( - int argc, char **argv, - function_ref mlirTransformer, - TranslationCallback llvmModuleBuilder) { +/// standard C++ main functions. +int mlir::JitRunnerMain(int argc, char **argv, JitRunnerConfig config) { // Create the options struct containing the command line options for the // runner. This must come before the command line options are parsed. Options options; @@ -274,8 +283,8 @@ return 1; } - if (mlirTransformer) - if (failed(mlirTransformer(m.get()))) + if (config.mlirTransformer) + if (failed(config.mlirTransformer(m.get()))) return EXIT_FAILURE; auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); @@ -292,10 +301,14 @@ auto transformer = mlir::makeLLVMPassesTransformer( passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition); + CompileAndExecuteConfig compileAndExecuteConfig; + compileAndExecuteConfig.transformer = transformer; + compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder; + compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap; + // Get the function used to compile and execute the module. using CompileAndExecuteFnT = - Error (*)(Options &, ModuleOp, TranslationCallback, StringRef, - std::function); + Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig); auto compileAndExecuteFn = StringSwitch(options.mainFuncType.getValue()) .Case("i32", compileAndExecuteSingleReturnFunction) @@ -304,11 +317,11 @@ .Case("void", compileAndExecuteVoidFunction) .Default(nullptr); - Error error = - compileAndExecuteFn - ? compileAndExecuteFn(options, m.get(), llvmModuleBuilder, - options.mainFuncName.getValue(), transformer) - : make_string_error("unsupported function type"); + Error error = compileAndExecuteFn + ? compileAndExecuteFn(options, m.get(), + options.mainFuncName.getValue(), + compileAndExecuteConfig) + : make_string_error("unsupported function type"); int exitCode = EXIT_SUCCESS; llvm::handleAllErrors(std::move(error), diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp --- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp +++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp @@ -24,5 +24,5 @@ llvm::InitializeNativeTargetAsmPrinter(); mlir::initializeLLVMPasses(); - return mlir::JitRunnerMain(argc, argv, nullptr); + return mlir::JitRunnerMain(argc, argv); } diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -136,5 +136,9 @@ LLVMInitializeNVPTXAsmPrinter(); mlir::initializeLLVMPasses(); - return mlir::JitRunnerMain(argc, argv, &runMLIRPasses); + + mlir::JitRunnerConfig jitRunnerConfig; + jitRunnerConfig.mlirTransformer = &runMLIRPasses; + + return mlir::JitRunnerMain(argc, argv, jitRunnerConfig); } diff --git a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp --- a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp +++ b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp @@ -86,5 +86,9 @@ llvm::InitializeNativeTargetAsmPrinter(); mlir::initializeLLVMPasses(); - return mlir::JitRunnerMain(argc, argv, &runMLIRPasses, &convertMLIRModule); + mlir::JitRunnerConfig jitRunnerConfig; + jitRunnerConfig.mlirTransformer = &runMLIRPasses; + jitRunnerConfig.llvmModuleBuilder = &convertMLIRModule; + + return mlir::JitRunnerMain(argc, argv, jitRunnerConfig); } diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -58,5 +58,8 @@ llvm::InitializeNativeTargetAsmPrinter(); mlir::initializeLLVMPasses(); - return mlir::JitRunnerMain(argc, argv, &runMLIRPasses); + mlir::JitRunnerConfig jitRunnerConfig; + jitRunnerConfig.mlirTransformer = &runMLIRPasses; + + return mlir::JitRunnerMain(argc, argv, jitRunnerConfig); }