diff --git a/llvm/include/llvm/Support/raw_ostream.h b/llvm/include/llvm/Support/raw_ostream.h --- a/llvm/include/llvm/Support/raw_ostream.h +++ b/llvm/include/llvm/Support/raw_ostream.h @@ -714,6 +714,19 @@ ~buffer_unique_ostream() override { *OS << str(); } }; +class Error; + +/// A writeToStream helper creates an output stream, based on the specified +/// \p OutputFileName: std::outs for the "-", raw_null_ostream for +/// the "/dev/null", raw_fd_stream for other names. The final output +/// file is atomically replaced with the temporary file after \p Write handler +/// is finished. \p KeepOwnership is used to setting specified \p UserID and +/// \p GroupID for the resulting file if writeToStream is called under /root. +Error writeToStream(StringRef OutputFileName, + std::function Write, + bool KeepOwnership = false, unsigned UserID = 0, + unsigned GroupID = 0); + } // end namespace llvm #endif // LLVM_SUPPORT_RAW_OSTREAM_H diff --git a/llvm/lib/Support/raw_ostream.cpp b/llvm/lib/Support/raw_ostream.cpp --- a/llvm/lib/Support/raw_ostream.cpp +++ b/llvm/lib/Support/raw_ostream.cpp @@ -989,3 +989,42 @@ void buffer_ostream::anchor() {} void buffer_unique_ostream::anchor() {} + +Error llvm::writeToStream(StringRef OutputFileName, + std::function Write, + bool KeepOwnership, unsigned UserID, + unsigned GroupID) { + if (OutputFileName == "-") + return Write(outs()); + + if (OutputFileName == "/dev/null") { + raw_null_ostream Out; + return Write(Out); + } + + unsigned Mode = sys::fs::all_read | sys::fs::all_write | sys::fs::all_exe; + Expected Temp = + sys::fs::TempFile::create(OutputFileName + ".temp-stream-%%%%%%", Mode); + if (!Temp) + return createFileError(OutputFileName, Temp.takeError()); + +#ifndef _WIN32 + // Try to preserve file ownership if requested. + if (KeepOwnership) { + sys::fs::file_status Stat; + if (!sys::fs::status(Temp->FD, Stat) && Stat.getUser() == 0) + sys::fs::changeFileOwnership(Temp->FD, UserID, GroupID); + } +#endif + + raw_fd_ostream Out(Temp->FD, false); + + if (Error E = Write(Out)) { + if (Error DiscardError = Temp->discard()) + return joinErrors(std::move(E), std::move(DiscardError)); + return E; + } + Out.flush(); + + return Temp->keep(OutputFileName); +} diff --git a/llvm/tools/llvm-objcopy/ELF/ELFObjcopy.cpp b/llvm/tools/llvm-objcopy/ELF/ELFObjcopy.cpp --- a/llvm/tools/llvm-objcopy/ELF/ELFObjcopy.cpp +++ b/llvm/tools/llvm-objcopy/ELF/ELFObjcopy.cpp @@ -190,7 +190,7 @@ (*DWOFile)->OSABI = Config.OutputArch.getValue().OSABI; } - return writeToFile(File, [&](raw_ostream &OutFile) -> Error { + return writeToStream(File, [&](raw_ostream &OutFile) -> Error { std::unique_ptr Writer = createWriter(Config, **DWOFile, OutFile, OutputElfType); if (Error E = Writer->finalize()) diff --git a/llvm/tools/llvm-objcopy/llvm-objcopy.h b/llvm/tools/llvm-objcopy/llvm-objcopy.h --- a/llvm/tools/llvm-objcopy/llvm-objcopy.h +++ b/llvm/tools/llvm-objcopy/llvm-objcopy.h @@ -27,18 +27,6 @@ Expected> createNewArchiveMembers(CopyConfig &Config, const object::Archive &Ar); -/// A writeToFile helper creates an output stream, based on the specified -/// \p OutputFileName: std::outs for the "-", raw_null_ostream for -/// the "/dev/null", temporary file in the same directory as the final output -/// file for other names. The final output file is atomically replaced with -/// the temporary file after \p Write handler is finished. \p KeepOwnership -/// used to setting specified \p UserID and \p GroupID for the resulting file -/// if writeToFile is called under /root. -Error writeToFile(StringRef OutputFileName, - std::function Write, - bool KeepOwnership = false, unsigned UserID = 0, - unsigned GroupID = 0); - } // end namespace objcopy } // end namespace llvm diff --git a/llvm/tools/llvm-objcopy/llvm-objcopy.cpp b/llvm/tools/llvm-objcopy/llvm-objcopy.cpp --- a/llvm/tools/llvm-objcopy/llvm-objcopy.cpp +++ b/llvm/tools/llvm-objcopy/llvm-objcopy.cpp @@ -57,44 +57,6 @@ namespace llvm { namespace objcopy { -Error writeToFile(StringRef OutputFileName, - std::function Write, bool KeepOwnership, - unsigned UserID, unsigned GroupID) { - if (OutputFileName == "-") - return Write(outs()); - - if (OutputFileName == "/dev/null") { - raw_null_ostream Out; - return Write(Out); - } - - unsigned Mode = sys::fs::all_read | sys::fs::all_write | sys::fs::all_exe; - Expected Temp = - sys::fs::TempFile::create(OutputFileName + ".temp-objcopy-%%%%%%", Mode); - if (!Temp) - return createFileError(OutputFileName, Temp.takeError()); - -#ifndef _WIN32 - // Try to preserve file ownership if requested. - if (KeepOwnership) { - sys::fs::file_status Stat; - if (!sys::fs::status(Temp->FD, Stat) && Stat.getUser() == 0) - sys::fs::changeFileOwnership(Temp->FD, UserID, GroupID); - } -#endif - - raw_fd_ostream Out(Temp->FD, false); - - if (Error E = Write(Out)) { - if (Error DiscardError = Temp->discard()) - return joinErrors(std::move(E), std::move(DiscardError)); - return E; - } - Out.flush(); - - return Temp->keep(OutputFileName); -} - // The name this program was invoked as. StringRef ToolName; @@ -344,7 +306,7 @@ if (!BufOrErr) return createFileError(Config.InputFilename, BufOrErr.getError()); - if (Error E = writeToFile( + if (Error E = writeToStream( Config.OutputFilename, [&](raw_ostream &OutFile) -> Error { return ProcessRaw(Config, *BufOrErr->get(), OutFile); })) @@ -359,7 +321,7 @@ if (Error E = executeObjcopyOnArchive(Config, *Ar)) return E; } else { - if (Error E = writeToFile( + if (Error E = writeToStream( Config.OutputFilename, [&](raw_ostream &OutFile) -> Error { return executeObjcopyOnBinary( diff --git a/llvm/unittests/Support/raw_ostream_test.cpp b/llvm/unittests/Support/raw_ostream_test.cpp --- a/llvm/unittests/Support/raw_ostream_test.cpp +++ b/llvm/unittests/Support/raw_ostream_test.cpp @@ -8,7 +8,9 @@ #include "llvm/ADT/SmallString.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/FileUtilities.h" #include "llvm/Support/Format.h" +#include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/raw_ostream.h" #include "gtest/gtest.h" @@ -469,4 +471,89 @@ OS.flush(); EXPECT_EQ("11111111111111111111hello1world", Str); } + +static void checkFileData(StringRef FileName, StringRef GoldenData) { + ErrorOr> BufOrErr = + MemoryBuffer::getFileOrSTDIN(FileName); + EXPECT_FALSE(BufOrErr.getError()); + + EXPECT_EQ((*BufOrErr)->getBufferSize(), GoldenData.size()); + EXPECT_EQ(memcmp((*BufOrErr)->getBufferStart(), GoldenData.data(), + GoldenData.size()), + 0); +} + +TEST(raw_ostreamTest, writeToStream) { + SmallString<64> Path; + int FD; + ASSERT_FALSE(sys::fs::createTemporaryFile("foo", "bar", FD, Path)); + FileRemover Cleanup(Path); + + if (Error Err = writeToStream(Path, [](raw_ostream &Out) -> Error { + Out << "HelloWorld"; + return Error::success(); + })) { + ADD_FAILURE() << "Error while writing to \'" << Path << "\'."; + } else { + checkFileData(Path, "HelloWorld"); + } +} + +TEST(raw_ostreamTest, writeToFileWithKeepingOwnership) { + SmallString<64> Path; + int FD; + ASSERT_FALSE(sys::fs::createTemporaryFile("foo", "bar", FD, Path)); + FileRemover Cleanup(Path); + + sys::fs::file_status Stat; + uint32_t User = 0; + uint32_t Group = 0; + + if (sys::fs::status(Path, Stat)) { + User = Stat.getUser(); + Group = Stat.getGroup(); + } + + if (Error Err = writeToStream( + Path, + [](raw_ostream &Out) -> Error { + Out << "HelloWorld"; + return Error::success(); + }, + true, User, Group)) { + ADD_FAILURE() << "Error while writing to \'" << Path << "\'."; + } else { + checkFileData(Path, "HelloWorld"); + } +} + +TEST(raw_ostreamTest, writeToNonexistingPath) { + if (Error Err = writeToStream("/_bad/_path", [](raw_ostream &Out) -> Error { + Out << "HelloWorld"; + return Error::success(); + })) { + ASSERT_TRUE((bool)Err); + consumeError(std::move(Err)); + } else { + ADD_FAILURE() << "The test should be finished with error."; + } +} + +TEST(raw_ostreamTest, writeToDevNull) { + if (Error Err = writeToStream("/dev/null", [](raw_ostream &Out) -> Error { + Out << "HelloWorld"; + return Error::success(); + })) { + ADD_FAILURE() << "Error while writing to /dev/null."; + } +} + +TEST(raw_ostreamTest, writeToStdOut) { + if (Error Err = writeToStream("-", [](raw_ostream &Out) -> Error { + Out << "HelloWorld"; + return Error::success(); + })) { + ADD_FAILURE() << "Error while writing to std out."; + } +} }