diff --git a/llvm/include/llvm/ADT/Sequence.h b/llvm/include/llvm/ADT/Sequence.h --- a/llvm/include/llvm/ADT/Sequence.h +++ b/llvm/include/llvm/ADT/Sequence.h @@ -15,163 +15,218 @@ #ifndef LLVM_ADT_SEQUENCE_H #define LLVM_ADT_SEQUENCE_H -#include //std::ptrdiff_t -#include //std::random_access_iterator_tag +#include // assert +#include // std::ptrdiff_t +#include // std::random_access_iterator_tag +#include // std::numeric_limits +#include // std::underlying_type, std::is_enum + +#include "llvm/Support/MathExtras.h" namespace llvm { namespace detail { -template struct iota_range_iterator { - using iterator_category = std::random_access_iterator_tag; - using value_type = T; - using difference_type = std::ptrdiff_t; - using pointer = T *; - using reference = T &; +// Returns whether a value of type U can be represented with type T. +template bool canTypeFitValue(const U value) { + const intmax_t botT = intmax_t(std::numeric_limits::min()); + const intmax_t botU = intmax_t(std::numeric_limits::min()); + const uintmax_t topT = uintmax_t(std::numeric_limits::max()); + const uintmax_t topU = uintmax_t(std::numeric_limits::max()); + return !((botT > botU && value < static_cast(botT)) || + (topT < topU && value > static_cast(topT))); +} -private: - struct Forward { - static void increment(T &V) { ++V; } - static void decrement(T &V) { --V; } - static void offset(T &V, difference_type Offset) { V += Offset; } - static T add(const T &V, difference_type Offset) { return V + Offset; } - static difference_type difference(const T &A, const T &B) { return A - B; } - }; - - struct Reverse { - static void increment(T &V) { --V; } - static void decrement(T &V) { ++V; } - static void offset(T &V, difference_type Offset) { V -= Offset; } - static T add(const T &V, difference_type Offset) { return V - Offset; } - static difference_type difference(const T &A, const T &B) { return B - A; } - }; - - using Op = std::conditional_t; - -public: - // default-constructible - iota_range_iterator() = default; - // copy-constructible - iota_range_iterator(const iota_range_iterator &) = default; - // value constructor - explicit iota_range_iterator(T Value) : Value(Value) {} - // copy-assignable - iota_range_iterator &operator=(const iota_range_iterator &) = default; - // destructible - ~iota_range_iterator() = default; - - // Can be compared for equivalence using the equality/inequality operators, - bool operator!=(const iota_range_iterator &RHS) const { - return Value != RHS.Value; - } - bool operator==(const iota_range_iterator &RHS) const { - return Value == RHS.Value; +// An integer type that asserts when: +// - constructed from a value that doesn't fit into intmax_t, +// - casted to a type that cannot hold the current value, +// - its internal representation overflows. +struct StrongInt { + // Integral constructor, asserts if Value cannot be represented as intmax_t. + template ::value, bool> = 0> + static StrongInt from(Integral FromValue) { + if (!canTypeFitValue(FromValue)) + assertOutOfBounds(); + StrongInt Result; + Result.Value = static_cast(FromValue); + return Result; } - // Comparison - bool operator<(const iota_range_iterator &Other) const { - return Op::difference(Value, Other.Value) < 0; - } - bool operator<=(const iota_range_iterator &Other) const { - return Op::difference(Value, Other.Value) <= 0; - } - bool operator>(const iota_range_iterator &Other) const { - return Op::difference(Value, Other.Value) > 0; - } - bool operator>=(const iota_range_iterator &Other) const { - return Op::difference(Value, Other.Value) >= 0; + // Enum constructor, asserts if Value cannot be represented as intmax_t. + template ::value, bool> = 0> + static StrongInt from(Enum FromValue) { + using type = typename std::underlying_type::type; + return from(static_cast(FromValue)); } - // Dereference - T operator*() const { return Value; } - T operator[](difference_type Offset) const { return Op::add(Value, Offset); } + // Equality + bool operator==(const StrongInt &O) const { return Value == O.Value; } + bool operator!=(const StrongInt &O) const { return Value != O.Value; } - // Arithmetic - iota_range_iterator operator+(difference_type Offset) const { - return {Op::add(Value, Offset)}; - } - iota_range_iterator operator-(difference_type Offset) const { - return {Op::add(Value, -Offset)}; + StrongInt safe_add(intmax_t Offset) const { + StrongInt Result; + if (AddOverflow(Value, Offset, Result.Value)) + assertOutOfBounds(); + return Result; } - // Iterator difference - difference_type operator-(const iota_range_iterator &Other) const { - return Op::difference(Value, Other.Value); + intmax_t safe_diff(StrongInt Other) const { + intmax_t Result; + if (SubOverflow(Value, Other.Value, Result)) + assertOutOfBounds(); + return Result; } - // Pre/Post Increment - iota_range_iterator &operator++() { - Op::increment(Value); - return *this; + // Convert to integral, asserts if Value cannot be represented as Integral. + template ::value, bool> = 0> + Integral to() const { + if (!canTypeFitValue(Value)) + assertOutOfBounds(); + return static_cast(Value); } - iota_range_iterator operator++(int) { - iota_range_iterator Tmp = *this; - Op::increment(Value); - return Tmp; + + // Convert to enum, asserts if Value cannot be represented as Enum's + // underlying type. + template ::value, bool> = 0> + Enum to() const { + using type = typename std::underlying_type::type; + return Enum(to()); } - // Pre/Post Decrement - iota_range_iterator &operator--() { - Op::decrement(Value); - return *this; +private: + static void assertOutOfBounds() { assert(false && "Out of bounds"); } + + intmax_t Value; +}; + +template struct SafeIntIterator { + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = intmax_t; + using pointer = T *; + using reference = T &; + + // Construct from T. + explicit SafeIntIterator(T Value) : SI(StrongInt::from(Value)) {} + // Construct from other direction. + SafeIntIterator(const SafeIntIterator &O) : SI(O.SI) {} + + // Dereference + value_type operator*() const { return SI.to(); } + // Indexing + value_type operator[](intmax_t Offset) const { return *(*this + Offset); } + + // Can be compared for equivalence using the equality/inequality operators. + bool operator==(const SafeIntIterator &O) const { return SI == O.SI; } + bool operator!=(const SafeIntIterator &O) const { return SI != O.SI; } + // Comparison + bool operator<(const SafeIntIterator &O) const { return (*this - O) < 0; } + bool operator>(const SafeIntIterator &O) const { return (*this - O) > 0; } + bool operator<=(const SafeIntIterator &O) const { return (*this - O) <= 0; } + bool operator>=(const SafeIntIterator &O) const { return (*this - O) >= 0; } + + // Pre Increment/Decrement + void operator++() { offset(1); } + void operator--() { offset(-1); } + + // Post Increment/Decrement + SafeIntIterator operator++(int) { + const auto Copy = *this; + ++*this; + return Copy; } - iota_range_iterator operator--(int) { - iota_range_iterator Tmp = *this; - Op::decrement(Value); - return Tmp; + SafeIntIterator operator--(int) { + const auto Copy = *this; + --*this; + return Copy; } // Compound assignment operators - iota_range_iterator &operator+=(difference_type Offset) { - Op::offset(Value, Offset); - return *this; - } - iota_range_iterator &operator-=(difference_type Offset) { - Op::offset(Value, -Offset); - return *this; + void operator+=(intmax_t Offset) { offset(Offset); } + void operator-=(intmax_t Offset) { offset(-Offset); } + + // Arithmetic + SafeIntIterator operator+(intmax_t Offset) const { return add(Offset); } + SafeIntIterator operator-(intmax_t Offset) const { return add(-Offset); } + + // Difference + intmax_t operator-(const SafeIntIterator &O) const { + return IsReverse ? O.SI.safe_diff(SI) : SI.safe_diff(O.SI); } private: - T Value; + SafeIntIterator(const StrongInt &SI) : SI(SI) {} + + StrongInt add(intmax_t Offset) const { + return SI.safe_add(IsReverse ? -Offset : Offset); + } + + void offset(intmax_t Offset) { SI = add(Offset); } + + StrongInt SI; + + // To allow construction from the other direction. + template friend struct SafeIntIterator; }; } // namespace detail -template struct iota_range { - static_assert(std::is_integral::value, - "ValueT must be an integral type"); - - using value_type = ValueT; - using reference = ValueT &; - using const_reference = const ValueT &; - using iterator = detail::iota_range_iterator; +template struct iota_range { + using value_type = T; + using reference = T &; + using const_reference = const T &; + using iterator = detail::SafeIntIterator; using const_iterator = iterator; - using reverse_iterator = detail::iota_range_iterator; + using reverse_iterator = detail::SafeIntIterator; using const_reverse_iterator = reverse_iterator; - using difference_type = std::ptrdiff_t; + using difference_type = intmax_t; using size_type = std::size_t; - value_type Begin; - value_type End; - - explicit iota_range(ValueT Begin, ValueT End) : Begin(Begin), End(End) {} + explicit iota_range(T Begin, T End, bool Inclusive) + : BeginValue(Begin), PastEndValue(End) { + assert(Begin <= End && "Begin must be less or equal to End."); + if (Inclusive) + ++PastEndValue; + } - size_t size() const { return End - Begin; } - bool empty() const { return Begin == End; } + size_t size() const { return PastEndValue - BeginValue; } + bool empty() const { return BeginValue == PastEndValue; } - auto begin() const { return const_iterator(Begin); } - auto end() const { return const_iterator(End); } + auto begin() const { return const_iterator(BeginValue); } + auto end() const { return const_iterator(PastEndValue); } - auto rbegin() const { return const_reverse_iterator(End - 1); } - auto rend() const { return const_reverse_iterator(Begin - 1); } + auto rbegin() const { return const_reverse_iterator(PastEndValue - 1); } + auto rend() const { return const_reverse_iterator(BeginValue - 1); } private: - static_assert(std::is_same>::value, - "ValueT must not be const nor volatile"); + static_assert(std::is_integral::value || std::is_enum::value, + "T must be an integral or enum type"); + static_assert(std::is_same>::value, + "T must not be const nor volatile"); + + iterator BeginValue; + iterator PastEndValue; }; -template auto seq(ValueT Begin, ValueT End) { - return iota_range(Begin, End); +/// Iterate over an integral/enum type from Begin up to - but not including - +/// End. +/// Note on enum iteration: `seq` will generate each consecutive value, even if +/// no enumerator with that value exists. +template auto seq(T Begin, T End) { + return iota_range(Begin, End, false); +} + +/// Iterate over an integral/enum type from Begin to End inclusive. +/// Note on enum iteration: `seq_inclusive` will generate each consecutive +/// value, even if no enumerator with that value exists. +/// To prevent overflow, `End` must be different from INTMAX_MAX if T is signed +/// (resp. UINTMAX_MAX if T is unsigned). +template auto seq_inclusive(T Begin, T End) { + return iota_range(Begin, End, true); } } // end namespace llvm diff --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h --- a/llvm/include/llvm/Support/MachineValueType.h +++ b/llvm/include/llvm/Support/MachineValueType.h @@ -14,6 +14,7 @@ #ifndef LLVM_SUPPORT_MACHINEVALUETYPE_H #define LLVM_SUPPORT_MACHINEVALUETYPE_H +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" @@ -1398,84 +1399,55 @@ /// returned as Other, otherwise they are invalid. static MVT getVT(Type *Ty, bool HandleUnknown = false); - private: - /// A simple iterator over the MVT::SimpleValueType enum. - struct mvt_iterator { - SimpleValueType VT; - - mvt_iterator(SimpleValueType VT) : VT(VT) {} - - MVT operator*() const { return VT; } - bool operator!=(const mvt_iterator &LHS) const { return VT != LHS.VT; } - - mvt_iterator& operator++() { - VT = (MVT::SimpleValueType)((int)VT + 1); - assert((int)VT <= MVT::MAX_ALLOWED_VALUETYPE && - "MVT iterator overflowed."); - return *this; - } - }; - - /// A range of the MVT::SimpleValueType enum. - using mvt_range = iterator_range; - public: /// SimpleValueType Iteration /// @{ - static mvt_range all_valuetypes() { - return mvt_range(MVT::FIRST_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_VALUETYPE + 1)); + static auto all_valuetypes() { + return seq_inclusive(MVT::FIRST_VALUETYPE, MVT::LAST_VALUETYPE); } - static mvt_range integer_valuetypes() { - return mvt_range(MVT::FIRST_INTEGER_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_INTEGER_VALUETYPE + 1)); + static auto integer_valuetypes() { + return seq_inclusive(MVT::FIRST_INTEGER_VALUETYPE, + MVT::LAST_INTEGER_VALUETYPE); } - static mvt_range fp_valuetypes() { - return mvt_range(MVT::FIRST_FP_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_FP_VALUETYPE + 1)); + static auto fp_valuetypes() { + return seq_inclusive(MVT::FIRST_FP_VALUETYPE, MVT::LAST_FP_VALUETYPE); } - static mvt_range vector_valuetypes() { - return mvt_range(MVT::FIRST_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_VECTOR_VALUETYPE + 1)); + static auto vector_valuetypes() { + return seq_inclusive(MVT::FIRST_VECTOR_VALUETYPE, + MVT::LAST_VECTOR_VALUETYPE); } - static mvt_range fixedlen_vector_valuetypes() { - return mvt_range( - MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE + 1)); + static auto fixedlen_vector_valuetypes() { + return seq_inclusive(MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE, + MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE); } - static mvt_range scalable_vector_valuetypes() { - return mvt_range( - MVT::FIRST_SCALABLE_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_SCALABLE_VECTOR_VALUETYPE + 1)); + static auto scalable_vector_valuetypes() { + return seq_inclusive(MVT::FIRST_SCALABLE_VECTOR_VALUETYPE, + MVT::LAST_SCALABLE_VECTOR_VALUETYPE); } - static mvt_range integer_fixedlen_vector_valuetypes() { - return mvt_range( - MVT::FIRST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE + 1)); + static auto integer_fixedlen_vector_valuetypes() { + return seq_inclusive(MVT::FIRST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE, + MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE); } - static mvt_range fp_fixedlen_vector_valuetypes() { - return mvt_range( - MVT::FIRST_FP_FIXEDLEN_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_FP_FIXEDLEN_VECTOR_VALUETYPE + 1)); + static auto fp_fixedlen_vector_valuetypes() { + return seq_inclusive(MVT::FIRST_FP_FIXEDLEN_VECTOR_VALUETYPE, + MVT::LAST_FP_FIXEDLEN_VECTOR_VALUETYPE); } - static mvt_range integer_scalable_vector_valuetypes() { - return mvt_range( - MVT::FIRST_INTEGER_SCALABLE_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE + 1)); + static auto integer_scalable_vector_valuetypes() { + return seq_inclusive(MVT::FIRST_INTEGER_SCALABLE_VECTOR_VALUETYPE, + MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE); } - static mvt_range fp_scalable_vector_valuetypes() { - return mvt_range( - MVT::FIRST_FP_SCALABLE_VECTOR_VALUETYPE, - (MVT::SimpleValueType)(MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE + 1)); + static auto fp_scalable_vector_valuetypes() { + return seq_inclusive(MVT::FIRST_FP_SCALABLE_VECTOR_VALUETYPE, + MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE); } /// @} }; diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -4634,8 +4634,7 @@ EVT InVT = InOp.getValueType(); if (InVT.getSizeInBits() != VT.getSizeInBits()) { EVT InEltVT = InVT.getVectorElementType(); - for (int i = MVT::FIRST_VECTOR_VALUETYPE, e = MVT::LAST_VECTOR_VALUETYPE; i < e; ++i) { - EVT FixedVT = (MVT::SimpleValueType)i; + for (EVT FixedVT : MVT::vector_valuetypes()) { EVT FixedEltVT = FixedVT.getVectorElementType(); if (TLI.isTypeLegal(FixedVT) && FixedVT.getSizeInBits() == VT.getSizeInBits() && @@ -5162,14 +5161,11 @@ if (!Scalable && Width == WidenEltWidth) return RetVT; - // See if there is larger legal integer than the element type to load/store. - unsigned VT; // Don't bother looking for an integer type if the vector is scalable, skip // to vector types. if (!Scalable) { - for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE; - VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) { - EVT MemVT((MVT::SimpleValueType) VT); + // See if there is larger legal integer than the element type to load/store. + for (EVT MemVT : reverse(MVT::integer_valuetypes())) { unsigned MemVTWidth = MemVT.getSizeInBits(); if (MemVT.getSizeInBits() <= WidenEltWidth) break; @@ -5190,9 +5186,7 @@ // See if there is a larger vector type to load/store that has the same vector // element type and is evenly divisible with the WidenVT. - for (VT = (unsigned)MVT::LAST_VECTOR_VALUETYPE; - VT >= (unsigned)MVT::FIRST_VECTOR_VALUETYPE; --VT) { - EVT MemVT = (MVT::SimpleValueType) VT; + for (EVT MemVT : reverse(MVT::vector_valuetypes())) { // Skip vector MVTs which don't match the scalable property of WidenVT. if (Scalable != MemVT.isScalableVector()) continue; diff --git a/llvm/tools/llvm-exegesis/lib/X86/Target.cpp b/llvm/tools/llvm-exegesis/lib/X86/Target.cpp --- a/llvm/tools/llvm-exegesis/lib/X86/Target.cpp +++ b/llvm/tools/llvm-exegesis/lib/X86/Target.cpp @@ -918,9 +918,9 @@ continue; case X86::OperandType::OPERAND_COND_CODE: { Exploration = true; - auto CondCodes = seq((int)X86::CondCode::COND_O, - 1 + (int)X86::CondCode::LAST_VALID_COND); - Choices.reserve(std::distance(CondCodes.begin(), CondCodes.end())); + auto CondCodes = + seq_inclusive(X86::CondCode::COND_O, X86::CondCode::LAST_VALID_COND); + Choices.reserve(CondCodes.size()); for (int CondCode : CondCodes) Choices.emplace_back(MCOperand::createImm(CondCode)); break; diff --git a/llvm/tools/llvm-reduce/deltas/ReduceAttributes.cpp b/llvm/tools/llvm-reduce/deltas/ReduceAttributes.cpp --- a/llvm/tools/llvm-reduce/deltas/ReduceAttributes.cpp +++ b/llvm/tools/llvm-reduce/deltas/ReduceAttributes.cpp @@ -84,7 +84,8 @@ AttrPtrVecVecTy &AttributeSetsToPreserve) { assert(AttributeSetsToPreserve.empty() && "Should not be sharing vectors."); AttributeSetsToPreserve.reserve(AL.getNumAttrSets()); - for (unsigned SetIdx : seq(AL.index_begin(), AL.index_end())) { + for (unsigned SetIdx = AL.index_begin(), SetEndIdx = AL.index_end(); + SetIdx != SetEndIdx; ++SetIdx) { AttrPtrIdxVecVecTy AttributesToPreserve; AttributesToPreserve.first = SetIdx; visitAttributeSet(AL.getAttributes(AttributesToPreserve.first), diff --git a/llvm/unittests/ADT/SequenceTest.cpp b/llvm/unittests/ADT/SequenceTest.cpp --- a/llvm/unittests/ADT/SequenceTest.cpp +++ b/llvm/unittests/ADT/SequenceTest.cpp @@ -7,30 +7,199 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/Sequence.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" -#include +#include +#include +#include using namespace llvm; +using testing::ElementsAre; + namespace { -TEST(SequenceTest, Forward) { - int X = 0; - for (int I : seq(0, 10)) { - EXPECT_EQ(X, I); - ++X; - } - EXPECT_EQ(10, X); +using detail::canTypeFitValue; +using detail::StrongInt; + +using IntegralTypes = testing::Types; + +template class StrongIntTest : public testing::Test {}; +TYPED_TEST_SUITE(StrongIntTest, IntegralTypes); +TYPED_TEST(StrongIntTest, Operations) { + using T = TypeParam; + auto Max = std::numeric_limits::max(); + auto Min = std::numeric_limits::min(); + + // We bail out for types that are not entirely representable within intmax_t. + if (!canTypeFitValue(Max) || !canTypeFitValue(Min)) + return; + + // All representable values convert back and forth. + EXPECT_EQ(StrongInt::from(Min).template to(), Min); + EXPECT_EQ(StrongInt::from(Max).template to(), Max); + + // Addition -2, -1, 0, 1, 2. + const T Expected = Max / 2; + const StrongInt Actual = StrongInt::from(Expected); + EXPECT_EQ(Actual.safe_add(-2).template to(), Expected - 2); + EXPECT_EQ(Actual.safe_add(-1).template to(), Expected - 1); + EXPECT_EQ(Actual.safe_add(0).template to(), Expected); + EXPECT_EQ(Actual.safe_add(1).template to(), Expected + 1); + EXPECT_EQ(Actual.safe_add(2).template to(), Expected + 2); + + // EQ/NEQ + EXPECT_EQ(Actual, Actual); + EXPECT_NE(Actual, Actual.safe_add(1)); + + // Difference + EXPECT_EQ(Actual.safe_diff(Actual), 0); + EXPECT_EQ(Actual.safe_add(1).safe_diff(Actual), 1); + EXPECT_EQ(Actual.safe_diff(Actual.safe_add(2)), -2); +} + +TEST(StrongIntTest, Enums) { + enum UntypedEnum { A = 3 }; + EXPECT_EQ(StrongInt::from(A).to(), A); + + enum TypedEnum : uint32_t { B = 3 }; + EXPECT_EQ(StrongInt::from(B).to(), B); + + enum class ScopedEnum : uint16_t { C = 3 }; + EXPECT_EQ(StrongInt::from(ScopedEnum::C).to(), ScopedEnum::C); +} + +#if defined(GTEST_HAS_DEATH_TEST) && !defined(NDEBUG) +TEST(StrongIntDeathTest, OutOfBounds) { + // Values above 'INTMAX_MAX' are not representable. + EXPECT_DEATH(StrongInt::from(INTMAX_MAX + 1ULL), "Out of bounds"); + EXPECT_DEATH(StrongInt::from(UINTMAX_MAX), "Out of bounds"); + // Casting to narrower type asserts when out of bounds. + EXPECT_DEATH(StrongInt::from(-1).to(), "Out of bounds"); + EXPECT_DEATH(StrongInt::from(256).to(), "Out of bounds"); + // Operations leading to intmax_t overflow assert. + EXPECT_DEATH(StrongInt::from(INTMAX_MAX).safe_add(1), "Out of bounds"); + EXPECT_DEATH(StrongInt::from(INTMAX_MIN).safe_add(-1), "Out of bounds"); + EXPECT_DEATH( + StrongInt::from(INTMAX_MIN).safe_diff(StrongInt::from(INTMAX_MAX)), + "Out of bounds"); +} +#endif + +TEST(SafeIntIteratorTest, Operations) { + detail::SafeIntIterator Forward(0); + detail::SafeIntIterator Reverse(0); + + const auto SetToZero = [&]() { + Forward = detail::SafeIntIterator(0); + Reverse = detail::SafeIntIterator(0); + }; + + // Equality / Comparisons + SetToZero(); + EXPECT_EQ(Forward, Forward); + EXPECT_LT(Forward - 1, Forward); + EXPECT_LE(Forward, Forward); + EXPECT_LE(Forward - 1, Forward); + EXPECT_GT(Forward + 1, Forward); + EXPECT_GE(Forward, Forward); + EXPECT_GE(Forward + 1, Forward); + + EXPECT_EQ(Reverse, Reverse); + EXPECT_LT(Reverse - 1, Reverse); + EXPECT_LE(Reverse, Reverse); + EXPECT_LE(Reverse - 1, Reverse); + EXPECT_GT(Reverse + 1, Reverse); + EXPECT_GE(Reverse, Reverse); + EXPECT_GE(Reverse + 1, Reverse); + + // Dereference + SetToZero(); + EXPECT_EQ(*Forward, 0); + EXPECT_EQ(*Reverse, 0); + + // Indexing + SetToZero(); + EXPECT_EQ(Forward[2], 2); + EXPECT_EQ(Reverse[2], -2); + + // Pre-increment + SetToZero(); + ++Forward; + EXPECT_EQ(*Forward, 1); + ++Reverse; + EXPECT_EQ(*Reverse, -1); + + // Pre-decrement + SetToZero(); + --Forward; + EXPECT_EQ(*Forward, -1); + --Reverse; + EXPECT_EQ(*Reverse, 1); + + // Post-increment + SetToZero(); + EXPECT_EQ(*(Forward++), 0); + EXPECT_EQ(*Forward, 1); + EXPECT_EQ(*(Reverse++), 0); + EXPECT_EQ(*Reverse, -1); + + // Post-decrement + SetToZero(); + EXPECT_EQ(*(Forward--), 0); + EXPECT_EQ(*Forward, -1); + EXPECT_EQ(*(Reverse--), 0); + EXPECT_EQ(*Reverse, 1); + + // Compound assignment operators + SetToZero(); + Forward += 1; + EXPECT_EQ(*Forward, 1); + Reverse += 1; + EXPECT_EQ(*Reverse, -1); + SetToZero(); + Forward -= 2; + EXPECT_EQ(*Forward, -2); + Reverse -= 2; + EXPECT_EQ(*Reverse, 2); + + // Arithmetic + SetToZero(); + EXPECT_EQ(*(Forward + 3), 3); + EXPECT_EQ(*(Reverse + 3), -3); + SetToZero(); + EXPECT_EQ(*(Forward - 4), -4); + EXPECT_EQ(*(Reverse - 4), 4); + + // Difference + SetToZero(); + EXPECT_EQ(Forward - Forward, 0); + EXPECT_EQ(Reverse - Reverse, 0); + EXPECT_EQ((Forward + 1) - Forward, 1); + EXPECT_EQ(Forward - (Forward + 1), -1); + EXPECT_EQ((Reverse + 1) - Reverse, 1); + EXPECT_EQ(Reverse - (Reverse + 1), -1); } -TEST(SequenceTest, Backward) { - int X = 9; - for (int I : reverse(seq(0, 10))) { - EXPECT_EQ(X, I); - --X; - } - EXPECT_EQ(-1, X); +TEST(SequenceTest, Iteration) { + EXPECT_THAT(seq(-4, 5), ElementsAre(-4, -3, -2, -1, 0, 1, 2, 3, 4)); + EXPECT_THAT(reverse(seq(-4, 5)), ElementsAre(4, 3, 2, 1, 0, -1, -2, -3, -4)); + + EXPECT_THAT(seq_inclusive(-4, 5), + ElementsAre(-4, -3, -2, -1, 0, 1, 2, 3, 4, 5)); + EXPECT_THAT(reverse(seq_inclusive(-4, 5)), + ElementsAre(5, 4, 3, 2, 1, 0, -1, -2, -3, -4)); } TEST(SequenceTest, Distance) { diff --git a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp --- a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp +++ b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp @@ -18,7 +18,7 @@ namespace { TEST(ScalableVectorMVTsTest, IntegerMVTs) { - for (auto VecTy : MVT::integer_scalable_vector_valuetypes()) { + for (MVT VecTy : MVT::integer_scalable_vector_valuetypes()) { ASSERT_TRUE(VecTy.isValid()); ASSERT_TRUE(VecTy.isInteger()); ASSERT_TRUE(VecTy.isVector()); @@ -30,7 +30,7 @@ } TEST(ScalableVectorMVTsTest, FloatMVTs) { - for (auto VecTy : MVT::fp_scalable_vector_valuetypes()) { + for (MVT VecTy : MVT::fp_scalable_vector_valuetypes()) { ASSERT_TRUE(VecTy.isValid()); ASSERT_TRUE(VecTy.isFloatingPoint()); ASSERT_TRUE(VecTy.isVector()); diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -1551,9 +1551,9 @@ } TEST(ConstantRange, ICmp) { - for (auto Pred : seq(CmpInst::Predicate::FIRST_ICMP_PREDICATE, - 1 + CmpInst::Predicate::LAST_ICMP_PREDICATE)) - ICmpTestImpl((CmpInst::Predicate)Pred); + for (auto Pred : seq_inclusive(CmpInst::Predicate::FIRST_ICMP_PREDICATE, + CmpInst::Predicate::LAST_ICMP_PREDICATE)) + ICmpTestImpl(Pred); } TEST(ConstantRange, MakeGuaranteedNoWrapRegion) { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1157,7 +1157,7 @@ unsigned endPos = map.getResults().back().cast().getPosition(); AffineExpr expr; SmallVector dynamicDims; - for (auto dim : llvm::seq(startPos, endPos + 1)) { + for (auto dim : llvm::seq_inclusive(startPos, endPos)) { dynamicDims.push_back(builder.createOrFold(loc, src, dim)); AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos); expr = (expr ? expr * currExpr : currExpr); @@ -1190,7 +1190,7 @@ map.value().getResults().front().cast().getPosition(); unsigned endPos = map.value().getResults().back().cast().getPosition(); - for (auto dim : llvm::seq(startPos, endPos + 1)) { + for (auto dim : llvm::seq_inclusive(startPos, endPos)) { expandedDimToCollapsedDim[dim] = map.index(); } }