diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h --- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -71,7 +71,15 @@ std::optional jitCodeGenOptLevel; /// If `sharedLibPaths` are provided, the underlying JIT-compilation will - /// open and link the shared libraries for symbol resolution. + /// open and link the shared libraries for symbol resolution. Libraries that + /// are designed to be used with the `ExecutionEngine` may implement a + /// loading and unloading protocol: if they implement the two functions with + /// the names defined in `kLibraryInitFnName` and `kLibraryDestroyFnName`, + /// these functions will be called upon loading the library and upon + /// destruction of the `ExecutionEngine`. In the init function, the library + /// may provide a list of symbols that it wants to make available to code + /// run by the `ExecutionEngine`. If the two functions are not defined, only + /// symbols with public visibility are available to the executed code. ArrayRef sharedLibPaths = {}; /// Specifies an existing `sectionMemoryMapper` to be associated with the @@ -105,9 +113,32 @@ /// be used to invoke the JIT-compiled function. class ExecutionEngine { public: + /// Name of init functions of shared libraries. If a library provides a + /// function with this name and the one of the destroy function, this function + /// is called upon loading the library. + static constexpr const char *const kLibraryInitFnName = + "__mlir_execution_engine_init"; + + /// Name of destroy functions of shared libraries. If a library provides a + /// function with this name and the one of the init function, this function is + /// called upon destructing the `ExecutionEngine`. + static constexpr const char *const kLibraryDestroyFnName = + "__mlir_execution_engine_destroy"; + + /// Function type for init functions of shared libraries. The library may + /// provide a list of symbols that it wants to make available to code run by + /// the `ExecutionEngine`. If the two functions are not defined, only symbols + /// with public visibility are available to the executed code. + using LibraryInitFn = void (*)(llvm::StringMap &); + + /// Function type for destroy functions of shared libraries. + using LibraryDestroyFn = void (*)(); + ExecutionEngine(bool enableObjectDump, bool enableGDBNotificationListener, bool enablePerfNotificationListener); + ~ExecutionEngine(); + /// Creates an execution engine for the given MLIR IR. If TargetMachine is /// not provided, default TM is created (i.e. ignoring any command line flags /// that could affect the set-up). @@ -216,6 +247,10 @@ /// Perf notification listener. llvm::JITEventListener *perfListener; + + /// Destroy functions in the libraries loaded by the ExecutionEngine that are + /// called when this ExecutionEngine is destructed. + SmallVector destroyFns; }; } // namespace mlir diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -468,10 +468,11 @@ // The bug is fixed in VS2019 16.1. Separating the declaration and definition is // a work around for older versions of Visual Studio. // NOLINTNEXTLINE(*-identifier-naming): externally called. -extern "C" API void __mlir_runner_init(llvm::StringMap &exportSymbols); +extern "C" API void +__mlir_execution_engine_init(llvm::StringMap &exportSymbols); // NOLINTNEXTLINE(*-identifier-naming): externally called. -void __mlir_runner_init(llvm::StringMap &exportSymbols) { +void __mlir_execution_engine_init(llvm::StringMap &exportSymbols) { auto exportSymbol = [&](llvm::StringRef name, auto ptr) { assert(exportSymbols.count(name) == 0 && "symbol already exists"); exportSymbols[name] = reinterpret_cast(ptr); @@ -526,7 +527,9 @@ } // NOLINTNEXTLINE(*-identifier-naming): externally called. -extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); } +extern "C" API void __mlir_execution_engine_destroy() { + resetDefaultAsyncRuntime(); +} } // namespace runtime } // namespace mlir diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -222,6 +222,12 @@ } } +ExecutionEngine::~ExecutionEngine() { + // Run all dynamic library destroy callbacks to prepare for the shutdown. + for (LibraryDestroyFn destroy : destroyFns) + destroy(); +} + Expected> ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options, std::unique_ptr tm) { @@ -267,6 +273,16 @@ auto dataLayout = llvmModule->getDataLayout(); + // Use absolute library path so that gdb can find the symbol table. + SmallVector, 4> sharedLibPaths; + transform( + options.sharedLibPaths, std::back_inserter(sharedLibPaths), + [](StringRef libPath) { + SmallString<256> absPath(libPath.begin(), libPath.end()); + cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath))); + return absPath; + }); + // Callback to create the object layer with symbol resolution to current // process and dynamically linked libraries. auto objectLinkingLayerCreator = [&](ExecutionSession &session, @@ -292,7 +308,7 @@ } // Resolve symbols from shared libraries. - for (auto libPath : options.sharedLibPaths) { + for (auto &libPath : sharedLibPaths) { auto mb = llvm::MemoryBuffer::getFile(libPath); if (!mb) { errs() << "Failed to create MemoryBuffer for: " << libPath @@ -301,7 +317,7 @@ } auto &jd = session.createBareJITDylib(std::string(libPath)); auto loaded = DynamicLibrarySearchGenerator::Load( - libPath.data(), dataLayout.getGlobalPrefix()); + libPath.str().str().c_str(), dataLayout.getGlobalPrefix()); if (!loaded) { errs() << "Could not load " << libPath << ":\n " << loaded.takeError() << "\n"; @@ -346,6 +362,42 @@ cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess( dataLayout.getGlobalPrefix()))); + // If shared library implements custom execution layer library init and + // destroy functions, we'll use them to register the library. + + llvm::StringMap exportSymbols; + SmallVector destroyFns; + + for (auto &libPath : sharedLibPaths) { + auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary( + libPath.str().str().c_str()); + void *initSym = lib.getAddressOfSymbol(kLibraryInitFnName); + void *destroySim = lib.getAddressOfSymbol(kLibraryDestroyFnName); + + // Library does not provide call backs, rely on symbol visiblity. + if (!initSym || !destroySim) { + continue; + } + + auto initFn = reinterpret_cast(initSym); + initFn(exportSymbols); + + auto destroyFn = reinterpret_cast(destroySim); + destroyFns.push_back(destroyFn); + } + engine->destroyFns = std::move(destroyFns); + + // Build a runtime symbol map from the exported symbols and register them. + auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) { + auto symbolMap = llvm::orc::SymbolMap(); + for (auto &exportSymbol : exportSymbols) + symbolMap[interner(exportSymbol.getKey())] = { + llvm::orc::ExecutorAddr::fromPtr(exportSymbol.getValue()), + llvm::JITSymbolFlags::Exported}; + return symbolMap; + }; + engine->registerSymbols(runtimeSymbolMap); + return std::move(engine); } diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -185,65 +185,15 @@ if (auto clOptLevel = getCommandLineOptLevel(options)) jitCodeGenOptLevel = static_cast(*clOptLevel); - // If shared library implements custom mlir-runner library init and destroy - // functions, we'll use them to register the library with the execution - // engine. Otherwise we'll pass library directly to the execution engine. - SmallVector, 4> libPaths; - - // Use absolute library path so that gdb can find the symbol table. - transform( - options.clSharedLibs, std::back_inserter(libPaths), - [](std::string libPath) { - SmallString<256> absPath(libPath.begin(), libPath.end()); - cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath))); - return absPath; - }); - - // Libraries that we'll pass to the ExecutionEngine for loading. - SmallVector executionEngineLibs; - - using MlirRunnerInitFn = void (*)(llvm::StringMap &); - using MlirRunnerDestroyFn = void (*)(); - - llvm::StringMap exportSymbols; - SmallVector destroyFns; - - // Handle libraries that do support mlir-runner init/destroy callbacks. - for (auto &libPath : libPaths) { - auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str()); - void *initSym = lib.getAddressOfSymbol("__mlir_runner_init"); - void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy"); - - // Library does not support mlir runner, load it with ExecutionEngine. - if (!initSym || !destroySim) { - executionEngineLibs.push_back(libPath); - continue; - } - - auto initFn = reinterpret_cast(initSym); - initFn(exportSymbols); - - auto destroyFn = reinterpret_cast(destroySim); - destroyFns.push_back(destroyFn); - } - - // Build a runtime symbol map from the config and exported symbols. - auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) { - auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner) - : llvm::orc::SymbolMap(); - for (auto &exportSymbol : exportSymbols) - symbolMap[interner(exportSymbol.getKey())] = - { llvm::orc::ExecutorAddr::fromPtr(exportSymbol.getValue()), - llvm::JITSymbolFlags::Exported }; - return symbolMap; - }; + SmallVector sharedLibs(options.clSharedLibs.begin(), + options.clSharedLibs.end()); mlir::ExecutionEngineOptions engineOptions; engineOptions.llvmModuleBuilder = config.llvmModuleBuilder; if (config.transformer) engineOptions.transformer = config.transformer; engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel; - engineOptions.sharedLibPaths = executionEngineLibs; + engineOptions.sharedLibPaths = sharedLibs; engineOptions.enableObjectDump = true; auto expectedEngine = mlir::ExecutionEngine::create(module, engineOptions, std::move(tm)); @@ -251,7 +201,6 @@ return expectedEngine.takeError(); auto engine = std::move(*expectedEngine); - engine->registerSymbols(runtimeSymbolMap); auto expectedFPtr = engine->lookupPacked(entryPoint); if (!expectedFPtr) @@ -265,10 +214,6 @@ void (*fptr)(void **) = *expectedFPtr; (*fptr)(args); - // Run all dynamic library destroy callbacks to prepare for the shutdown. - for (MlirRunnerDestroyFn destroy : destroyFns) - destroy(); - return Error::success(); }