diff --git a/openmp/libomptarget/DeviceRTL/include/Configuration.h b/openmp/libomptarget/DeviceRTL/include/Configuration.h --- a/openmp/libomptarget/DeviceRTL/include/Configuration.h +++ b/openmp/libomptarget/DeviceRTL/include/Configuration.h @@ -40,6 +40,12 @@ /// Returns the cycles per second of the device's fixed frequency clock. uint64_t getClockFrequency(); +/// Returns the pointer to the beginning of the indirect call table. +void *getIndirectCallTablePtr(); + +/// Returns the size of the indirect call table. +uint64_t getIndirectCallTableSize(); + /// Return if debugging is enabled for the given debug kind. bool isDebugMode(DebugKind Level); diff --git a/openmp/libomptarget/DeviceRTL/src/Configuration.cpp b/openmp/libomptarget/DeviceRTL/src/Configuration.cpp --- a/openmp/libomptarget/DeviceRTL/src/Configuration.cpp +++ b/openmp/libomptarget/DeviceRTL/src/Configuration.cpp @@ -50,6 +50,15 @@ return __omp_rtl_device_environment.ClockFrequency; } +void *config::getIndirectCallTablePtr() { + return reinterpret_cast( + __omp_rtl_device_environment.IndirectCallTable); +} + +uint64_t config::getIndirectCallTableSize() { + return __omp_rtl_device_environment.IndirectCallTableSize; +} + bool config::isDebugMode(config::DebugKind Kind) { return config::getDebugKind() & Kind; } diff --git a/openmp/libomptarget/DeviceRTL/src/Misc.cpp b/openmp/libomptarget/DeviceRTL/src/Misc.cpp --- a/openmp/libomptarget/DeviceRTL/src/Misc.cpp +++ b/openmp/libomptarget/DeviceRTL/src/Misc.cpp @@ -69,6 +69,36 @@ #pragma omp end declare variant +/// Lookup a device-side function using a host pointer /p HstPtr using the table +/// provided by the device plugin. The table is an ordered pair of host and +/// device pointers sorted on the value of the host pointer. +void *indirectCallLookup(void *HstPtr) { + struct IndirectCallTable { + void *HstPtr; + void *DevPtr; + }; + IndirectCallTable *Table = + reinterpret_cast(config::getIndirectCallTablePtr()); + uint64_t TableSize = config::getIndirectCallTableSize(); + + if (!Table || TableSize == 0) + return nullptr; + + uint32_t Left = 0; + uint32_t Right = TableSize; + while (Left != Right) { + uint32_t Current = Left + (Right - Left) / 2; + if (Table[Current].HstPtr == HstPtr) + return Table[Current].DevPtr; + + if (HstPtr < Table[Current].HstPtr) + Right = Current; + else + Left = Current; + } + return nullptr; +} + } // namespace impl } // namespace ompx @@ -84,6 +114,10 @@ double omp_get_wtick(void) { return ompx::impl::getWTick(); } double omp_get_wtime(void) { return ompx::impl::getWTime(); } + +void *__llvm_omp_indirect_call_lookup(void *HstPtr) { + return ompx::impl::indirectCallLookup(HstPtr); +} } ///} diff --git a/openmp/libomptarget/include/Environment.h b/openmp/libomptarget/include/Environment.h --- a/openmp/libomptarget/include/Environment.h +++ b/openmp/libomptarget/include/Environment.h @@ -31,6 +31,8 @@ uint32_t DeviceNum; uint32_t DynamicMemSize; uint64_t ClockFrequency; + uintptr_t IndirectCallTable; + uint64_t IndirectCallTableSize; }; // NOTE: Please don't change the order of those members as their indices are diff --git a/openmp/libomptarget/include/omptarget.h b/openmp/libomptarget/include/omptarget.h --- a/openmp/libomptarget/include/omptarget.h +++ b/openmp/libomptarget/include/omptarget.h @@ -83,13 +83,16 @@ OMP_TGT_MAPTYPE_MEMBER_OF = 0xffff000000000000 }; +/// Flags for offload entries. enum OpenMPOffloadingDeclareTargetFlags { - /// Mark the entry as having a 'link' attribute. + /// Mark the entry global as having a 'link' attribute. OMP_DECLARE_TARGET_LINK = 0x01, - /// Mark the entry as being a global constructor. + /// Mark the entry kernel as being a global constructor. OMP_DECLARE_TARGET_CTOR = 0x02, - /// Mark the entry as being a global destructor. - OMP_DECLARE_TARGET_DTOR = 0x04 + /// Mark the entry kernel as being a global destructor. + OMP_DECLARE_TARGET_DTOR = 0x04, + /// Mark the entry global as being an indirectly callable function. + OMP_DECLARE_TARGET_INDIRECT = 0x04 }; enum OpenMPOffloadingRequiresDirFlags { diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp @@ -267,6 +267,53 @@ } RecordReplay; +// Extract the mapping of host function pointers to device function pointers +// from the entry table. Functions marked as 'indirect' in OpenMP will have +// offloading entries generated for them which map the host's function pointer +// to a global containing the corresponding function pointer on the device. +static Expected> +setupIndirectCallTable(GenericPluginTy &Plugin, GenericDeviceTy &Device, + DeviceImageTy &Image) { + GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler(); + + llvm::ArrayRef<__tgt_offload_entry> Entries(Image.getTgtImage()->EntriesBegin, + Image.getTgtImage()->EntriesEnd); + llvm::SmallVector> IndirectCallTable; + for (const auto &Entry : Entries) { + if (Entry.size == 0 || !(Entry.flags & OMP_DECLARE_TARGET_INDIRECT)) + continue; + + assert(Entry.size == sizeof(void *) && "Global not a function pointer?"); + auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back(); + + GlobalTy DeviceGlobal(Entry.name, Entry.size); + if (auto Err = + Handler.getGlobalMetadataFromDevice(Device, Image, DeviceGlobal)) + return std::move(Err); + + HstPtr = Entry.addr; + if (auto Err = Device.dataRetrieve(&DevPtr, DeviceGlobal.getPtr(), + Entry.size, nullptr)) + return std::move(Err); + } + + // If we do not have any indirect globals we exit early. + if (IndirectCallTable.empty()) + return std::pair{nullptr, 0}; + + // Sort the array to allow for more efficient lookup of device pointers. + llvm::sort(IndirectCallTable, + [](const auto &x, const auto &y) { return x.first < y.first; }); + + uint64_t TableSize = + IndirectCallTable.size() * sizeof(std::pair); + void *DevicePtr = Device.allocate(TableSize, nullptr, TARGET_ALLOC_DEVICE); + if (auto Err = Device.dataSubmit(DevicePtr, IndirectCallTable.data(), + TableSize, nullptr)) + return std::move(Err); + return std::pair(DevicePtr, IndirectCallTable.size()); +} + AsyncInfoWrapperTy::AsyncInfoWrapperTy(GenericDeviceTy &Device, __tgt_async_info *AsyncInfoPtr) : Device(Device), @@ -626,6 +673,11 @@ if (!shouldSetupDeviceEnvironment()) return Plugin::success(); + // Obtain a table mapping host function pointers to device function pointers. + auto CallTablePairOrErr = setupIndirectCallTable(Plugin, *this, Image); + if (!CallTablePairOrErr) + return CallTablePairOrErr.takeError(); + DeviceEnvironmentTy DeviceEnvironment; DeviceEnvironment.DebugKind = OMPX_DebugKind; DeviceEnvironment.NumDevices = Plugin.getNumDevices(); @@ -633,6 +685,9 @@ DeviceEnvironment.DeviceNum = DeviceId; DeviceEnvironment.DynamicMemSize = OMPX_SharedMemorySize; DeviceEnvironment.ClockFrequency = getClockFrequency(); + DeviceEnvironment.IndirectCallTable = + reinterpret_cast(CallTablePairOrErr->first); + DeviceEnvironment.IndirectCallTableSize = CallTablePairOrErr->second; // Create the metainfo of the device environment global. GlobalTy DevEnvGlobal("__omp_rtl_device_environment", diff --git a/openmp/libomptarget/src/rtl.cpp b/openmp/libomptarget/src/rtl.cpp --- a/openmp/libomptarget/src/rtl.cpp +++ b/openmp/libomptarget/src/rtl.cpp @@ -303,6 +303,10 @@ Device.HasPendingGlobals = true; for (__tgt_offload_entry *Entry = Img->EntriesBegin; Entry != Img->EntriesEnd; ++Entry) { + // Globals are not callable and use a different set of flags. + if (Entry->size != 0) + continue; + if (Entry->flags & OMP_DECLARE_TARGET_CTOR) { DP("Adding ctor " DPxMOD " to the pending list.\n", DPxPTR(Entry->addr)); diff --git a/openmp/libomptarget/test/api/omp_indirect_call.c b/openmp/libomptarget/test/api/omp_indirect_call.c new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/test/api/omp_indirect_call.c @@ -0,0 +1,47 @@ +// RUN: %libomptarget-compile-run-and-check-generic + +#include +#include + +#pragma omp begin declare variant match(device = {kind(gpu)}) +// Provided by the runtime. +void *__llvm_omp_indirect_call_lookup(void *host_ptr); +#pragma omp declare target to(__llvm_omp_indirect_call_lookup) \ + device_type(nohost) +#pragma omp end declare variant + +#pragma omp begin declare variant match(device = {kind(cpu)}) +// We assume unified addressing on the CPU target. +void *__llvm_omp_indirect_call_lookup(void *host_ptr) { return host_ptr; } +#pragma omp end declare variant + +#pragma omp begin declare target indirect +void foo(int *x) { *x = *x + 1; } +void bar(int *x) { *x = *x + 1; } +void baz(int *x) { *x = *x + 1; } +#pragma omp end declare target + +int main() { + void *foo_ptr = foo; + void *bar_ptr = bar; + void *baz_ptr = baz; + + int count = 0; + void *foo_res; + void *bar_res; + void *baz_res; +#pragma omp target map(to : foo_ptr, bar_ptr, baz_ptr) map(tofrom : count) + { + foo_res = __llvm_omp_indirect_call_lookup(foo_ptr); + ((void (*)(int *))foo_res)(&count); + bar_res = __llvm_omp_indirect_call_lookup(bar_ptr); + ((void (*)(int *))bar_res)(&count); + baz_res = __llvm_omp_indirect_call_lookup(baz_ptr); + ((void (*)(int *))baz_res)(&count); + } + + assert(count == 3 && "Calling failed"); + + // CHECK: PASS + printf("PASS\n"); +}