Index: llvm/include/llvm/ADT/SmallVector.h =================================================================== --- llvm/include/llvm/ADT/SmallVector.h +++ llvm/include/llvm/ADT/SmallVector.h @@ -136,11 +136,16 @@ this->Size = this->Capacity = 0; // FIXME: Setting Capacity to 0 is suspect. } + // Return true if V is in this vector. + bool isReferenceToStorage(const void *V) const { + return V >= this->begin() && V < this->end(); + } + /// Return true unless Elt will be invalidated by resizing the vector to /// NewSize. bool isSafeToReferenceAfterResize(const void *Elt, size_t NewSize) { // Past the end. - if (LLVM_LIKELY(Elt < this->begin() || Elt >= this->end())) + if (LLVM_LIKELY(!isReferenceToStorage(Elt))) return true; // Return false if Elt will be destroyed by shrinking. @@ -283,6 +288,9 @@ std::is_trivially_destructible::value> class SmallVectorTemplateBase : public SmallVectorTemplateCommon { protected: + static constexpr bool TakesParamByValue = false; + using ValueParamT = const T &; + SmallVectorTemplateBase(size_t Size) : SmallVectorTemplateCommon(Size) {} static void destroy_range(T *S, T *E) { @@ -312,20 +320,36 @@ /// element, or MinSize more elements if specified. void grow(size_t MinSize = 0); + /// Reserve enough space to add N copies of Elt, and return the updated + /// element pointer in case it was a reference to the storage. + T *reserveForAndGetAddress(T &Elt, size_t N = 1) { + size_t NewSize = this->size() + N; + if (LLVM_LIKELY(NewSize <= this->capacity())) + return const_cast(&Elt); + + bool ReferencesStorage = false; + int64_t Index = -1; + if (LLVM_UNLIKELY(this->isReferenceToStorage(&Elt))) { + ReferencesStorage = true; + Index = &Elt - this->begin(); + } + this->grow(NewSize); + return ReferencesStorage ? this->begin() + Index : &Elt; + } + const T *reserveForAndGetAddress(const T &Elt, size_t N = 1) { + return reserveForAndGetAddress(const_cast(Elt), N); + } + public: void push_back(const T &Elt) { - this->assertSafeToAdd(&Elt); - if (LLVM_UNLIKELY(this->size() >= this->capacity())) - this->grow(); - ::new ((void*) this->end()) T(Elt); + const T *EltPtr = reserveForAndGetAddress(Elt); + ::new ((void *)this->end()) T(*EltPtr); this->set_size(this->size() + 1); } void push_back(T &&Elt) { - this->assertSafeToAdd(&Elt); - if (LLVM_UNLIKELY(this->size() >= this->capacity())) - this->grow(); - ::new ((void*) this->end()) T(::std::move(Elt)); + T *EltPtr = reserveForAndGetAddress(Elt); + ::new ((void *)this->end()) T(::std::move(*EltPtr)); this->set_size(this->size() + 1); } @@ -376,6 +400,15 @@ template class SmallVectorTemplateBase : public SmallVectorTemplateCommon { protected: + /// True if it's cheap enough to take parameters by value. Doing so avoids + /// overhead related to mitigations for reference invalidation. + static constexpr bool TakesParamByValue = sizeof(T) <= 2 * sizeof(void *); + + /// Either const T& or T, depending on whether it's cheap enough to take + /// parameters by value. + using ValueParamT = + typename std::conditional::type; + SmallVectorTemplateBase(size_t Size) : SmallVectorTemplateCommon(Size) {} // No need to do a destroy loop for POD's. @@ -416,12 +449,27 @@ /// least one more element or MinSize if specified. void grow(size_t MinSize = 0) { this->grow_pod(MinSize, sizeof(T)); } + /// Reserve enough space to add N copies of Elt, and return the updated + /// element pointer in case it was a reference to the storage. + T *reserveForAndGetAddress(const T &Elt, size_t N = 1) { + size_t NewSize = this->size() + N; + if (LLVM_LIKELY(NewSize <= this->capacity())) + return const_cast(&Elt); + + bool ReferencesStorage = false; + int64_t Index = -1; + if (LLVM_UNLIKELY(!TakesParamByValue && this->isReferenceToStorage(&Elt))) { + ReferencesStorage = true; + Index = &Elt - this->begin(); + } + this->grow(NewSize); + return ReferencesStorage ? this->begin() + Index : const_cast(&Elt); + } + public: - void push_back(const T &Elt) { - this->assertSafeToAdd(&Elt); - if (LLVM_UNLIKELY(this->size() >= this->capacity())) - this->grow(); - memcpy(reinterpret_cast(this->end()), &Elt, sizeof(T)); + void push_back(ValueParamT Elt) { + const T *EltPtr = reserveForAndGetAddress(Elt); + memcpy(reinterpret_cast(this->end()), EltPtr, sizeof(T)); this->set_size(this->size() + 1); } @@ -441,6 +489,9 @@ using size_type = typename SuperClass::size_type; protected: + using SmallVectorTemplateBase::TakesParamByValue; + using ValueParamT = typename SuperClass::ValueParamT; + // Default ctor - Initialize to empty. explicit SmallVectorImpl(unsigned N) : SmallVectorTemplateBase(N) {} @@ -473,7 +524,7 @@ } } - void resize(size_type N, const T &NV) { + void resize(size_type N, ValueParamT NV) { if (N == this->size()) return; @@ -483,10 +534,8 @@ return; } - this->assertSafeToReferenceAfterResize(&NV, N); - if (this->capacity() < N) - this->grow(N); - std::uninitialized_fill(this->end(), this->begin() + N, NV); + const T *NVPtr = this->reserveForAndGetAddress(NV, N - this->size()); + std::uninitialized_fill(this->end(), this->begin() + N, *NVPtr); this->set_size(N); } @@ -525,12 +574,9 @@ } /// Append \p NumInputs copies of \p Elt to the end. - void append(size_type NumInputs, const T &Elt) { - this->assertSafeToAdd(&Elt, NumInputs); - if (NumInputs > this->capacity() - this->size()) - this->grow(this->size()+NumInputs); - - std::uninitialized_fill_n(this->end(), NumInputs, Elt); + void append(size_type NumInputs, ValueParamT Elt) { + const T *EltPtr = this->reserveForAndGetAddress(Elt, NumInputs); + std::uninitialized_fill_n(this->end(), NumInputs, *EltPtr); this->set_size(this->size() + NumInputs); } @@ -608,14 +654,11 @@ assert(I >= this->begin() && "Insertion iterator is out of bounds."); assert(I <= this->end() && "Inserting past the end of the vector."); - // Check that adding an element won't invalidate Elt. - this->assertSafeToAdd(&Elt); - - if (this->size() >= this->capacity()) { - size_t EltNo = I-this->begin(); - this->grow(); - I = this->begin()+EltNo; - } + // Grow if necessary. + size_t Index = I - this->begin(); + std::remove_reference_t *EltPtr = + this->reserveForAndGetAddress(Elt); + I = this->begin() + Index; ::new ((void*) this->end()) T(::std::move(this->back())); // Push everything else over. @@ -623,23 +666,48 @@ this->set_size(this->size() + 1); // If we just moved the element we're inserting, be sure to update - // the reference. - std::remove_reference_t *EltPtr = &Elt; - if (I <= EltPtr && EltPtr < this->end()) + // the reference (never happens if TakesParamByValue). + if (!TakesParamByValue && I <= EltPtr && EltPtr < this->end()) ++EltPtr; *I = ::std::forward(*EltPtr); return I; } + template < + class ArgType, + std::enable_if_t< + std::is_same>, + T>::value && + !TakesParamByValue, + bool> = false> + iterator insert_one_maybe_copy(iterator I, ArgType &&Elt) { + return insert_one_impl(I, std::forward(Elt)); + } + + template < + class ArgType, + std::enable_if_t< + std::is_same>, + T>::value && + TakesParamByValue, + bool> = false> + iterator insert_one_maybe_copy(iterator I, ArgType &&Elt) { + // Copy Elt in order to mitigate reference invalidation without needing to + // update the pointer values in insert_one_impl. + return insert_one_impl(I, T(Elt)); + } + public: iterator insert(iterator I, T &&Elt) { - return insert_one_impl(I, std::move(Elt)); + return insert_one_maybe_copy(I, std::move(Elt)); } - iterator insert(iterator I, const T &Elt) { return insert_one_impl(I, Elt); } + iterator insert(iterator I, const T &Elt) { + return insert_one_maybe_copy(I, Elt); + } - iterator insert(iterator I, size_type NumToInsert, const T &Elt) { + iterator insert(iterator I, size_type NumToInsert, ValueParamT Elt) { // Convert iterator to elt# to avoid invalidating iterator when we reserve() size_t InsertElt = I - this->begin(); @@ -651,11 +719,9 @@ assert(I >= this->begin() && "Insertion iterator is out of bounds."); assert(I <= this->end() && "Inserting past the end of the vector."); - // Check that adding NumToInsert elements won't invalidate Elt. - this->assertSafeToAdd(&Elt, NumToInsert); - - // Ensure there is enough space. - reserve(this->size() + NumToInsert); + // Ensure there is enough space, and get the (maybe updated) address of + // Elt. + const T *EltPtr = this->reserveForAndGetAddress(Elt, NumToInsert); // Uninvalidate the iterator. I = this->begin()+InsertElt; @@ -672,7 +738,12 @@ // Copy the existing elements that get replaced. std::move_backward(I, OldEnd-NumToInsert, OldEnd); - std::fill_n(I, NumToInsert, Elt); + // If we just moved the element we're inserting, be sure to update + // the reference (never happens if TakesParamByValue). + if (!TakesParamByValue && I <= EltPtr && EltPtr < this->end()) + EltPtr += NumToInsert; + + std::fill_n(I, NumToInsert, *EltPtr); return I; } @@ -685,11 +756,16 @@ size_t NumOverwritten = OldEnd-I; this->uninitialized_move(I, OldEnd, this->end()-NumOverwritten); + // If we just moved the element we're inserting, be sure to update + // the reference (never happens if TakesParamByValue). + if (!TakesParamByValue && I <= EltPtr && EltPtr < this->end()) + EltPtr += NumToInsert; + // Replace the overwritten part. - std::fill_n(I, NumOverwritten, Elt); + std::fill_n(I, NumOverwritten, *EltPtr); // Insert the non-overwritten middle part. - std::uninitialized_fill_n(OldEnd, NumToInsert-NumOverwritten, Elt); + std::uninitialized_fill_n(OldEnd, NumToInsert-NumOverwritten, *EltPtr); return I; } Index: llvm/unittests/ADT/SmallVectorTest.cpp =================================================================== --- llvm/unittests/ADT/SmallVectorTest.cpp +++ llvm/unittests/ADT/SmallVectorTest.cpp @@ -53,6 +53,7 @@ Constructable(Constructable && src) : constructed(true) { value = src.value; + src.value = 0; ++numConstructorCalls; ++numMoveConstructorCalls; } @@ -74,6 +75,7 @@ Constructable & operator=(Constructable && src) { EXPECT_TRUE(constructed); value = src.value; + src.value = 0; ++numAssignmentCalls; ++numMoveAssignmentCalls; return *this; @@ -1031,11 +1033,16 @@ return N; } + template static bool isValueType() { + return std::is_same::value; + } + void SetUp() override { SmallVectorTestBase::SetUp(); // Fill up the small size so that insertions move the elements. - V.append({0, 0, 0}); + for (int I = 0, E = NumBuiltinElts(V); I != E; ++I) + V.emplace_back(I + 1); } }; @@ -1050,38 +1057,48 @@ TYPED_TEST(SmallVectorReferenceInvalidationTest, PushBack) { auto &V = this->V; - (void)V; -#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST - EXPECT_DEATH(V.push_back(V.back()), this->AssertionMessage); -#endif + V.push_back(V.back()); + int N = this->NumBuiltinElts(V); + EXPECT_EQ(N, V.back()); + if (this->template isValueType()) + EXPECT_EQ(N, V[N - 1]); } TYPED_TEST(SmallVectorReferenceInvalidationTest, PushBackMoved) { auto &V = this->V; - (void)V; -#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST - EXPECT_DEATH(V.push_back(std::move(V.back())), this->AssertionMessage); -#endif + V.push_back(std::move(V.back())); + int N = this->NumBuiltinElts(V); + EXPECT_EQ(N, V.back()); + if (this->template isValueType()) + EXPECT_EQ(0, V[N - 1]); } TYPED_TEST(SmallVectorReferenceInvalidationTest, Resize) { auto &V = this->V; (void)V; int N = this->NumBuiltinElts(V); -#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST - EXPECT_DEATH(V.resize(N + 1, V.back()), this->AssertionMessage); -#endif - - // No assertion when shrinking, since the parameter isn't accessed. - V.resize(N - 1, V.back()); + V.resize(N + 1, V.back()); + EXPECT_EQ(N, V.back()); + + // Append enough more elements that V will grow again. This time the old + // storage will have been freed at time of access and sanitizers can catch + // the use-after-free. + V.resize(V.capacity() + 1, V.front()); + EXPECT_EQ(1, V.back()); } TYPED_TEST(SmallVectorReferenceInvalidationTest, Append) { auto &V = this->V; (void)V; -#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST - EXPECT_DEATH(V.append(1, V.back()), this->AssertionMessage); -#endif + V.append(1, V.back()); + int N = this->NumBuiltinElts(V); + EXPECT_EQ(N, V[N - 1]); + + // Append enough more elements that V will grow again. This time the old + // storage will have been freed at time of access and sanitizers can catch + // the use-after-free. + V.append(V.capacity(), V.front()); + EXPECT_EQ(1, V.back()); } TYPED_TEST(SmallVectorReferenceInvalidationTest, AppendRange) { @@ -1127,26 +1144,54 @@ TYPED_TEST(SmallVectorReferenceInvalidationTest, Insert) { auto &V = this->V; (void)V; -#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST - EXPECT_DEATH(V.insert(V.begin(), V.back()), this->AssertionMessage); -#endif + V.insert(V.begin(), V.back()); + int N = this->NumBuiltinElts(V); + EXPECT_EQ(N, V.front()); + if (this->template isValueType()) { + // Check that the value was copied out (not moved). + EXPECT_EQ(N, V.back()); + } + + // The logic for reference invalidation when NOT growing is disabled when + // TakesParamByValue since it relies on taking a copy. Check that we actually + // took a copy. + V.insert(V.begin(), V.back()); + EXPECT_EQ(N, V.front()); } TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertMoved) { auto &V = this->V; (void)V; -#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST - EXPECT_DEATH(V.insert(V.begin(), std::move(V.back())), - this->AssertionMessage); -#endif + int N = this->NumBuiltinElts(V); + V.insert(V.begin(), std::move(V.back())); + EXPECT_EQ(N, V.front()); + if (this->template isValueType()) { + // Check that the value got moved out, then copy it back in. + EXPECT_EQ(0, V.back()); + V.back() = V.front(); + } + + // The logic for reference invalidation when NOT growing is disabled when + // TakesParamByValue since it relies on taking a copy. Check that we actually + // took a copy. + V.insert(V.begin(), std::move(V.back())); + EXPECT_EQ(N, V.front()); } TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertN) { auto &V = this->V; (void)V; -#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST - EXPECT_DEATH(V.insert(V.begin(), 2, V.back()), this->AssertionMessage); -#endif + + // Cover NumToInsert <= this->end() - I. + V.insert(V.begin() + 1, 1, V.back()); + int N = this->NumBuiltinElts(V); + EXPECT_EQ(N, V[1]); + + // Cover NumToInsert > this->end() - I. Also insert enough more elements that + // V will grow again. This time the old storage will have been freed at time + // of access and sanitizers can catch the use-after-free. + V.insert(V.begin(), V.capacity(), V.front()); + EXPECT_EQ(1, V.front()); } TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertRange) {