Index: test/tools/llvm-cfi-verify/AArch64/call.s =================================================================== --- /dev/null +++ test/tools/llvm-cfi-verify/AArch64/call.s @@ -0,0 +1,9 @@ +# RUN: llvm-cfi-verify -search-length-undef=6 -trap-function=__cfi_slowpath@plt %S/Inputs/call | FileCheck %s + +# CHECK-LABEL: {{^Instruction: .* \(PROTECTED\)}} +# CHECK-NEXT: tiny.cc:9 + +# CHECK: Expected Protected: 1 (100.00%) +# CHECK: Unexpected Protected: 0 (0.00%) +# CHECK: Expected Unprotected: 0 (0.00%) +# CHECK: Unexpected Unprotected (BAD): 0 (0.00%) Index: tools/llvm-cfi-verify/lib/FileAnalysis.h =================================================================== --- tools/llvm-cfi-verify/lib/FileAnalysis.h +++ tools/llvm-cfi-verify/lib/FileAnalysis.h @@ -11,6 +11,7 @@ #define LLVM_CFI_VERIFY_FILE_ANALYSIS_H #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" #include "llvm/BinaryFormat/ELF.h" #include "llvm/DebugInfo/Symbolize/Symbolize.h" #include "llvm/MC/MCAsmInfo.h" @@ -110,6 +111,10 @@ // Returns whether this instruction is used by CFI to trap the program. bool isCFITrap(const Instr &InstrMeta) const; + // Returns whether this instruction is a call to a function that represents a + // trap. + bool isTrapCall(const Instr &InstrMeta) const; + // Returns whether this function can fall through to the next instruction. // Undefined (and bad) instructions cannot fall through, and instruction that // modify the control flow can only fall through if they are conditional @@ -183,6 +188,10 @@ // internal members. Should only be called once by Create(). Error parseCodeSections(); + // Parses the symbol table to look for the address of function that represents + // a trap, if any. + Error parseSymbolTable(); + private: // Members that describe the input file. object::OwningBinary Binary; @@ -218,6 +227,9 @@ // A list of addresses of indirect control flow instructions. std::set IndirectInstructions; + + // The address of the function that represents a trap. + Optional TrapFunctionAddress; }; class UnsupportedDisassembly : public ErrorInfo { Index: tools/llvm-cfi-verify/lib/FileAnalysis.cpp =================================================================== --- tools/llvm-cfi-verify/lib/FileAnalysis.cpp +++ tools/llvm-cfi-verify/lib/FileAnalysis.cpp @@ -52,6 +52,11 @@ "will result in false positives for 'CFI unprotected' instructions."), cl::location(IgnoreDWARFFlag), cl::init(false)); +static cl::opt TrapFunction( + "trap-function", + cl::desc("The name of the function that represents a trap."), + cl::init("")); + StringRef stringCFIProtectionStatus(CFIProtectionStatus Status) { switch (Status) { case CFIProtectionStatus::PROTECTED: @@ -105,6 +110,9 @@ if (auto SectionParseResponse = Analysis.parseCodeSections()) return std::move(SectionParseResponse); + if (auto SymbolTableParseResponse = Analysis.parseSymbolTable()) + return std::move(SymbolTableParseResponse); + return std::move(Analysis); } @@ -165,7 +173,20 @@ bool FileAnalysis::isCFITrap(const Instr &InstrMeta) const { const auto &InstrDesc = MII->get(InstrMeta.Instruction.getOpcode()); - return InstrDesc.isTrap(); + return InstrDesc.isTrap() || isTrapCall(InstrMeta); +} + +bool FileAnalysis::isTrapCall(const Instr &InstrMeta) const { + if (!TrapFunctionAddress) + return false; + const auto &InstrDesc = MII->get(InstrMeta.Instruction.getOpcode()); + if (!InstrDesc.isCall()) + return false; + uint64_t Target; + if (!MIA->evaluateBranch(InstrMeta.Instruction, InstrMeta.VMAddress, + InstrMeta.InstructionSize, Target)) + return false; + return Target == TrapFunctionAddress.getValue(); } bool FileAnalysis::canFallThrough(const Instr &InstrMeta) const { @@ -518,6 +539,72 @@ } } +Error FileAnalysis::parseSymbolTable() { + if (TrapFunction.empty()) + return Error::success(); + StringRef TrapName = TrapFunction; + // Look for the trap function in the list of symbols. + for (auto &Sym : Object->symbols()) { + auto SymNameOrErr = Sym.getName(); + if (!SymNameOrErr) + consumeError(SymNameOrErr.takeError()); + else if (*SymNameOrErr == TrapName) { + auto AddrOrErr = Sym.getAddress(); + if (auto Err = AddrOrErr.takeError()) + return Err; + TrapFunctionAddress = *AddrOrErr; + return Error::success(); + } + } + // If we could not find the trap function in the object's list of symbols, it + // might still be in the PLT. Unfortunately, LLVM does not make that + // information easily accessible, so we have to compute it. + // We begin by finding the base address of the PLT. + if (!TrapName.endswith("@plt")) + return make_error("Could not find trap function", + inconvertibleErrorCode()); + const object::SectionRef *PLT = + &*find_if(Object->sections(), [&](const object::SectionRef &Section) { + StringRef Str; + return !Section.getName(Str) && Str == ".plt"; + }); + if (PLT == nullptr) + return make_error("Could not find trap function", + inconvertibleErrorCode()); + // We now search for the desired function in the dynamic relocation section. + // This section has the same order as the PLT, so if we find it, we can + // compute its offset from the base address of the PLT. + const object::SectionRef *RelaPLT = + &*find_if(Object->sections(), [&](const object::SectionRef &Section) { + StringRef Str; + return !Section.getName(Str) && Str == ".rela.plt"; + }); + if (RelaPLT == nullptr) + return make_error("Could not find trap function", + inconvertibleErrorCode()); + // We now find the target function in the dynamic relocation section and use + // its index there to compute its offset from the base of the PLT. + static const int PLTEntrySize = 0x10; + // Skip over the first PLT entry, which does not correspond to a function. + uint64_t NumEntries = + std::distance(RelaPLT->relocation_begin(), RelaPLT->relocation_end()); + uint64_t SymbolAddr = + PLT->getAddress() + PLT->getSize() - NumEntries * PLTEntrySize; + for (const auto &Relocation : RelaPLT->relocations()) { + const auto Symbol = Relocation.getSymbol(); + auto SymNameOrErr = Symbol->getName(); + if (!SymNameOrErr) + consumeError(SymNameOrErr.takeError()); + else if (TrapName == (*SymNameOrErr + "@plt").str()) { + TrapFunctionAddress = SymbolAddr; + return Error::success(); + } + SymbolAddr += PLTEntrySize; + } + return make_error("Could not find trap function", + inconvertibleErrorCode()); +} + UnsupportedDisassembly::UnsupportedDisassembly(StringRef Text) : Text(Text) {} char UnsupportedDisassembly::ID;