diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -303,6 +303,11 @@ bool empty() const { return size() == 0; } size_t size() const; + /// Sorts the NamedAttributes in the array ordered by name as expected by + /// getWithSorted. + /// Requires: uniquely named attributes. + static void sort(SmallVectorImpl &array); + /// Methods for supporting type inquiry through isa, cast, and dyn_cast. static bool kindof(unsigned kind) { return kind == StandardAttributes::Dictionary; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -310,6 +310,9 @@ attributes.append(newAttributes.begin(), newAttributes.end()); } + /// Sorts the NamedAttributes. + void sortAttributes() { DictionaryAttr::sort(attributes); } + /// Add an array of successors. void addSuccessors(ArrayRef newSuccessors) { successors.append(newSuccessors.begin(), newSuccessors.end()); diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -101,45 +101,45 @@ return strncmp(attr.first.data(), name.data(), name.size()) < 0; } -DictionaryAttr DictionaryAttr::get(ArrayRef value, - MLIRContext *context) { - assert(llvm::all_of(value, - [](const NamedAttribute &attr) { return attr.second; }) && - "value cannot have null entries"); - - // We need to sort the element list to canonicalize it, but we also don't want - // to do a ton of work in the super common case where the element list is - // already sorted. - SmallVector storage; +/// Helper function that does either an in place sort or sorts from source array +/// into destination. If inPlace then storage is both the source and the +/// destination, else value is the source and storage destination. Returns +/// whether source was sorted. +template +static bool DictionaryAttrSort(ArrayRef value, + SmallVectorImpl &storage) { + if (inPlace) + value = storage; + // Specialize for the common case. switch (value.size()) { case 0: - break; case 1: - // A single element is already sorted. + // Zero or one elements are already sorted. break; case 2: assert(value[0].first != value[1].first && "DictionaryAttr element names must be unique"); - - // Don't invoke a general sort for two element case. if (compareNamedAttributes(&value[0], &value[1]) > 0) { - storage.push_back(value[1]); - storage.push_back(value[0]); - value = storage; + if (inPlace) { + std::swap(storage[0], storage[1]); + } else { + storage.push_back(value[1]); + storage.push_back(value[0]); + } + return true; } break; default: // Check to see they are sorted already. - bool isSorted = true; - for (unsigned i = 0, e = value.size() - 1; i != e; ++i) { - if (compareNamedAttributes(&value[i], &value[i + 1]) > 0) { - isSorted = false; - break; - } - } + auto it = std::adjacent_find(value.begin(), value.end(), + [](NamedAttribute l, NamedAttribute r) { + return compareNamedAttributes(&l, &r) > 0; + }); + bool isSorted = it == value.end(); // If not, do a general sort. if (!isSorted) { - storage.append(value.begin(), value.end()); + if (!inPlace) + storage.append(value.begin(), value.end()); llvm::array_pod_sort(storage.begin(), storage.end(), compareNamedAttributes); value = storage; @@ -151,7 +151,28 @@ return l.first == r.first; }) == value.end() && "DictionaryAttr element names must be unique"); + return !isSorted; } + return false; +} + +/// Sorts the NamedAttributes in the array ordered by name as expected by +/// getWithSorted. +/// Requires: uniquely named attributes. +void DictionaryAttr::sort(SmallVectorImpl &array) { + DictionaryAttrSort({}, array); +} + +DictionaryAttr DictionaryAttr::get(ArrayRef value, + MLIRContext *context) { + assert(llvm::all_of(value, + [](const NamedAttribute &attr) { return attr.second; }) && + "value cannot have null entries"); + + // We need to sort the element list to canonicalize it. + SmallVector storage; + if (DictionaryAttrSort(value, storage)) + value = storage; return Base::get(context, StandardAttributes::Dictionary, value); }