diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -292,6 +292,12 @@ /// Requires: uniquely named attributes. static bool sortInPlace(SmallVectorImpl &array); + /// Returns an entry with a duplicate name the given array, if it exists, else + /// returns llvm::None. If `isSorted` is true, the array is assumed to be + /// sorted else it will be sorted in place before finding the duplicate entry. + static Optional + findDuplicate(SmallVectorImpl &array, bool isSorted); + private: /// Return empty dictionary. static DictionaryAttr getEmpty(MLIRContext *context); diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -237,6 +237,10 @@ /// Pop last element from list. void pop_back() { attrs.pop_back(); } + /// Returns an entry with a duplicate name the given array, if it exists, else + /// returns llvm::None. + Optional findDuplicate() const; + /// Return a dictionary attribute for the underlying dictionary. This will /// return an empty dictionary attribute if empty rather than null. DictionaryAttr getDictionary(MLIRContext *context) const; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -99,8 +99,6 @@ storage.assign({value[0]}); break; case 2: { - assert(value[0].first != value[1].first && - "DictionaryAttr element names must be unique"); bool isSorted = value[0] < value[1]; if (inPlace) { if (!isSorted) @@ -122,25 +120,52 @@ llvm::array_pod_sort(storage.begin(), storage.end()); value = storage; } - - // Ensure that the attribute elements are unique. - assert(std::adjacent_find(value.begin(), value.end(), - [](NamedAttribute l, NamedAttribute r) { - return l.first == r.first; - }) == value.end() && - "DictionaryAttr element names must be unique"); return !isSorted; } return false; } +/// Returns an entry with a duplicate name from a sorted array of named +/// attributes. Returns llvm::None if all elements have unique names. +Optional findDuplicateElement(ArrayRef value) { + if (value.size() < 2) + return llvm::None; + + if (value.size() == 2) { + if (value[0].first == value[1].first) + return value[0]; + return llvm::None; + } + + auto it = std::adjacent_find( + value.begin(), value.end(), + [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); + if (it != value.end()) + return *it; + return llvm::None; +} + bool DictionaryAttr::sort(ArrayRef value, SmallVectorImpl &storage) { - return dictionaryAttrSort(value, storage); + bool isSorted = dictionaryAttrSort(value, storage); + assert(!findDuplicateElement(storage) && + "DictionaryAttr element names must be unique"); + return isSorted; } bool DictionaryAttr::sortInPlace(SmallVectorImpl &array) { - return dictionaryAttrSort(array, array); + bool isSorted = dictionaryAttrSort(array, array); + assert(!findDuplicateElement(array) && + "DictionaryAttr element names must be unique"); + return isSorted; +} + +Optional +DictionaryAttr::findDuplicate(SmallVectorImpl &array, + bool isSorted) { + if (!isSorted) + dictionaryAttrSort(array, array); + return findDuplicateElement(array); } DictionaryAttr DictionaryAttr::get(ArrayRef value, @@ -155,7 +180,8 @@ SmallVector storage; if (dictionaryAttrSort(value, storage)) value = storage; - + assert(!findDuplicateElement(value) && + "DictionaryAttr element names must be unique"); return Base::get(context, value); } /// Construct a dictionary with an array of values that is known to already be @@ -170,10 +196,7 @@ return l.first.strref() < r.first.strref(); }) && "expected attribute values to be sorted"); - assert(std::adjacent_find(value.begin(), value.end(), - [](NamedAttribute l, NamedAttribute r) { - return l.first == r.first; - }) == value.end() && + assert(!findDuplicateElement(value) && "DictionaryAttr element names must be unique"); return Base::get(context, value); } diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -32,6 +32,16 @@ ArrayRef NamedAttrList::getAttrs() const { return attrs; } +Optional NamedAttrList::findDuplicate() const { + Optional duplicate = + DictionaryAttr::findDuplicate(attrs, isSorted()); + // DictionaryAttr::findDuplicate will sort the list, so reset the sorted + // state. + if (!isSorted()) + dictionarySorted.setPointerAndInt(nullptr, true); + return duplicate; +} + DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const { if (!isSorted()) { DictionaryAttr::sortInPlace(attrs); diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -249,7 +249,8 @@ else return emitError("expected attribute name"); if (!seenKeys.insert(*nameId).second) - return emitError("duplicate key in dictionary attribute"); + return emitError("duplicate key '") + << *nameId << "' in dictionary attribute"; consumeToken(); // Lazy load a dialect in the context if there is a possible namespace. diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -845,6 +845,15 @@ ParseResult parseOperation(OperationState &opState) { if (opDefinition->parseAssembly(*this, opState)) return failure(); + // Verify that the parsed attributes does not have duplicate attributes. + // This can happen if an attribute set during parsing is also specified in + // the attribute dictionary in the assembly, or the attribute is set + // multiple during parsing. + Optional duplicate = opState.attributes.findDuplicate(); + if (duplicate) + return emitError(getNameLoc(), "attribute '") + << duplicate->first + << "' occurs more than once in the attribute list"; return success(); } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -1513,12 +1513,17 @@ // ----- func @duplicate_dictionary_attr_key() { - // expected-error @+1 {{duplicate key in dictionary attribute}} + // expected-error @+1 {{duplicate key 'a' in dictionary attribute}} "foo.op"() {a, a} : () -> () } // ----- +// expected-error @+1 {{attribute 'attr' occurs more than once in the attribute list}} +test.format_symbol_name_attr_op @name { attr = "xx" } + +// ----- + func @forward_reference_type_check() -> (i8) { br ^bb2