diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -24,6 +24,8 @@ #include #include +#include "llvm/ADT/StringMap.h" + using namespace mlir::runtime; //===----------------------------------------------------------------------===// @@ -109,9 +111,17 @@ } // namespace // Returns the default per-process instance of an async runtime. -static AsyncRuntime *getDefaultAsyncRuntimeInstance() { +static std::unique_ptr &getDefaultAsyncRuntimeInstance() { static auto runtime = std::make_unique(); - return runtime.get(); + return runtime; +} + +static void resetDefaultAsyncRuntime() { + return getDefaultAsyncRuntimeInstance().reset(); +} + +static AsyncRuntime *getDefaultAsyncRuntime() { + return getDefaultAsyncRuntimeInstance().get(); } // Async token provides a mechanism to signal asynchronous operation completion. @@ -184,19 +194,19 @@ // Creates a new `async.token` in not-ready state. extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { - AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance()); + AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime()); return token; } // Creates a new `async.value` in not-ready state. extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) { - AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size); + AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size); return value; } // Create a new `async.group` in empty state. extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { - AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); + AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime()); return group; } @@ -342,4 +352,41 @@ std::cout << "Current thread id: " << thisId << std::endl; } +//===----------------------------------------------------------------------===// +// MLIR Runner (JitRunner) dynamic library integration. +//===----------------------------------------------------------------------===// + +extern "C" void __mlir_runner_init(llvm::StringMap &exportSymbols) { + auto exportSymbol = [&](llvm::StringRef name, auto ptr) { + exportSymbols[name] = reinterpret_cast(ptr); + }; + + exportSymbol("mlirAsyncRuntimeAddRef", + &mlir::runtime::mlirAsyncRuntimeAddRef); + exportSymbol("mlirAsyncRuntimeDropRef", + &mlir::runtime::mlirAsyncRuntimeDropRef); + exportSymbol("mlirAsyncRuntimeExecute", + &mlir::runtime::mlirAsyncRuntimeExecute); + exportSymbol("mlirAsyncRuntimeCreateToken", + &mlir::runtime::mlirAsyncRuntimeCreateToken); + exportSymbol("mlirAsyncRuntimeEmplaceToken", + &mlir::runtime::mlirAsyncRuntimeEmplaceToken); + exportSymbol("mlirAsyncRuntimeAwaitToken", + &mlir::runtime::mlirAsyncRuntimeAwaitToken); + exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute", + &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute); + exportSymbol("mlirAsyncRuntimeCreateGroup", + &mlir::runtime::mlirAsyncRuntimeCreateGroup); + exportSymbol("mlirAsyncRuntimeAddTokenToGroup", + &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup); + exportSymbol("mlirAsyncRuntimeAwaitAllInGroup", + &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup); + exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute", + &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute); + exportSymbol("mlirAsyncRuntimePrintCurrentThreadId", + &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId); +} + +extern "C" void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); } + #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -155,17 +155,59 @@ if (auto clOptLevel = getCommandLineOptLevel(options)) jitCodeGenOptLevel = static_cast(clOptLevel.getValue()); + + // If shared library implements custom mlir-runner library init and destroy + // functions, we'll use them to register the library with the execution + // engine. Otherwise we'll pass library directly to the execution engine. SmallVector libs(options.clSharedLibs.begin(), options.clSharedLibs.end()); + + // Libraries that we'll pass to the ExecutionEngine for loading. + SmallVector executionEngineLibs; + + using MlirRunnerInitFn = void (*)(llvm::StringMap &); + using MlirRunnerDestroyFn = void (*)(); + + llvm::StringMap exportSymbols; + SmallVector destroyFns; + + // Handle libraries that do support mlir-runner init/destroy callbacks. + for (auto libPath : libs) { + auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.data()); + void *initSym = lib.getAddressOfSymbol("__mlir_runner_init"); + void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy"); + + // Library does not support mlir runner, load it with ExecutionEngine. + if (!initSym || !destroySim) { + executionEngineLibs.push_back(libPath); + continue; + } + + auto initFn = reinterpret_cast(initSym); + initFn(exportSymbols); + + auto destroyFn = reinterpret_cast(destroySim); + destroyFns.push_back(destroyFn); + } + + // Build a runtime symbol map from the config and exported symbols. + auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) { + auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner) + : llvm::orc::SymbolMap(); + for (auto &exportSymbol : exportSymbols) + symbolMap[interner(exportSymbol.getKey())] = + llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue()); + return symbolMap; + }; + auto expectedEngine = mlir::ExecutionEngine::create( module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel, - libs); + executionEngineLibs); if (!expectedEngine) return expectedEngine.takeError(); auto engine = std::move(*expectedEngine); - if (config.runtimeSymbolMap) - engine->registerSymbols(config.runtimeSymbolMap); + engine->registerSymbols(runtimeSymbolMap); auto expectedFPtr = engine->lookup(entryPoint); if (!expectedFPtr) @@ -179,6 +221,9 @@ void (*fptr)(void **) = *expectedFPtr; (*fptr)(args); + // Run all dynamic library destroy callbacks to prepare for the shutdown. + llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); }); + return Error::success(); }