diff --git a/llvm/include/llvm/ADT/Sequence.h b/llvm/include/llvm/ADT/Sequence.h --- a/llvm/include/llvm/ADT/Sequence.h +++ b/llvm/include/llvm/ADT/Sequence.h @@ -15,8 +15,9 @@ #ifndef LLVM_ADT_SEQUENCE_H #define LLVM_ADT_SEQUENCE_H -#include // std::ptrdiff_t -#include // std::random_access_iterator_tag +#include // std::ptrdiff_t +#include // std::random_access_iterator_tag +#include // std::underlying_type, std::is_enum namespace llvm { @@ -138,9 +139,6 @@ } // namespace detail template struct iota_range { - static_assert(std::is_integral::value, - "ValueT must be an integral type"); - using value_type = ValueT; using reference = ValueT &; using const_reference = const ValueT &; @@ -175,10 +173,83 @@ "ValueT must not be const nor volatile"); }; +template struct IterableEnum { +private: + static_assert(std::is_enum::value, "E must be an enum"); + using T = typename std::underlying_type::type; + + T Value; + +public: + IterableEnum() = default; + explicit IterableEnum(E EnumValue) : Value(T(EnumValue)) {} + IterableEnum(const IterableEnum &Other) = default; + IterableEnum(IterableEnum &&Other) = default; + + operator E() const { return *reinterpret_cast(&Value); } + + // Assignment + IterableEnum &operator=(const IterableEnum &Other) = default; + IterableEnum &operator=(IterableEnum &&Other) = default; + IterableEnum &operator=(E RHS) { + Value = T(RHS); + return *this; + } + IterableEnum &operator++() { + ++Value; + return *this; + } + IterableEnum &operator--() { + --Value; + return *this; + } + IterableEnum operator+(int Offset) const { + IterableEnum Copy = *this; + Copy.Value += Offset; + return Copy; + } + IterableEnum operator-(int Offset) const { + IterableEnum Copy = *this; + Copy.Value -= Offset; + return Copy; + } + + // Comparison + bool operator==(IterableEnum RHS) const { return Value == RHS.Value; } + bool operator!=(IterableEnum RHS) const { return Value != RHS.Value; } + bool operator<(IterableEnum RHS) const { return Value < RHS.Value; } + bool operator>(IterableEnum RHS) const { return Value > RHS.Value; } + bool operator<=(IterableEnum RHS) const { return Value <= RHS.Value; } + bool operator>=(IterableEnum RHS) const { return Value >= RHS.Value; } + + bool operator==(E RHS) const { return Value == T(RHS); } + bool operator!=(E RHS) const { return Value != T(RHS); } + bool operator<(E RHS) const { return Value < T(RHS); } + bool operator>(E RHS) const { return Value > T(RHS); } + bool operator<=(E RHS) const { return Value <= T(RHS); } + bool operator>=(E RHS) const { return Value >= T(RHS); } + + // Arithmetic + std::ptrdiff_t operator-(IterableEnum RHS) const { return Value - RHS.Value; } +}; + +/// Iterate over an integral type from Begin to End exclusive. template auto seq(ValueT Begin, ValueT End) { + static_assert(!std::is_enum::value, + "Use enum_seq when iterating enumerations"); + static_assert(std::is_integral::value, + "ValueT must be an integral type"); return iota_range(std::move(Begin), std::move(End)); } +/// Iterate over a typed enum from First to Last inclusive. +template auto enum_seq(ValueE First, ValueE Last) { + static_assert(std::is_enum::value, + "Can't use enum_seq with non enum types"); + using T = IterableEnum; + return iota_range(T(First), T(Last) + 1); +} + } // end namespace llvm #endif // LLVM_ADT_SEQUENCE_H diff --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h --- a/llvm/include/llvm/CodeGen/ValueTypes.h +++ b/llvm/include/llvm/CodeGen/ValueTypes.h @@ -40,6 +40,7 @@ public: constexpr EVT() = default; constexpr EVT(MVT::SimpleValueType SVT) : V(SVT) {} + constexpr EVT(IterableEnum SVT) : V(SVT) {} constexpr EVT(MVT S) : V(S) {} bool operator==(EVT VT) const { diff --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h --- a/llvm/include/llvm/Support/MachineValueType.h +++ b/llvm/include/llvm/Support/MachineValueType.h @@ -14,6 +14,7 @@ #ifndef LLVM_SUPPORT_MACHINEVALUETYPE_H #define LLVM_SUPPORT_MACHINEVALUETYPE_H +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" @@ -305,6 +306,7 @@ SimpleValueType SimpleTy = INVALID_SIMPLE_VALUE_TYPE; constexpr MVT() = default; + constexpr MVT(IterableEnum SVT) : SimpleTy(SVT) {} constexpr MVT(SimpleValueType SVT) : SimpleTy(SVT) {} bool operator>(const MVT& S) const { return SimpleTy > S.SimpleTy; } @@ -1344,84 +1346,54 @@ /// returned as Other, otherwise they are invalid. static MVT getVT(Type *Ty, bool HandleUnknown = false); - private: - /// A simple iterator over the MVT::SimpleValueType enum. - struct mvt_iterator { - SimpleValueType VT; - - mvt_iterator(SimpleValueType VT) : VT(VT) {} - - MVT operator*() const { return VT; } - bool operator!=(const mvt_iterator &LHS) const { return VT != LHS.VT; } - - mvt_iterator& operator++() { - VT = (MVT::SimpleValueType)((int)VT + 1); - assert((int)VT <= MVT::MAX_ALLOWED_VALUETYPE && - "MVT iterator overflowed."); - return *this; - } - }; - - /// A range of the MVT::SimpleValueType enum. - using mvt_range = iterator_range; - public: /// SimpleValueType Iteration /// @{ - static mvt_range all_valuetypes() { - return mvt_range(MVT::FIRST_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_VALUETYPE + 1)); + static auto all_valuetypes() { + return enum_seq(MVT::FIRST_VALUETYPE, MVT::LAST_VALUETYPE); } - static mvt_range integer_valuetypes() { - return mvt_range(MVT::FIRST_INTEGER_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_INTEGER_VALUETYPE + 1)); + static auto integer_valuetypes() { + return enum_seq(MVT::FIRST_INTEGER_VALUETYPE, + MVT::LAST_INTEGER_VALUETYPE); } - static mvt_range fp_valuetypes() { - return mvt_range(MVT::FIRST_FP_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_FP_VALUETYPE + 1)); + static auto fp_valuetypes() { + return enum_seq(MVT::FIRST_FP_VALUETYPE, MVT::LAST_FP_VALUETYPE); } - static mvt_range vector_valuetypes() { - return mvt_range(MVT::FIRST_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_VECTOR_VALUETYPE + 1)); + static auto vector_valuetypes() { + return enum_seq(MVT::FIRST_VECTOR_VALUETYPE, MVT::LAST_VECTOR_VALUETYPE); } - static mvt_range fixedlen_vector_valuetypes() { - return mvt_range( - MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE + 1)); + static auto fixedlen_vector_valuetypes() { + return enum_seq(MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE, + MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE); } - static mvt_range scalable_vector_valuetypes() { - return mvt_range( - MVT::FIRST_SCALABLE_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_SCALABLE_VECTOR_VALUETYPE + 1)); + static auto scalable_vector_valuetypes() { + return enum_seq(MVT::FIRST_SCALABLE_VECTOR_VALUETYPE, + MVT::LAST_SCALABLE_VECTOR_VALUETYPE); } - static mvt_range integer_fixedlen_vector_valuetypes() { - return mvt_range( - MVT::FIRST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE + 1)); + static auto integer_fixedlen_vector_valuetypes() { + return enum_seq(MVT::FIRST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE, + MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE); } - static mvt_range fp_fixedlen_vector_valuetypes() { - return mvt_range( - MVT::FIRST_FP_FIXEDLEN_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_FP_FIXEDLEN_VECTOR_VALUETYPE + 1)); + static auto fp_fixedlen_vector_valuetypes() { + return enum_seq(MVT::FIRST_FP_FIXEDLEN_VECTOR_VALUETYPE, + MVT::LAST_FP_FIXEDLEN_VECTOR_VALUETYPE); } - static mvt_range integer_scalable_vector_valuetypes() { - return mvt_range( - MVT::FIRST_INTEGER_SCALABLE_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE + 1)); + static auto integer_scalable_vector_valuetypes() { + return enum_seq(MVT::FIRST_INTEGER_SCALABLE_VECTOR_VALUETYPE, + MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE); } - static mvt_range fp_scalable_vector_valuetypes() { - return mvt_range( - MVT::FIRST_FP_SCALABLE_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE + 1)); + static auto fp_scalable_vector_valuetypes() { + return enum_seq(MVT::FIRST_FP_SCALABLE_VECTOR_VALUETYPE, + MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE); } /// @} }; diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -4629,8 +4629,7 @@ EVT InVT = InOp.getValueType(); if (InVT.getSizeInBits() != VT.getSizeInBits()) { EVT InEltVT = InVT.getVectorElementType(); - for (int i = MVT::FIRST_VECTOR_VALUETYPE, e = MVT::LAST_VECTOR_VALUETYPE; i < e; ++i) { - EVT FixedVT = (MVT::SimpleValueType)i; + for (EVT FixedVT : MVT::vector_valuetypes()) { EVT FixedEltVT = FixedVT.getVectorElementType(); if (TLI.isTypeLegal(FixedVT) && FixedVT.getSizeInBits() == VT.getSizeInBits() && @@ -5157,14 +5156,11 @@ if (!Scalable && Width == WidenEltWidth) return RetVT; - // See if there is larger legal integer than the element type to load/store. - unsigned VT; // Don't bother looking for an integer type if the vector is scalable, skip // to vector types. if (!Scalable) { - for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE; - VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) { - EVT MemVT((MVT::SimpleValueType) VT); + // See if there is larger legal integer than the element type to load/store. + for (EVT MemVT : reverse(MVT::integer_valuetypes())) { unsigned MemVTWidth = MemVT.getSizeInBits(); if (MemVT.getSizeInBits() <= WidenEltWidth) break; @@ -5185,9 +5181,7 @@ // See if there is a larger vector type to load/store that has the same vector // element type and is evenly divisible with the WidenVT. - for (VT = (unsigned)MVT::LAST_VECTOR_VALUETYPE; - VT >= (unsigned)MVT::FIRST_VECTOR_VALUETYPE; --VT) { - EVT MemVT = (MVT::SimpleValueType) VT; + for (EVT MemVT : reverse(MVT::vector_valuetypes())) { // Skip vector MVTs which don't match the scalable property of WidenVT. if (Scalable != MemVT.isScalableVector()) continue; diff --git a/llvm/tools/llvm-exegesis/lib/X86/Target.cpp b/llvm/tools/llvm-exegesis/lib/X86/Target.cpp --- a/llvm/tools/llvm-exegesis/lib/X86/Target.cpp +++ b/llvm/tools/llvm-exegesis/lib/X86/Target.cpp @@ -918,9 +918,9 @@ continue; case X86::OperandType::OPERAND_COND_CODE: { Exploration = true; - auto CondCodes = seq((int)X86::CondCode::COND_O, - 1 + (int)X86::CondCode::LAST_VALID_COND); - Choices.reserve(std::distance(CondCodes.begin(), CondCodes.end())); + auto CondCodes = + enum_seq(X86::CondCode::COND_O, X86::CondCode::LAST_VALID_COND); + Choices.reserve(CondCodes.size()); for (int CondCode : CondCodes) Choices.emplace_back(MCOperand::createImm(CondCode)); break; diff --git a/llvm/unittests/ADT/SequenceTest.cpp b/llvm/unittests/ADT/SequenceTest.cpp --- a/llvm/unittests/ADT/SequenceTest.cpp +++ b/llvm/unittests/ADT/SequenceTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/Sequence.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" #include @@ -48,4 +49,61 @@ EXPECT_EQ(Backward[2], 7); } +enum class CharEnum : char { A = 1, B, C, D, E, LAST = E }; +using Iterable = llvm::IterableEnum; + +TEST(IterableEnumTest, Ctor) { + Iterable Init(CharEnum::B); + EXPECT_EQ(Init, CharEnum::B); +} + +TEST(IterableEnumTest, Assign) { + Iterable Init; + Init = CharEnum::LAST; + EXPECT_EQ(Init, CharEnum::LAST); + Init = Iterable(CharEnum::C); + EXPECT_EQ(Init, CharEnum::C); +} + +TEST(IterableEnumTest, Cast) { + Iterable Value(CharEnum::LAST); + CharEnum Enum = Value; + EXPECT_EQ(Enum, CharEnum::LAST); +} + +TEST(IterableEnumTest, ConstCast) { + const Iterable Value(CharEnum::LAST); + const CharEnum Enum = Value; + EXPECT_EQ(Enum, CharEnum::LAST); +} + +TEST(IterableEnumTest, IncDec) { + Iterable Value(CharEnum::A); + ++Value; + EXPECT_EQ(Value, CharEnum::B); + --Value; + EXPECT_EQ(Value, CharEnum::A); +} + +TEST(IterableEnumTest, Cmp) { + Iterable C(CharEnum::C); + Iterable D(CharEnum::D); + EXPECT_EQ(C, C); + EXPECT_NE(C, D); + EXPECT_LT(C, D); + EXPECT_GT(D, C); + EXPECT_GE(C, C); + EXPECT_LE(C, C); +} + +TEST(IterableEnumTest, ForwardIteration) { + EXPECT_THAT(llvm::enum_seq(CharEnum::C, CharEnum::LAST), + testing::ElementsAre(CharEnum::C, CharEnum::D, CharEnum::E)); +} + +TEST(IterableEnumTest, BackwardIteration) { + EXPECT_THAT(reverse(llvm::enum_seq(CharEnum::B, CharEnum::D)), + testing::ElementsAre(CharEnum::D, CharEnum::C, CharEnum::B)); +} + } // anonymous namespace diff --git a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp --- a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp +++ b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp @@ -18,7 +18,7 @@ namespace { TEST(ScalableVectorMVTsTest, IntegerMVTs) { - for (auto VecTy : MVT::integer_scalable_vector_valuetypes()) { + for (MVT VecTy : MVT::integer_scalable_vector_valuetypes()) { ASSERT_TRUE(VecTy.isValid()); ASSERT_TRUE(VecTy.isInteger()); ASSERT_TRUE(VecTy.isVector()); @@ -30,7 +30,7 @@ } TEST(ScalableVectorMVTsTest, FloatMVTs) { - for (auto VecTy : MVT::fp_scalable_vector_valuetypes()) { + for (MVT VecTy : MVT::fp_scalable_vector_valuetypes()) { ASSERT_TRUE(VecTy.isValid()); ASSERT_TRUE(VecTy.isFloatingPoint()); ASSERT_TRUE(VecTy.isVector());