diff --git a/llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h b/llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h --- a/llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h @@ -403,10 +403,30 @@ StringMap BootstrapSymbols; }; +class InProcessMemoryAccess : public ExecutorProcessControl::MemoryAccess { +public: + InProcessMemoryAccess() = default; + void writeUInt8sAsync(ArrayRef Ws, + WriteResultFn OnWriteComplete) override; + + void writeUInt16sAsync(ArrayRef Ws, + WriteResultFn OnWriteComplete) override; + + void writeUInt32sAsync(ArrayRef Ws, + WriteResultFn OnWriteComplete) override; + + void writeUInt64sAsync(ArrayRef Ws, + WriteResultFn OnWriteComplete) override; + + void writeBuffersAsync(ArrayRef Ws, + WriteResultFn OnWriteComplete) override; +}; + /// A ExecutorProcessControl instance that asserts if any of its methods are /// used. Suitable for use is unit tests, and by ORC clients who haven't moved /// to ExecutorProcessControl-based APIs yet. -class UnsupportedExecutorProcessControl : public ExecutorProcessControl { +class UnsupportedExecutorProcessControl : public ExecutorProcessControl, + private InProcessMemoryAccess { public: UnsupportedExecutorProcessControl( std::shared_ptr SSP = nullptr, @@ -418,6 +438,7 @@ : std::make_unique()) { this->TargetTriple = Triple(TT); this->PageSize = PageSize; + this->MemAccess = this; } Expected loadDylib(const char *DylibPath) override { @@ -452,9 +473,8 @@ }; /// A ExecutorProcessControl implementation targeting the current process. -class SelfExecutorProcessControl - : public ExecutorProcessControl, - private ExecutorProcessControl::MemoryAccess { +class SelfExecutorProcessControl : public ExecutorProcessControl, + private InProcessMemoryAccess { public: SelfExecutorProcessControl( std::shared_ptr SSP, std::unique_ptr D, @@ -490,21 +510,6 @@ Error disconnect() override; private: - void writeUInt8sAsync(ArrayRef Ws, - WriteResultFn OnWriteComplete) override; - - void writeUInt16sAsync(ArrayRef Ws, - WriteResultFn OnWriteComplete) override; - - void writeUInt32sAsync(ArrayRef Ws, - WriteResultFn OnWriteComplete) override; - - void writeUInt64sAsync(ArrayRef Ws, - WriteResultFn OnWriteComplete) override; - - void writeBuffersAsync(ArrayRef Ws, - WriteResultFn OnWriteComplete) override; - static shared::CWrapperFunctionResult jitDispatchViaWrapperFunctionManager(void *Ctx, const void *FnTag, const char *Data, size_t Size); diff --git a/llvm/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h --- a/llvm/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h @@ -13,14 +13,17 @@ #ifndef LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H #define LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" #include "llvm/ExecutionEngine/Orc/OrcABISupport.h" #include "llvm/Support/Error.h" #include "llvm/Support/Memory.h" #include "llvm/Support/Process.h" +#include "llvm/Support/StringSaver.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include #include @@ -29,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -55,6 +59,8 @@ namespace orc { +class ObjectLinkingLayer; + /// Base class for pools of compiler re-entry trampolines. /// These trampolines are callable addresses that save all register state /// before calling a supplied function to return the trampoline landing @@ -309,6 +315,106 @@ virtual void anchor(); }; +/// Base class for managing redirectable symbols in which a call +/// gets redirected to another symbol in runtime. +class RedirectionManager { +public: + /// Symbol name to symbol definition map. + using SymbolAddrMap = DenseMap; + + virtual ~RedirectionManager() = default; + + /// Create redirectable symbols with given symbol names and initial + /// desitnation symbols. + virtual Error + createRedirectableSymbols(const SymbolAddrMap &InitialDests) = 0; + + /// Create a single redirectable symbol with given symbol name and initial + /// desitnation symbol. + virtual Error createRedirectableSymbol(SymbolStringPtr Symbol, + ExecutorSymbolDef InitialDest) { + return createRedirectableSymbols({{Symbol, InitialDest}}); + } + + /// Release redirectable symbols. + virtual Error releaseRedirectableSymbols(const SymbolNameSet &Symbols) = 0; + + /// Release redirectable symbol. + virtual Error releaseRedirectableSymbol(SymbolStringPtr Symbol) { + return releaseRedirectableSymbols({Symbol}); + } + + /// Change the redirection destination of given symbols to new destination + /// symbols. + virtual Error redirect(const SymbolAddrMap &NewDests) = 0; + + /// Change the redirection destination of given symbol to new destination + /// symbol. + virtual Error redirect(SymbolStringPtr Symbol, ExecutorSymbolDef NewDest) { + return redirect({{Symbol, NewDest}}); + } + +private: + virtual void anchor(); +}; + +class JITLinkRedirectionManager : public RedirectionManager { +public: + /// Create redirection manager that uses JITLink based implementaion. + static Expected> + Create(ExecutionSession &ES, ObjectLinkingLayer &ObjLinkingLayer, + const DataLayout &DL, JITDylib &JD) { + return std::unique_ptr( + new JITLinkRedirectionManager(ES, ObjLinkingLayer, DL, JD)); + } + + Error createRedirectableSymbols(const SymbolAddrMap &InitialDests) override; + + Error releaseRedirectableSymbols(const SymbolNameSet &Symbols) override; + + Error redirect(const SymbolAddrMap &NewDests) override; + +private: + using StubHandle = unsigned; + constexpr static unsigned StubBlockSize = 256; + constexpr static StringRef JumpStubPrefix = "$__IND_JUMP_STUBS"; + constexpr static StringRef StubPtrPrefix = "$IND_JUMP_PTR_"; + constexpr static StringRef JumpStubTableName = "$IND_JUMP_"; + constexpr static StringRef StubPtrTableName = "$__IND_JUMP_PTRS"; + + JITLinkRedirectionManager(ExecutionSession &ES, + ObjectLinkingLayer &ObjLinkingLayer, + const DataLayout &DL, JITDylib &JD) + : ES(ES), ObjLinkingLayer(ObjLinkingLayer), JD(JD), DL(DL) {} + + StringRef JumpStubSymbolName(unsigned I) { + return StringPool.save((JumpStubPrefix + Twine(I)).str()); + } + + StringRef StubPtrSymbolName(unsigned I) { + return StringPool.save((StubPtrPrefix + Twine(I)).str()); + } + + unsigned GetNumAvailableStubs() const { return AvailbleStubs.size(); } + + Error redirectInner(const SymbolAddrMap &NewDests); + Error grow(unsigned Need); + + ExecutionSession &ES; + ObjectLinkingLayer &ObjLinkingLayer; + JITDylib &JD; + const DataLayout DL; + + std::vector AvailbleStubs; + DenseMap SymbolToStubs; + std::vector JumpStubs; + std::vector StubPointers; + + BumpPtrAllocator BAlloc; + StringSaver StringPool{BAlloc}; + std::mutex Mutex; +}; + template class LocalIndirectStubsInfo { public: LocalIndirectStubsInfo(unsigned NumStubs, sys::OwningMemoryBlock StubsMem) @@ -332,7 +438,7 @@ return errorCodeToError(EC); sys::MemoryBlock StubsBlock(StubsAndPtrsMem.base(), ISAS.StubBytes); - auto StubsBlockMem = static_cast(StubsAndPtrsMem.base()); + auto *StubsBlockMem = static_cast(StubsAndPtrsMem.base()); auto PtrBlockAddress = ExecutorAddr::fromPtr(StubsBlockMem) + ISAS.StubBytes; diff --git a/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp b/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp --- a/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp +++ b/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp @@ -139,35 +139,35 @@ return Error::success(); } -void SelfExecutorProcessControl::writeUInt8sAsync( - ArrayRef Ws, WriteResultFn OnWriteComplete) { +void InProcessMemoryAccess::writeUInt8sAsync(ArrayRef Ws, + WriteResultFn OnWriteComplete) { for (auto &W : Ws) *W.Addr.toPtr() = W.Value; OnWriteComplete(Error::success()); } -void SelfExecutorProcessControl::writeUInt16sAsync( +void InProcessMemoryAccess::writeUInt16sAsync( ArrayRef Ws, WriteResultFn OnWriteComplete) { for (auto &W : Ws) *W.Addr.toPtr() = W.Value; OnWriteComplete(Error::success()); } -void SelfExecutorProcessControl::writeUInt32sAsync( +void InProcessMemoryAccess::writeUInt32sAsync( ArrayRef Ws, WriteResultFn OnWriteComplete) { for (auto &W : Ws) *W.Addr.toPtr() = W.Value; OnWriteComplete(Error::success()); } -void SelfExecutorProcessControl::writeUInt64sAsync( +void InProcessMemoryAccess::writeUInt64sAsync( ArrayRef Ws, WriteResultFn OnWriteComplete) { for (auto &W : Ws) *W.Addr.toPtr() = W.Value; OnWriteComplete(Error::success()); } -void SelfExecutorProcessControl::writeBuffersAsync( +void InProcessMemoryAccess::writeBuffersAsync( ArrayRef Ws, WriteResultFn OnWriteComplete) { for (auto &W : Ws) memcpy(W.Addr.toPtr(), W.Buffer.data(), W.Buffer.size()); diff --git a/llvm/lib/ExecutionEngine/Orc/IndirectionUtils.cpp b/llvm/lib/ExecutionEngine/Orc/IndirectionUtils.cpp --- a/llvm/lib/ExecutionEngine/Orc/IndirectionUtils.cpp +++ b/llvm/lib/ExecutionEngine/Orc/IndirectionUtils.cpp @@ -8,15 +8,24 @@ #include "llvm/ExecutionEngine/Orc/IndirectionUtils.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ExecutionEngine/JITLink/JITLink.h" +#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" #include "llvm/ExecutionEngine/JITLink/x86_64.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/OrcABISupport.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" +#include "llvm/ExecutionEngine/Orc/Shared/MemoryFlags.h" +#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/MC/MCDisassembler/MCDisassembler.h" #include "llvm/MC/MCInstrAnalysis.h" #include "llvm/Support/Format.h" +#include "llvm/Support/MathExtras.h" #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/Cloning.h" #include +#include #define DEBUG_TYPE "orc" @@ -62,6 +71,152 @@ TrampolinePool::~TrampolinePool() = default; void IndirectStubsManager::anchor() {} +void RedirectionManager::anchor() {} + +Error JITLinkRedirectionManager::createRedirectableSymbols( + const SymbolAddrMap &InitialDests) { + std::unique_lock Lock(Mutex); + if (GetNumAvailableStubs() < InitialDests.size()) + if (auto Err = grow(InitialDests.size() - GetNumAvailableStubs())) + return Err; + + SymbolMap NewSymbolDefs; + for (auto &[K, V] : InitialDests) { + StubHandle StubID = AvailbleStubs.back(); + if (SymbolToStubs.count(K)) + return make_error( + "Tried to create duplicate redirectable symbols", + inconvertibleErrorCode()); + SymbolToStubs[K] = StubID; + NewSymbolDefs[K] = JumpStubs[StubID]; + AvailbleStubs.pop_back(); + } + + if (auto Err = JD.define(absoluteSymbols(NewSymbolDefs))) + return Err; + + if (auto Err = redirectInner(InitialDests)) + return Err; + + return Error::success(); +} + +Error JITLinkRedirectionManager::releaseRedirectableSymbols( + const SymbolNameSet &Symbols) { + std::unique_lock Lock(Mutex); + for (auto &K : Symbols) { + if (!SymbolToStubs.count(K)) + return make_error( + "Tried to remove non-existent redirectalbe symbol", + inconvertibleErrorCode()); + AvailbleStubs.push_back(SymbolToStubs.at(K)); + SymbolToStubs.erase(K); + } + + if (auto Err = JD.remove(Symbols)) + return Err; + + return Error::success(); +} + +Error JITLinkRedirectionManager::redirect(const SymbolAddrMap &NewDests) { + std::unique_lock Lock(Mutex); + return redirectInner(NewDests); +} + +Error JITLinkRedirectionManager::redirectInner(const SymbolAddrMap &NewDests) { + std::vector> PtrWrites; + for (auto &[K, V] : NewDests) { + if (!SymbolToStubs.count(K)) + return make_error( + "Tried to redirect non-existent redirectalbe symbol", + inconvertibleErrorCode()); + StubHandle StubID = SymbolToStubs.at(K); + PtrWrites.push_back({StubPointers[StubID].getAddress(), V.getAddress()}); + } + + if (DL.getPointerSize() == 8) { + std::vector NativeWrites; + for (auto &[Ptr, Target] : PtrWrites) + NativeWrites.push_back(tpctypes::UInt64Write(Ptr, Target.getValue())); + if (auto Err = + ES.getExecutorProcessControl().getMemoryAccess().writeUInt64s( + NativeWrites)) + return Err; + } else { + assert(DL.getPointerSize() == 4 && "Unsupported pointer size"); + std::vector NativeWrites; + for (auto &[Ptr, Target] : PtrWrites) + NativeWrites.push_back(tpctypes::UInt32Write(Ptr, Target.getValue())); + if (auto Err = + ES.getExecutorProcessControl().getMemoryAccess().writeUInt32s( + NativeWrites)) + return Err; + } + return Error::success(); +} + +Error JITLinkRedirectionManager::grow(unsigned Need) { + unsigned OldSize = JumpStubs.size(); + unsigned NumNewStubs = alignTo(Need, StubBlockSize); + unsigned NewSize = OldSize + NumNewStubs; + + JumpStubs.resize(NewSize); + StubPointers.resize(NewSize); + AvailbleStubs.reserve(NewSize); + + SymbolLookupSet LookupSymbols; + DenseMap NewDefsMap; + + Triple TT = ES.getTargetTriple(); + auto G = std::make_unique( + "", TT, DL.getPointerSize(), + DL.isBigEndian() ? support::big : support::little, + jitlink::getGenericEdgeKindName); + auto &PointerSection = + G->createSection(StubPtrTableName, MemProt::Write | MemProt::Read); + auto &StubsSection = + G->createSection(JumpStubTableName, MemProt::Exec | MemProt::Read); + + for (size_t I = OldSize; I < NewSize; I++) { + auto Pointer = jitlink::createAnonymousPointer(*G, PointerSection); + if (auto Err = Pointer.takeError()) + return Err; + + StringRef PtrSymName = StubPtrSymbolName(I); + Pointer->setName(PtrSymName); + Pointer->setScope(jitlink::Scope::Default); + LookupSymbols.add(ES.intern(PtrSymName)); + NewDefsMap[ES.intern(PtrSymName)] = &StubPointers[I]; + + auto Stub = + jitlink::createAnonymousPointerJumpStub(*G, StubsSection, *Pointer); + if (auto Err = Stub.takeError()) + return Err; + + StringRef JumpStubSymName = JumpStubSymbolName(I); + Stub->setName(JumpStubSymName); + Stub->setScope(jitlink::Scope::Default); + LookupSymbols.add(ES.intern(JumpStubSymName)); + NewDefsMap[ES.intern(JumpStubSymName)] = &JumpStubs[I]; + } + + if (auto Err = ObjLinkingLayer.add(JD, std::move(G))) + return Err; + + auto LookupResult = ES.lookup(makeJITDylibSearchOrder(&JD), LookupSymbols); + if (auto Err = LookupResult.takeError()) + return Err; + + for (auto &[K, V] : *LookupResult) + *NewDefsMap.at(K) = V; + + for (size_t I = OldSize; I < NewSize; I++) + AvailbleStubs.push_back(I); + + return Error::success(); +} + Expected JITCompileCallbackManager::getCompileCallback(CompileFunction Compile) { if (auto TrampolineAddr = TP->getTrampoline()) { @@ -251,9 +406,9 @@ GlobalVariable* createImplPointer(PointerType &PT, Module &M, const Twine &Name, Constant *Initializer) { - auto IP = new GlobalVariable(M, &PT, false, GlobalValue::ExternalLinkage, - Initializer, Name, nullptr, - GlobalValue::NotThreadLocal, 0, true); + auto *IP = new GlobalVariable(M, &PT, false, GlobalValue::ExternalLinkage, + Initializer, Name, nullptr, + GlobalValue::NotThreadLocal, 0, true); IP->setVisibility(GlobalValue::HiddenVisibility); return IP; } @@ -316,7 +471,7 @@ if (VMap) { (*VMap)[&F] = NewF; - auto NewArgI = NewF->arg_begin(); + auto *NewArgI = NewF->arg_begin(); for (auto ArgI = F.arg_begin(), ArgE = F.arg_end(); ArgI != ArgE; ++ArgI, ++NewArgI) (*VMap)[&*ArgI] = &*NewArgI; @@ -411,7 +566,7 @@ auto &B = Sym.getBlock(); assert(!B.isZeroFill() && "expected content block"); auto SymAddress = Sym.getAddress(); - auto SymStartInBlock = + auto *SymStartInBlock = (const uint8_t *)B.getContent().data() + Sym.getOffset(); auto SymSize = Sym.getSize() ? Sym.getSize() : B.getSize() - Sym.getOffset(); auto Content = ArrayRef(SymStartInBlock, SymSize); diff --git a/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt b/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt --- a/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt +++ b/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt @@ -40,6 +40,7 @@ TaskDispatchTest.cpp ThreadSafeModuleTest.cpp WrapperFunctionUtilsTest.cpp + JITLinkRedirectionManagerTest.cpp ) target_link_libraries(OrcJITTests PRIVATE diff --git a/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp b/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp @@ -0,0 +1,97 @@ +#include "OrcTestCommon.h" +#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::orc; +using namespace llvm::jitlink; + +static int initialTarget() { return 42; } +static int middleTarget() { return 13; } +static int finalTarget() { return 53; } + +class JITLinkRedirectionManagerTest : public testing::Test { +public: + ~JITLinkRedirectionManagerTest() { + if (ES) + if (auto Err = ES->endSession()) + ES->reportError(std::move(Err)); + } + +protected: + void SetUp() override { + auto JTMB = JITTargetMachineBuilder::detectHost(); + // Bail out if we can not detect the host. + if (!JTMB) { + consumeError(JTMB.takeError()); + GTEST_SKIP(); + } + + ES = std::make_unique( + std::make_unique( + nullptr, nullptr, JTMB->getTargetTriple().getTriple())); + JD = &ES->createBareJITDylib("main"); + ObjLinkingLayer = std::make_unique( + *ES, std::make_unique(4096)); + DL = std::make_unique( + cantFail(JTMB->getDefaultDataLayoutForTarget())); + } + JITDylib *JD{nullptr}; + std::unique_ptr ES; + std::unique_ptr ObjLinkingLayer; + std::unique_ptr DL; +}; + +TEST_F(JITLinkRedirectionManagerTest, BasicRedirectionOperation) { + auto RM = JITLinkRedirectionManager::Create(*ES, *ObjLinkingLayer, *DL, *JD); + // Bail out if we can not create + if (!RM) { + consumeError(RM.takeError()); + GTEST_SKIP(); + } + + auto DefineTarget = [&](StringRef TargetName, ExecutorAddr Addr) { + SymbolStringPtr Target = ES->intern(TargetName); + cantFail(JD->define(std::make_unique( + SymbolFlagsMap({{Target, JITSymbolFlags::Exported}}), + [&](std::unique_ptr R) -> void { + // No dependencies registered, can't fail. + cantFail( + R->notifyResolved({{Target, {Addr, JITSymbolFlags::Exported}}})); + cantFail(R->notifyEmitted()); + }))); + return cantFail(ES->lookup({JD}, TargetName)); + }; + + auto InitialTarget = + DefineTarget("InitialTarget", ExecutorAddr::fromPtr(&initialTarget)); + auto MiddleTarget = + DefineTarget("MiddleTarget", ExecutorAddr::fromPtr(&middleTarget)); + auto FinalTarget = + DefineTarget("FinalTarget", ExecutorAddr::fromPtr(&finalTarget)); + + auto RedirectableSymbol = ES->intern("RedirectableTarget"); + EXPECT_THAT_ERROR( + (*RM)->createRedirectableSymbols({{RedirectableSymbol, InitialTarget}}), + Succeeded()); + auto RTDef = cantFail(ES->lookup({JD}, RedirectableSymbol)); + + auto RTPtr = RTDef.getAddress().toPtr(); + auto Result = RTPtr(); + EXPECT_EQ(Result, 42) << "Failed to call initial target"; + + EXPECT_THAT_ERROR((*RM)->redirect({{RedirectableSymbol, MiddleTarget}}), + Succeeded()); + Result = RTPtr(); + EXPECT_EQ(Result, 13) << "Failed to call middle redirected target"; + + EXPECT_THAT_ERROR((*RM)->redirect({{RedirectableSymbol, FinalTarget}}), + Succeeded()); + Result = RTPtr(); + EXPECT_EQ(Result, 53) << "Failed to call redirected target"; +}