diff --git a/llvm/include/llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h b/llvm/include/llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h @@ -0,0 +1,108 @@ +//===- JITLinkRedirectableSymbolManager.h - JITLink redirection --*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Redirectable Symbol Manager implementation using JITLink +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_JITLINKREDIRECABLEMANAGER_H +#define LLVM_EXECUTIONENGINE_ORC_JITLINKREDIRECABLEMANAGER_H + +#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/RedirectionManager.h" +#include "llvm/Support/StringSaver.h" + +namespace llvm { +namespace orc { + +class JITLinkRedirectableSymbolManager : public RedirectableSymbolManager, + public ResourceManager { +public: + /// Create redirection manager that uses JITLink based implementaion. + static Expected> + Create(ExecutionSession &ES, ObjectLinkingLayer &ObjLinkingLayer, + JITDylib &JD) { + Error Err = Error::success(); + auto RM = std::unique_ptr( + new JITLinkRedirectableSymbolManager(ES, ObjLinkingLayer, JD, Err)); + if (Err) + return Err; + return std::move(RM); + } + + Error createRedirectableSymbols(const SymbolAddrMap &InitialDests, + ResourceTrackerSP RT) override; + + Error redirect(const SymbolAddrMap &NewDests) override; + + Error handleRemoveResources(JITDylib &JD, ResourceKey K) override; + + void handleTransferResources(JITDylib &JD, ResourceKey DstK, + ResourceKey SrcK) override; + +private: + using StubHandle = unsigned; + constexpr static unsigned StubBlockSize = 256; + constexpr static StringRef JumpStubPrefix = "$__IND_JUMP_STUBS"; + constexpr static StringRef StubPtrPrefix = "$IND_JUMP_PTR_"; + constexpr static StringRef JumpStubTableName = "$IND_JUMP_"; + constexpr static StringRef StubPtrTableName = "$__IND_JUMP_PTRS"; + + JITLinkRedirectableSymbolManager(ExecutionSession &ES, + ObjectLinkingLayer &ObjLinkingLayer, + JITDylib &JD, Error &Err) + : ES(ES), ObjLinkingLayer(ObjLinkingLayer), JD(JD), + AnonymousPtrCreator( + jitlink::getAnonymousPointerCreator(ES.getTargetTriple())), + PtrJumpStubCreator( + jitlink::getPointerJumpStubCreator(ES.getTargetTriple())) { + if (!AnonymousPtrCreator || !PtrJumpStubCreator) + Err = make_error("Architecture not supported", + inconvertibleErrorCode()); + if (Err) + return; + ES.registerResourceManager(*this); + } + + ~JITLinkRedirectableSymbolManager() { ES.deregisterResourceManager(*this); } + + StringRef JumpStubSymbolName(unsigned I) { + return StringPool.save((JumpStubPrefix + Twine(I)).str()); + } + + StringRef StubPtrSymbolName(unsigned I) { + return StringPool.save((StubPtrPrefix + Twine(I)).str()); + } + + unsigned GetNumAvailableStubs() const { return AvailableStubs.size(); } + + Error redirectInner(const SymbolAddrMap &NewDests); + Error grow(unsigned Need); + + ExecutionSession &ES; + ObjectLinkingLayer &ObjLinkingLayer; + JITDylib &JD; + jitlink::AnonymousPointerCreator AnonymousPtrCreator; + jitlink::PointerJumpStubCreator PtrJumpStubCreator; + + std::vector AvailableStubs; + DenseMap SymbolToStubs; + std::vector JumpStubs; + std::vector StubPointers; + DenseMap> TrackedResources; + + BumpPtrAllocator BAlloc; + StringSaver StringPool{BAlloc}; + std::mutex Mutex; +}; + +} // namespace orc +} // namespace llvm + +#endif diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RedirectionManager.h b/llvm/include/llvm/ExecutionEngine/Orc/RedirectionManager.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/RedirectionManager.h @@ -0,0 +1,62 @@ +//===- RedirectionManager.h - Redirection manager interface -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Redirection manager interface that redirects a call to symbol to another. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/Core.h" + +namespace llvm { +namespace orc { + +/// Base class for performing redirection of call to symbol to another symbol in +/// runtime. +class RedirectionManager { +public: + /// Symbol name to symbol definition map. + using SymbolAddrMap = DenseMap; + + virtual ~RedirectionManager() = default; + /// Change the redirection destination of given symbols to new destination + /// symbols. + virtual Error redirect(const SymbolAddrMap &NewDests) = 0; + + /// Change the redirection destination of given symbol to new destination + /// symbol. + virtual Error redirect(SymbolStringPtr Symbol, ExecutorSymbolDef NewDest) { + return redirect({{Symbol, NewDest}}); + } + +private: + virtual void anchor(); +}; + +/// Base class for managing redirectable symbols in which a call +/// gets redirected to another symbol in runtime. +class RedirectableSymbolManager : public RedirectionManager { +public: + /// Symbol name to symbol definition map. + using SymbolAddrMap = DenseMap; + + /// Create redirectable symbols with given symbol names and initial + /// desitnation symbols. + virtual Error createRedirectableSymbols(const SymbolAddrMap &InitialDests, + ResourceTrackerSP RT) = 0; + + /// Create a single redirectable symbol with given symbol name and initial + /// desitnation symbol. + virtual Error createRedirectableSymbol(SymbolStringPtr Symbol, + ExecutorSymbolDef InitialDest, + ResourceTrackerSP RT) { + return createRedirectableSymbols({{Symbol, InitialDest}}, RT); + } +}; + +} // namespace orc +} // namespace llvm diff --git a/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt b/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt --- a/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt +++ b/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt @@ -48,6 +48,8 @@ ExecutorProcessControl.cpp TaskDispatch.cpp ThreadSafeModule.cpp + RedirectionManager.cpp + JITLinkRedirectableSymbolManager.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/ExecutionEngine/Orc diff --git a/llvm/lib/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.cpp b/llvm/lib/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.cpp @@ -0,0 +1,180 @@ +//===-- JITLinkRedirectableSymbolManager.cpp - JITLink redirection in Orc -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h" + +#define DEBUG_TYPE "orc" + +using namespace llvm; +using namespace llvm::orc; + +Error JITLinkRedirectableSymbolManager::createRedirectableSymbols( + const SymbolAddrMap &InitialDests, ResourceTrackerSP RT) { + std::unique_lock Lock(Mutex); + if (GetNumAvailableStubs() < InitialDests.size()) + if (auto Err = grow(InitialDests.size() - GetNumAvailableStubs())) + return Err; + + SymbolMap NewSymbolDefs; + std::vector Symbols; + for (auto &[K, V] : InitialDests) { + StubHandle StubID = AvailableStubs.back(); + if (SymbolToStubs.count(K)) + return make_error( + "Tried to create duplicate redirectable symbols", + inconvertibleErrorCode()); + SymbolToStubs[K] = StubID; + NewSymbolDefs[K] = JumpStubs[StubID]; + Symbols.push_back(K); + AvailableStubs.pop_back(); + } + + if (auto Err = JD.define(absoluteSymbols(NewSymbolDefs), RT)) + return Err; + + if (auto Err = redirectInner(InitialDests)) + return Err; + + auto Err = RT->withResourceKeyDo([&](ResourceKey Key) { + TrackedResources[Key].insert(TrackedResources[Key].end(), Symbols.begin(), + Symbols.end()); + }); + if (Err) + return Err; + + return Error::success(); +} + +Error JITLinkRedirectableSymbolManager::redirect( + const SymbolAddrMap &NewDests) { + std::unique_lock Lock(Mutex); + return redirectInner(NewDests); +} + +Error JITLinkRedirectableSymbolManager::redirectInner( + const SymbolAddrMap &NewDests) { + std::vector> PtrWrites; + for (auto &[K, V] : NewDests) { + if (!SymbolToStubs.count(K)) + return make_error( + "Tried to redirect non-existent redirectalbe symbol", + inconvertibleErrorCode()); + StubHandle StubID = SymbolToStubs.at(K); + PtrWrites.push_back({StubPointers[StubID].getAddress(), V.getAddress()}); + } + + if (ES.getTargetTriple().isArch64Bit()) { + std::vector NativeWrites; + for (auto &[Ptr, Target] : PtrWrites) + NativeWrites.push_back(tpctypes::UInt64Write(Ptr, Target.getValue())); + if (auto Err = + ES.getExecutorProcessControl().getMemoryAccess().writeUInt64s( + NativeWrites)) + return Err; + } else { + assert(DL.getPointerSize() == 4 && "Unsupported pointer size"); + std::vector NativeWrites; + for (auto &[Ptr, Target] : PtrWrites) + NativeWrites.push_back(tpctypes::UInt32Write(Ptr, Target.getValue())); + if (auto Err = + ES.getExecutorProcessControl().getMemoryAccess().writeUInt32s( + NativeWrites)) + return Err; + } + return Error::success(); +} + +Error JITLinkRedirectableSymbolManager::grow(unsigned Need) { + unsigned OldSize = JumpStubs.size(); + unsigned NumNewStubs = alignTo(Need, StubBlockSize); + unsigned NewSize = OldSize + NumNewStubs; + + JumpStubs.resize(NewSize); + StubPointers.resize(NewSize); + AvailableStubs.reserve(NewSize); + + SymbolLookupSet LookupSymbols; + DenseMap NewDefsMap; + + Triple TT = ES.getTargetTriple(); + auto G = std::make_unique( + "", TT, TT.isArch64Bit() ? 8 : 4, + TT.isLittleEndian() ? support::little : support::big, + jitlink::getGenericEdgeKindName); + auto &PointerSection = + G->createSection(StubPtrTableName, MemProt::Write | MemProt::Read); + auto &StubsSection = + G->createSection(JumpStubTableName, MemProt::Exec | MemProt::Read); + + for (size_t I = OldSize; I < NewSize; I++) { + auto Pointer = AnonymousPtrCreator(*G, PointerSection, nullptr, 0); + if (auto Err = Pointer.takeError()) + return Err; + + StringRef PtrSymName = StubPtrSymbolName(I); + Pointer->setName(PtrSymName); + Pointer->setScope(jitlink::Scope::Default); + LookupSymbols.add(ES.intern(PtrSymName)); + NewDefsMap[ES.intern(PtrSymName)] = &StubPointers[I]; + + auto Stub = PtrJumpStubCreator(*G, StubsSection, *Pointer); + if (auto Err = Stub.takeError()) + return Err; + + StringRef JumpStubSymName = JumpStubSymbolName(I); + Stub->setName(JumpStubSymName); + Stub->setScope(jitlink::Scope::Default); + LookupSymbols.add(ES.intern(JumpStubSymName)); + NewDefsMap[ES.intern(JumpStubSymName)] = &JumpStubs[I]; + } + + if (auto Err = ObjLinkingLayer.add(JD, std::move(G))) + return Err; + + auto LookupResult = ES.lookup(makeJITDylibSearchOrder(&JD), LookupSymbols); + if (auto Err = LookupResult.takeError()) + return Err; + + for (auto &[K, V] : *LookupResult) + *NewDefsMap.at(K) = V; + + for (size_t I = OldSize; I < NewSize; I++) + AvailableStubs.push_back(I); + + return Error::success(); +} + +Error JITLinkRedirectableSymbolManager::handleRemoveResources(JITDylib &JD, + ResourceKey K) { + if (&JD != &this->JD) + return Error::success(); + + std::unique_lock Lock(Mutex); + for (auto &Symbol : TrackedResources[K]) { + if (!SymbolToStubs.count(Symbol)) + return make_error( + "Tried to remove non-existent redirectalbe symbol", + inconvertibleErrorCode()); + AvailableStubs.push_back(SymbolToStubs.at(Symbol)); + SymbolToStubs.erase(Symbol); + } + TrackedResources.erase(K); + + return Error::success(); +} + +void JITLinkRedirectableSymbolManager::handleTransferResources( + JITDylib &JD, ResourceKey DstK, ResourceKey SrcK) { + if (&JD != &this->JD) + return; + std::unique_lock Lock(Mutex); + TrackedResources[DstK].insert(TrackedResources[DstK].end(), + TrackedResources[SrcK].begin(), + TrackedResources[SrcK].end()); + TrackedResources.erase(SrcK); +} diff --git a/llvm/lib/ExecutionEngine/Orc/RedirectionManager.cpp b/llvm/lib/ExecutionEngine/Orc/RedirectionManager.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/ExecutionEngine/Orc/RedirectionManager.cpp @@ -0,0 +1,16 @@ +//===---- RedirectionManager.cpp - Redirection manager interface in Orc ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/RedirectionManager.h" + +#define DEBUG_TYPE "orc" + +using namespace llvm; +using namespace llvm::orc; + +void RedirectionManager::anchor() {} diff --git a/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt b/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt --- a/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt +++ b/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt @@ -40,6 +40,7 @@ TaskDispatchTest.cpp ThreadSafeModuleTest.cpp WrapperFunctionUtilsTest.cpp + JITLinkRedirectionManagerTest.cpp ) target_link_libraries(OrcJITTests PRIVATE diff --git a/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp b/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp @@ -0,0 +1,100 @@ +#include "OrcTestCommon.h" +#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::orc; +using namespace llvm::jitlink; + +static int initialTarget() { return 42; } +static int middleTarget() { return 13; } +static int finalTarget() { return 53; } + +class JITLinkRedirectionManagerTest : public testing::Test { +public: + ~JITLinkRedirectionManagerTest() { + if (ES) + if (auto Err = ES->endSession()) + ES->reportError(std::move(Err)); + } + +protected: + void SetUp() override { + auto JTMB = JITTargetMachineBuilder::detectHost(); + // Bail out if we can not detect the host. + if (!JTMB) { + consumeError(JTMB.takeError()); + GTEST_SKIP(); + } + + ES = std::make_unique( + std::make_unique( + nullptr, nullptr, JTMB->getTargetTriple().getTriple())); + JD = &ES->createBareJITDylib("main"); + ObjLinkingLayer = std::make_unique( + *ES, std::make_unique(4096)); + DL = std::make_unique( + cantFail(JTMB->getDefaultDataLayoutForTarget())); + } + JITDylib *JD{nullptr}; + std::unique_ptr ES; + std::unique_ptr ObjLinkingLayer; + std::unique_ptr DL; +}; + +TEST_F(JITLinkRedirectionManagerTest, BasicRedirectionOperation) { + auto RM = + JITLinkRedirectableSymbolManager::Create(*ES, *ObjLinkingLayer, *JD); + // Bail out if we can not create + if (!RM) { + consumeError(RM.takeError()); + GTEST_SKIP(); + } + + auto DefineTarget = [&](StringRef TargetName, ExecutorAddr Addr) { + SymbolStringPtr Target = ES->intern(TargetName); + cantFail(JD->define(std::make_unique( + SymbolFlagsMap({{Target, JITSymbolFlags::Exported}}), + [&](std::unique_ptr R) -> void { + // No dependencies registered, can't fail. + cantFail( + R->notifyResolved({{Target, {Addr, JITSymbolFlags::Exported}}})); + cantFail(R->notifyEmitted()); + }))); + return cantFail(ES->lookup({JD}, TargetName)); + }; + + auto InitialTarget = + DefineTarget("InitialTarget", ExecutorAddr::fromPtr(&initialTarget)); + auto MiddleTarget = + DefineTarget("MiddleTarget", ExecutorAddr::fromPtr(&middleTarget)); + auto FinalTarget = + DefineTarget("FinalTarget", ExecutorAddr::fromPtr(&finalTarget)); + + auto RedirectableSymbol = ES->intern("RedirectableTarget"); + EXPECT_THAT_ERROR( + (*RM)->createRedirectableSymbols({{RedirectableSymbol, InitialTarget}}, + JD->getDefaultResourceTracker()), + Succeeded()); + auto RTDef = cantFail(ES->lookup({JD}, RedirectableSymbol)); + + auto RTPtr = RTDef.getAddress().toPtr(); + auto Result = RTPtr(); + EXPECT_EQ(Result, 42) << "Failed to call initial target"; + + EXPECT_THAT_ERROR((*RM)->redirect({{RedirectableSymbol, MiddleTarget}}), + Succeeded()); + Result = RTPtr(); + EXPECT_EQ(Result, 13) << "Failed to call middle redirected target"; + + EXPECT_THAT_ERROR((*RM)->redirect({{RedirectableSymbol, FinalTarget}}), + Succeeded()); + Result = RTPtr(); + EXPECT_EQ(Result, 53) << "Failed to call redirected target"; +}