diff --git a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp --- a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp +++ b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp @@ -10,6 +10,13 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Frontend/OpenMP/OMPGridValues.h" +#include "llvm/Object/ELF.h" +#include "llvm/Object/ELFObjectFile.h" + #include #include #include @@ -24,6 +31,7 @@ #include #include +#include "ELFSymbols.h" #include "impl_runtime.h" #include "interop_hsa.h" @@ -35,12 +43,8 @@ #include "omptargetplugin.h" #include "print_tracing.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" -#include "llvm/Frontend/OpenMP/OMPGridValues.h" - using namespace llvm; +using namespace llvm::object; // hostrpc interface, FIXME: consider moving to its own include these are // statically linked into amdgpu/plugin if present from hostrpc_services.a, @@ -1600,128 +1604,53 @@ return Changed; } -Elf64_Shdr *findOnlyShtHash(Elf *Elf) { - size_t N; - int Rc = elf_getshdrnum(Elf, &N); - if (Rc != 0) { - return nullptr; - } - - Elf64_Shdr *Result = nullptr; - for (size_t I = 0; I < N; I++) { - Elf_Scn *Scn = elf_getscn(Elf, I); - if (Scn) { - Elf64_Shdr *Shdr = elf64_getshdr(Scn); - if (Shdr) { - if (Shdr->sh_type == SHT_HASH) { - if (Result == nullptr) { - Result = Shdr; - } else { - // multiple SHT_HASH sections not handled - return nullptr; - } - } - } - } - } - return Result; -} - -const Elf64_Sym *elfLookup(Elf *Elf, char *Base, Elf64_Shdr *SectionHash, - const char *Symname) { - - assert(SectionHash); - size_t SectionSymtabIndex = SectionHash->sh_link; - Elf64_Shdr *SectionSymtab = - elf64_getshdr(elf_getscn(Elf, SectionSymtabIndex)); - size_t SectionStrtabIndex = SectionSymtab->sh_link; - - const Elf64_Sym *Symtab = - reinterpret_cast(Base + SectionSymtab->sh_offset); - - const uint32_t *Hashtab = - reinterpret_cast(Base + SectionHash->sh_offset); - - // Layout: - // nbucket - // nchain - // bucket[nbucket] - // chain[nchain] - uint32_t Nbucket = Hashtab[0]; - const uint32_t *Bucket = &Hashtab[2]; - const uint32_t *Chain = &Hashtab[Nbucket + 2]; - - const size_t Max = strlen(Symname) + 1; - const uint32_t Hash = elf_hash(Symname); - for (uint32_t I = Bucket[Hash % Nbucket]; I != 0; I = Chain[I]) { - char *N = elf_strptr(Elf, SectionStrtabIndex, Symtab[I].st_name); - if (strncmp(Symname, N, Max) == 0) { - return &Symtab[I]; - } - } - - return nullptr; -} - struct SymbolInfo { - void *Addr = nullptr; + const void *Addr = nullptr; uint32_t Size = UINT32_MAX; uint32_t ShType = SHT_NULL; }; -int getSymbolInfoWithoutLoading(Elf *Elf, char *Base, const char *Symname, - SymbolInfo *Res) { - if (elf_kind(Elf) != ELF_K_ELF) { - return 1; - } - - Elf64_Shdr *SectionHash = findOnlyShtHash(Elf); - if (!SectionHash) { - return 1; - } - - const Elf64_Sym *Sym = elfLookup(Elf, Base, SectionHash, Symname); - if (!Sym) { +int getSymbolInfoWithoutLoading(const ELFObjectFile &ELFObj, + StringRef SymName, SymbolInfo *Res) { + auto SymOrErr = getELFSymbol(ELFObj, SymName); + if (!SymOrErr) { + std::string ErrorString = toString(SymOrErr.takeError()); + DP("Failed ELF lookup: %s\n", ErrorString.c_str()); return 1; } - - if (Sym->st_size > UINT32_MAX) { - return 1; - } - - if (Sym->st_shndx == SHN_UNDEF) { - return 1; - } - - Elf_Scn *Section = elf_getscn(Elf, Sym->st_shndx); - if (!Section) { + if (!*SymOrErr) return 1; - } - Elf64_Shdr *Header = elf64_getshdr(Section); - if (!Header) { + auto SymSecOrErr = ELFObj.getELFFile().getSection((*SymOrErr)->st_shndx); + if (!SymSecOrErr) { + std::string ErrorString = toString(SymOrErr.takeError()); + DP("Failed ELF lookup: %s\n", ErrorString.c_str()); return 1; } - Res->Addr = Sym->st_value + Base; - Res->Size = static_cast(Sym->st_size); - Res->ShType = Header->sh_type; + Res->Addr = (*SymOrErr)->st_value + ELFObj.getELFFile().base(); + Res->Size = static_cast((*SymOrErr)->st_size); + Res->ShType = static_cast((*SymSecOrErr)->sh_type); return 0; } -int getSymbolInfoWithoutLoading(char *Base, size_t ImgSize, const char *Symname, +int getSymbolInfoWithoutLoading(char *Base, size_t ImgSize, const char *SymName, SymbolInfo *Res) { - Elf *Elf = elf_memory(Base, ImgSize); - if (Elf) { - int Rc = getSymbolInfoWithoutLoading(Elf, Base, Symname, Res); - elf_end(Elf); - return Rc; + StringRef Buffer = StringRef(Base, ImgSize); + auto ElfOrErr = ObjectFile::createELFObjectFile(MemoryBufferRef(Buffer, ""), + /*InitContent=*/false); + if (!ElfOrErr) { + REPORT("Failed to load ELF: %s\n", toString(ElfOrErr.takeError()).c_str()); + return 1; } + + if (const auto *ELFObj = dyn_cast(ElfOrErr->get())) + return getSymbolInfoWithoutLoading(*ELFObj, SymName, Res); return 1; } hsa_status_t interopGetSymbolInfo(char *Base, size_t ImgSize, - const char *SymName, void **VarAddr, + const char *SymName, const void **VarAddr, uint32_t *VarSize) { SymbolInfo SI; int Rc = getSymbolInfoWithoutLoading(Base, ImgSize, SymName, &SI); @@ -2492,7 +2421,7 @@ KernDescNameStr += "_kern_desc"; const char *KernDescName = KernDescNameStr.c_str(); - void *KernDescPtr; + const void *KernDescPtr; uint32_t KernDescSize; void *CallStackAddr = nullptr; Err = interopGetSymbolInfo((char *)Image->ImageStart, ImgSize, KernDescName, @@ -2531,7 +2460,7 @@ WGSizeNameStr += "_wg_size"; const char *WGSizeName = WGSizeNameStr.c_str(); - void *WGSizePtr; + const void *WGSizePtr; uint32_t WGSize; Err = interopGetSymbolInfo((char *)Image->ImageStart, ImgSize, WGSizeName, &WGSizePtr, &WGSize); @@ -2570,7 +2499,7 @@ ExecModeNameStr += "_exec_mode"; const char *ExecModeName = ExecModeNameStr.c_str(); - void *ExecModePtr; + const void *ExecModePtr; uint32_t VarSize; Err = interopGetSymbolInfo((char *)Image->ImageStart, ImgSize, ExecModeName, &ExecModePtr, &VarSize); diff --git a/openmp/libomptarget/plugins/common/elf_common/CMakeLists.txt b/openmp/libomptarget/plugins/common/elf_common/CMakeLists.txt --- a/openmp/libomptarget/plugins/common/elf_common/CMakeLists.txt +++ b/openmp/libomptarget/plugins/common/elf_common/CMakeLists.txt @@ -10,7 +10,7 @@ # ##===----------------------------------------------------------------------===## -add_library(elf_common OBJECT elf_common.cpp) +add_library(elf_common OBJECT elf_common.cpp ELFSymbols.cpp) # Build elf_common with PIC to be able to link it with plugin shared libraries. set_property(TARGET elf_common PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/openmp/libomptarget/plugins/common/elf_common/ELFSymbols.h b/openmp/libomptarget/plugins/common/elf_common/ELFSymbols.h new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/plugins/common/elf_common/ELFSymbols.h @@ -0,0 +1,27 @@ +//===-- ELFSymbols.h - ELF Symbol look-up functionality ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// ELF routines for obtaining a symbol from an Elf file without loading it. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_ELF_COMMON_ELF_SYMBOLS_H +#define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_ELF_COMMON_ELF_SYMBOLS_H + +#include "llvm/Object/ELF.h" +#include "llvm/Object/ELFObjectFile.h" + +/// Returns the symbol associated with the \p Name in the \p ELFObj. It will +/// first search for the hash sections to identify symbols from the hash table. +/// If that fails it will fall back to a linear search in the case of an +/// executable file without a hash table. +llvm::Expected +getELFSymbol(const llvm::object::ELFObjectFile &ELFObj, + llvm::StringRef Name); + +#endif diff --git a/openmp/libomptarget/plugins/common/elf_common/ELFSymbols.cpp b/openmp/libomptarget/plugins/common/elf_common/ELFSymbols.cpp new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/plugins/common/elf_common/ELFSymbols.cpp @@ -0,0 +1,193 @@ +#include "ELFSymbols.h" + +using namespace llvm; +using namespace llvm::object; +using namespace llvm::ELF; + +template +static Expected +getSymbolFromGnuHashTable(StringRef Name, const typename ELFT::GnuHash &HashTab, + ArrayRef SymTab, + StringRef StrTab) { + const uint32_t NameHash = hashGnu(Name); + const typename ELFT::Word NBucket = HashTab.nbuckets; + const typename ELFT::Word SymOffset = HashTab.symndx; + ArrayRef Filter = HashTab.filter(); + ArrayRef Bucket = HashTab.buckets(); + ArrayRef Chain = HashTab.values(SymTab.size()); + + // Check the bloom filter and exit early if the symbol is not present. + uint64_t ElfClassBits = ELFT::Is64Bits ? 64 : 32; + typename ELFT::Off Word = + Filter[(NameHash / ElfClassBits) % HashTab.maskwords]; + uint64_t Mask = (0x1ull << (NameHash % ElfClassBits)) | + (0x1ull << ((NameHash >> HashTab.shift2) % ElfClassBits)); + if ((Word & Mask) != Mask) + return nullptr; + + // The symbol may or may not be present, check the hash values. + for (typename ELFT::Word I = Bucket[NameHash % NBucket]; + I >= SymOffset && I < SymTab.size(); I = I + 1) { + const uint32_t ChainHash = Chain[I - SymOffset]; + + if ((NameHash | 0x1) != (ChainHash | 0x1)) + continue; + + if (SymTab[I].st_name >= StrTab.size()) + return createError("symbol [index " + Twine(I) + + "] has invalid st_name: " + Twine(SymTab[I].st_name)); + if (StrTab.drop_front(SymTab[I].st_name).data() == Name) + return &SymTab[I]; + + if (ChainHash & 0x1) + return nullptr; + } + return nullptr; +} + +template +static Expected +getSymbolFromSysVHashTable(StringRef Name, const typename ELFT::Hash &HashTab, + ArrayRef SymTab, + StringRef StrTab) { + const uint32_t Hash = hashSysV(Name); + const typename ELFT::Word NBucket = HashTab.nbucket; + ArrayRef Bucket = HashTab.buckets(); + ArrayRef Chain = HashTab.chains(); + for (typename ELFT::Word I = Bucket[Hash % NBucket]; I != ELF::STN_UNDEF; + I = Chain[I]) { + if (I >= SymTab.size()) + return createError( + "symbol [index " + Twine(I) + + "] is greater than the number of symbols: " + Twine(SymTab.size())); + if (SymTab[I].st_name >= StrTab.size()) + return createError("symbol [index " + Twine(I) + + "] has invalid st_name: " + Twine(SymTab[I].st_name)); + + if (StrTab.drop_front(SymTab[I].st_name).data() == Name) + return &SymTab[I]; + } + return nullptr; +} + +template +static Expected +getHashTableSymbol(const ELFFile &Elf, const typename ELFT::Shdr &Sec, + StringRef Name) { + if (Sec.sh_type != ELF::SHT_HASH && Sec.sh_type != ELF::SHT_GNU_HASH) + return createError( + "invalid sh_type for hash table, expected SHT_HASH or SHT_GNU_HASH"); + Expected SectionsOrError = Elf.sections(); + if (!SectionsOrError) + return SectionsOrError.takeError(); + + auto SymTabOrErr = getSection(*SectionsOrError, Sec.sh_link); + if (!SymTabOrErr) + return SymTabOrErr.takeError(); + + auto StrTabOrErr = + Elf.getStringTableForSymtab(**SymTabOrErr, *SectionsOrError); + if (!StrTabOrErr) + return StrTabOrErr.takeError(); + StringRef StrTab = *StrTabOrErr; + + auto SymsOrErr = Elf.symbols(*SymTabOrErr); + if (!SymsOrErr) + return SymsOrErr.takeError(); + ArrayRef SymTab = *SymsOrErr; + + // If this is a GNU hash table we verify its size and search the symbol + // table using the GNU hash table format. + if (Sec.sh_type == ELF::SHT_GNU_HASH) { + const typename ELFT::GnuHash *HashTab = + reinterpret_cast(Elf.base() + + Sec.sh_offset); + if (Sec.sh_offset + Sec.sh_size >= Elf.getBufSize()) + return createError("section has invalid sh_offset: " + + Twine(Sec.sh_offset)); + if (Sec.sh_size < sizeof(typename ELFT::GnuHash) || + Sec.sh_size < + sizeof(typename ELFT::GnuHash) + + sizeof(typename ELFT::Word) * HashTab->maskwords + + sizeof(typename ELFT::Word) * HashTab->nbuckets + + sizeof(typename ELFT::Word) * (SymTab.size() - HashTab->symndx)) + return createError("section has invalid sh_size: " + Twine(Sec.sh_size)); + return getSymbolFromGnuHashTable(Name, *HashTab, SymTab, StrTab); + } + + // If this is a Sys-V hash table we verify its size and search the symbol + // table using the Sys-V hash table format. + if (Sec.sh_type == ELF::SHT_HASH) { + const typename ELFT::Hash *HashTab = + reinterpret_cast(Elf.base() + + Sec.sh_offset); + if (Sec.sh_offset + Sec.sh_size >= Elf.getBufSize()) + return createError("section has invalid sh_offset: " + + Twine(Sec.sh_offset)); + if (Sec.sh_size < sizeof(typename ELFT::Hash) || + Sec.sh_size < sizeof(typename ELFT::Hash) + + sizeof(typename ELFT::Word) * HashTab->nbucket + + sizeof(typename ELFT::Word) * HashTab->nchain) + return createError("section has invalid sh_size: " + Twine(Sec.sh_size)); + + return getSymbolFromSysVHashTable(Name, *HashTab, SymTab, StrTab); + } + + return nullptr; +} + +template +static Expected +getSymTableSymbol(const ELFFile &Elf, const typename ELFT::Shdr &Sec, + StringRef Name) { + if (Sec.sh_type != ELF::SHT_SYMTAB && Sec.sh_type != ELF::SHT_DYNSYM) + return createError( + "invalid sh_type for hash table, expected SHT_SYMTAB or SHT_DYNSYM"); + Expected SectionsOrError = Elf.sections(); + if (!SectionsOrError) + return SectionsOrError.takeError(); + + auto StrTabOrErr = Elf.getStringTableForSymtab(Sec, *SectionsOrError); + if (!StrTabOrErr) + return StrTabOrErr.takeError(); + StringRef StrTab = *StrTabOrErr; + + auto SymsOrErr = Elf.symbols(&Sec); + if (!SymsOrErr) + return SymsOrErr.takeError(); + ArrayRef SymTab = *SymsOrErr; + + for (const typename ELFT::Sym &Sym : SymTab) + if (StrTab.drop_front(Sym.st_name).data() == Name) + return &Sym; + + return nullptr; +} + +Expected +getELFSymbol(const ELFObjectFile &ELFObj, StringRef Name) { + // First try to look up the symbol via the hash table. + for (ELFSectionRef Sec : ELFObj.sections()) { + if (Sec.getType() != SHT_HASH && Sec.getType() != SHT_GNU_HASH) + continue; + + auto HashTabOrErr = ELFObj.getELFFile().getSection(Sec.getIndex()); + if (!HashTabOrErr) + return HashTabOrErr.takeError(); + return getHashTableSymbol(ELFObj.getELFFile(), **HashTabOrErr, + Name); + } + + // If this is an executable file check the entire standard symbol table. + for (ELFSectionRef Sec : ELFObj.sections()) { + if (Sec.getType() != SHT_SYMTAB) + continue; + + auto SymTabOrErr = ELFObj.getELFFile().getSection(Sec.getIndex()); + if (!SymTabOrErr) + return SymTabOrErr.takeError(); + return getSymTableSymbol(ELFObj.getELFFile(), **SymTabOrErr, Name); + } + + return nullptr; +}