diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -276,9 +276,18 @@ using Base::Base; using ValueType = ArrayRef; + /// Construct a dictionary attribute with the provided list of named + /// attributes. This method assumes that the provided list is unordered. If + /// the caller can guarantee that the attributes are ordered by name, + /// getWithSorted should be used instead. static DictionaryAttr get(ArrayRef value, MLIRContext *context); + /// Construct a dictionary with an array of values that is known to already be + /// sorted by name and uniqued. + static DictionaryAttr getWithSorted(ArrayRef value, + MLIRContext *context); + ArrayRef getValue() const; /// Return the specified attribute if present, null otherwise. @@ -1390,8 +1399,8 @@ // NamedAttributeList //===----------------------------------------------------------------------===// -/// A NamedAttributeList is used to manage a list of named attributes. This -/// provides simple interfaces for adding/removing/finding attributes from +/// A NamedAttributeList is a mutable wrapper around a DictionaryAttr. It +/// provides additional interfaces for adding, removing, replacing attributes /// within a DictionaryAttr. /// /// We assume there will be relatively few attributes on a given operation diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -90,6 +90,17 @@ return strcmp(lhs->first.data(), rhs->first.data()); } +/// Returns if the name of the given attribute precedes that of 'name'. +static bool compareNamedAttributeWithName(const NamedAttribute &attr, + StringRef name) { + // This is correct even when attr.first.data()[name.size()] is not a zero + // string terminator, because we only care about a less than comparison. + // This can't use memcmp, because it doesn't guarantee that it will stop + // reading both buffers if one is shorter than the other, even if there is + // a difference. + return strncmp(attr.first.data(), name.data(), name.size()) < 0; +} + DictionaryAttr DictionaryAttr::get(ArrayRef value, MLIRContext *context) { assert(llvm::all_of(value, @@ -145,6 +156,24 @@ return Base::get(context, StandardAttributes::Dictionary, value); } +/// Construct a dictionary with an array of values that is known to already be +/// sorted by name and uniqued. +DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef value, + MLIRContext *context) { + // Ensure that the attribute elements are unique and sorted. + assert(llvm::is_sorted(value, + [](NamedAttribute l, NamedAttribute r) { + return l.first.strref() < r.first.strref(); + }) && + "expected attribute values to be sorted"); + assert(std::adjacent_find(value.begin(), value.end(), + [](NamedAttribute l, NamedAttribute r) { + return l.first == r.first; + }) == value.end() && + "DictionaryAttr element names must be unique"); + return Base::get(context, StandardAttributes::Dictionary, value); +} + ArrayRef DictionaryAttr::getValue() const { return getImpl()->getElements(); } @@ -152,15 +181,7 @@ /// Return the specified attribute if present, null otherwise. Attribute DictionaryAttr::get(StringRef name) const { ArrayRef values = getValue(); - auto compare = [](NamedAttribute attr, StringRef name) -> bool { - // This is correct even when attr.first.data()[name.size()] is not a zero - // string terminator, because we only care about a less than comparison. - // This can't use memcmp, because it doesn't guarantee that it will stop - // reading both buffers if one is shorter than the other, even if there is - // a difference. - return strncmp(attr.first.data(), name.data(), name.size()) < 0; - }; - auto it = llvm::lower_bound(values, name, compare); + auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName); return it != values.end() && it->first == name ? it->second : Attribute(); } Attribute DictionaryAttr::get(Identifier name) const { @@ -1124,19 +1145,29 @@ void NamedAttributeList::set(Identifier name, Attribute value) { assert(value && "attributes may never be null"); - // If we already have this attribute, replace it. - auto origAttrs = getAttrs(); - SmallVector newAttrs(origAttrs.begin(), origAttrs.end()); - for (auto &elt : newAttrs) - if (elt.first == name) { - elt.second = value; - attrs = DictionaryAttr::get(newAttrs, value.getContext()); + // Look for an existing value for the given name, and set it in-place. + ArrayRef values = getAttrs(); + auto it = llvm::find_if( + values, [name](NamedAttribute attr) { return attr.first == name; }); + if (it != values.end()) { + // Bail out early if the value is the same as what we already have. + if (it->second == value) return; - } - // Otherwise, add it. + SmallVector newAttrs(values.begin(), values.end()); + newAttrs[it - values.begin()].second = value; + attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); + return; + } + + // Otherwise, insert the new attribute into its sorted position. + it = llvm::lower_bound(values, name, compareNamedAttributeWithName); + SmallVector newAttrs; + newAttrs.reserve(values.size() + 1); + newAttrs.append(values.begin(), it); newAttrs.push_back({name, value}); - attrs = DictionaryAttr::get(newAttrs, value.getContext()); + newAttrs.append(it, values.end()); + attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); } /// Remove the attribute with the specified name if it exists. The return @@ -1155,7 +1186,8 @@ newAttrs.reserve(origAttrs.size() - 1); newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); - attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext()); + attrs = DictionaryAttr::getWithSorted(newAttrs, + newAttrs[0].second.getContext()); return RemoveResult::Removed; } }