Index: unittests/ExecutionEngine/Orc/CMakeLists.txt =================================================================== --- unittests/ExecutionEngine/Orc/CMakeLists.txt +++ unittests/ExecutionEngine/Orc/CMakeLists.txt @@ -2,6 +2,7 @@ set(LLVM_LINK_COMPONENTS Core ExecutionEngine + IRReader Object OrcJIT RuntimeDyld @@ -16,6 +17,7 @@ GlobalMappingLayerTest.cpp LazyEmittingLayerTest.cpp LegacyAPIInteropTest.cpp + LLJITTest.cpp ObjectTransformLayerTest.cpp OrcCAPITest.cpp OrcTestCommon.cpp Index: unittests/ExecutionEngine/Orc/LLJITTest.cpp =================================================================== --- /dev/null +++ unittests/ExecutionEngine/Orc/LLJITTest.cpp @@ -0,0 +1,257 @@ +//===- LLJITTest.cpp - Unit tests for LLIJIT ------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/LLJIT.h" +#include "OrcTestCommon.h" + +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/SourceMgr.h" + +#include "gtest/gtest.h" + +#include +#include + +using namespace llvm; +using namespace llvm::orc; + +//===----------------------------------------------------------------------===// +// Some utils shared by all tests +//===----------------------------------------------------------------------===// + +#define LLVM_IR(...) #__VA_ARGS__ + +const char *FooIR = LLVM_IR( + define i32 @foo() { + ret i32 1 + } +); + +const char *BarIR = LLVM_IR( + define i32 @bar() { + ret i32 1 + } +); + +const char *BuzIR = LLVM_IR( + declare i32 @foo() + declare i32 @bar() + + define i32 @buz() { + %1 = call i32 @foo() + %2 = call i32 @bar() + %3 = add nsw i32 %1, %2 + ret i32 %3 + } +); + +static const MemoryBufferRef Foo{FooIR, "FooIR"}; +static const MemoryBufferRef Bar{BarIR, "BarIR"}; +static const MemoryBufferRef Buz{BuzIR, "BuzIR"}; + +static std::unique_ptr Init() { + InitializeNativeTarget(); + InitializeNativeTargetAsmParser(); + InitializeNativeTargetAsmPrinter(); + + auto ES = llvm::make_unique(); + auto TMB = cantFail(JITTargetMachineBuilder::detectHost()); + auto TM = cantFail(TMB.createTargetMachine()); + auto DL = TM->createDataLayout(); + + return cantFail(LLJIT::Create(std::move(ES), std::move(TM), DL)); +} + +static int Exec(JITEvaluatedSymbol FuncSym) { + auto FuncPtr = (int (*)())FuncSym.getAddress(); + return FuncPtr(); +} + +static void JoinAll(std::vector &Ts) { + for (std::thread &T : Ts) + if (T.joinable()) + T.join(); +} + +//===----------------------------------------------------------------------===// + +TEST(LLJIT, UnthreadedSingleDylib) { + std::unique_ptr Jit = Init(); + + SMDiagnostic E; + std::vector Cs(3); + std::vector Bs{Foo, Bar, Buz}; + + for (size_t i = 0; i < 3; i++) + cantFail(Jit->addIRModule(parseIR(Bs[i], E, Cs[i]))); + + JITEvaluatedSymbol FooSym = cantFail(Jit->lookup("foo")); + EXPECT_EQ(1, Exec(FooSym)); + + JITEvaluatedSymbol BarSym = cantFail(Jit->lookup("bar")); + EXPECT_EQ(1, Exec(BarSym)); + + JITEvaluatedSymbol BuzSym = cantFail(Jit->lookup("buz")); + EXPECT_EQ(2, Exec(BuzSym)); +} + +//===----------------------------------------------------------------------===// + +TEST(LLJIT, UnthreadedMultiDylib) { + std::unique_ptr Jit = Init(); + ExecutionSession &ES = Jit->getExecutionSession(); + + JITDylib &Main = Jit->getMainJITDylib(); + JITDylib &Extra1 = ES.createJITDylib("1"); + JITDylib &Extra2 = ES.createJITDylib("2"); + + SMDiagnostic E; + std::vector Cs(3); + std::vector Bs{Foo, Bar, Buz}; + std::vector Ds{&Main, &Extra1, &Extra2}; + + for (size_t i = 0; i < 3; i++) + cantFail(Jit->addIRModule(*Ds[i], parseIR(Bs[i], E, Cs[i]))); + + JITEvaluatedSymbol FooSym = cantFail(Jit->lookup("foo")); + EXPECT_EQ(1, Exec(FooSym)); + + JITEvaluatedSymbol BarSym = cantFail(Jit->lookup(Extra1, "bar")); + EXPECT_EQ(1, Exec(BarSym)); + + // Allow buz to resolve foo and bar from the other dylibs. + Extra2.addToSearchOrder(Main); + Extra2.addToSearchOrder(Extra1); + + JITEvaluatedSymbol BuzSym = cantFail(Jit->lookup(Extra2, "buz")); + EXPECT_EQ(2, Exec(BuzSym)); +} + +//===----------------------------------------------------------------------===// + +TEST(LLJIT, ThreadedSingleDylib) { + std::unique_ptr Jit = Init(); + ExecutionSession &ES = Jit->getExecutionSession(); + + SMDiagnostic E; + std::vector Cs(3); + std::vector Bs{Foo, Bar, Buz}; + + for (size_t i = 0; i < 3; i++) + cantFail(Jit->addIRModule(parseIR(Bs[i], E, Cs[i]))); + + // Dispatch compile jobs to separate threads. Below we invoke lookups from + // runner threads. Eventually execution arrives here and spawns additional + // materialization threads. These must be joined back to the original runner + // threads, which is quite hard to track. Thus, we use thread-local scope-exit + // guards here. It's not a good idea when calling from the main thread though! + std::thread::id MainThreadId = std::this_thread::get_id(); + ES.setDispatchMaterialization([&](JITDylib &JD, + std::unique_ptr MU) { + assert(std::this_thread::get_id() != MainThreadId && + "Don't join via thread-local scope-exit " + "when calling from main thread!!!"); + + auto SMU = std::shared_ptr(std::move(MU)); + auto Worker = + std::make_shared([SMU, &JD]() { SMU->doMaterialize(JD); }); + + thread_local auto JoinWorker = + llvm::make_scope_exit([Worker]() { Worker->join(); }); + }); + + // Run functions in different threads. + std::vector Ts; + + Ts.emplace_back([&Jit]() { + JITEvaluatedSymbol BuzSym = cantFail(Jit->lookup("buz")); + EXPECT_EQ(2, Exec(BuzSym)); + }); + + Ts.emplace_back([&Jit]() { + JITEvaluatedSymbol FooSym = cantFail(Jit->lookup("foo")); + EXPECT_EQ(1, Exec(FooSym)); + }); + + Ts.emplace_back([&Jit]() { + JITEvaluatedSymbol BarSym = cantFail(Jit->lookup("bar")); + EXPECT_EQ(1, Exec(BarSym)); + }); + + JoinAll(Ts); +} + +//===----------------------------------------------------------------------===// + +TEST(LLJIT, ThreadedMultiDylib) { + std::unique_ptr Jit = Init(); + ExecutionSession &ES = Jit->getExecutionSession(); + + JITDylib &Main = Jit->getMainJITDylib(); + JITDylib &Extra1 = ES.createJITDylib("1"); + JITDylib &Extra2 = ES.createJITDylib("2"); + Main.addToSearchOrder(Extra1); + Main.addToSearchOrder(Extra2); + + SMDiagnostic E; + std::vector Cs(3); + std::vector Bs{Foo, Bar, Buz}; + std::vector Ds{&Main, &Extra1, &Extra2}; + + for (size_t i = 0; i < 3; i++) + cantFail(Jit->addIRModule(*Ds[i], parseIR(Bs[i], E, Cs[i]))); + + // Dispatch compile jobs to separate threads. Below we invoke lookups from + // runner threads. Eventually execution arrives here and spawns additional + // materialization threads. These must be joined back to the original runner + // threads, which is quite hard to track. Thus, we use thread-local scope-exit + // guards here. It's not a good idea when calling from the main thread though! + std::thread::id MainThreadId = std::this_thread::get_id(); + ES.setDispatchMaterialization([&](JITDylib &JD, + std::unique_ptr MU) { + assert(std::this_thread::get_id() != MainThreadId && + "Don't join via thread-local scope-exit " + "when calling from main thread!!!"); + + auto SMU = std::shared_ptr(std::move(MU)); + auto Worker = + std::make_shared([SMU, &JD]() { SMU->doMaterialize(JD); }); + + thread_local auto JoinWorker = + llvm::make_scope_exit([Worker]() { Worker->join(); }); + }); + + // Allow buz to resolve foo and bar from the other dylibs. + Extra2.addToSearchOrder(Main); + Extra2.addToSearchOrder(Extra1); + + // Run functions in different threads. + std::vector Ts; + + Ts.emplace_back([&]() { + JITEvaluatedSymbol BuzSym = cantFail(Jit->lookup(Extra2, "buz")); + EXPECT_EQ(2, Exec(BuzSym)); + }); + + Ts.emplace_back([&]() { + JITEvaluatedSymbol FooSym = cantFail(Jit->lookup("foo")); + EXPECT_EQ(1, Exec(FooSym)); + }); + + Ts.emplace_back([&]() { + JITEvaluatedSymbol BarSym = cantFail(Jit->lookup(Extra1, "bar")); + EXPECT_EQ(1, Exec(BarSym)); + }); + + JoinAll(Ts); +}