diff --git a/llvm/include/llvm/ADT/Sequence.h b/llvm/include/llvm/ADT/Sequence.h --- a/llvm/include/llvm/ADT/Sequence.h +++ b/llvm/include/llvm/ADT/Sequence.h @@ -15,46 +15,29 @@ #ifndef LLVM_ADT_SEQUENCE_H #define LLVM_ADT_SEQUENCE_H -#include //std::ptrdiff_t -#include //std::random_access_iterator_tag +#include // assert +#include // std::ptrdiff_t +#include // std::random_access_iterator_tag +#include // std::numeric_limits +#include // std::underlying_type, std::is_enum namespace llvm { namespace detail { -template struct iota_range_iterator { +template struct iota_range_iterator { using iterator_category = std::random_access_iterator_tag; using value_type = T; using difference_type = std::ptrdiff_t; using pointer = T *; using reference = T &; -private: - struct Forward { - static void increment(T &V) { ++V; } - static void decrement(T &V) { --V; } - static void offset(T &V, difference_type Offset) { V += Offset; } - static T add(const T &V, difference_type Offset) { return V + Offset; } - static difference_type difference(const T &A, const T &B) { return A - B; } - }; - - struct Reverse { - static void increment(T &V) { --V; } - static void decrement(T &V) { ++V; } - static void offset(T &V, difference_type Offset) { V -= Offset; } - static T add(const T &V, difference_type Offset) { return V - Offset; } - static difference_type difference(const T &A, const T &B) { return B - A; } - }; - - using Op = std::conditional_t; - -public: // default-constructible iota_range_iterator() = default; // copy-constructible iota_range_iterator(const iota_range_iterator &) = default; // value constructor - explicit iota_range_iterator(T Value) : Value(Value) {} + explicit iota_range_iterator(U Value) : Value(Value) {} // copy-assignable iota_range_iterator &operator=(const iota_range_iterator &) = default; // destructible @@ -83,8 +66,10 @@ } // Dereference - T operator*() const { return Value; } - T operator[](difference_type Offset) const { return Op::add(Value, Offset); } + T operator*() const { return static_cast(Value); } + T operator[](difference_type Offset) const { + return static_cast(Op::add(Value, Offset)); + } // Arithmetic iota_range_iterator operator+(difference_type Offset) const { @@ -132,46 +117,109 @@ } private: - T Value; + struct Forward { + static void increment(U &V) { ++V; } + static void decrement(U &V) { --V; } + static void offset(U &V, difference_type Offset) { V += Offset; } + static U add(const U &V, difference_type Offset) { return V + Offset; } + static difference_type difference(const U &A, const U &B) { + return difference_type(A) - difference_type(B); + } + }; + + struct Reverse { + static void increment(U &V) { --V; } + static void decrement(U &V) { ++V; } + static void offset(U &V, difference_type Offset) { V -= Offset; } + static U add(const U &V, difference_type Offset) { return V - Offset; } + static difference_type difference(const U &A, const U &B) { + return difference_type(B) - difference_type(A); + } + }; + + using Op = std::conditional_t; + + U Value; }; +// Providing std::type_identity for C++14. +template struct type_identity { using type = T; }; + } // namespace detail -template struct iota_range { - static_assert(std::is_integral::value, - "ValueT must be an integral type"); +template struct iota_range { +private: + using raw_type = typename std::conditional_t::value, + std::underlying_type, + detail::type_identity>::type; + + static raw_type compute_past_end(raw_type End, bool Inclusive) { + if (Inclusive) { + // This assertion forbids overflow of `PastEndValue`. + assert(End != std::numeric_limits::max() && + "Forbidden End value for seq_inclusive."); + return End + 1; + } + return End; + } + static raw_type raw(T Value) { return static_cast(Value); } - using value_type = ValueT; - using reference = ValueT &; - using const_reference = const ValueT &; - using iterator = detail::iota_range_iterator; + raw_type BeginValue; + raw_type PastEndValue; + +public: + using value_type = T; + using reference = T &; + using const_reference = const T &; + using iterator = detail::iota_range_iterator; using const_iterator = iterator; - using reverse_iterator = detail::iota_range_iterator; + using reverse_iterator = + detail::iota_range_iterator; using const_reverse_iterator = reverse_iterator; using difference_type = std::ptrdiff_t; using size_type = std::size_t; - value_type Begin; - value_type End; - - explicit iota_range(ValueT Begin, ValueT End) : Begin(Begin), End(End) {} + explicit iota_range(T Begin, T End, bool Inclusive) + : BeginValue(raw(Begin)), + PastEndValue(compute_past_end(raw(End), Inclusive)) { + assert(Begin <= End && "Begin must be less or equal to End."); + } - size_t size() const { return End - Begin; } - bool empty() const { return Begin == End; } + size_t size() const { return PastEndValue - BeginValue; } + bool empty() const { return BeginValue == PastEndValue; } - auto begin() const { return const_iterator(Begin); } - auto end() const { return const_iterator(End); } + auto begin() const { return const_iterator(BeginValue); } + auto end() const { return const_iterator(PastEndValue); } - auto rbegin() const { return const_reverse_iterator(End - 1); } - auto rend() const { return const_reverse_iterator(Begin - 1); } + auto rbegin() const { return const_reverse_iterator(PastEndValue - 1); } + auto rend() const { + assert(std::is_unsigned::value || + BeginValue != std::numeric_limits::min() && + "Forbidden Begin value for reverse iteration"); + return const_reverse_iterator(BeginValue - 1); + } private: - static_assert(std::is_same>::value, - "ValueT must not be const nor volatile"); + static_assert(std::is_integral::value || std::is_enum::value, + "T must be an integral or enum type"); + static_assert(std::is_same>::value, + "T must not be const nor volatile"); + static_assert(std::is_integral::value, + "raw_type must be an integral type"); }; -template auto seq(ValueT Begin, ValueT End) { - return iota_range(Begin, End); +/// Iterate over an integral/enum type from Begin up to - but not including - +/// End. +template auto seq(T Begin, T End) { + return iota_range(Begin, End, false); +} + +/// Iterate over an integral/enum type from Begin to End inclusive. +/// To prevent overflow, `End` must be different from the maximum value of the +/// storage type for `T`. That is, in case of `T` being an enum, `End` must be +/// different from the maximum value of `std::underlying_type::type`. +template auto seq_inclusive(T Begin, T End) { + return iota_range(Begin, End, true); } } // end namespace llvm diff --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h --- a/llvm/include/llvm/Support/MachineValueType.h +++ b/llvm/include/llvm/Support/MachineValueType.h @@ -14,6 +14,7 @@ #ifndef LLVM_SUPPORT_MACHINEVALUETYPE_H #define LLVM_SUPPORT_MACHINEVALUETYPE_H +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" @@ -1388,84 +1389,55 @@ /// returned as Other, otherwise they are invalid. static MVT getVT(Type *Ty, bool HandleUnknown = false); - private: - /// A simple iterator over the MVT::SimpleValueType enum. - struct mvt_iterator { - SimpleValueType VT; - - mvt_iterator(SimpleValueType VT) : VT(VT) {} - - MVT operator*() const { return VT; } - bool operator!=(const mvt_iterator &LHS) const { return VT != LHS.VT; } - - mvt_iterator& operator++() { - VT = (MVT::SimpleValueType)((int)VT + 1); - assert((int)VT <= MVT::MAX_ALLOWED_VALUETYPE && - "MVT iterator overflowed."); - return *this; - } - }; - - /// A range of the MVT::SimpleValueType enum. - using mvt_range = iterator_range; - public: /// SimpleValueType Iteration /// @{ - static mvt_range all_valuetypes() { - return mvt_range(MVT::FIRST_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_VALUETYPE + 1)); + static auto all_valuetypes() { + return seq_inclusive(MVT::FIRST_VALUETYPE, MVT::LAST_VALUETYPE); } - static mvt_range integer_valuetypes() { - return mvt_range(MVT::FIRST_INTEGER_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_INTEGER_VALUETYPE + 1)); + static auto integer_valuetypes() { + return seq_inclusive(MVT::FIRST_INTEGER_VALUETYPE, + MVT::LAST_INTEGER_VALUETYPE); } - static mvt_range fp_valuetypes() { - return mvt_range(MVT::FIRST_FP_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_FP_VALUETYPE + 1)); + static auto fp_valuetypes() { + return seq_inclusive(MVT::FIRST_FP_VALUETYPE, MVT::LAST_FP_VALUETYPE); } - static mvt_range vector_valuetypes() { - return mvt_range(MVT::FIRST_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_VECTOR_VALUETYPE + 1)); + static auto vector_valuetypes() { + return seq_inclusive(MVT::FIRST_VECTOR_VALUETYPE, + MVT::LAST_VECTOR_VALUETYPE); } - static mvt_range fixedlen_vector_valuetypes() { - return mvt_range( - MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE + 1)); + static auto fixedlen_vector_valuetypes() { + return seq_inclusive(MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE, + MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE); } - static mvt_range scalable_vector_valuetypes() { - return mvt_range( - MVT::FIRST_SCALABLE_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_SCALABLE_VECTOR_VALUETYPE + 1)); + static auto scalable_vector_valuetypes() { + return seq_inclusive(MVT::FIRST_SCALABLE_VECTOR_VALUETYPE, + MVT::LAST_SCALABLE_VECTOR_VALUETYPE); } - static mvt_range integer_fixedlen_vector_valuetypes() { - return mvt_range( - MVT::FIRST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE + 1)); + static auto integer_fixedlen_vector_valuetypes() { + return seq_inclusive(MVT::FIRST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE, + MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE); } - static mvt_range fp_fixedlen_vector_valuetypes() { - return mvt_range( - MVT::FIRST_FP_FIXEDLEN_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_FP_FIXEDLEN_VECTOR_VALUETYPE + 1)); + static auto fp_fixedlen_vector_valuetypes() { + return seq_inclusive(MVT::FIRST_FP_FIXEDLEN_VECTOR_VALUETYPE, + MVT::LAST_FP_FIXEDLEN_VECTOR_VALUETYPE); } - static mvt_range integer_scalable_vector_valuetypes() { - return mvt_range( - MVT::FIRST_INTEGER_SCALABLE_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE + 1)); + static auto integer_scalable_vector_valuetypes() { + return seq_inclusive(MVT::FIRST_INTEGER_SCALABLE_VECTOR_VALUETYPE, + MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE); } - static mvt_range fp_scalable_vector_valuetypes() { - return mvt_range( - MVT::FIRST_FP_SCALABLE_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE + 1)); + static auto fp_scalable_vector_valuetypes() { + return seq_inclusive(MVT::FIRST_FP_SCALABLE_VECTOR_VALUETYPE, + MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE); } /// @} }; diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -4629,8 +4629,7 @@ EVT InVT = InOp.getValueType(); if (InVT.getSizeInBits() != VT.getSizeInBits()) { EVT InEltVT = InVT.getVectorElementType(); - for (int i = MVT::FIRST_VECTOR_VALUETYPE, e = MVT::LAST_VECTOR_VALUETYPE; i < e; ++i) { - EVT FixedVT = (MVT::SimpleValueType)i; + for (EVT FixedVT : MVT::vector_valuetypes()) { EVT FixedEltVT = FixedVT.getVectorElementType(); if (TLI.isTypeLegal(FixedVT) && FixedVT.getSizeInBits() == VT.getSizeInBits() && @@ -5157,14 +5156,11 @@ if (!Scalable && Width == WidenEltWidth) return RetVT; - // See if there is larger legal integer than the element type to load/store. - unsigned VT; // Don't bother looking for an integer type if the vector is scalable, skip // to vector types. if (!Scalable) { - for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE; - VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) { - EVT MemVT((MVT::SimpleValueType) VT); + // See if there is larger legal integer than the element type to load/store. + for (EVT MemVT : reverse(MVT::integer_valuetypes())) { unsigned MemVTWidth = MemVT.getSizeInBits(); if (MemVT.getSizeInBits() <= WidenEltWidth) break; @@ -5185,9 +5181,7 @@ // See if there is a larger vector type to load/store that has the same vector // element type and is evenly divisible with the WidenVT. - for (VT = (unsigned)MVT::LAST_VECTOR_VALUETYPE; - VT >= (unsigned)MVT::FIRST_VECTOR_VALUETYPE; --VT) { - EVT MemVT = (MVT::SimpleValueType) VT; + for (EVT MemVT : reverse(MVT::vector_valuetypes())) { // Skip vector MVTs which don't match the scalable property of WidenVT. if (Scalable != MemVT.isScalableVector()) continue; diff --git a/llvm/tools/llvm-exegesis/lib/X86/Target.cpp b/llvm/tools/llvm-exegesis/lib/X86/Target.cpp --- a/llvm/tools/llvm-exegesis/lib/X86/Target.cpp +++ b/llvm/tools/llvm-exegesis/lib/X86/Target.cpp @@ -918,9 +918,9 @@ continue; case X86::OperandType::OPERAND_COND_CODE: { Exploration = true; - auto CondCodes = seq((int)X86::CondCode::COND_O, - 1 + (int)X86::CondCode::LAST_VALID_COND); - Choices.reserve(std::distance(CondCodes.begin(), CondCodes.end())); + auto CondCodes = + seq_inclusive(X86::CondCode::COND_O, X86::CondCode::LAST_VALID_COND); + Choices.reserve(CondCodes.size()); for (int CondCode : CondCodes) Choices.emplace_back(MCOperand::createImm(CondCode)); break; diff --git a/llvm/unittests/ADT/SequenceTest.cpp b/llvm/unittests/ADT/SequenceTest.cpp --- a/llvm/unittests/ADT/SequenceTest.cpp +++ b/llvm/unittests/ADT/SequenceTest.cpp @@ -7,12 +7,15 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/Sequence.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" #include using namespace llvm; +using testing::ElementsAre; + namespace { TEST(SequenceTest, Forward) { @@ -48,4 +51,104 @@ EXPECT_EQ(Backward[2], 7); } +enum class CharEnum : char { A = 1, B, C, D, E }; + +TEST(SequenceTest, ForwardIteration) { + EXPECT_THAT(seq_inclusive(CharEnum::C, CharEnum::E), + ElementsAre(CharEnum::C, CharEnum::D, CharEnum::E)); +} + +TEST(SequenceTest, BackwardIteration) { + EXPECT_THAT(reverse(seq_inclusive(CharEnum::B, CharEnum::D)), + ElementsAre(CharEnum::D, CharEnum::C, CharEnum::B)); +} + +using IntegralTypes = testing::Types; + +template class SequenceTest : public testing::Test { +public: + const T min = std::numeric_limits::min(); + const T minp1 = min + 1; + const T max = std::numeric_limits::max(); + const T maxm1 = max - 1; + + void checkIteration() const { + // Forward + EXPECT_THAT(seq(min, min), ElementsAre()); + EXPECT_THAT(seq(min, minp1), ElementsAre(min)); + EXPECT_THAT(seq(maxm1, max), ElementsAre(maxm1)); + EXPECT_THAT(seq(max, max), ElementsAre()); + // Reverse + if (std::is_unsigned::value) { + EXPECT_THAT(reverse(seq(min, min)), ElementsAre()); + EXPECT_THAT(reverse(seq(min, minp1)), ElementsAre(min)); + } + EXPECT_THAT(reverse(seq(maxm1, max)), ElementsAre(maxm1)); + EXPECT_THAT(reverse(seq(max, max)), ElementsAre()); + // Inclusive + EXPECT_THAT(seq_inclusive(min, min), ElementsAre(min)); + EXPECT_THAT(seq_inclusive(min, minp1), ElementsAre(min, minp1)); + EXPECT_THAT(seq_inclusive(maxm1, maxm1), ElementsAre(maxm1)); + // Inclusive Reverse + if (std::is_unsigned::value) { + EXPECT_THAT(reverse(seq_inclusive(min, min)), ElementsAre(min)); + EXPECT_THAT(reverse(seq_inclusive(min, minp1)), ElementsAre(minp1, min)); + } + EXPECT_THAT(reverse(seq_inclusive(maxm1, maxm1)), ElementsAre(maxm1)); + } + + void checkIterators() const { + auto checkValidIterators = [](auto sequence) { + EXPECT_LE(sequence.begin(), sequence.end()); + }; + checkValidIterators(seq(min, min)); + checkValidIterators(seq(max, max)); + checkValidIterators(seq_inclusive(min, min)); + checkValidIterators(seq_inclusive(maxm1, maxm1)); + } +}; +TYPED_TEST_SUITE(SequenceTest, IntegralTypes); +TYPED_TEST(SequenceTest, Boundaries) { + this->checkIteration(); + this->checkIterators(); +} + +// DEATH tests +#if !defined(NDEBUG) +template class SequenceDeathTest : public SequenceTest { +public: + using SequenceTest::min; + using SequenceTest::minp1; + using SequenceTest::max; + using SequenceTest::maxm1; + + void checkInvalidOrder() const { + EXPECT_DEATH(seq(max, min), "Begin must be less or equal to End."); + EXPECT_DEATH(seq(minp1, min), "Begin must be less or equal to End."); + EXPECT_DEATH(seq_inclusive(maxm1, min), + "Begin must be less or equal to End."); + EXPECT_DEATH(seq_inclusive(minp1, min), + "Begin must be less or equal to End."); + } + void checkInvalidValues() const { + EXPECT_DEATH(seq_inclusive(min, max), + "Forbidden End value for seq_inclusive."); + EXPECT_DEATH(seq_inclusive(minp1, max), + "Forbidden End value for seq_inclusive."); + if (std::is_signed::value) { + EXPECT_DEATH(reverse(seq(min, min)), + "Forbidden Begin value for reverse iteration"); + EXPECT_DEATH(reverse(seq_inclusive(min, min)), + "Forbidden Begin value for reverse iteration"); + } + } +}; +TYPED_TEST_SUITE(SequenceDeathTest, IntegralTypes); +TYPED_TEST(SequenceDeathTest, DeathTests) { + this->checkInvalidOrder(); + this->checkInvalidValues(); +} +#endif // !defined(NDEBUG) + } // anonymous namespace diff --git a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp --- a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp +++ b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp @@ -18,7 +18,7 @@ namespace { TEST(ScalableVectorMVTsTest, IntegerMVTs) { - for (auto VecTy : MVT::integer_scalable_vector_valuetypes()) { + for (MVT VecTy : MVT::integer_scalable_vector_valuetypes()) { ASSERT_TRUE(VecTy.isValid()); ASSERT_TRUE(VecTy.isInteger()); ASSERT_TRUE(VecTy.isVector()); @@ -30,7 +30,7 @@ } TEST(ScalableVectorMVTsTest, FloatMVTs) { - for (auto VecTy : MVT::fp_scalable_vector_valuetypes()) { + for (MVT VecTy : MVT::fp_scalable_vector_valuetypes()) { ASSERT_TRUE(VecTy.isValid()); ASSERT_TRUE(VecTy.isFloatingPoint()); ASSERT_TRUE(VecTy.isVector()); diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -1551,9 +1551,9 @@ } TEST(ConstantRange, ICmp) { - for (auto Pred : seq(CmpInst::Predicate::FIRST_ICMP_PREDICATE, - 1 + CmpInst::Predicate::LAST_ICMP_PREDICATE)) - ICmpTestImpl((CmpInst::Predicate)Pred); + for (auto Pred : seq_inclusive(CmpInst::Predicate::FIRST_ICMP_PREDICATE, + CmpInst::Predicate::LAST_ICMP_PREDICATE)) + ICmpTestImpl(Pred); } TEST(ConstantRange, MakeGuaranteedNoWrapRegion) { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1572,7 +1572,7 @@ unsigned endPos = map.getResults().back().cast().getPosition(); AffineExpr expr; SmallVector dynamicDims; - for (auto dim : llvm::seq(startPos, endPos + 1)) { + for (auto dim : llvm::seq_inclusive(startPos, endPos)) { dynamicDims.push_back(builder.createOrFold(loc, src, dim)); AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos); expr = (expr ? expr * currExpr : currExpr); @@ -1605,7 +1605,7 @@ map.value().getResults().front().cast().getPosition(); unsigned endPos = map.value().getResults().back().cast().getPosition(); - for (auto dim : llvm::seq(startPos, endPos + 1)) { + for (auto dim : llvm::seq_inclusive(startPos, endPos)) { expandedDimToCollapsedDim[dim] = map.index(); } }