diff --git a/llvm/include/llvm/ADT/SetVector.h b/llvm/include/llvm/ADT/SetVector.h --- a/llvm/include/llvm/ADT/SetVector.h +++ b/llvm/include/llvm/ADT/SetVector.h @@ -29,7 +29,6 @@ #include namespace llvm { - /// A vector that has set insertion semantics. /// /// This adapter class provides a way to keep a set of things that also has the @@ -47,7 +46,7 @@ /// value_type to float and key_type to int can produce very surprising results, /// but it is not explicitly disallowed. template , - typename Set = DenseSet> + typename Set = DenseSet, unsigned N = 0> class SetVector { public: using value_type = typename Vector::value_type; @@ -150,6 +149,17 @@ /// Insert a new element into the SetVector. /// \returns true if the element was inserted into the SetVector. bool insert(const value_type &X) { + if constexpr (canBeSmall()) + if (isSmall()) { + if (llvm::find(vector_, X) == vector_.end()) { + vector_.push_back(X); + if (vector_.size() > N) + makeBig(); + return true; + } + return false; + } + bool result = set_.insert(X).second; if (result) vector_.push_back(X); @@ -160,12 +170,21 @@ template void insert(It Start, It End) { for (; Start != End; ++Start) - if (set_.insert(*Start).second) - vector_.push_back(*Start); + insert(*Start); } /// Remove an item from the set vector. bool remove(const value_type& X) { + if constexpr (canBeSmall()) + if (isSmall()) { + typename vector_type::iterator I = find(vector_, X); + if (I != vector_.end()) { + vector_.erase(I); + return true; + } + return false; + } + if (set_.erase(X)) { typename vector_type::iterator I = find(vector_, X); assert(I != vector_.end() && "Corrupted SetVector instances!"); @@ -180,6 +199,10 @@ /// element erased. This is the end of the SetVector if the last element is /// erased. iterator erase(const_iterator I) { + if constexpr (canBeSmall()) + if (isSmall()) + return vector_.erase(I); + const key_type &V = *I; assert(set_.count(V) && "Corrupted SetVector instances!"); set_.erase(V); @@ -201,8 +224,15 @@ /// \returns true if any element is removed. template bool remove_if(UnaryPredicate P) { - typename vector_type::iterator I = - llvm::remove_if(vector_, TestAndEraseFromSet(P, set_)); + typename vector_type::iterator I = [this, P] { + if constexpr (canBeSmall()) + if (isSmall()) + return llvm::remove_if(vector_, P); + + return llvm::remove_if(vector_, + TestAndEraseFromSet(P, set_)); + }(); + if (I == vector_.end()) return false; vector_.erase(I, vector_.end()); @@ -211,12 +241,20 @@ /// Check if the SetVector contains the given key. bool contains(const key_type &key) const { + if constexpr (canBeSmall()) + if (isSmall()) + return is_contained(vector_, key); + return set_.find(key) != set_.end(); } /// Count the number of elements of a given key in the SetVector. /// \returns 0 if the element is not in the SetVector, 1 if it is. size_type count(const key_type &key) const { + if constexpr (canBeSmall()) + if (isSmall()) + return is_contained(vector_, key); + return set_.count(key); } @@ -272,7 +310,7 @@ remove(*SI); } - void swap(SetVector &RHS) { + void swap(SetVector &RHS) { set_.swap(RHS.set_); vector_.swap(RHS.vector_); } @@ -301,6 +339,16 @@ } }; + [[nodiscard]] static constexpr bool canBeSmall() { return N != 0; } + + [[nodiscard]] bool isSmall() const { return set_.empty(); } + + void makeBig() { + if constexpr (canBeSmall()) + for (const auto &entry : vector_) + set_.insert(entry); + } + set_type set_; ///< The set. vector_type vector_; ///< The vector. }; @@ -308,8 +356,7 @@ /// A SetVector that performs no allocations if smaller than /// a certain size. template -class SmallSetVector - : public SetVector, SmallDenseSet> { +class SmallSetVector : public SetVector, DenseSet, N> { public: SmallSetVector() = default; @@ -325,9 +372,9 @@ namespace std { /// Implement std::swap in terms of SetVector swap. -template -inline void -swap(llvm::SetVector &LHS, llvm::SetVector &RHS) { +template +inline void swap(llvm::SetVector &LHS, + llvm::SetVector &RHS) { LHS.swap(RHS); } diff --git a/mlir/include/mlir/Support/LLVM.h b/mlir/include/mlir/Support/LLVM.h --- a/mlir/include/mlir/Support/LLVM.h +++ b/mlir/include/mlir/Support/LLVM.h @@ -59,7 +59,7 @@ class MutableArrayRef; template class PointerUnion; -template +template class SetVector; template class SmallPtrSet; @@ -123,8 +123,8 @@ template > using DenseSet = llvm::DenseSet; template , - typename Set = DenseSet> -using SetVector = llvm::SetVector; + typename Set = DenseSet, unsigned N = 0> +using SetVector = llvm::SetVector; template using StringSet = llvm::StringSet; using llvm::MutableArrayRef;