diff --git a/llvm/include/llvm/Support/raw_ostream.h b/llvm/include/llvm/Support/raw_ostream.h --- a/llvm/include/llvm/Support/raw_ostream.h +++ b/llvm/include/llvm/Support/raw_ostream.h @@ -47,7 +47,16 @@ /// buffered disciplines etc. It is a simple buffer that outputs /// a chunk at a time. class raw_ostream { +public: + // Class kinds to support LLVM-style RTTI. + enum class OStreamKind { + OK_OStream, + OK_FDStream, + }; + private: + OStreamKind Kind; + /// The buffer is handled in such a way that the buffer is /// uninitialized, unbuffered, or out of space when OutBufCur >= /// OutBufEnd. Thus a single comparison suffices to determine if we @@ -105,9 +114,10 @@ static constexpr Colors SAVEDCOLOR = Colors::SAVEDCOLOR; static constexpr Colors RESET = Colors::RESET; - explicit raw_ostream(bool unbuffered = false) - : BufferMode(unbuffered ? BufferKind::Unbuffered - : BufferKind::InternalBuffer) { + explicit raw_ostream(bool unbuffered = false, + OStreamKind K = OStreamKind::OK_OStream) + : Kind(K), BufferMode(unbuffered ? BufferKind::Unbuffered + : BufferKind::InternalBuffer) { // Start out ready to flush. OutBufStart = OutBufEnd = OutBufCur = nullptr; } @@ -120,6 +130,8 @@ /// tell - Return the current offset with the file. uint64_t tell() const { return current_pos() + GetNumBytesInBuffer(); } + OStreamKind get_kind() const { return Kind; } + //===--------------------------------------------------------------------===// // Configuration Interface //===--------------------------------------------------------------------===// @@ -388,8 +400,9 @@ void anchor() override; public: - explicit raw_pwrite_stream(bool Unbuffered = false) - : raw_ostream(Unbuffered) {} + explicit raw_pwrite_stream(bool Unbuffered = false, + OStreamKind K = OStreamKind::OK_OStream) + : raw_ostream(Unbuffered, K) {} void pwrite(const char *Ptr, size_t Size, uint64_t Offset) { #ifndef NDEBUG uint64_t Pos = tell(); @@ -436,10 +449,17 @@ /// Determine an efficient buffer size. size_t preferred_buffer_size() const override; + void anchor() override; + +protected: /// Set the flag indicating that an output error has been encountered. void error_detected(std::error_code EC) { this->EC = EC; } - void anchor() override; + /// Return the file descriptor. + int get_fd() const { return FD; } + + // Update the file position by increasing \p Delta. + void inc_pos(uint64_t Delta) { pos += Delta; } public: /// Open the specified file for writing. If an error occurs, information @@ -464,7 +484,8 @@ /// FD is the file descriptor that this writes to. If ShouldClose is true, /// this closes the file when the stream is destroyed. If FD is for stdout or /// stderr, it will not be closed. - raw_fd_ostream(int fd, bool shouldClose, bool unbuffered=false); + raw_fd_ostream(int fd, bool shouldClose, bool unbuffered = false, + OStreamKind K = OStreamKind::OK_OStream); ~raw_fd_ostream() override; @@ -548,6 +569,34 @@ /// This returns a reference to a raw_ostream which simply discards output. raw_ostream &nulls(); +//===----------------------------------------------------------------------===// +// File Streams +//===----------------------------------------------------------------------===// + +/// A raw_ostream of a file for reading/writing/seeking. +/// +class raw_fd_stream : public raw_fd_ostream { +public: + /// Open the specified file for reading/writing/seeking. If an error occurs, + /// information about the error is put into EC, and the stream should be + /// immediately destroyed. + raw_fd_stream(StringRef Filename, std::error_code &EC); + + /// This reads the \p Size bytes into a buffer pointed by \p Ptr. + /// + /// \param Ptr The start of the buffer to hold data to be read. + /// + /// \param Size The number of bytes to be read. + /// + /// On success, the number of bytes read is returned, and the file position is + /// advanced by this number. On error, -1 is returned, use error() to get the + /// error code. + ssize_t read(char *Ptr, size_t Size); + + /// Check if \p OS is a pointer of type raw_fd_stream*. + static bool classof(const raw_ostream *OS); +}; + //===----------------------------------------------------------------------===// // Output Stream Adaptors //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Support/raw_ostream.cpp b/llvm/lib/Support/raw_ostream.cpp --- a/llvm/lib/Support/raw_ostream.cpp +++ b/llvm/lib/Support/raw_ostream.cpp @@ -620,8 +620,9 @@ /// FD is the file descriptor that this writes to. If ShouldClose is true, this /// closes the file when the stream is destroyed. -raw_fd_ostream::raw_fd_ostream(int fd, bool shouldClose, bool unbuffered) - : raw_pwrite_stream(unbuffered), FD(fd), ShouldClose(shouldClose) { +raw_fd_ostream::raw_fd_ostream(int fd, bool shouldClose, bool unbuffered, + OStreamKind K) + : raw_pwrite_stream(unbuffered, K), FD(fd), ShouldClose(shouldClose) { if (FD < 0 ) { ShouldClose = false; return; @@ -904,6 +905,37 @@ return S; } +//===----------------------------------------------------------------------===// +// File Streams +//===----------------------------------------------------------------------===// + +raw_fd_stream::raw_fd_stream(StringRef Filename, std::error_code &EC) + : raw_fd_ostream(getFD(Filename, EC, sys::fs::CD_CreateAlways, + sys::fs::FA_Write | sys::fs::FA_Read, + sys::fs::OF_None), + true, false, OStreamKind::OK_FDStream) { + if (EC) + return; + + // Do not support non-seekable files. + if (!supportsSeeking()) + EC = std::make_error_code(std::errc::invalid_argument); +} + +ssize_t raw_fd_stream::read(char *Ptr, size_t Size) { + assert(get_fd() >= 0 && "File already closed."); + ssize_t Ret = ::read(get_fd(), (void *)Ptr, Size); + if (Ret >= 0) + inc_pos(Ret); + else + error_detected(std::error_code(errno, std::generic_category())); + return Ret; +} + +bool raw_fd_stream::classof(const raw_ostream *OS) { + return OS->get_kind() == OStreamKind::OK_FDStream; +} + //===----------------------------------------------------------------------===// // raw_string_ostream //===----------------------------------------------------------------------===// diff --git a/llvm/unittests/Support/raw_fd_stream_test.cpp b/llvm/unittests/Support/raw_fd_stream_test.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Support/raw_fd_stream_test.cpp @@ -0,0 +1,67 @@ +//===- llvm/unittest/Support/raw_fd_stream_test.cpp - raw_fd_stream tests -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallString.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/raw_ostream.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +TEST(raw_fd_streamTest, ReadAfterWrite) { + SmallString<64> Path; + int FD; + ASSERT_FALSE(sys::fs::createTemporaryFile("foo", "bar", FD, Path)); + FileRemover Cleanup(Path); + std::error_code EC; + raw_fd_stream OS(Path, EC); + EXPECT_TRUE(!EC); + + char Bytes[8]; + + OS.write("01234567", 8); + + OS.seek(3); + EXPECT_EQ(OS.read(Bytes, 2), 2); + EXPECT_EQ(Bytes[0], '3'); + EXPECT_EQ(Bytes[1], '4'); + + OS.seek(4); + OS.write("xyz", 3); + + OS.seek(0); + EXPECT_EQ(OS.read(Bytes, 8), 8); + EXPECT_EQ(Bytes[0], '0'); + EXPECT_EQ(Bytes[1], '1'); + EXPECT_EQ(Bytes[2], '2'); + EXPECT_EQ(Bytes[3], '3'); + EXPECT_EQ(Bytes[4], 'x'); + EXPECT_EQ(Bytes[5], 'y'); + EXPECT_EQ(Bytes[6], 'z'); + EXPECT_EQ(Bytes[7], '7'); +} + +TEST(raw_fd_streamTest, DynCast) { + { + std::error_code EC; + raw_fd_stream OS("-", EC); + EXPECT_TRUE(dyn_cast(&OS)); + } + { + std::error_code EC; + raw_fd_ostream OS("-", EC); + EXPECT_FALSE(dyn_cast(&OS)); + } +} + +} // namespace