diff --git a/llvm/include/llvm/BinaryFormat/DXContainer.h b/llvm/include/llvm/BinaryFormat/DXContainer.h --- a/llvm/include/llvm/BinaryFormat/DXContainer.h +++ b/llvm/include/llvm/BinaryFormat/DXContainer.h @@ -49,14 +49,14 @@ uint32_t Flags; // DxilShaderHashFlags uint8_t Digest[16]; - void byteSwap() { sys::swapByteOrder(Flags); } + void swapBytes() { sys::swapByteOrder(Flags); } }; struct ContainerVersion { uint16_t Major; uint16_t Minor; - void byteSwap() { + void swapBytes() { sys::swapByteOrder(Major); sys::swapByteOrder(Minor); } @@ -69,8 +69,8 @@ uint32_t FileSize; uint32_t PartCount; - void byteSwap() { - Version.byteSwap(); + void swapBytes() { + Version.swapBytes(); sys::swapByteOrder(FileSize); sys::swapByteOrder(PartCount); } @@ -82,6 +82,8 @@ struct PartHeader { uint8_t Name[4]; uint32_t Size; + + void swapBytes() { sys::swapByteOrder(Size); } // Structure is followed directly by part data: uint8_t PartData[PartSize]. }; diff --git a/llvm/include/llvm/Object/DXContainer.h b/llvm/include/llvm/Object/DXContainer.h --- a/llvm/include/llvm/Object/DXContainer.h +++ b/llvm/include/llvm/Object/DXContainer.h @@ -15,6 +15,7 @@ #ifndef LLVM_OBJECT_DXCONTAINER_H #define LLVM_OBJECT_DXCONTAINER_H +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/BinaryFormat/DXContainer.h" #include "llvm/Support/Error.h" @@ -28,10 +29,80 @@ MemoryBufferRef Data; dxbc::Header Header; + SmallVector PartOffsets; Error parseHeader(); + Error parsePartOffsets(); + friend class PartIterator; public: + // The PartIterator is a wrapper around the iterator for the PartOffsets + // member of the DXContainer. It contains a refernce to the container, and the + // current iterator value, as well as storage for a parsed part header. + class PartIterator { + const DXContainer &Container; + SmallVectorImpl::const_iterator OffsetIt; + struct PartData { + dxbc::PartHeader Part; + StringRef Data; + } IteratorState; + + friend class DXContainer; + + PartIterator(const DXContainer &C, + SmallVectorImpl::const_iterator It) + : Container(C), OffsetIt(It) { + if (OffsetIt == Container.PartOffsets.end()) + updateIteratorImpl(Container.PartOffsets.back()); + else + updateIterator(); + } + + // Updates the iterator's state data. This results in copying the part + // header into the iterator and handling any required byte swapping. This is + // called when incrementing or decrementing the iterator. + void updateIterator() { + if (OffsetIt != Container.PartOffsets.end()) + updateIteratorImpl(*OffsetIt); + } + + // Implementation for updating the iterator state based on a specified + // offest. + void updateIteratorImpl(const uint32_t Offset); + + public: + PartIterator &operator++() { + if (OffsetIt == Container.PartOffsets.end()) + return *this; + ++OffsetIt; + updateIterator(); + return *this; + } + + PartIterator operator++(int) { + PartIterator Tmp = *this; + ++(*this); + return Tmp; + } + + bool operator==(const PartIterator &RHS) const { + return OffsetIt == RHS.OffsetIt; + } + + bool operator!=(const PartIterator &RHS) const { + return OffsetIt != RHS.OffsetIt; + } + + const PartData &operator*() { return IteratorState; } + const PartData *operator->() { return &IteratorState; } + }; + + PartIterator begin() const { + return PartIterator(*this, PartOffsets.begin()); + } + + PartIterator end() const { return PartIterator(*this, PartOffsets.end()); } + StringRef getData() const { return Data.getBuffer(); } static Expected create(MemoryBufferRef Object); diff --git a/llvm/lib/Object/DXContainer.cpp b/llvm/lib/Object/DXContainer.cpp --- a/llvm/lib/Object/DXContainer.cpp +++ b/llvm/lib/Object/DXContainer.cpp @@ -18,15 +18,32 @@ } template -static Error readStruct(StringRef Buffer, const char *P, T &Struct) { +static Error readStruct(StringRef Buffer, const char *Src, T &Struct) { // Don't read before the beginning or past the end of the file - if (P < Buffer.begin() || P + sizeof(T) > Buffer.end()) + if (Src < Buffer.begin() || Src + sizeof(T) > Buffer.end()) return parseFailed("Reading structure out of file bounds"); - memcpy(&Struct, P, sizeof(T)); + memcpy(&Struct, Src, sizeof(T)); // DXContainer is always little endian if (sys::IsBigEndianHost) - Struct.byteSwap(); + Struct.swapBytes(); + return Error::success(); +} + +template +static Error readInteger(StringRef Buffer, const char *Src, T &Val) { + static_assert(std::is_integral::value, + "Cannot call readInteger on non-integral type."); + assert(reinterpret_cast(Src) % alignof(T) == 0 && + "Unaligned read of value from buffer!"); + // Don't read before the beginning or past the end of the file + if (Src < Buffer.begin() || Src + sizeof(T) > Buffer.end()) + return parseFailed("Reading structure out of file bounds"); + + Val = *reinterpret_cast(Src); + // DXContainer is always little endian + if (sys::IsBigEndianHost) + sys::swapByteOrder(Val); return Error::success(); } @@ -36,9 +53,35 @@ return readStruct(Data.getBuffer(), Data.getBuffer().data(), Header); } +Error DXContainer::parsePartOffsets() { + const char *Current = Data.getBuffer().data() + sizeof(dxbc::Header); + for (uint32_t Part = 0; Part < Header.PartCount; ++Part) { + uint32_t PartOffset; + if (Error Err = readInteger(Data.getBuffer(), Current, PartOffset)) + return Err; + Current += sizeof(uint32_t); + if (PartOffset + sizeof(dxbc::PartHeader) > Data.getBufferSize()) + return parseFailed("Part offset points beyond boundary of the file"); + PartOffsets.push_back(PartOffset); + } + return Error::success(); +} + Expected DXContainer::create(MemoryBufferRef Object) { DXContainer Container(Object); if (Error Err = Container.parseHeader()) return std::move(Err); + if (Error Err = Container.parsePartOffsets()) + return std::move(Err); return Container; } + +void DXContainer::PartIterator::updateIteratorImpl(const uint32_t Offset) { + StringRef Buffer = Container.Data.getBuffer(); + const char *Current = Buffer.data() + Offset; + // Offsets are validated during parsing, so all offsets in the container are + // valid and contain enough readable data to read a header. + cantFail(readStruct(Buffer, Current, IteratorState.Part)); + IteratorState.Data = + StringRef(Current + sizeof(dxbc::PartHeader), IteratorState.Part.Size); +} diff --git a/llvm/unittests/Object/DXContainerTest.cpp b/llvm/unittests/Object/DXContainerTest.cpp --- a/llvm/unittests/Object/DXContainerTest.cpp +++ b/llvm/unittests/Object/DXContainerTest.cpp @@ -39,11 +39,17 @@ FailedWithMessage("Reading structure out of file bounds")); } +TEST(DXCFile, EmptyFile) { + EXPECT_THAT_EXPECTED( + DXContainer::create(MemoryBufferRef(StringRef("", 0), "")), + FailedWithMessage("Reading structure out of file bounds")); +} + TEST(DXCFile, ParseHeader) { uint8_t Buffer[] = {0x44, 0x58, 0x42, 0x43, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x70, 0x0D, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00}; + 0x70, 0x0D, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; DXContainer C = llvm::cantFail(DXContainer::create(getMemoryBuffer<32>(Buffer))); EXPECT_TRUE(memcmp(C.getHeader().Magic, "DXBC", 4) == 0); @@ -52,3 +58,71 @@ EXPECT_EQ(C.getHeader().Version.Major, 1u); EXPECT_EQ(C.getHeader().Version.Minor, 0u); } + +TEST(DXCFile, ParsePartMissingOffsets) { + uint8_t Buffer[] = { + 0x44, 0x58, 0x42, 0x43, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x70, 0x0D, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + }; + EXPECT_THAT_EXPECTED( + DXContainer::create(getMemoryBuffer<32>(Buffer)), + FailedWithMessage("Reading structure out of file bounds")); +} + +TEST(DXCFile, ParsePartInvalidOffsets) { + uint8_t Buffer[] = { + 0x44, 0x58, 0x42, 0x43, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x70, 0x0D, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, + }; + EXPECT_THAT_EXPECTED( + DXContainer::create(getMemoryBuffer<36>(Buffer)), + FailedWithMessage("Part offset points beyond boundary of the file")); +} + +TEST(DXCFile, ParseEmptyParts) { + uint8_t Buffer[] = { + 0x44, 0x58, 0x42, 0x43, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x70, 0x0D, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x3C, 0x00, 0x00, 0x00, + 0x44, 0x00, 0x00, 0x00, 0x4C, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, + 0x5C, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x6C, 0x00, 0x00, 0x00, + 0x53, 0x46, 0x49, 0x30, 0x00, 0x00, 0x00, 0x00, 0x49, 0x53, 0x47, 0x31, + 0x00, 0x00, 0x00, 0x00, 0x4F, 0x53, 0x47, 0x31, 0x00, 0x00, 0x00, 0x00, + 0x50, 0x53, 0x56, 0x30, 0x00, 0x00, 0x00, 0x00, 0x53, 0x54, 0x41, 0x54, + 0x00, 0x00, 0x00, 0x00, 0x44, 0x58, 0x49, 0x4C, 0x00, 0x00, 0x00, 0x00, + 0x44, 0x45, 0x41, 0x44, 0x00, 0x00, 0x00, 0x00, + }; + DXContainer C = + llvm::cantFail(DXContainer::create(getMemoryBuffer<116>(Buffer))); + EXPECT_EQ(C.getHeader().PartCount, 7u); + + // All the part sizes are 0, which makes a nice test of the range based for + int ElementsVisited = 0; + for (auto Part : C) { + EXPECT_EQ(Part.Part.Size, 0u); + EXPECT_EQ(Part.Data.size(), 0u); + ++ElementsVisited; + } + EXPECT_EQ(ElementsVisited, 7); + + { + auto It = C.begin(); + EXPECT_TRUE(memcmp(It->Part.Name, "SFI0", 4) == 0); + ++It; + EXPECT_TRUE(memcmp(It->Part.Name, "ISG1", 4) == 0); + ++It; + EXPECT_TRUE(memcmp(It->Part.Name, "OSG1", 4) == 0); + ++It; + EXPECT_TRUE(memcmp(It->Part.Name, "PSV0", 4) == 0); + ++It; + EXPECT_TRUE(memcmp(It->Part.Name, "STAT", 4) == 0); + ++It; + EXPECT_TRUE(memcmp(It->Part.Name, "DXIL", 4) == 0); + ++It; + EXPECT_TRUE(memcmp(It->Part.Name, "DEAD", 4) == 0); + ++It; // Don't increment past the end + EXPECT_TRUE(memcmp(It->Part.Name, "DEAD", 4) == 0); + } +}