diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -245,17 +245,124 @@ ArrayRef attributeNames; }; +//===----------------------------------------------------------------------===// +// Attribute Dictionary-Like Interface +//===----------------------------------------------------------------------===// + +/// Attribute collections provide a dictionary-like interface. Define common +/// lookup functions. +namespace impl { + +/// Unsorted string search or identifier lookups are linear scans. +template +std::pair findAttrUnsorted(IteratorT first, IteratorT last, + NameT name) { + for (auto it = first; it != last; ++it) { + if (it->first == name) + return {it, true}; + } + return {last, false}; +} + +/// Using llvm::lower_bound requires an extra string comparison to check whether +/// the returned iterator points to the found element indicates the lower bound. +/// Skip this redundant comparison by checking if `compare == 0` during the +/// binary search. +template +std::pair findAttrSorted(IteratorT first, IteratorT last, + StringRef name) { + auto length = std::distance(first, last); + using difference_type = decltype(length); + + while (length > 0) { + difference_type half = length / 2; + IteratorT mid = first + half; + auto compare = mid->first.strref().compare(name); + if (compare < 0) { + first = mid + 1; + length = length - half - 1; + } else if (compare > 0) { + length = half; + } else { + return {mid, true}; + } + } + return {first, false}; +} + +/// Identifier lookups on large attribute lists will switch to string binary +/// search. String binary searches become significantly faster when the size of +/// the attribute list exceeds 16. +template +std::pair findAttrSorted(IteratorT first, IteratorT last, + Identifier name) { + constexpr unsigned kSmallAttributeList = 16; + if (std::distance(first, last) > kSmallAttributeList) + return findAttrSorted(first, last, name.strref()); + return findAttrUnsorted(first, last, name); +} + +/// CRTP class to implement a dictionary-like interface on a collection of +/// named attributes. Derived classes are required simply to implement `begin`, +/// `end`, and `isSorted`, which returns true if the underlying representation +/// has been sorted by the attributes' string names. +template +class AttributeDictionaryLike { +protected: + /// Common attribute lookup functions. + template + static auto findAttr(DerivedPtrT derived, NameT name) { + return derived->isSorted() + ? findAttrSorted(derived->begin(), derived->end(), name) + : findAttrUnsorted(derived->begin(), derived->end(), name); + } + template + auto findAttr(NameT name) { + return findAttr(static_cast(this), name); + } + template + auto findAttr(NameT name) const { + return findAttr(static_cast(this), name); + } + +public: + using BaseT = AttributeDictionaryLike; + + /// Lookup the value of an attribute. Returns null if it was not found. + template + Attribute get(NameT name) const { + auto it = findAttr(name); + return it.second ? it.first->second : Attribute(); + } + + /// Lookup an attribute and return it as a named attribute if one was found. + template + Optional getNamed(NameT name) const { + auto it = findAttr(name); + return it.second ? *it.first : Optional(); + } + + /// Returns whether an attribute exists. + template + bool has(NameT name) const { + return findAttr(name).second; + } +}; + +} // end namespace impl + //===----------------------------------------------------------------------===// // NamedAttrList //===----------------------------------------------------------------------===// /// NamedAttrList is array of NamedAttributes that tracks whether it is sorted /// and does some basic work to remain sorted. -class NamedAttrList { +class NamedAttrList : public impl::AttributeDictionaryLike { public: + using iterator = SmallVectorImpl::iterator; using const_iterator = SmallVectorImpl::const_iterator; - using const_reference = const NamedAttribute &; using reference = NamedAttribute &; + using const_reference = const NamedAttribute &; using size_type = size_t; NamedAttrList() : dictionarySorted({}, true) {} @@ -325,14 +432,6 @@ /// Return all of the attributes on this operation. ArrayRef getAttrs() const; - /// Return the specified attribute if present, null otherwise. - Attribute get(Identifier name) const; - Attribute get(StringRef name) const; - - /// Return the specified named attribute if present, None otherwise. - Optional getNamed(StringRef name) const; - Optional getNamed(Identifier name) const; - /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. /// Returns the previous attribute value of `name`, or null if no @@ -346,6 +445,8 @@ Attribute erase(Identifier name); Attribute erase(StringRef name); + iterator begin() { return attrs.begin(); } + iterator end() { return attrs.end(); } const_iterator begin() const { return attrs.begin(); } const_iterator end() const { return attrs.end(); } @@ -366,6 +467,9 @@ // but the case where there is a DictionaryAttr but attrs isn't sorted should // not occur. mutable llvm::PointerIntPair dictionarySorted; + + /// Give AttributeDictionaryLike access to `isSorted. + friend class impl::AttributeDictionaryLike; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -176,26 +176,22 @@ /// Return the specified attribute if present, null otherwise. Attribute DictionaryAttr::get(StringRef name) const { - Optional attr = getNamed(name); - return attr ? attr->second : nullptr; + auto it = impl::findAttrSorted(begin(), end(), name); + return it.second ? it.first->second : Attribute(); } Attribute DictionaryAttr::get(Identifier name) const { - Optional attr = getNamed(name); - return attr ? attr->second : nullptr; + auto it = impl::findAttrSorted(begin(), end(), name); + return it.second ? it.first->second : Attribute(); } /// Return the specified named attribute if present, None otherwise. Optional DictionaryAttr::getNamed(StringRef name) const { - ArrayRef values = getValue(); - const auto *it = llvm::lower_bound(values, name); - return it != values.end() && it->first == name ? *it - : Optional(); + auto it = impl::findAttrSorted(begin(), end(), name); + return it.second ? *it.first : Optional(); } Optional DictionaryAttr::getNamed(Identifier name) const { - for (auto elt : getValue()) - if (elt.first == name) - return elt; - return llvm::None; + auto it = impl::findAttrSorted(begin(), end(), name); + return it.second ? *it.first : Optional(); } DictionaryAttr::iterator DictionaryAttr::begin() const { diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -81,77 +81,37 @@ attrs.push_back(newAttribute); } -/// Helper function to find attribute in possible sorted vector of -/// NamedAttributes. -template -static auto *findAttr(SmallVectorImpl &attrs, T name, - bool sorted) { - if (!sorted) { - return llvm::find_if( - attrs, [name](NamedAttribute attr) { return attr.first == name; }); - } - - auto *it = llvm::lower_bound(attrs, name); - if (it == attrs.end() || it->first != name) - return attrs.end(); - return it; -} - -/// Return the specified attribute if present, null otherwise. -Attribute NamedAttrList::get(StringRef name) const { - auto *it = findAttr(attrs, name, isSorted()); - return it != attrs.end() ? it->second : nullptr; -} - -/// Return the specified attribute if present, null otherwise. -Attribute NamedAttrList::get(Identifier name) const { - auto *it = findAttr(attrs, name, isSorted()); - return it != attrs.end() ? it->second : nullptr; -} - -/// Return the specified named attribute if present, None otherwise. -Optional NamedAttrList::getNamed(StringRef name) const { - auto *it = findAttr(attrs, name, isSorted()); - return it != attrs.end() ? *it : Optional(); -} -Optional NamedAttrList::getNamed(Identifier name) const { - auto *it = findAttr(attrs, name, isSorted()); - return it != attrs.end() ? *it : Optional(); -} - /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. Attribute NamedAttrList::set(Identifier name, Attribute value) { assert(value && "attributes may never be null"); // Look for an existing value for the given name, and set it in-place. - auto *it = findAttr(attrs, name, isSorted()); - if (it != attrs.end()) { + auto it = findAttr(name); + if (it.second) { // Only update if the value is different from the existing. - Attribute oldValue = it->second; - if (oldValue != value) { + if (it.first->second != value) { + std::swap(it.first->second, value); dictionarySorted.setPointer(nullptr); - it->second = value; } - return oldValue; + return value; } - - // Otherwise, insert the new attribute into its sorted position. - it = llvm::lower_bound(attrs, name); + // Perform a string lookup to insert the new attribute into its sorted + // position. + if (isSorted()) + it = findAttr(name.strref()); + attrs.insert(it.first, {name, value}); dictionarySorted.setPointer(nullptr); - attrs.insert(it, {name, value}); return Attribute(); } + Attribute NamedAttrList::set(StringRef name, Attribute value) { - assert(value && "setting null attribute not supported"); + assert(value && "attributes may never be null"); return set(mlir::Identifier::get(name, value.getContext()), value); } Attribute NamedAttrList::eraseImpl(SmallVectorImpl::iterator it) { - if (it == attrs.end()) - return nullptr; - // Erasing does not affect the sorted property. Attribute attr = it->second; attrs.erase(it); @@ -160,11 +120,13 @@ } Attribute NamedAttrList::erase(Identifier name) { - return eraseImpl(findAttr(attrs, name, isSorted())); + auto it = findAttr(name); + return it.second ? eraseImpl(it.first) : Attribute(); } Attribute NamedAttrList::erase(StringRef name) { - return eraseImpl(findAttr(attrs, name, isSorted())); + auto it = findAttr(name); + return it.second ? eraseImpl(it.first) : Attribute(); } NamedAttrList &