diff --git a/bolt/include/bolt/Core/BinaryContext.h b/bolt/include/bolt/Core/BinaryContext.h --- a/bolt/include/bolt/Core/BinaryContext.h +++ b/bolt/include/bolt/Core/BinaryContext.h @@ -284,15 +284,11 @@ std::unordered_set UndefinedSymbols; /// [name] -> [BinaryData*] map used for global symbol resolution. - using SymbolMapType = StringMap; - SymbolMapType GlobalSymbols; + StringMap GlobalSymbols; + StringMap NameToJumpTable; /// [address] -> [BinaryData], ... /// Addresses never change. - /// Note: it is important that clients do not hold on to instances of - /// BinaryData* while the map is still being modified during BinaryFunction - /// disassembly. This is because of the possibility that a regular - /// BinaryData is later discovered to be a JumpTable. using BinaryDataMapType = std::map; using binary_data_iterator = BinaryDataMapType::iterator; using binary_data_const_iterator = BinaryDataMapType::const_iterator; @@ -382,6 +378,14 @@ return nullptr; } + /// Return JumpTable starting at a specific \p Address. + JumpTable *getJumpTableAtAddress(uint64_t Address) { + auto JTI = JumpTables.find(Address); + if (JTI != JumpTables.end()) + return JTI->second; + return nullptr; + } + unsigned getDWARFEncodingSize(unsigned Encoding) { switch (Encoding & 0x0f) { default: @@ -768,11 +772,19 @@ /// was not set. ErrorOr getSymbolValue(const MCSymbol &Symbol) const { const BinaryData *BD = getBinaryDataByName(Symbol.getName()); - if (!BD) + if (!BD) { + if (const JumpTable *JT = getJumpTableByName(Symbol.getName())) + return JT->getAddress(); return std::make_error_code(std::errc::bad_address); + } return BD->getAddress(); } + const JumpTable *getJumpTableByName(StringRef Name) const { + auto Itr = NameToJumpTable.find(Name); + return Itr != NameToJumpTable.end() ? Itr->second : nullptr; + } + /// Return a global symbol registered at a given \p Address and \p Size. /// If no symbol exists, create one with unique name using \p Prefix. /// If there are multiple symbols registered at the \p Address, then diff --git a/bolt/lib/Core/BinaryContext.cpp b/bolt/lib/Core/BinaryContext.cpp --- a/bolt/lib/Core/BinaryContext.cpp +++ b/bolt/lib/Core/BinaryContext.cpp @@ -807,18 +807,9 @@ return JT->getFirstLabel(); } - // Re-use the existing symbol if possible. - MCSymbol *JTLabel = nullptr; - if (BinaryData *Object = getBinaryDataAtAddress(Address)) { - if (!isInternalSymbolName(Object->getSymbol()->getName())) - JTLabel = Object->getSymbol(); - } - const uint64_t EntrySize = getJumpTableEntrySize(Type); - if (!JTLabel) { - const std::string JumpTableName = generateJumpTableName(Function, Address); - JTLabel = registerNameAtAddress(JumpTableName, Address, 0, EntrySize); - } + const std::string JumpTableName = generateJumpTableName(Function, Address); + MCSymbol *JTLabel = Ctx->getOrCreateSymbol(JumpTableName); LLVM_DEBUG(dbgs() << "BOLT-DEBUG: creating jump table " << JTLabel->getName() << " in function " << Function << '\n'); @@ -830,6 +821,7 @@ if (opts::Verbosity > 2) JT->print(outs()); JumpTables.emplace(Address, JT); + NameToJumpTable[JTLabel->getName()] = JT; // Duplicate the entry for the parent function for easy access. Function.JumpTables.emplace(Address, JT); @@ -864,6 +856,7 @@ // addresses in the input binary memory space JumpTableID = ~JumpTableID; JumpTables.emplace(JumpTableID, NewJT); + NameToJumpTable[NewLabel->getName()] = NewJT; Function.JumpTables.emplace(JumpTableID, NewJT); return std::make_pair(JumpTableID, NewLabel); } diff --git a/bolt/lib/Core/BinaryFunction.cpp b/bolt/lib/Core/BinaryFunction.cpp --- a/bolt/lib/Core/BinaryFunction.cpp +++ b/bolt/lib/Core/BinaryFunction.cpp @@ -1683,15 +1683,6 @@ } } } - - const uint64_t BDSize = - BC.getBinaryDataAtAddress(JT.getAddress())->getSize(); - if (!BDSize) { - BC.setBinaryDataSize(JT.getAddress(), JT.getSize()); - } else { - assert(BDSize >= JT.getSize() && - "jump table cannot be larger than the containing object"); - } } // Add TakenBranches from JumpTables. diff --git a/bolt/lib/Passes/IdenticalCodeFolding.cpp b/bolt/lib/Passes/IdenticalCodeFolding.cpp --- a/bolt/lib/Passes/IdenticalCodeFolding.cpp +++ b/bolt/lib/Passes/IdenticalCodeFolding.cpp @@ -247,30 +247,14 @@ return false; // Check if symbols are jump tables. - const BinaryData *SIA = BC.getBinaryDataByName(SymbolA->getName()); - if (!SIA) - return false; - const BinaryData *SIB = BC.getBinaryDataByName(SymbolB->getName()); - if (!SIB) - return false; - - assert((SIA->getAddress() != SIB->getAddress()) && - "different symbols should not have the same value"); - - const JumpTable *JumpTableA = - A.getJumpTableContainingAddress(SIA->getAddress()); + const JumpTable *JumpTableA = BC.getJumpTableByName(SymbolA->getName()); if (!JumpTableA) return false; - const JumpTable *JumpTableB = - B.getJumpTableContainingAddress(SIB->getAddress()); + const JumpTable *JumpTableB = BC.getJumpTableByName(SymbolB->getName()); if (!JumpTableB) return false; - if ((SIA->getAddress() - JumpTableA->getAddress()) != - (SIB->getAddress() - JumpTableB->getAddress())) - return false; - return equalJumpTables(*JumpTableA, *JumpTableB, A, B); }; diff --git a/bolt/lib/Profile/DataAggregator.cpp b/bolt/lib/Profile/DataAggregator.cpp --- a/bolt/lib/Profile/DataAggregator.cpp +++ b/bolt/lib/Profile/DataAggregator.cpp @@ -1707,7 +1707,10 @@ PC -= Func->getAddress(); // Try to resolve symbol for memory load - if (BinaryData *BD = BC->getBinaryDataContainingAddress(Addr)) { + if (JumpTable *JT = BC->getJumpTableContainingAddress(Addr)) { + MemName = JT->getName(); + Addr -= JT->getAddress(); + } else if (BinaryData *BD = BC->getBinaryDataContainingAddress(Addr)) { MemName = BD->getName(); Addr -= BD->getAddress(); } else if (opts::FilterMemProfile) { diff --git a/bolt/lib/Profile/DataReader.cpp b/bolt/lib/Profile/DataReader.cpp --- a/bolt/lib/Profile/DataReader.cpp +++ b/bolt/lib/Profile/DataReader.cpp @@ -305,8 +305,12 @@ BC.MIB->getOrCreateAnnotationAs( II->second, "MemoryAccessProfile"); BinaryData *BD = nullptr; - if (MI.Addr.IsSymbol) - BD = BC.getBinaryDataByName(MI.Addr.Name); + if (MI.Addr.IsSymbol) { + if (const JumpTable *JT = BC.getJumpTableByName(MI.Addr.Name)) + BD = BC.getJumpTableAtAddress(JT->getAddress()); + else + BD = BC.getBinaryDataByName(MI.Addr.Name); + } MemAccessProfile.AddressAccessInfo.push_back( {BD, MI.Addr.Offset, MI.Count}); auto NextII = std::next(II); diff --git a/bolt/lib/Rewrite/RewriteInstance.cpp b/bolt/lib/Rewrite/RewriteInstance.cpp --- a/bolt/lib/Rewrite/RewriteInstance.cpp +++ b/bolt/lib/Rewrite/RewriteInstance.cpp @@ -3072,6 +3072,10 @@ AllResults[Symbol] = JITEvaluatedSymbol(Address, JITSymbolFlags()); continue; } + if (const JumpTable *JT = BC.getJumpTableByName(SymName)) { + AllResults[Symbol] = JITEvaluatedSymbol(JT->getAddress(), JITSymbolFlags()); + continue; + } LLVM_DEBUG(dbgs() << "Resolved to address 0x0\n"); AllResults[Symbol] = JITEvaluatedSymbol(0, JITSymbolFlags()); } diff --git a/bolt/test/X86/split-func-jump-table-fragment-bidirection.s b/bolt/test/X86/split-func-jump-table-fragment-bidirection.s --- a/bolt/test/X86/split-func-jump-table-fragment-bidirection.s +++ b/bolt/test/X86/split-func-jump-table-fragment-bidirection.s @@ -11,7 +11,7 @@ # RUN: llvm-bolt -print-cfg -v=3 %t.exe -o %t.out 2>&1 | FileCheck %s # CHECK: BOLT-INFO: Multiple fragments access same jump table: main; main.cold.1 -# CHECK: PIC Jump table JUMP_TABLE1 for function main, main.cold.1 at {{.*}} with a total count of 0: +# CHECK: PIC Jump table {{.*}} for function main, main.cold.1 at {{.*}} with a total count of 0: .text .globl main diff --git a/bolt/test/runtime/X86/jt-symbol-disambiguation.s b/bolt/test/runtime/X86/jt-symbol-disambiguation.s new file mode 100644 --- /dev/null +++ b/bolt/test/runtime/X86/jt-symbol-disambiguation.s @@ -0,0 +1,91 @@ +# In this test case, the symbol that represents the end of a table +# in .rodata is being colocated with the start of a jump table from +# another function, and BOLT moves that jump table. This should not +# cause the symbol representing the end of the table to be moved as +# well. +# Bug reported in https://github.com/llvm/llvm-project/issues/55004 + +# REQUIRES: system-linux + +# RUN: llvm-mc -filetype=obj -triple x86_64-unknown-unknown %s -o %t.o +# RUN: llvm-strip --strip-unneeded %t.o +# RUN: %clang %cflags -no-pie -nostartfiles -nostdlib -lc %t.o -o %t.exe -Wl,-q + +# RUN: llvm-bolt %t.exe -o %t.exe.bolt --relocs=1 --lite=0 \ +# RUN: --reorder-blocks=reverse -jump-tables=move + +# RUN: %t.exe.bolt 1 2 3 + + .file "jt-symbol-disambiguation.s" + .text + +# ---- +# Func foo contains a jump table whose start is colocated with a +# symbol marking the end of a data table +# ---- + .globl foo + .type foo, @function +foo: + .cfi_startproc + xor %rax,%rax + and $0x3,%rdi + leaq .JT1(%rip), %rax + movslq (%rax, %rdi, 4), %rdi + addq %rax, %rdi + jmpq *%rdi +.LBB1: + movl $0x1,%eax + jmp .LBB5 +.LBB2: + movl $0x2,%eax + jmp .LBB5 +.LBB3: + movl $0x3,%eax + jmp .LBB5 +.LBB4: + movl $0x4,%eax +.LBB5: + retq + .cfi_endproc + .size foo, .-foo + +# ---- +# Func _start scans a table using begin/end pointers. End pointer is colocated +# with the start of a jump table of function foo. When that jump +# table moves, end pointer in _start should not be affected. +# ---- + .globl _start + .type _start, @function +_start: + .cfi_startproc + movq (%rsp), %rdi + callq foo + leaq .start_of_table(%rip), %rsi # iterator + leaq .end_of_table(%rip), %rdi # iterator end +.LBB6: + cmpq %rsi, %rdi + je .LBB7 + movq (%rsi), %rbx + leaq 8(%rsi), %rsi # ++iterator + jmp .LBB6 +.LBB7: + xor %rdi, %rdi + callq exit@PLT + .cfi_endproc + .size _start, .-_start + +# ---- +# Data section +# ---- + .section .rodata,"a",@progbits + .p2align 3 +.start_of_table: + .quad 123 + .quad 456 + .quad 789 +.end_of_table: +.JT1: + .long .LBB1 - .JT1 + .long .LBB2 - .JT1 + .long .LBB3 - .JT1 + .long .LBB4 - .JT1