diff --git a/openmp/libomptarget/include/device.h b/openmp/libomptarget/include/device.h --- a/openmp/libomptarget/include/device.h +++ b/openmp/libomptarget/include/device.h @@ -479,6 +479,10 @@ /// RTLs identified on the host RTLsTy RTLs; + /// Executable images and information extracted from the input images passed + /// to the runtime. + std::list> Images; + /// Devices associated with RTLs std::vector> Devices; std::mutex RTLsMtx; ///< For RTLs and Devices diff --git a/openmp/libomptarget/include/omptarget.h b/openmp/libomptarget/include/omptarget.h --- a/openmp/libomptarget/include/omptarget.h +++ b/openmp/libomptarget/include/omptarget.h @@ -143,6 +143,11 @@ __tgt_offload_entry *EntriesEnd; // End of table (non inclusive) }; +/// This struct contains information about a given image. +struct __tgt_image_info { + const char *Arch; +}; + /// This struct is a record of all the host code that may be offloaded to a /// target. struct __tgt_bin_desc { diff --git a/openmp/libomptarget/include/omptargetplugin.h b/openmp/libomptarget/include/omptargetplugin.h --- a/openmp/libomptarget/include/omptargetplugin.h +++ b/openmp/libomptarget/include/omptargetplugin.h @@ -31,6 +31,12 @@ // having to load the library, which can be expensive. int32_t __tgt_rtl_is_valid_binary(__tgt_device_image *Image); +// This provides the same functionality as __tgt_rtl_is_valid_binary except we +// also use additional information to determine if the image is valid. This +// allows us to determine if an image has a compatible architecture. +int32_t __tgt_rtl_is_valid_binary_info(__tgt_device_image *Image, + __tgt_image_info *Info); + // Return an integer other than zero if the data can be exchaned from SrcDevId // to DstDevId. If it is data exchangable, the device plugin should provide // function to move data from source device to destination device directly. diff --git a/openmp/libomptarget/include/rtl.h b/openmp/libomptarget/include/rtl.h --- a/openmp/libomptarget/include/rtl.h +++ b/openmp/libomptarget/include/rtl.h @@ -26,6 +26,7 @@ struct RTLInfoTy { typedef int32_t(is_valid_binary_ty)(void *); + typedef int32_t(is_valid_binary_info_ty)(void *, void *); typedef int32_t(is_data_exchangable_ty)(int32_t, int32_t); typedef int32_t(number_of_devices_ty)(); typedef int32_t(init_device_ty)(int32_t); @@ -82,6 +83,7 @@ // Functions implemented in the RTL. is_valid_binary_ty *is_valid_binary = nullptr; + is_valid_binary_info_ty *is_valid_binary_info = nullptr; is_data_exchangable_ty *is_data_exchangable = nullptr; number_of_devices_ty *number_of_devices = nullptr; init_device_ty *init_device = nullptr; diff --git a/openmp/libomptarget/src/rtl.cpp b/openmp/libomptarget/src/rtl.cpp --- a/openmp/libomptarget/src/rtl.cpp +++ b/openmp/libomptarget/src/rtl.cpp @@ -14,6 +14,8 @@ #include "device.h" #include "private.h" +#include "llvm/Object/OffloadBinary.h" + #include #include #include @@ -165,6 +167,8 @@ R.NumberOfDevices); // Optional functions + *((void **)&R.is_valid_binary_info) = + dlsym(DynlibHandle, "__tgt_rtl_is_valid_binary_info"); *((void **)&R.deinit_device) = dlsym(DynlibHandle, "__tgt_rtl_deinit_device"); *((void **)&R.init_requires) = @@ -275,6 +279,39 @@ } } +static __tgt_device_image getExecutableImage(__tgt_device_image *Image) { + llvm::StringRef ImageStr(static_cast(Image->ImageStart), + static_cast(Image->ImageEnd) - + static_cast(Image->ImageStart)); + auto BinaryOrErr = + llvm::object::OffloadBinary::create(llvm::MemoryBufferRef(ImageStr, "")); + if (!BinaryOrErr) { + llvm::consumeError(BinaryOrErr.takeError()); + return *Image; + } + + void *Begin = const_cast( + static_cast((*BinaryOrErr)->getImage().bytes_begin())); + void *End = const_cast( + static_cast((*BinaryOrErr)->getImage().bytes_end())); + + return {Begin, End, Image->EntriesBegin, Image->EntriesEnd}; +} + +static __tgt_image_info getImageInfo(__tgt_device_image *Image) { + llvm::StringRef ImageStr(static_cast(Image->ImageStart), + static_cast(Image->ImageEnd) - + static_cast(Image->ImageStart)); + auto BinaryOrErr = + llvm::object::OffloadBinary::create(llvm::MemoryBufferRef(ImageStr, "")); + if (!BinaryOrErr) { + llvm::consumeError(BinaryOrErr.takeError()); + return __tgt_image_info{}; + } + + return __tgt_image_info{(*BinaryOrErr)->getArch().data()}; +} + void RTLsTy::registerRequires(int64_t Flags) { // TODO: add more elaborate check. // Minimal check: only set requires flags if previous value @@ -350,17 +387,30 @@ void RTLsTy::registerLib(__tgt_bin_desc *Desc) { PM->RTLsMtx.lock(); + + // Extract the exectuable image and extra information if availible. + for (int32_t i = 0; i < Desc->NumDeviceImages; ++i) + PM->Images.emplace_back(getExecutableImage(&Desc->DeviceImages[i]), + getImageInfo(&Desc->DeviceImages[i])); + // Register the images with the RTLs that understand them, if any. - for (int32_t I = 0; I < Desc->NumDeviceImages; ++I) { - // Obtain the image. - __tgt_device_image *Img = &Desc->DeviceImages[I]; + for (auto &ImageAndInfo : PM->Images) { + // Obtain the image and information that was previously extracted. + __tgt_device_image *Img = &ImageAndInfo.first; + __tgt_image_info *Info = &ImageAndInfo.second; RTLInfoTy *FoundRTL = nullptr; // Scan the RTLs that have associated images until we find one that supports // the current image. for (auto &R : AllRTLs) { - if (!R.is_valid_binary(Img)) { + if (R.is_valid_binary_info) { + if (!R.is_valid_binary_info(Img, Info)) { + DP("Image " DPxMOD " is NOT compatible with RTL %s!\n", + DPxPTR(Img->ImageStart), R.RTLName.c_str()); + continue; + } + } else if (!R.is_valid_binary(Img)) { DP("Image " DPxMOD " is NOT compatible with RTL %s!\n", DPxPTR(Img->ImageStart), R.RTLName.c_str()); continue; @@ -412,9 +462,10 @@ PM->RTLsMtx.lock(); // Find which RTL understands each image, if any. - for (int32_t I = 0; I < Desc->NumDeviceImages; ++I) { - // Obtain the image. - __tgt_device_image *Img = &Desc->DeviceImages[I]; + for (auto &ImageAndInfo : PM->Images) { + // Obtain the image and information that was previously extracted. + __tgt_device_image *Img = &ImageAndInfo.first; + __tgt_image_info *Info = &ImageAndInfo.second; RTLInfoTy *FoundRTL = NULL; @@ -424,9 +475,15 @@ assert(R->IsUsed && "Expecting used RTLs."); - if (!R->is_valid_binary(Img)) { - DP("Image " DPxMOD " is NOT compatible with RTL " DPxMOD "!\n", - DPxPTR(Img->ImageStart), DPxPTR(R->LibraryHandler)); + if (R->is_valid_binary_info) { + if (!R->is_valid_binary_info(Img, Info)) { + DP("Image " DPxMOD " is NOT compatible with RTL %s!\n", + DPxPTR(Img->ImageStart), R->RTLName.c_str()); + continue; + } + } else if (!R->is_valid_binary(Img)) { + DP("Image " DPxMOD " is NOT compatible with RTL %s!\n", + DPxPTR(Img->ImageStart), R->RTLName.c_str()); continue; }