diff --git a/llvm/include/llvm/Support/raw_ostream.h b/llvm/include/llvm/Support/raw_ostream.h --- a/llvm/include/llvm/Support/raw_ostream.h +++ b/llvm/include/llvm/Support/raw_ostream.h @@ -309,6 +309,14 @@ // Subclass Interface //===--------------------------------------------------------------------===// + // Class kinds to support LLVM-style RTTI. + enum class Kind { + OStream = 0, + FDStream = 1, + }; + + virtual Kind get_kind() const; + private: /// The is the piece of the class that is implemented by subclasses. This /// writes the \p Size bytes starting at @@ -436,10 +444,17 @@ /// Determine an efficient buffer size. size_t preferred_buffer_size() const override; + void anchor() override; + +protected: /// Set the flag indicating that an output error has been encountered. void error_detected(std::error_code EC) { this->EC = EC; } - void anchor() override; + /// Return the file descriptor. + int get_fd() const { return FD; } + + // Update the file position by increasing \p Delta. + void inc_pos(uint64_t Delta) { pos += Delta; } public: /// Open the specified file for writing. If an error occurs, information @@ -548,6 +563,36 @@ /// This returns a reference to a raw_ostream which simply discards output. raw_ostream &nulls(); +//===----------------------------------------------------------------------===// +// File Streams +//===----------------------------------------------------------------------===// + +/// A raw_ostream that reads from, writes to and seeks a file descriptor. +/// +class raw_fd_stream : public raw_fd_ostream { +public: + /// Open the specified file for reading/writing/seeking. If an error occurs, + /// information about the error is put into EC, and the stream should be + /// immediately destroyed. + raw_fd_stream(StringRef Filename, std::error_code &EC); + + /// This reads the \p Size bytes into a buffer pointed by \p Ptr. + /// + /// \param Ptr The start of the buffer to hold data to be read. + /// + /// \param Size The number of bytes to be read. + /// + /// On success, the number of bytes read is returned, and the file position is + // advanced by this number. On error, -1 is returned, and EC is set. + ssize_t read(char *Ptr, size_t Size); + + /// Cast \p OS to a pointer of raw_fd_stream if the type of \p OS is + /// raw_fd_stream*. Return nullptr otherwise. + static bool classof(const raw_ostream *OS); + + Kind get_kind() const override; +}; + //===----------------------------------------------------------------------===// // Output Stream Adaptors //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Support/raw_ostream.cpp b/llvm/lib/Support/raw_ostream.cpp --- a/llvm/lib/Support/raw_ostream.cpp +++ b/llvm/lib/Support/raw_ostream.cpp @@ -554,6 +554,8 @@ void raw_ostream::anchor() {} +raw_ostream::Kind raw_ostream::get_kind() const { return Kind::OStream; } + //===----------------------------------------------------------------------===// // Formatted Output //===----------------------------------------------------------------------===// @@ -904,6 +906,39 @@ return S; } +//===----------------------------------------------------------------------===// +// File Streams +//===----------------------------------------------------------------------===// + +raw_fd_stream::raw_fd_stream(StringRef Filename, std::error_code &EC) + : raw_fd_ostream(Filename, EC, sys::fs::CD_CreateAlways, + sys::fs::FA_Write | sys::fs::FA_Read, sys::fs::OF_None) { + if (EC) + return; + + // Do not support STDOUT_FILENO and non-seekable files. + if (Filename == "-" || !supportsSeeking()) { + EC = std::make_error_code(std::errc::invalid_argument); + } +} + +ssize_t raw_fd_stream::read(char *Ptr, size_t Size) { + assert(get_fd() >= 0 && "File already closed."); + ssize_t RET = ::read(get_fd(), (void *)Ptr, Size); + if (RET >= 0) { + inc_pos(RET); + } else { + error_detected(std::error_code(errno, std::generic_category())); + } + return RET; +} + +raw_ostream::Kind raw_fd_stream::get_kind() const { return Kind::FDStream; } + +bool raw_fd_stream::classof(const raw_ostream *OS) { + return OS->get_kind() == Kind::FDStream; +} + //===----------------------------------------------------------------------===// // raw_string_ostream //===----------------------------------------------------------------------===// diff --git a/llvm/unittests/Support/raw_fd_stream_test.cpp b/llvm/unittests/Support/raw_fd_stream_test.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Support/raw_fd_stream_test.cpp @@ -0,0 +1,94 @@ +//===- llvm/unittest/Support/raw_fd_stream_test.cpp - raw_fd_stream tests -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallString.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/raw_ostream.h" +#include "gtest/gtest.h" + +using namespace llvm; + +#define ASSERT_NO_ERROR(x) \ + if (std::error_code ASSERT_NO_ERROR_ec = x) { \ + SmallString<128> MessageStorage; \ + raw_svector_ostream Message(MessageStorage); \ + Message << #x ": did not return errc::success.\n" \ + << "error number: " << ASSERT_NO_ERROR_ec.value() << "\n" \ + << "error message: " << ASSERT_NO_ERROR_ec.message() << "\n"; \ + GTEST_FATAL_FAILURE_(MessageStorage.c_str()); \ + } else { \ + } + +namespace { + +#ifdef _WIN32 +#define setenv(name, var, ignore) _putenv_s(name, var) +#endif + +TEST(raw_fd_streamTest, ReadAfterWrite) { + SmallString<64> Path; + const char *ParentPath = getenv("RAW_FD_STREAM_TEST_FILE"); + if (ParentPath) { + Path = ParentPath; + } else { + int FD; + ASSERT_NO_ERROR(sys::fs::createTemporaryFile("foo", "bar", FD, Path)); + setenv("RAW_FD_STREAM_TEST_FILE", Path.c_str(), true); + } + FileRemover Cleanup(Path); + std::error_code ec; + raw_fd_stream OS(Path, ec); + EXPECT_TRUE(!ec); + + char Bytes[8]; + + OS.write("01234567", 8); + + OS.seek(3); + EXPECT_EQ(OS.read(Bytes, 2), 2); + EXPECT_EQ(Bytes[0], '3'); + EXPECT_EQ(Bytes[1], '4'); + + OS.seek(4); + OS.write("xyz", 3); + + OS.seek(0); + EXPECT_EQ(OS.read(Bytes, 8), 8); + EXPECT_EQ(Bytes[0], '0'); + EXPECT_EQ(Bytes[1], '1'); + EXPECT_EQ(Bytes[2], '2'); + EXPECT_EQ(Bytes[3], '3'); + EXPECT_EQ(Bytes[4], 'x'); + EXPECT_EQ(Bytes[5], 'y'); + EXPECT_EQ(Bytes[6], 'z'); + EXPECT_EQ(Bytes[7], '7'); +} + +TEST(raw_fd_streamTest, DynCast) { + { + std::error_code ec; + raw_fd_stream OS("-", ec); + EXPECT_TRUE(dyn_cast(&OS)); + } + { + std::error_code ec; + raw_fd_ostream OS("-", ec); + EXPECT_FALSE(dyn_cast(&OS)); + } +} + +TEST(raw_fd_streamTest, DontSupportStdout) { + std::error_code ec; + raw_fd_stream OS("-", ec); + EXPECT_EQ(static_cast(ec.value()), std::errc::invalid_argument); +} + +} // namespace