diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -387,6 +387,10 @@ Optional getNamed(StringRef name) const; Optional getNamed(Identifier name) const; + /// Return whether the specified attribute is present. + bool contains(StringRef name) const; + bool contains(Identifier name) const; + /// Support range iteration. using iterator = llvm::ArrayRef::iterator; iterator begin() const; diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -329,8 +329,8 @@ /// Return true if the operation has an attribute with the provided name, /// false otherwise. - bool hasAttr(Identifier name) { return static_cast(getAttr(name)); } - bool hasAttr(StringRef name) { return static_cast(getAttr(name)); } + bool hasAttr(Identifier name) { return attrs.contains(name); } + bool hasAttr(StringRef name) { return attrs.contains(name); } template bool hasAttrOfType(NameT &&name) { return static_cast( diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -245,6 +245,65 @@ ArrayRef attributeNames; }; +//===----------------------------------------------------------------------===// +// Attribute Dictionary-Like Interface +//===----------------------------------------------------------------------===// + +/// Attribute collections provide a dictionary-like interface. Define common +/// lookup functions. +namespace impl { + +/// Unsorted string search or identifier lookups are linear scans. +template +std::pair findAttrUnsorted(IteratorT first, IteratorT last, + NameT name) { + for (auto it = first; it != last; ++it) { + if (it->first == name) + return {it, true}; + } + return {last, false}; +} + +/// Using llvm::lower_bound requires an extra string comparison to check whether +/// the returned iterator points to the found element indicates the lower bound. +/// Skip this redundant comparison by checking if `compare == 0` during the +/// binary search. +template +std::pair findAttrSorted(IteratorT first, IteratorT last, + StringRef name) { + auto length = std::distance(first, last); + using difference_type = decltype(length); + + while (length > 0) { + difference_type half = length / 2; + IteratorT mid = first + half; + auto compare = mid->first.strref().compare(name); + if (compare < 0) { + first = mid + 1; + length = length - half - 1; + } else if (compare > 0) { + length = half; + } else { + return {mid, true}; + } + } + return {first, false}; +} + +/// Identifier lookups on large attribute lists will switch to string binary +/// search. String binary searches become significantly faster when the size of +/// the attribute list exceeds 16. +template +std::pair findAttrSorted(IteratorT first, IteratorT last, + Identifier name) { + constexpr unsigned kSmallAttributeList = 16; + if (std::distance(first, last) > kSmallAttributeList) + return findAttrSorted(first, last, name.strref()); + return findAttrUnsorted(first, last, name); +} + +} // end namespace impl + //===----------------------------------------------------------------------===// // NamedAttrList //===----------------------------------------------------------------------===// @@ -253,9 +312,10 @@ /// and does some basic work to remain sorted. class NamedAttrList { public: + using iterator = SmallVectorImpl::iterator; using const_iterator = SmallVectorImpl::const_iterator; - using const_reference = const NamedAttribute &; using reference = NamedAttribute &; + using const_reference = const NamedAttribute &; using size_type = size_t; NamedAttrList() : dictionarySorted({}, true) {} @@ -346,6 +406,8 @@ Attribute erase(Identifier name); Attribute erase(StringRef name); + iterator begin() { return attrs.begin(); } + iterator end() { return attrs.end(); } const_iterator begin() const { return attrs.begin(); } const_iterator end() const { return attrs.end(); } @@ -359,6 +421,14 @@ /// Erase the attribute at the given iterator position. Attribute eraseImpl(SmallVectorImpl::iterator it); + /// Lookup an attribute in the list. + template + static auto findAttr(AttrListT &attrs, NameT name) { + return attrs.isSorted() + ? impl::findAttrSorted(attrs.begin(), attrs.end(), name) + : impl::findAttrUnsorted(attrs.begin(), attrs.end(), name); + } + // These are marked mutable as they may be modified (e.g., sorted) mutable SmallVector attrs; // Pair with cached DictionaryAttr and status of whether attrs is sorted. diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -176,26 +176,30 @@ /// Return the specified attribute if present, null otherwise. Attribute DictionaryAttr::get(StringRef name) const { - Optional attr = getNamed(name); - return attr ? attr->second : nullptr; + auto it = impl::findAttrSorted(begin(), end(), name); + return it.second ? it.first->second : Attribute(); } Attribute DictionaryAttr::get(Identifier name) const { - Optional attr = getNamed(name); - return attr ? attr->second : nullptr; + auto it = impl::findAttrSorted(begin(), end(), name); + return it.second ? it.first->second : Attribute(); } /// Return the specified named attribute if present, None otherwise. Optional DictionaryAttr::getNamed(StringRef name) const { - ArrayRef values = getValue(); - const auto *it = llvm::lower_bound(values, name); - return it != values.end() && it->first == name ? *it - : Optional(); + auto it = impl::findAttrSorted(begin(), end(), name); + return it.second ? *it.first : Optional(); } Optional DictionaryAttr::getNamed(Identifier name) const { - for (auto elt : getValue()) - if (elt.first == name) - return elt; - return llvm::None; + auto it = impl::findAttrSorted(begin(), end(), name); + return it.second ? *it.first : Optional(); +} + +/// Return whether the specified attribute is present. +bool DictionaryAttr::contains(StringRef name) const { + return impl::findAttrSorted(begin(), end(), name).second; +} +bool DictionaryAttr::contains(Identifier name) const { + return impl::findAttrSorted(begin(), end(), name).second; } DictionaryAttr::iterator DictionaryAttr::begin() const { diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -81,42 +81,24 @@ attrs.push_back(newAttribute); } -/// Helper function to find attribute in possible sorted vector of -/// NamedAttributes. -template -static auto *findAttr(SmallVectorImpl &attrs, T name, - bool sorted) { - if (!sorted) { - return llvm::find_if( - attrs, [name](NamedAttribute attr) { return attr.first == name; }); - } - - auto *it = llvm::lower_bound(attrs, name); - if (it == attrs.end() || it->first != name) - return attrs.end(); - return it; -} - /// Return the specified attribute if present, null otherwise. Attribute NamedAttrList::get(StringRef name) const { - auto *it = findAttr(attrs, name, isSorted()); - return it != attrs.end() ? it->second : nullptr; + auto it = findAttr(*this, name); + return it.second ? it.first->second : Attribute(); } - -/// Return the specified attribute if present, null otherwise. Attribute NamedAttrList::get(Identifier name) const { - auto *it = findAttr(attrs, name, isSorted()); - return it != attrs.end() ? it->second : nullptr; + auto it = findAttr(*this, name); + return it.second ? it.first->second : Attribute(); } /// Return the specified named attribute if present, None otherwise. Optional NamedAttrList::getNamed(StringRef name) const { - auto *it = findAttr(attrs, name, isSorted()); - return it != attrs.end() ? *it : Optional(); + auto it = findAttr(*this, name); + return it.second ? *it.first : Optional(); } Optional NamedAttrList::getNamed(Identifier name) const { - auto *it = findAttr(attrs, name, isSorted()); - return it != attrs.end() ? *it : Optional(); + auto it = findAttr(*this, name); + return it.second ? *it.first : Optional(); } /// If the an attribute exists with the specified name, change it to the new @@ -125,33 +107,31 @@ assert(value && "attributes may never be null"); // Look for an existing value for the given name, and set it in-place. - auto *it = findAttr(attrs, name, isSorted()); - if (it != attrs.end()) { + auto it = findAttr(*this, name); + if (it.second) { // Only update if the value is different from the existing. - Attribute oldValue = it->second; - if (oldValue != value) { + if (it.first->second != value) { + std::swap(it.first->second, value); dictionarySorted.setPointer(nullptr); - it->second = value; } - return oldValue; + return value; } - - // Otherwise, insert the new attribute into its sorted position. - it = llvm::lower_bound(attrs, name); + // Perform a string lookup to insert the new attribute into its sorted + // position. + if (isSorted()) + it = findAttr(*this, name.strref()); + attrs.insert(it.first, {name, value}); dictionarySorted.setPointer(nullptr); - attrs.insert(it, {name, value}); return Attribute(); } + Attribute NamedAttrList::set(StringRef name, Attribute value) { - assert(value && "setting null attribute not supported"); + assert(value && "attributes may never be null"); return set(mlir::Identifier::get(name, value.getContext()), value); } Attribute NamedAttrList::eraseImpl(SmallVectorImpl::iterator it) { - if (it == attrs.end()) - return nullptr; - // Erasing does not affect the sorted property. Attribute attr = it->second; attrs.erase(it); @@ -160,11 +140,13 @@ } Attribute NamedAttrList::erase(Identifier name) { - return eraseImpl(findAttr(attrs, name, isSorted())); + auto it = findAttr(*this, name); + return it.second ? eraseImpl(it.first) : Attribute(); } Attribute NamedAttrList::erase(StringRef name) { - return eraseImpl(findAttr(attrs, name, isSorted())); + auto it = findAttr(*this, name); + return it.second ? eraseImpl(it.first) : Attribute(); } NamedAttrList &