diff --git a/llvm/include/llvm/ADT/Sequence.h b/llvm/include/llvm/ADT/Sequence.h --- a/llvm/include/llvm/ADT/Sequence.h +++ b/llvm/include/llvm/ADT/Sequence.h @@ -31,6 +31,50 @@ /// /// Prints: `0 1 2 3 `. /// +/// Similar to `seq` and `seq_inclusive`, the `enum_seq` and +/// `enum_seq_inclusive` functions produce sequences of enum values that can be +/// iterated over. +/// To enable iteration with enum types, you need to either mark enums as safe +/// to iterate on by specializing `enum_iteration_traits`, or opt into +/// potentially unsafe iteration at every callsite by passing +/// `force_iteration_on_noniterable_enum`. +/// +/// Examples with enum types: +/// ``` +/// namespace X { +/// enum class MyEnum : unsigned {A = 0, B, C}; +/// } // namespace X +/// +/// template <> struct enum_iteration_traits { +/// static contexpr bool is_iterable = true; +/// }; +/// +/// class MyClass { +/// public: +/// enum Safe { D = 3, E, F }; +/// enum MaybeUnsafe { G = 1, H = 2, I = 4 }; +/// }; +/// +/// template <> struct enum_iteration_traits { +/// static contexpr bool is_iterable = true; +/// }; +/// ``` +/// +/// ``` +/// for (auto v : enum_seq(MyClass::Safe::D, MyClass::Safe::F)) +/// outs() << int(v) << " "; +/// ``` +/// +/// Prints: `3 4 `. +/// +/// ``` +/// for (auto v : enum_seq(MyClass::MaybeUnsafe::H, MyClass::MaybeUnsafe::I, +/// force_iteration_on_noniterable_enum{})) +/// outs() << int(v) << " "; +/// ``` +/// +/// Prints: `2 3 `. +/// //===----------------------------------------------------------------------===// #ifndef LLVM_ADT_SEQUENCE_H @@ -39,12 +83,28 @@ #include // assert #include // std::random_access_iterator_tag #include // std::numeric_limits -#include // std::underlying_type, std::is_enum +#include // std::is_integeral, std::is_enum, std::underlying_type, std::enable_if #include "llvm/Support/MathExtras.h" // AddOverflow / SubOverflow namespace llvm { +// Enum traits that marks enums as safe or unsafe to iterate over. +// By default, enum types are *not* considered safe for iteration. +// To allow iteration for your enum type, provide a specialization with +// `is_iterable` set to `true` in the `llvm` namespace. +// Alternatively, you can pass the `force_iteration_on_noniterable_enum` tag +// to `enum_seq` or `enum_seq_inclusive`. +template struct enum_iteration_traits { + static constexpr bool is_iterable = false; +}; + +struct force_iteration_on_noniterable_enum_t { + explicit force_iteration_on_noniterable_enum_t() = default; +}; +constexpr force_iteration_on_noniterable_enum_t + force_iteration_on_noniterable_enum; + namespace detail { // Returns whether a value of type U can be represented with type T. @@ -234,27 +294,81 @@ iterator PastEndValue; }; -/// Iterate over an integral/enum type from Begin up to - but not including - -/// End. -/// Note on enum iteration: `seq` will generate each consecutive value, even if -/// no enumerator with that value exists. +/// Iterate over an integral type from Begin up to - but not including - End. /// Note: Begin and End values have to be within [INTMAX_MIN, INTMAX_MAX] for /// forward iteration (resp. [INTMAX_MIN + 1, INTMAX_MAX] for reverse /// iteration). -template auto seq(T Begin, T End) { +template ::value && + !std::is_enum::value>> +auto seq(T Begin, T End) { return iota_range(Begin, End, false); } -/// Iterate over an integral/enum type from Begin to End inclusive. -/// Note on enum iteration: `seq_inclusive` will generate each consecutive -/// value, even if no enumerator with that value exists. +/// Iterate over an integral type from Begin to End inclusive. /// Note: Begin and End values have to be within [INTMAX_MIN, INTMAX_MAX - 1] /// for forward iteration (resp. [INTMAX_MIN + 1, INTMAX_MAX - 1] for reverse /// iteration). -template auto seq_inclusive(T Begin, T End) { +template ::value && + !std::is_enum::value>> +auto seq_inclusive(T Begin, T End) { return iota_range(Begin, End, true); } +/// Iterate over an enum type from Begin up to - but not including - End. +/// Note: `enum_seq` will generate each consecutive value, even if no +/// enumerator with that value exists. +/// Note: Begin and End values have to be within [INTMAX_MIN, INTMAX_MAX] for +/// forward iteration (resp. [INTMAX_MIN + 1, INTMAX_MAX] for reverse +/// iteration). +template ::value>> +auto enum_seq(EnumT Begin, EnumT End) { + static_assert(enum_iteration_traits::is_iterable, + "Enum type is not marked as iterable."); + return iota_range(Begin, End, false); +} + +/// Iterate over an enum type from Begin up to - but not including - End, even +/// when `EnumT` is not marked as safely iterable by `enum_iteration_traits`. +/// Note: `enum_seq` will generate each consecutive value, even if no +/// enumerator with that value exists. +/// Note: Begin and End values have to be within [INTMAX_MIN, INTMAX_MAX] for +/// forward iteration (resp. [INTMAX_MIN + 1, INTMAX_MAX] for reverse +/// iteration). +template ::value>> +auto enum_seq(EnumT Begin, EnumT End, force_iteration_on_noniterable_enum_t) { + return iota_range(Begin, End, false); +} + +/// Iterate over an enum type from Begin to End inclusive. +/// Note: `enum_seq_inclusive` will generate each consecutive value, even if no +/// enumerator with that value exists. +/// Note: Begin and End values have to be within [INTMAX_MIN, INTMAX_MAX - 1] +/// for forward iteration (resp. [INTMAX_MIN + 1, INTMAX_MAX - 1] for reverse +/// iteration). +template ::value>> +auto enum_seq_inclusive(EnumT Begin, EnumT End) { + static_assert(enum_iteration_traits::is_iterable, + "Enum type is not marked as iterable."); + return iota_range(Begin, End, true); +} + +/// Iterate over an enum type from Begin to End inclusive, even when `EnumT` +/// is not marked as safely iterable by `enum_iteration_traits`. +/// Note: `enum_seq_inclusive` will generate each consecutive value, even if no +/// enumerator with that value exists. +/// Note: Begin and End values have to be within [INTMAX_MIN, INTMAX_MAX - 1] +/// for forward iteration (resp. [INTMAX_MIN + 1, INTMAX_MAX - 1] for reverse +/// iteration). +template ::value>> +auto enum_seq_inclusive(EnumT Begin, EnumT End, + force_iteration_on_noniterable_enum_t) { + return iota_range(Begin, End, true); +} + } // end namespace llvm #endif // LLVM_ADT_SEQUENCE_H diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -19,6 +19,7 @@ #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" @@ -1051,6 +1052,10 @@ DEFINE_TRANSPARENT_OPERAND_ACCESSORS(CmpInst, Value) +template <> struct enum_iteration_traits { + static constexpr bool is_iterable = true; +}; + /// A lightweight accessor for an operand bundle meant to be passed /// around by value. struct OperandBundleUse { diff --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h --- a/llvm/include/llvm/Support/MachineValueType.h +++ b/llvm/include/llvm/Support/MachineValueType.h @@ -1405,51 +1405,61 @@ /// SimpleValueType Iteration /// @{ static auto all_valuetypes() { - return seq_inclusive(MVT::FIRST_VALUETYPE, MVT::LAST_VALUETYPE); + return enum_seq_inclusive(MVT::FIRST_VALUETYPE, MVT::LAST_VALUETYPE, + force_iteration_on_noniterable_enum); } static auto integer_valuetypes() { - return seq_inclusive(MVT::FIRST_INTEGER_VALUETYPE, - MVT::LAST_INTEGER_VALUETYPE); + return enum_seq_inclusive(MVT::FIRST_INTEGER_VALUETYPE, + MVT::LAST_INTEGER_VALUETYPE, + force_iteration_on_noniterable_enum); } static auto fp_valuetypes() { - return seq_inclusive(MVT::FIRST_FP_VALUETYPE, MVT::LAST_FP_VALUETYPE); + return enum_seq_inclusive(MVT::FIRST_FP_VALUETYPE, MVT::LAST_FP_VALUETYPE, + force_iteration_on_noniterable_enum); } static auto vector_valuetypes() { - return seq_inclusive(MVT::FIRST_VECTOR_VALUETYPE, - MVT::LAST_VECTOR_VALUETYPE); + return enum_seq_inclusive(MVT::FIRST_VECTOR_VALUETYPE, + MVT::LAST_VECTOR_VALUETYPE, + force_iteration_on_noniterable_enum); } static auto fixedlen_vector_valuetypes() { - return seq_inclusive(MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE, - MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE); + return enum_seq_inclusive(MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE, + MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE, + force_iteration_on_noniterable_enum); } static auto scalable_vector_valuetypes() { - return seq_inclusive(MVT::FIRST_SCALABLE_VECTOR_VALUETYPE, - MVT::LAST_SCALABLE_VECTOR_VALUETYPE); + return enum_seq_inclusive(MVT::FIRST_SCALABLE_VECTOR_VALUETYPE, + MVT::LAST_SCALABLE_VECTOR_VALUETYPE, + force_iteration_on_noniterable_enum); } static auto integer_fixedlen_vector_valuetypes() { - return seq_inclusive(MVT::FIRST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE, - MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE); + return enum_seq_inclusive(MVT::FIRST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE, + MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE, + force_iteration_on_noniterable_enum); } static auto fp_fixedlen_vector_valuetypes() { - return seq_inclusive(MVT::FIRST_FP_FIXEDLEN_VECTOR_VALUETYPE, - MVT::LAST_FP_FIXEDLEN_VECTOR_VALUETYPE); + return enum_seq_inclusive(MVT::FIRST_FP_FIXEDLEN_VECTOR_VALUETYPE, + MVT::LAST_FP_FIXEDLEN_VECTOR_VALUETYPE, + force_iteration_on_noniterable_enum); } static auto integer_scalable_vector_valuetypes() { - return seq_inclusive(MVT::FIRST_INTEGER_SCALABLE_VECTOR_VALUETYPE, - MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE); + return enum_seq_inclusive(MVT::FIRST_INTEGER_SCALABLE_VECTOR_VALUETYPE, + MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE, + force_iteration_on_noniterable_enum); } static auto fp_scalable_vector_valuetypes() { - return seq_inclusive(MVT::FIRST_FP_SCALABLE_VECTOR_VALUETYPE, - MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE); + return enum_seq_inclusive(MVT::FIRST_FP_SCALABLE_VECTOR_VALUETYPE, + MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE, + force_iteration_on_noniterable_enum); } /// @} }; diff --git a/llvm/tools/llvm-exegesis/lib/X86/Target.cpp b/llvm/tools/llvm-exegesis/lib/X86/Target.cpp --- a/llvm/tools/llvm-exegesis/lib/X86/Target.cpp +++ b/llvm/tools/llvm-exegesis/lib/X86/Target.cpp @@ -918,8 +918,9 @@ continue; case X86::OperandType::OPERAND_COND_CODE: { Exploration = true; - auto CondCodes = - seq_inclusive(X86::CondCode::COND_O, X86::CondCode::LAST_VALID_COND); + auto CondCodes = enum_seq_inclusive(X86::CondCode::COND_O, + X86::CondCode::LAST_VALID_COND, + force_iteration_on_noniterable_enum); Choices.reserve(CondCodes.size()); for (int CondCode : CondCodes) Choices.emplace_back(MCOperand::createImm(CondCode)); diff --git a/llvm/unittests/ADT/SequenceTest.cpp b/llvm/unittests/ADT/SequenceTest.cpp --- a/llvm/unittests/ADT/SequenceTest.cpp +++ b/llvm/unittests/ADT/SequenceTest.cpp @@ -16,6 +16,7 @@ using namespace llvm; using testing::ElementsAre; +using testing::IsEmpty; namespace { @@ -68,17 +69,6 @@ EXPECT_EQ(Actual - (Actual + 2), -2); } -TEST(StrongIntTest, Enums) { - enum UntypedEnum { A = 3 }; - EXPECT_EQ(CheckedInt::from(A).to(), A); - - enum TypedEnum : uint32_t { B = 3 }; - EXPECT_EQ(CheckedInt::from(B).to(), B); - - enum class ScopedEnum : uint16_t { C = 3 }; - EXPECT_EQ(CheckedInt::from(ScopedEnum::C).to(), ScopedEnum::C); -} - #if defined(GTEST_HAS_DEATH_TEST) && !defined(NDEBUG) TEST(StrongIntDeathTest, OutOfBounds) { // Values above 'INTMAX_MAX' are not representable. @@ -215,4 +205,94 @@ EXPECT_EQ(Backward[2], 7); } -} // anonymous namespace +enum UntypedEnum { A = 3 }; +enum TypedEnum : uint32_t { B = 3 }; + +namespace X { +enum class ScopedEnum : uint16_t { C = 3 }; +} // namespace X + +struct S { + enum NestedEnum { D = 4 }; + enum NestedEnum2 { E = 5 }; + +private: + enum NestedEnum3 { F = 6 }; + friend struct llvm::enum_iteration_traits; + +public: + static auto getNestedEnum3() { return NestedEnum3::F; } +}; + +} // namespace + +namespace llvm { + +template <> struct enum_iteration_traits { + static constexpr bool is_iterable = true; +}; + +template <> struct enum_iteration_traits { + static constexpr bool is_iterable = true; +}; + +template <> struct enum_iteration_traits { + static constexpr bool is_iterable = true; +}; + +template <> struct enum_iteration_traits { + static constexpr bool is_iterable = true; +}; + +template <> struct enum_iteration_traits { + static constexpr bool is_iterable = true; +}; + +} // namespace llvm + +namespace { + +TEST(StrongIntTest, Enums) { + EXPECT_EQ(CheckedInt::from(A).to(), A); + EXPECT_EQ(CheckedInt::from(B).to(), B); + EXPECT_EQ(CheckedInt::from(X::ScopedEnum::C).to(), + X::ScopedEnum::C); +} + +TEST(SequenceTest, IterableEnums) { + EXPECT_THAT(enum_seq(UntypedEnum::A, UntypedEnum::A), IsEmpty()); + EXPECT_THAT(enum_seq_inclusive(UntypedEnum::A, UntypedEnum::A), + ElementsAre(UntypedEnum::A)); + + EXPECT_THAT(enum_seq(TypedEnum::B, TypedEnum::B), IsEmpty()); + EXPECT_THAT(enum_seq_inclusive(TypedEnum::B, TypedEnum::B), + ElementsAre(TypedEnum::B)); + + EXPECT_THAT(enum_seq(X::ScopedEnum::C, X::ScopedEnum::C), IsEmpty()); + EXPECT_THAT(enum_seq_inclusive(X::ScopedEnum::C, X::ScopedEnum::C), + ElementsAre(X::ScopedEnum::C)); + + EXPECT_THAT(enum_seq_inclusive(S::NestedEnum::D, S::NestedEnum::D), + ElementsAre(S::NestedEnum::D)); + EXPECT_THAT(enum_seq_inclusive(S::getNestedEnum3(), S::getNestedEnum3()), + ElementsAre(S::getNestedEnum3())); +} + +TEST(SequenceTest, NonIterableEnums) { + EXPECT_THAT(enum_seq(S::NestedEnum2::E, S::NestedEnum2::E, + force_iteration_on_noniterable_enum), + IsEmpty()); + EXPECT_THAT(enum_seq_inclusive(S::NestedEnum2::E, S::NestedEnum2::E, + force_iteration_on_noniterable_enum), + ElementsAre(S::NestedEnum2::E)); + + // Check that this also works with enums marked as iterable. + EXPECT_THAT(enum_seq(UntypedEnum::A, UntypedEnum::A, + force_iteration_on_noniterable_enum), + IsEmpty()); + EXPECT_THAT(enum_seq_inclusive(UntypedEnum::A, UntypedEnum::A, + force_iteration_on_noniterable_enum), + ElementsAre(UntypedEnum::A)); +} + +} // namespace diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -1572,8 +1572,8 @@ } TEST(ConstantRange, ICmp) { - for (auto Pred : seq_inclusive(CmpInst::Predicate::FIRST_ICMP_PREDICATE, - CmpInst::Predicate::LAST_ICMP_PREDICATE)) + for (auto Pred : enum_seq_inclusive(CmpInst::Predicate::FIRST_ICMP_PREDICATE, + CmpInst::Predicate::LAST_ICMP_PREDICATE)) ICmpTestImpl(Pred); } @@ -2531,8 +2531,9 @@ } TEST_F(ConstantRangeTest, areInsensitiveToSignednessOfICmpPredicate) { - for (auto Pred : seq_inclusive(ICmpInst::Predicate::FIRST_ICMP_PREDICATE, - ICmpInst::Predicate::LAST_ICMP_PREDICATE)) { + for (auto Pred : + enum_seq_inclusive(ICmpInst::Predicate::FIRST_ICMP_PREDICATE, + ICmpInst::Predicate::LAST_ICMP_PREDICATE)) { if (ICmpInst::isEquality(Pred)) continue; ICmpInst::Predicate FlippedSignednessPred = @@ -2548,8 +2549,9 @@ } TEST_F(ConstantRangeTest, areInsensitiveToSignednessOfInvertedICmpPredicate) { - for (auto Pred : seq_inclusive(ICmpInst::Predicate::FIRST_ICMP_PREDICATE, - ICmpInst::Predicate::LAST_ICMP_PREDICATE)) { + for (auto Pred : + enum_seq_inclusive(ICmpInst::Predicate::FIRST_ICMP_PREDICATE, + ICmpInst::Predicate::LAST_ICMP_PREDICATE)) { if (ICmpInst::isEquality(Pred)) continue; ICmpInst::Predicate InvertedFlippedSignednessPred = @@ -2567,8 +2569,9 @@ } TEST_F(ConstantRangeTest, getEquivalentPredWithFlippedSignedness) { - for (auto Pred : seq_inclusive(ICmpInst::Predicate::FIRST_ICMP_PREDICATE, - ICmpInst::Predicate::LAST_ICMP_PREDICATE)) { + for (auto Pred : + enum_seq_inclusive(ICmpInst::Predicate::FIRST_ICMP_PREDICATE, + ICmpInst::Predicate::LAST_ICMP_PREDICATE)) { if (ICmpInst::isEquality(Pred)) continue; testConstantRangeICmpPredEquivalence(