Index: include/llvm/DebugInfo/PDB/Raw/TpiStream.h =================================================================== --- include/llvm/DebugInfo/PDB/Raw/TpiStream.h +++ include/llvm/DebugInfo/PDB/Raw/TpiStream.h @@ -52,6 +52,8 @@ iterator_range types(bool *HadError) const; private: + Error verifyHashValues(); + const PDBFile &Pdb; std::unique_ptr Stream; HashFunctionType HashFunction; Index: lib/DebugInfo/PDB/Raw/TpiStream.cpp =================================================================== --- lib/DebugInfo/PDB/Raw/TpiStream.cpp +++ lib/DebugInfo/PDB/Raw/TpiStream.cpp @@ -9,6 +9,7 @@ #include "llvm/DebugInfo/PDB/Raw/TpiStream.h" +#include "llvm/DebugInfo/CodeView/CVTypeVisitor.h" #include "llvm/DebugInfo/CodeView/CodeView.h" #include "llvm/DebugInfo/CodeView/StreamReader.h" #include "llvm/DebugInfo/CodeView/TypeIndex.h" @@ -24,6 +25,7 @@ #include "llvm/Support/Endian.h" using namespace llvm; +using namespace llvm::codeview; using namespace llvm::support; using namespace llvm::pdb; @@ -70,77 +72,76 @@ TpiStream::~TpiStream() {} // Computes a hash for a given TPI record. -template -static Error getTpiHash(const codeview::CVType &Rec, uint32_t &Hash) { - ArrayRef Data = Rec.Data; - ErrorOr Obj = T::deserialize(K, Data); - if (Obj.getError()) - return llvm::make_error( - codeview::cv_error_code::corrupt_record); - - auto Opts = static_cast(Obj->getOptions()); - if (Opts & static_cast(codeview::ClassOptions::ForwardReference)) { - // We don't know how to calculate a hash value for this yet. - // Currently we just skip it. - Hash = 0; - return Error::success(); +template static uint32_t getTpiHash(T &Rec) { + auto Opts = static_cast(Rec.getOptions()); + + // We don't know how to calculate a hash value for this yet. + // Currently we just skip it. + if (Opts & static_cast(codeview::ClassOptions::ForwardReference)) + return 0; + + if (!(Opts & static_cast(codeview::ClassOptions::Scoped))) + return hashStringV1(Rec.getName()); + + if (Opts & static_cast(codeview::ClassOptions::HasUniqueName)) + return hashStringV1(Rec.getUniqueName()); + + // This case is not implemented yet. + return 0; +} + +namespace { +class TpiHashVerifier : public CVTypeVisitor { +public: + TpiHashVerifier(FixedStreamArray &HashValues, + uint32_t NumHashBuckets) : + HashValues(HashValues), NumHashBuckets(NumHashBuckets) {} + + void visitUdtSourceLine(TypeLeafKind, UdtSourceLineRecord &Rec) { + verifySourceLine(Rec); } - if (!(Opts & static_cast(codeview::ClassOptions::Scoped))) { - Hash = hashStringV1(Obj->getName()); - return Error::success(); + void visitUdtModSourceLine(TypeLeafKind, UdtModSourceLineRecord &Rec) { + verifySourceLine(Rec); } - if (Opts & static_cast(codeview::ClassOptions::HasUniqueName)) { - Hash = hashStringV1(Obj->getUniqueName()); - return Error::success(); + void visitClass(TypeLeafKind, ClassRecord &Rec) { verify(Rec); } + void visitEnum(TypeLeafKind, EnumRecord &Rec) { verify(Rec); } + void visitInterface(TypeLeafKind, ClassRecord &Rec) { verify(Rec); } + void visitStruct(TypeLeafKind, ClassRecord &Rec) { verify(Rec); } + void visitUnion(TypeLeafKind, UnionRecord &Rec) { verify(Rec); } + + void visitTypeEnd(TypeLeafKind Leaf, ArrayRef RecordData) { + ++Index; } - // This case is not implemented yet. - Hash = 0; - return Error::success(); +private: + template void verify(T &Rec) { + uint32_t Hash = getTpiHash(Rec); + if (Hash && Hash % NumHashBuckets != HashValues[Index]) + parseError(); + } + + template void verifySourceLine(T &Rec) { + char Buf[4]; + support::endian::write32le(Buf, Rec.getUDT().getIndex()); + uint32_t Hash = hashStringV1(StringRef(Buf, 4)); + if (Hash % NumHashBuckets != HashValues[Index]) + parseError(); + } + + FixedStreamArray HashValues; + uint32_t NumHashBuckets; + uint32_t Index = 0; +}; } // Verifies that a given type record matches with a given hash value. // Currently we only verify SRC_LINE records. -static Error verifyTIHash(const codeview::CVType &Rec, uint32_t Expected, - uint32_t NumHashBuckets) { - using namespace codeview; - - ArrayRef D = Rec.Data; - uint32_t Hash; - - switch (Rec.Type) { - case LF_UDT_SRC_LINE: - case LF_UDT_MOD_SRC_LINE: - Hash = hashStringV1(StringRef((const char *)D.data(), 4)); - break; - case LF_CLASS: - if (auto EC = getTpiHash(Rec, Hash)) - return EC; - break; - case LF_ENUM: - if (auto EC = getTpiHash(Rec, Hash)) - return EC; - break; - case LF_INTERFACE: - if (auto EC = getTpiHash(Rec, Hash)) - return EC; - break; - case LF_STRUCTURE: - if (auto EC = getTpiHash(Rec, Hash)) - return EC; - break; - case LF_UNION: - if (auto EC = getTpiHash(Rec, Hash)) - return EC; - break; - default: - // This pattern is not implemented yet. - return Error::success(); - } - - if (Hash && (Hash % NumHashBuckets) != Expected) +Error TpiStream::verifyHashValues() { + TpiHashVerifier Verifier(HashValues, Header->NumHashBuckets); + Verifier.visitTypeStream(TypeRecords); + if (Verifier.hadError()) return make_error(raw_error_code::corrupt_file, "Corrupt TPI hash table."); return Error::success(); @@ -216,13 +217,8 @@ // TPI hash table is a parallel array for the type records. // Verify that the hash values match with type records. - size_t I = 0; - bool HasError; - for (const codeview::CVType &Rec : types(&HasError)) { - if (auto EC = verifyTIHash(Rec, HashValues[I], Header->NumHashBuckets)) - return EC; - ++I; - } + if (auto EC = verifyHashValues()) + return EC; return Error::success(); }