diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -755,26 +755,25 @@ using type = std::tuple())...>; }; -template +template using zip_traits = iterator_facade_base< ZipType, std::common_type_t< std::bidirectional_iterator_tag, typename std::iterator_traits::iterator_category...>, // ^ TODO: Implement random access methods. - typename ZipTupleType::type, + ReferenceTupleType, typename std::iterator_traits< std::tuple_element_t<0, std::tuple>>::difference_type, // ^ FIXME: This follows boost::make_zip_iterator's assumption that all // inner iterators have the same difference_type. It would fail if, for // instance, the second field's difference_type were non-numeric while the // first is. - typename ZipTupleType::type *, - typename ZipTupleType::type>; + ReferenceTupleType *, ReferenceTupleType>; -template -struct zip_common : public zip_traits { - using Base = zip_traits; +template +struct zip_common : public zip_traits { + using Base = zip_traits; using IndexSequence = std::index_sequence_for; using value_type = typename Base::value_type; @@ -824,8 +823,10 @@ }; template -struct zip_first : zip_common, Iters...> { - using zip_common::zip_common; +struct zip_first : zip_common, + typename ZipTupleType::type, Iters...> { + using zip_common::type, + Iters...>::zip_common; bool operator==(const zip_first &other) const { return std::get<0>(this->iterators) == std::get<0>(other.iterators); @@ -833,8 +834,11 @@ }; template -struct zip_shortest : zip_common, Iters...> { - using zip_common::zip_common; +struct zip_shortest + : zip_common, typename ZipTupleType::type, + Iters...> { + using zip_common::type, + Iters...>::zip_common; bool operator==(const zip_shortest &other) const { return any_iterator_equals(other, std::index_sequence_for{}); @@ -2182,113 +2186,174 @@ namespace detail { -template class enumerator_iter; +/// Tuple-like type for `zip_enumerator` dereference. +template struct enumerator_result; -template struct result_pair { - using value_reference = - typename std::iterator_traits>::reference; - - friend class enumerator_iter; - - result_pair(std::size_t Index, IterOfRange Iter) - : Index(Index), Iter(Iter) {} +template struct EnumeratorTupleType { + using type = enumerator_result())...>; +}; - std::size_t index() const { return Index; } - value_reference value() const { return *Iter; } +/// Zippy iterator that uses the second iterator for comparisons. For the +/// increment to be safe, the second range has to be the shortest. +/// Returns `enumerator_result` on dereference to provide `.index()` and +/// `.value()` member functions. +template +struct zip_enumerator + : zip_common, + typename EnumeratorTupleType::type, Iters...> { + static_assert(sizeof...(Iters) >= 2, "Expected at least two iteratees"); + using zip_common, + typename EnumeratorTupleType::type, + Iters...>::zip_common; -private: - std::size_t Index = std::numeric_limits::max(); - IterOfRange Iter; + bool operator==(const zip_enumerator &Other) const { + return std::get<1>(this->iterators) == std::get<1>(Other.iterators); + } }; -template -decltype(auto) get(const result_pair &Pair) { - static_assert(i < 2); - if constexpr (i == 0) { - return Pair.index(); - } else { - return Pair.value(); +template struct enumerator_result { + static constexpr std::size_t NumRefs = sizeof...(Refs); + static_assert(NumRefs != 0); + // `NumValues` includes the index. + static constexpr std::size_t NumValues = NumRefs + 1; + + // Tuple type whose element types are references for each `Ref`. + using range_reference_tuple = std::tuple; + // Tuple type who elements are references to all values, including both + // the index and `Refs` reference types. + using value_reference_tuple = std::tuple; + + enumerator_result(std::size_t Index, Refs &&...Rs) + : Idx(Index), Storage(std::forward(Rs)...) {} + + /// Returns the 0-based index of the current position within the original + /// input range(s). + std::size_t index() const { return Idx; } + + /// Returns the value(s) for the current iterator. This does not include the + /// index. + decltype(auto) value() const { + if constexpr (NumRefs == 1) { + return std::get<0>(Storage); + } else { + return Storage; + } } -} - -template -class enumerator_iter - : public iterator_facade_base, std::forward_iterator_tag, - const result_pair> { - using result_type = result_pair; - -public: - explicit enumerator_iter(IterOfRange EndIter) - : Result(std::numeric_limits::max(), EndIter) {} - enumerator_iter(std::size_t Index, IterOfRange Iter) - : Result(Index, Iter) {} - - const result_type &operator*() const { return Result; } - - enumerator_iter &operator++() { - assert(Result.Index != std::numeric_limits::max()); - ++Result.Iter; - ++Result.Index; - return *this; + /// Returns the value at index `I`. This includes the index. + template + friend decltype(auto) get(const enumerator_result &Result) { + static_assert(I < NumValues, "Index out of bounds"); + if constexpr (I == 0) { + return Result.Idx; + } else { + return std::get(Result.Storage); + } } - bool operator==(const enumerator_iter &RHS) const { - // Don't compare indices here, only iterators. It's possible for an end - // iterator to have different indices depending on whether it was created - // by calling std::end() versus incrementing a valid iterator. - return Result.Iter == RHS.Result.Iter; + template + friend bool operator==(const enumerator_result &Result, + const std::tuple &Other) { + static_assert(NumRefs == sizeof...(Ts), "Size mismatch"); + if (Result.Idx != std::get<0>(Other)) + return false; + return Result.is_value_equal(Other, std::make_index_sequence{}); } private: - result_type Result; + template + bool is_value_equal(const Tuple &Other, std::index_sequence) const { + return ((std::get(Storage) == std::get(Other)) && ...); + } + + std::size_t Idx; + // Make this tuple mutable to avoid casts that obfuscate const-correctness + // issues. Const-correctness of references is taken care of by `zippy` that + // defines const-non and const iterator types that will propagate down to + // `enumerator_result`'s `Refs`. + // Note that unlike the results of `zip*` functions, `enumerate`'s result are + // supposed to be modifiable even when defined as + // `const`. + mutable range_reference_tuple Storage; }; -template class enumerator { -public: - explicit enumerator(R &&Range) : TheRange(std::forward(Range)) {} +/// Infinite stream of increasing 0-based `size_t` indices. +struct index_stream { + struct iterator : iterator_facade_base { + std::size_t operator*() const { return Index; } + iterator &operator++() { + assert(Index != std::numeric_limits::max() && + "Attempting to increment end iterator"); + ++Index; + return *this; + } - enumerator_iter begin() { - return enumerator_iter(0, adl_begin(TheRange)); - } - enumerator_iter begin() const { - return enumerator_iter(0, adl_begin(TheRange)); - } + bool operator==(const iterator &Other) const { + return Index == Other.Index; + } - enumerator_iter end() { return enumerator_iter(adl_end(TheRange)); } - enumerator_iter end() const { - return enumerator_iter(adl_end(TheRange)); - } + std::size_t Index = 0; + }; -private: - R TheRange; + iterator begin() const { return {}; } + iterator end() const { + // We approximate 'infinity' with the max size_t value, which should be good + // enough to index over any container. + iterator It; + It.Index = std::numeric_limits::max(); + return It; + } }; } // end namespace detail -/// Given an input range, returns a new range whose values are are pair (A,B) -/// such that A is the 0-based index of the item in the sequence, and B is -/// the value from the original sequence. Example: +/// Given two or more input ranges, returns a new range whose values are are +/// tuples (A, B, C, ...), such that A is the 0-based index of the item in the +/// sequence, and B, C, ..., are the values from the original input ranges. All +/// input ranges are required to have equal lengths. Note that the returned +/// iterator allows for the values (B, C, ...) to be modified. Example: +/// +/// ```c++ +/// std::vector Letters = {'A', 'B', 'C', 'D'}; +/// std::vector Vals = {10, 11, 12, 13}; /// -/// std::vector Items = {'A', 'B', 'C', 'D'}; -/// for (auto X : enumerate(Items)) { -/// printf("Item %zu - %c\n", X.index(), X.value()); +/// for (auto [Index, Letter, Value] : enumerate(Letters, Vals)) { +/// printf("Item %zu - %c: %d\n", Index, Letter, Value); +/// Value -= 10; /// } +/// ``` /// -/// or using structured bindings: +/// Output: +/// Item 0 - A: 10 +/// Item 1 - B: 11 +/// Item 2 - C: 12 +/// Item 3 - D: 13 /// -/// for (auto [Index, Value] : enumerate(Items)) { -/// printf("Item %zu - %c\n", Index, Value); +/// or using an iterator: +/// ```c++ +/// for (auto it : enumerate(Vals)) { +/// it.value() += 10; +/// printf("Item %zu: %d\n", it.index(), it.value()); /// } +/// ``` /// /// Output: -/// Item 0 - A -/// Item 1 - B -/// Item 2 - C -/// Item 3 - D +/// Item 0: 20 +/// Item 1: 21 +/// Item 2: 22 +/// Item 3: 23 /// -template detail::enumerator enumerate(R &&TheRange) { - return detail::enumerator(std::forward(TheRange)); +template +auto enumerate(FirstRange &&First, RestRanges &&...Rest) { + assert((sizeof...(Rest) == 0 || + all_equal({std::distance(adl_begin(First), adl_end(First)), + std::distance(adl_begin(Rest), adl_end(Rest))...})) && + "Ranges have different length"); + using zip = detail::zippy; + return zip(detail::index_stream{}, std::forward(First), + std::forward(Rest)...); } namespace detail { @@ -2420,15 +2485,17 @@ } // end namespace llvm namespace std { -template -struct tuple_size> - : std::integral_constant {}; +template +struct tuple_size> + : std::integral_constant {}; -template -struct tuple_element> - : std::conditional::value_reference> { -}; +template +struct tuple_element> + : std::tuple_element> {}; + +template +struct tuple_element> + : std::tuple_element> {}; } // namespace std diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -13,7 +13,12 @@ #include #include +#include +#include #include +#include +#include +#include #include using namespace llvm; @@ -151,6 +156,131 @@ PairType(2u, '4'))); } +TEST(STLExtrasTest, EnumerateTwoRanges) { + using Tuple = std::tuple; + + std::vector Ints = {1, 2}; + std::vector Bools = {true, false}; + EXPECT_THAT(llvm::enumerate(Ints, Bools), + ElementsAre(Tuple(0, 1, true), Tuple(1, 2, false))); + + // Check that we can modify the values when the temporary is a const + // reference. + for (const auto &[Idx, Int, Bool] : llvm::enumerate(Ints, Bools)) { + (void)Idx; + Bool = false; + Int = -1; + } + + EXPECT_THAT(Ints, ElementsAre(-1, -1)); + EXPECT_THAT(Bools, ElementsAre(false, false)); + + // Check that we can modify the values when the result gets copied. + for (auto [Idx, Bool, Int] : llvm::enumerate(Bools, Ints)) { + (void)Idx; + Int = 3; + Bool = true; + } + + EXPECT_THAT(Ints, ElementsAre(3, 3)); + EXPECT_THAT(Bools, ElementsAre(true, true)); + + // Check that we can modify the values through `.values()`. + size_t Iters = 0; + for (auto It : llvm::enumerate(Bools, Ints)) { + EXPECT_EQ(It.index(), Iters); + ++Iters; + + std::get<0>(It.value()) = false; + std::get<1>(It.value()) = 4; + } + + EXPECT_THAT(Ints, ElementsAre(4, 4)); + EXPECT_THAT(Bools, ElementsAre(false, false)); +} + +TEST(STLExtrasTest, EnumerateThreeRanges) { + using Tuple = std::tuple; + + std::vector Ints = {1, 2}; + std::vector Bools = {true, false}; + char Chars[] = {'X', 'D'}; + EXPECT_THAT(llvm::enumerate(Ints, Bools, Chars), + ElementsAre(Tuple(0, 1, true, 'X'), Tuple(1, 2, false, 'D'))); + + for (auto [Idx, Int, Bool, Char] : llvm::enumerate(Ints, Bools, Chars)) { + (void)Idx; + Int = 0; + Bool = true; + Char = '!'; + } + + EXPECT_THAT(Ints, ElementsAre(0, 0)); + EXPECT_THAT(Bools, ElementsAre(true, true)); + EXPECT_THAT(Chars, ElementsAre('!', '!')); + + // Check that we can modify the values through `.values()`. + size_t Iters = 0; + for (auto It : llvm::enumerate(Ints, Bools, Chars)) { + EXPECT_EQ(It.index(), Iters); + ++Iters; + auto [Int, Bool, Char] = It.value(); + Int = 42; + Bool = false; + Char = '$'; + } + + EXPECT_THAT(Ints, ElementsAre(42, 42)); + EXPECT_THAT(Bools, ElementsAre(false, false)); + EXPECT_THAT(Chars, ElementsAre('$', '$')); +} + +TEST(STLExtrasTest, EnumerateTemporaries) { + using Tuple = std::tuple; + + EXPECT_THAT( + llvm::enumerate(llvm::SmallVector({1, 2, 3}), + std::vector({true, false, true})), + ElementsAre(Tuple(0, 1, true), Tuple(1, 2, false), Tuple(2, 3, true))); + + size_t Iters = 0; + // This is fine from the point of view of range lifetimes because `zippy` will + // move all temporaries into its storage. No lifetime extension is necessary. + for (auto [Idx, Int, Bool] : + llvm::enumerate(llvm::SmallVector({1, 2, 3}), + std::vector({true, false, true}))) { + EXPECT_EQ(Idx, Iters); + ++Iters; + Int = 0; + Bool = true; + } + + Iters = 0; + // The same thing but with the result as a const reference. + for (const auto &[Idx, Int, Bool] : + llvm::enumerate(llvm::SmallVector({1, 2, 3}), + std::vector({true, false, true}))) { + EXPECT_EQ(Idx, Iters); + ++Iters; + Int = 0; + Bool = true; + } +} + +#if defined(GTEST_HAS_DEATH_TEST) && !defined(NDEBUG) +TEST(STLExtrasTest, EnumerateDifferentLengths) { + std::vector Ints = {0, 1}; + bool Bools[] = {true, false, true}; + std::string Chars = "abc"; + EXPECT_DEATH(llvm::enumerate(Ints, Bools, Chars), + "Ranges have different length"); + EXPECT_DEATH(llvm::enumerate(Bools, Ints, Chars), + "Ranges have different length"); + EXPECT_DEATH(llvm::enumerate(Bools, Chars, Ints), + "Ranges have different length"); +} +#endif + template struct CanMove {}; template <> struct CanMove { CanMove(CanMove &&) = delete; @@ -188,8 +318,8 @@ template struct Range : Counted { using Counted::Counted; - int *begin() { return nullptr; } - int *end() { return nullptr; } + int *begin() const { return nullptr; } + int *end() const { return nullptr; } }; TEST(STLExtrasTest, EnumerateLifetimeSemanticsPRValue) { diff --git a/llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp b/llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp --- a/llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp +++ b/llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp @@ -89,18 +89,18 @@ TraversableEdges(MatchDag.getNumEdges()), TestablePredicates(MatchDag.getNumPredicates()) { // Number all the predicates in this DAG - for (auto &P : enumerate(MatchDag.predicates())) { + for (auto &&P : enumerate(MatchDag.predicates())) { PredicateIDs.insert(std::make_pair(P.value(), P.index())); } // Number all the predicate dependencies in this DAG and set up a bitvector // for each predicate indicating the unsatisfied dependencies. - for (auto &Dep : enumerate(MatchDag.predicate_edges())) { + for (auto &&Dep : enumerate(MatchDag.predicate_edges())) { PredicateDepIDs.insert(std::make_pair(Dep.value(), Dep.index())); } UnsatisfiedPredDepsForPred.resize(MatchDag.getNumPredicates(), BitVector(PredicateDepIDs.size())); - for (auto &Dep : enumerate(MatchDag.predicate_edges())) { + for (auto &&Dep : enumerate(MatchDag.predicate_edges())) { unsigned ID = PredicateIDs.lookup(Dep.value()->getPredicate()); UnsatisfiedPredDepsForPred[ID].set(Dep.index()); } @@ -134,10 +134,10 @@ // Mark the dependencies that are now satisfied as a result of this // instruction and mark any predicates whose dependencies are fully // satisfied. - for (auto &Dep : enumerate(MatchDag.predicate_edges())) { + for (auto &&Dep : enumerate(MatchDag.predicate_edges())) { if (Dep.value()->getRequiredMI() == Instr && Dep.value()->getRequiredMO() == nullptr) { - for (auto &DepsFor : enumerate(UnsatisfiedPredDepsForPred)) { + for (auto &&DepsFor : enumerate(UnsatisfiedPredDepsForPred)) { DepsFor.value().reset(Dep.index()); if (DepsFor.value().none()) TestablePredicates.set(DepsFor.index()); @@ -157,7 +157,7 @@ // When an operand becomes reachable, we potentially activate some traversals. // Record the edges that can now be followed as a result of this // instruction. - for (auto &E : enumerate(MatchDag.edges())) { + for (auto &&E : enumerate(MatchDag.edges())) { if (E.value()->getFromMI() == Instr && E.value()->getFromMO()->getIdx() == OpIdx) { TraversableEdges.set(E.index()); @@ -168,10 +168,10 @@ // Clear the dependencies that are now satisfied as a result of this // operand and activate any predicates whose dependencies are fully // satisfied. - for (auto &Dep : enumerate(MatchDag.predicate_edges())) { + for (auto &&Dep : enumerate(MatchDag.predicate_edges())) { if (Dep.value()->getRequiredMI() == Instr && Dep.value()->getRequiredMO() && Dep.value()->getRequiredMO()->getIdx() == OpIdx) { - for (auto &DepsFor : enumerate(UnsatisfiedPredDepsForPred)) { + for (auto &&DepsFor : enumerate(UnsatisfiedPredDepsForPred)) { DepsFor.value().reset(Dep.index()); if (DepsFor.value().none()) TestablePredicates.set(DepsFor.index()); @@ -339,7 +339,7 @@ "Must always partition into at least one partition"); TreeNode->setNumChildren(Partitioner->getNumPartitions()); - for (auto &C : enumerate(TreeNode->children())) { + for (auto &&C : enumerate(TreeNode->children())) { SubtreeBuilders.emplace_back(&C.value(), NextInstrID); Partitioner->applyForPartition(C.index(), *this, SubtreeBuilders.back()); } @@ -536,7 +536,7 @@ BitVector PossibleLeaves = getPossibleLeavesForPartition(PartitionIdx); // Consume any predicates we handled. - for (auto &EnumeratedLeaf : enumerate(Builder.getPossibleLeaves())) { + for (auto &&EnumeratedLeaf : enumerate(Builder.getPossibleLeaves())) { if (!PossibleLeaves[EnumeratedLeaf.index()]) continue; @@ -571,7 +571,7 @@ if (!InstrInfo) continue; const GIMatchDagInstr *Instr = InstrInfo->getInstrNode(); - for (auto &E : enumerate(Leaf.getMatchDag().edges())) { + for (auto &&E : enumerate(Leaf.getMatchDag().edges())) { if (E.value()->getFromMI() == Instr && E.value()->getFromMO()->getIdx() < CGI->Operands.size()) { ReferencedOperands.resize(E.value()->getFromMO()->getIdx() + 1); @@ -715,7 +715,7 @@ std::vector TraversedEdgesByNewLeaves; // Consume any edges we handled. - for (auto &EnumeratedLeaf : enumerate(Builder.getPossibleLeaves())) { + for (auto &&EnumeratedLeaf : enumerate(Builder.getPossibleLeaves())) { if (!PossibleLeaves[EnumeratedLeaf.index()]) continue;