diff --git a/llvm/include/llvm/Support/MemoryBuffer.h b/llvm/include/llvm/Support/MemoryBuffer.h --- a/llvm/include/llvm/Support/MemoryBuffer.h +++ b/llvm/include/llvm/Support/MemoryBuffer.h @@ -22,6 +22,7 @@ #include "llvm/Support/FileSystem.h" #include #include +#include #include namespace llvm { @@ -126,6 +127,13 @@ getFileOrSTDIN(const Twine &Filename, int64_t FileSize = -1, bool RequiresNullTerminator = true); + /// If Filename is "-" and stdin refers to a terminal, ie input has not been + /// redirected in, call TtyStdin then exit with code 1. Otherwise acts the + /// same as MemoryBuffer::getFileOrSTDIN(const Twine &, int64_t, bool). + static ErrorOr> + getFileOrSTDIN(const Twine &Filename, std::function TtyStdin, + int64_t FileSize = -1, bool RequiresNullTerminator = true); + /// Map a subrange of the specified file as a MemoryBuffer. static ErrorOr> getFileSlice(const Twine &Filename, uint64_t MapSize, uint64_t Offset, diff --git a/llvm/lib/Support/MemoryBuffer.cpp b/llvm/lib/Support/MemoryBuffer.cpp --- a/llvm/lib/Support/MemoryBuffer.cpp +++ b/llvm/lib/Support/MemoryBuffer.cpp @@ -23,6 +23,7 @@ #include "llvm/Support/SmallVectorMemoryBuffer.h" #include #include +#include #include #include #include @@ -150,6 +151,23 @@ return getFile(Filename, FileSize, RequiresNullTerminator); } +ErrorOr> +MemoryBuffer::getFileOrSTDIN(const Twine &Filename, + std::function TtyStdin, + int64_t FileSize, bool RequiresNullTerminator) { + SmallString<256> NameBuf; + StringRef NameRef = Filename.toStringRef(NameBuf); + + if (NameRef == "-") { + if (sys::fs::is_tty(0)) { + TtyStdin(); + std::exit(1); + } + return getSTDIN(); + } + return getFile(Filename, FileSize, RequiresNullTerminator); +} + ErrorOr> MemoryBuffer::getFileSlice(const Twine &FilePath, uint64_t MapSize, uint64_t Offset, bool IsVolatile) { diff --git a/llvm/unittests/Support/MemoryBufferTest.cpp b/llvm/unittests/Support/MemoryBufferTest.cpp --- a/llvm/unittests/Support/MemoryBufferTest.cpp +++ b/llvm/unittests/Support/MemoryBufferTest.cpp @@ -16,6 +16,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" +#include // for dup2(2) using namespace llvm; @@ -288,4 +289,37 @@ ASSERT_EQ(16u, MB.getBufferSize()); EXPECT_EQ("xxxxxxxxxxxxxxxx", MB.getBuffer()); } + +TEST_F(MemoryBufferTest, getFileOrSTDIN) { + int TestForChange = 0; + auto ChangeFunc = [&TestForChange] { + outs() << "message"; + TestForChange = 1; + }; + // See if stdin is a tty, otherwise skip this test. + if (sys::fs::is_tty(0)) { + EXPECT_EXIT(MemoryBuffer::getFileOrSTDIN("-", ChangeFunc), + ::testing::ExitedWithCode(1), "message") + << "Reading from tty should have exited with exit code 1"; + ASSERT_EQ(TestForChange, 1) << "Callback not called"; + TestForChange = 0; + } + + auto TempFile = sys::fs::TempFile::create("%%%%%%%%"); + if (!TempFile) + return; + + if (::dup2(TempFile->FD, 0)) + return; + + if (sys::fs::is_tty(0)) + return; + + MemoryBuffer::getFileOrSTDIN("-", ChangeFunc); + ASSERT_EQ(TestForChange, 0) << "Callback should not have been called"; + + // Need to check Error, also need to do something in if or compiler will warn. + if (Error E = TempFile->discard()) + return; +} }