Index: llvm/include/llvm/Support/MemoryBuffer.h =================================================================== --- llvm/include/llvm/Support/MemoryBuffer.h +++ llvm/include/llvm/Support/MemoryBuffer.h @@ -15,10 +15,12 @@ #define LLVM_SUPPORT_MEMORYBUFFER_H #include "llvm-c/Types.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/CBindingWrapping.h" #include "llvm/Support/ErrorOr.h" +#include "llvm/Support/FileSystem.h" #include #include #include @@ -179,6 +181,50 @@ // Create wrappers for C Binding types (see CBindingWrapping.h). DEFINE_SIMPLE_CONVERSION_FUNCTIONS(MemoryBuffer, LLVMMemoryBufferRef) +/// This interface is similar to MemoryBuffer, but allows writing to the underlying +/// contents. It only supports creation methods that are guaranteed to produce a +/// writable buffer. For example, mapping a file read-only is not supported. +class WritableMemoryBuffer { + std::unique_ptr Buffer; + + explicit WritableMemoryBuffer(std::unique_ptr Buffer) + : Buffer(std::move(Buffer)) {} + +public: + // const_cast is well-defined here, because the underlying buffer is guaranteed to + // have been initialized with a mutable buffer. + uint8_t *getBufferStart() { return getBuffer().begin(); } + uint8_t *getBufferEnd() { return getBuffer().end(); } + + const uint8_t *getBufferStart() const { getBuffer().begin(); } + const uint8_t *getBufferEnd() const { getBuffer().end(); } + + size_t getBufferSize() const { return Buffer->getBufferSize(); } + + ArrayRef getBuffer() const { + const uint8_t *Start = reinterpret_cast(Buffer->getBufferStart()); + const uint8_t *End = reinterpret_cast(Buffer->getBufferEnd()); + return makeArrayRef(Start, End); + } + MutableArrayRef getBuffer() { + auto B = static_cast(this)->getBuffer(); + return MutableArrayRef( + const_cast(B.begin()), const_cast(B.end())); + } + + StringRef getBufferIdentifier() const { return Buffer->getBufferIdentifier(); } + MemoryBuffer::BufferKind getBufferKind() const { return Buffer->getBufferKind(); } + + static ErrorOr> + getFile(const Twine &Filename, int64_t FileSize = -1, + bool RequiresNullTerminator = true, bool IsVolatile = false, sys::fs::OpenFlags WriteFlags = sys::fs::F_RW); + + /// Map a subrange of the specified file as a WritableMemoryBuffer. + static ErrorOr> + getFileSlice(const Twine &Filename, uint64_t MapSize, uint64_t Offset, + bool IsVolatile = false, sys::fs::OpenFlags WriteFlags = sys::fs::F_RW); +}; + } // end namespace llvm #endif // LLVM_SUPPORT_MEMORYBUFFER_H Index: llvm/lib/Support/MemoryBuffer.cpp =================================================================== --- llvm/lib/Support/MemoryBuffer.cpp +++ llvm/lib/Support/MemoryBuffer.cpp @@ -103,7 +103,7 @@ static ErrorOr> getFileAux(const Twine &Filename, int64_t FileSize, uint64_t MapSize, - uint64_t Offset, bool RequiresNullTerminator, bool IsVolatile); + uint64_t Offset, bool RequiresNullTerminator, bool IsVolatile, sys::fs::OpenFlags); std::unique_ptr MemoryBuffer::getMemBuffer(StringRef InputData, StringRef BufferName, @@ -179,10 +179,9 @@ ErrorOr> MemoryBuffer::getFileSlice(const Twine &FilePath, uint64_t MapSize, uint64_t Offset, bool IsVolatile) { - return getFileAux(FilePath, -1, MapSize, Offset, false, IsVolatile); + return getFileAux(FilePath, -1, MapSize, Offset, false, IsVolatile, sys::fs::F_None); } - //===----------------------------------------------------------------------===// // MemoryBuffer::getFile implementation. //===----------------------------------------------------------------------===// @@ -208,8 +207,8 @@ public: MemoryBufferMMapFile(bool RequiresNullTerminator, int FD, uint64_t Len, - uint64_t Offset, std::error_code &EC) - : MFR(FD, sys::fs::mapped_file_region::readonly, + uint64_t Offset, bool Writable, std::error_code &EC) + : MFR(FD, Writable ? sys::fs::mapped_file_region::priv : sys::fs::mapped_file_region::readonly, getLegalMapSize(Len, Offset), getLegalMapOffset(Offset), EC) { if (!EC) { const char *Start = getStart(Len, Offset); @@ -254,29 +253,56 @@ MemoryBuffer::getFile(const Twine &Filename, int64_t FileSize, bool RequiresNullTerminator, bool IsVolatile) { return getFileAux(Filename, FileSize, FileSize, 0, - RequiresNullTerminator, IsVolatile); + RequiresNullTerminator, IsVolatile, sys::fs::F_None); } static ErrorOr> getOpenFileImpl(int FD, const Twine &Filename, uint64_t FileSize, uint64_t MapSize, int64_t Offset, bool RequiresNullTerminator, - bool IsVolatile); + bool IsVolatile, bool Writable); static ErrorOr> getFileAux(const Twine &Filename, int64_t FileSize, uint64_t MapSize, - uint64_t Offset, bool RequiresNullTerminator, bool IsVolatile) { + uint64_t Offset, bool RequiresNullTerminator, bool IsVolatile, sys::fs::OpenFlags WriteFlags) { int FD; - std::error_code EC = sys::fs::openFileForRead(Filename, FD); + bool Writable = !!(WriteFlags & sys::fs::F_RW); + std::error_code EC = + Writable ? sys::fs::openFileForWrite(Filename, FD, WriteFlags) + : sys::fs::openFileForRead(Filename, FD); + if (EC) return EC; ErrorOr> Ret = getOpenFileImpl(FD, Filename, FileSize, MapSize, Offset, - RequiresNullTerminator, IsVolatile); + RequiresNullTerminator, IsVolatile, Writable); close(FD); return Ret; } +ErrorOr> +WritableMemoryBuffer::getFile(const Twine &Filename, int64_t FileSize, + bool RequiresNullTerminator, bool IsVolatile, sys::fs::OpenFlags Flags) { + assert((Flags & sys::fs::F_RW) && "Opening a file for write must specify F_RW!"); + + auto Result = getFileAux(Filename, FileSize, FileSize, 0, RequiresNullTerminator, IsVolatile, Flags); + if (!Result) + return Result.getError(); + return std::unique_ptr(new WritableMemoryBuffer(std::move(*Result))); +} + +ErrorOr> +WritableMemoryBuffer::getFileSlice(const Twine &Filename, uint64_t MapSize, uint64_t Offset, + bool IsVolatile, sys::fs::OpenFlags WriteFlags) { + auto Result = getFileAux(Filename, -1, MapSize, Offset, false, IsVolatile, sys::fs::F_RW); + if (!Result) + return Result.getError(); + + return std::unique_ptr(new WritableMemoryBuffer(std::move(*Result))); +} + + + static bool shouldUseMmap(int FD, size_t FileSize, size_t MapSize, @@ -335,7 +361,7 @@ static ErrorOr> getOpenFileImpl(int FD, const Twine &Filename, uint64_t FileSize, uint64_t MapSize, int64_t Offset, bool RequiresNullTerminator, - bool IsVolatile) { + bool IsVolatile, bool Writable) { static int PageSize = sys::Process::getPageSize(); // Default is to map the full file. @@ -366,7 +392,7 @@ std::error_code EC; std::unique_ptr Result( new (NamedBufferAlloc(Filename)) - MemoryBufferMMapFile(RequiresNullTerminator, FD, MapSize, Offset, EC)); + MemoryBufferMMapFile(RequiresNullTerminator, FD, MapSize, Offset, Writable, EC)); if (!EC) return std::move(Result); } @@ -413,14 +439,14 @@ MemoryBuffer::getOpenFile(int FD, const Twine &Filename, uint64_t FileSize, bool RequiresNullTerminator, bool IsVolatile) { return getOpenFileImpl(FD, Filename, FileSize, FileSize, 0, - RequiresNullTerminator, IsVolatile); + RequiresNullTerminator, IsVolatile, false); } ErrorOr> MemoryBuffer::getOpenFileSlice(int FD, const Twine &Filename, uint64_t MapSize, int64_t Offset, bool IsVolatile) { assert(MapSize != uint64_t(-1)); - return getOpenFileImpl(FD, Filename, -1, MapSize, Offset, false, IsVolatile); + return getOpenFileImpl(FD, Filename, -1, MapSize, Offset, false, IsVolatile, false); } ErrorOr> MemoryBuffer::getSTDIN() { Index: llvm/unittests/Support/MemoryBufferTest.cpp =================================================================== --- llvm/unittests/Support/MemoryBufferTest.cpp +++ llvm/unittests/Support/MemoryBufferTest.cpp @@ -15,6 +15,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/FileUtilities.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" using namespace llvm; @@ -226,4 +227,39 @@ EXPECT_TRUE(BufData2.substr(0x1800,8).equals("abcdefgh")); EXPECT_TRUE(BufData2.substr(0x2FF8,8).equals("abcdefgh")); } + +TEST_F(MemoryBufferTest, writableSlice) { + // Create a file initialized with a bunch of zeros. + int FD; + SmallString<64> TestPath; + sys::fs::createTemporaryFile("MemoryBufferTest_WritableSlice", "temp", FD, TestPath); + FileRemover Cleanup(TestPath); + raw_fd_ostream OF(FD, true, /*unbuffered=*/true); + for (unsigned i = 0; i < 0x4000; ++i) { + OF << 0; + } + OF.close(); + + auto MBOrError = WritableMemoryBuffer::getFileSlice(TestPath.str(), 0x1000, 0x2000); + std::error_code EC = MBOrError.getError(); + ASSERT_FALSE(EC); + + { + // Write some data. It should be mapped private, so that upon completion + // the original file contents are not modified. + WritableMemoryBuffer &MB = **MBOrError; + ASSERT_EQ(MB.getBufferSize(), 0x1000); + uint8_t *Start = MB.getBufferStart(); + ASSERT_EQ(MB.getBufferEnd(), MB.getBufferStart() + MB.getBufferSize()); + ::memset(Start, 1, MB.getBufferSize()); + } + + auto MB2OrError = MemoryBuffer::getFile(TestPath); + ASSERT_FALSE(MB2OrError.getError()); + auto &MB2 = **MB2OrError; + ASSERT_EQ(0x4000, MB2.getBufferSize()); + std::vector Expected(MB2.getBufferSize()); + int Result = ::memcmp(MB2.getBufferStart(), Expected.data(), MB2.getBufferSize()); + ASSERT_EQ(0, Result); +} }