Index: llvm/include/llvm/Support/VirtualFileSystem.h =================================================================== --- llvm/include/llvm/Support/VirtualFileSystem.h +++ llvm/include/llvm/Support/VirtualFileSystem.h @@ -193,14 +193,22 @@ class FileSystem; +namespace detail { + +/// Keeps state for the recursive_directory_iterator. +struct RecDirIterState { + std::stack> Stack; + bool HasNoPushRequest = false; +}; + +} // end namespace detail + /// An input iterator over the recursive contents of a virtual path, /// similar to llvm::sys::fs::recursive_directory_iterator. class recursive_directory_iterator { - using IterState = - std::stack>; - FileSystem *FS; - std::shared_ptr State; // Input iterator semantics on copy. + std::shared_ptr + State; // Input iterator semantics on copy. public: recursive_directory_iterator(FileSystem &FS, const Twine &Path, @@ -212,8 +220,8 @@ /// Equivalent to operator++, with an error code. recursive_directory_iterator &increment(std::error_code &EC); - const directory_entry &operator*() const { return *State->top(); } - const directory_entry *operator->() const { return &*State->top(); } + const directory_entry &operator*() const { return *State->Stack.top(); } + const directory_entry *operator->() const { return &*State->Stack.top(); } bool operator==(const recursive_directory_iterator &Other) const { return State == Other.State; // identity @@ -224,9 +232,12 @@ /// Gets the current level. Starting path is at level 0. int level() const { - assert(!State->empty() && "Cannot get level without any iteration state"); - return State->size() - 1; + assert(!State->Stack.empty() && + "Cannot get level without any iteration state"); + return State->Stack.size() - 1; } + + void no_push() { State->HasNoPushRequest = true; } }; /// The virtual file system interface. Index: llvm/lib/Support/VirtualFileSystem.cpp =================================================================== --- llvm/lib/Support/VirtualFileSystem.cpp +++ llvm/lib/Support/VirtualFileSystem.cpp @@ -2093,28 +2093,33 @@ : FS(&FS_) { directory_iterator I = FS->dir_begin(Path, EC); if (I != directory_iterator()) { - State = std::make_shared(); - State->push(I); + State = std::make_shared(); + State->Stack.push(I); } } vfs::recursive_directory_iterator & recursive_directory_iterator::increment(std::error_code &EC) { - assert(FS && State && !State->empty() && "incrementing past end"); - assert(!State->top()->path().empty() && "non-canonical end iterator"); + assert(FS && State && !State->Stack.empty() && "incrementing past end"); + assert(!State->Stack.top()->path().empty() && "non-canonical end iterator"); vfs::directory_iterator End; - if (State->top()->type() == sys::fs::file_type::directory_file) { - vfs::directory_iterator I = FS->dir_begin(State->top()->path(), EC); - if (I != End) { - State->push(I); - return *this; + + if (State->HasNoPushRequest) + State->HasNoPushRequest = false; + else { + if (State->Stack.top()->type() == sys::fs::file_type::directory_file) { + vfs::directory_iterator I = FS->dir_begin(State->Stack.top()->path(), EC); + if (I != End) { + State->Stack.push(I); + return *this; + } } } - while (!State->empty() && State->top().increment(EC) == End) - State->pop(); + while (!State->Stack.empty() && State->Stack.top().increment(EC) == End) + State->Stack.pop(); - if (State->empty()) + if (State->Stack.empty()) State.reset(); // end iterator return *this; Index: llvm/unittests/Support/VirtualFileSystemTest.cpp =================================================================== --- llvm/unittests/Support/VirtualFileSystemTest.cpp +++ llvm/unittests/Support/VirtualFileSystemTest.cpp @@ -452,6 +452,9 @@ ScopedDir _ab(TestDirectory + "/a/b"); ScopedDir _c(TestDirectory + "/c"); ScopedDir _cd(TestDirectory + "/c/d"); + ScopedDir _e(TestDirectory + "/e"); + ScopedDir _ef(TestDirectory + "/e/f"); + ScopedDir _g(TestDirectory + "/g"); I = vfs::recursive_directory_iterator(*FS, Twine(TestDirectory), EC); ASSERT_FALSE(EC); @@ -460,22 +463,27 @@ std::vector Contents; for (auto E = vfs::recursive_directory_iterator(); !EC && I != E; I.increment(EC)) { + if (I->path().endswith("/e")) + I.no_push(); Contents.push_back(I->path()); } // Check contents, which may be in any order - EXPECT_EQ(4U, Contents.size()); - int Counts[4] = {0, 0, 0, 0}; + EXPECT_EQ(6U, Contents.size()); + int Counts[7] = {0, 0, 0, 0, 0, 0, 0}; for (const std::string &Name : Contents) { ASSERT_FALSE(Name.empty()); int Index = Name[Name.size() - 1] - 'a'; - ASSERT_TRUE(Index >= 0 && Index < 4); + ASSERT_TRUE(Index >= 0 && Index < 7); Counts[Index]++; } EXPECT_EQ(1, Counts[0]); // a EXPECT_EQ(1, Counts[1]); // b EXPECT_EQ(1, Counts[2]); // c EXPECT_EQ(1, Counts[3]); // d + EXPECT_EQ(1, Counts[4]); // e + EXPECT_EQ(0, Counts[5]); // e + EXPECT_EQ(1, Counts[6]); // g } #ifdef LLVM_ON_UNIX