Index: test/tools/llvm-cfi-verify/AArch64/call.s =================================================================== --- /dev/null +++ test/tools/llvm-cfi-verify/AArch64/call.s @@ -0,0 +1,9 @@ +# RUN: llvm-cfi-verify -search-length-undef=6 %S/Inputs/call | FileCheck %s + +# CHECK-LABEL: {{^Instruction: .* \(PROTECTED\)}} +# CHECK-NEXT: tiny.cc:9 + +# CHECK: Expected Protected: 1 (100.00%) +# CHECK: Unexpected Protected: 0 (0.00%) +# CHECK: Expected Unprotected: 0 (0.00%) +# CHECK: Unexpected Unprotected (BAD): 0 (0.00%) Index: tools/llvm-cfi-verify/lib/FileAnalysis.h =================================================================== --- tools/llvm-cfi-verify/lib/FileAnalysis.h +++ tools/llvm-cfi-verify/lib/FileAnalysis.h @@ -11,6 +11,7 @@ #define LLVM_CFI_VERIFY_FILE_ANALYSIS_H #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/BinaryFormat/ELF.h" #include "llvm/DebugInfo/Symbolize/Symbolize.h" #include "llvm/MC/MCAsmInfo.h" @@ -110,6 +111,10 @@ // Returns whether this instruction is used by CFI to trap the program. bool isCFITrap(const Instr &InstrMeta) const; + // Returns whether this instruction is a call to a function that will trap on + // CFI violations (i.e., it serves as a trap in this instance). + bool willTrapOnCFIViolation(const Instr &InstrMeta) const; + // Returns whether this function can fall through to the next instruction. // Undefined (and bad) instructions cannot fall through, and instruction that // modify the control flow can only fall through if they are conditional @@ -183,6 +188,10 @@ // internal members. Should only be called once by Create(). Error parseCodeSections(); + // Parses the symbol table to look for the addresses of functions that will + // trap on CFI violations. + Error parseSymbolTable(); + private: // Members that describe the input file. object::OwningBinary Binary; @@ -218,6 +227,9 @@ // A list of addresses of indirect control flow instructions. std::set IndirectInstructions; + + // The addresses of functions that will trap on CFI violations. + SmallSet TrapOnFailFunctionAddresses; }; class UnsupportedDisassembly : public ErrorInfo { Index: tools/llvm-cfi-verify/lib/FileAnalysis.cpp =================================================================== --- tools/llvm-cfi-verify/lib/FileAnalysis.cpp +++ tools/llvm-cfi-verify/lib/FileAnalysis.cpp @@ -105,6 +105,9 @@ if (auto SectionParseResponse = Analysis.parseCodeSections()) return std::move(SectionParseResponse); + if (auto SymbolTableParseResponse = Analysis.parseSymbolTable()) + return std::move(SymbolTableParseResponse); + return std::move(Analysis); } @@ -165,7 +168,18 @@ bool FileAnalysis::isCFITrap(const Instr &InstrMeta) const { const auto &InstrDesc = MII->get(InstrMeta.Instruction.getOpcode()); - return InstrDesc.isTrap(); + return InstrDesc.isTrap() || willTrapOnCFIViolation(InstrMeta); +} + +bool FileAnalysis::willTrapOnCFIViolation(const Instr &InstrMeta) const { + const auto &InstrDesc = MII->get(InstrMeta.Instruction.getOpcode()); + if (!InstrDesc.isCall()) + return false; + uint64_t Target; + if (!MIA->evaluateBranch(InstrMeta.Instruction, InstrMeta.VMAddress, + InstrMeta.InstructionSize, Target)) + return false; + return TrapOnFailFunctionAddresses.count(Target) > 0; } bool FileAnalysis::canFallThrough(const Instr &InstrMeta) const { @@ -518,6 +532,69 @@ } } +Error FileAnalysis::parseSymbolTable() { + // Functions that will trap on CFI violations. + SmallSet TrapOnFailFunctions; + TrapOnFailFunctions.insert("__cfi_slowpath"); + TrapOnFailFunctions.insert("abort"); + TrapOnFailFunctions.insert("__cfi_slowpath@plt"); + TrapOnFailFunctions.insert("abort@plt"); + + // Look through the list of symbols for functions that will trap on CFI + // violations. + for (auto &Sym : Object->symbols()) { + auto SymNameOrErr = Sym.getName(); + if (!SymNameOrErr) + consumeError(SymNameOrErr.takeError()); + else if (TrapOnFailFunctions.count(*SymNameOrErr) > 0) { + auto AddrOrErr = Sym.getAddress(); + if (!AddrOrErr) + consumeError(AddrOrErr.takeError()); + else + TrapOnFailFunctionAddresses.insert(*AddrOrErr); + } + } + // Look for these functions in the PLT. Unfortunately, LLVM does not make + // that information easily accessible, so we have to compute it. + // We begin by finding the base address of the PLT. + const object::SectionRef *PLT = + &*find_if(Object->sections(), [&](const object::SectionRef &Section) { + StringRef Str; + return !Section.getName(Str) && Str == ".plt"; + }); + if (PLT == nullptr) + return Error::success(); + // We now search for the functions in the dynamic relocation section. + // This section has the same order as the PLT, so if we find a function, we + // can compute its offset from the base address of the PLT. + const object::SectionRef *RelaPLT = + &*find_if(Object->sections(), [&](const object::SectionRef &Section) { + StringRef Str; + return !Section.getName(Str) && Str == ".rela.plt"; + }); + if (RelaPLT == nullptr) + return Error::success(); + // We now find the functions in the dynamic relocation section and use + // their index there to compute their offset from the base of the PLT. + static const int PLTEntrySize = 0x10; + // Skip over the first PLT entry, which does not correspond to a function. + uint64_t NumEntries = + std::distance(RelaPLT->relocation_begin(), RelaPLT->relocation_end()); + uint64_t SymbolAddr = + PLT->getAddress() + PLT->getSize() - NumEntries * PLTEntrySize; + for (const auto &Relocation : RelaPLT->relocations()) { + const auto Symbol = Relocation.getSymbol(); + auto SymNameOrErr = Symbol->getName(); + if (!SymNameOrErr) + consumeError(SymNameOrErr.takeError()); + else if (TrapOnFailFunctions.count((*SymNameOrErr + "@plt").str()) > 0) { + TrapOnFailFunctionAddresses.insert(SymbolAddr); + } + SymbolAddr += PLTEntrySize; + } + return Error::success(); +} + UnsupportedDisassembly::UnsupportedDisassembly(StringRef Text) : Text(Text) {} char UnsupportedDisassembly::ID;