diff --git a/mlir/include/mlir/IR/Identifier.h b/mlir/include/mlir/IR/Identifier.h --- a/mlir/include/mlir/IR/Identifier.h +++ b/mlir/include/mlir/IR/Identifier.h @@ -50,7 +50,12 @@ unsigned size() const { return ::strlen(pointer); } /// Return true if this identifier is the specified string. - bool is(StringRef string) const { return strref().equals(string); } + bool is(StringRef string) const { + // Note: this can't use memcmp, because memcmp doesn't guarantee that it + // will stop reading both buffers if one is shorter than the other. + return strncmp(pointer, string.data(), string.size()) == 0 && + pointer[string.size()] == '\0'; + } const char *begin() const { return pointer; } const char *end() const { return pointer + size(); } diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -87,7 +87,7 @@ /// NamedAttributes. static int compareNamedAttributes(const NamedAttribute *lhs, const NamedAttribute *rhs) { - return lhs->first.strref().compare(rhs->first.strref()); + return strcmp(lhs->first.data(), rhs->first.data()); } DictionaryAttr DictionaryAttr::get(ArrayRef value, @@ -111,7 +111,7 @@ "DictionaryAttr element names must be unique"); // Don't invoke a general sort for two element case. - if (value[0].first.strref() > value[1].first.strref()) { + if (compareNamedAttributes(&value[0], &value[1]) > 0) { storage.push_back(value[1]); storage.push_back(value[0]); value = storage; @@ -121,7 +121,7 @@ // Check to see they are sorted already. bool isSorted = true; for (unsigned i = 0, e = value.size() - 1; i != e; ++i) { - if (value[i].first.strref() > value[i + 1].first.strref()) { + if (compareNamedAttributes(&value[i], &value[i + 1]) > 0) { isSorted = false; break; } @@ -152,8 +152,13 @@ /// Return the specified attribute if present, null otherwise. Attribute DictionaryAttr::get(StringRef name) const { ArrayRef values = getValue(); - auto compare = [](NamedAttribute attr, StringRef name) { - return attr.first.strref() < name; + auto compare = [](NamedAttribute attr, StringRef name) -> bool { + // This is correct even when attr.first.data()[name.size()] is not a zero + // string terminator, because we only care about a less than comparison. + // This can't use memcmp, because it doesn't guarantee that it will stop + // reading both buffers if one is shorter than the other, even if there is + // a difference. + return strncmp(attr.first.data(), name.data(), name.size()) < 0; }; auto it = llvm::lower_bound(values, name, compare); return it != values.end() && it->first.is(name) ? it->second : Attribute(); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -493,10 +493,6 @@ /// Return an identifier for the specified string. Identifier Identifier::get(StringRef str, MLIRContext *context) { - assert(!str.empty() && "Cannot create an empty identifier"); - assert(str.find('\0') == StringRef::npos && - "Cannot create an identifier with a nul character"); - auto &impl = context->getImpl(); { // Check for an existing identifier in read-only mode. @@ -506,6 +502,13 @@ return Identifier(it->getKeyData()); } + // Check invariants after seeing if we already have something in the + // identifier table - if we already had it in the table, then it already + // passed invariant checks. + assert(!str.empty() && "Cannot create an empty identifier"); + assert(str.find('\0') == StringRef::npos && + "Cannot create an identifier with a nul character"); + // Acquire a writer-lock so that we can safely create the new instance. llvm::sys::SmartScopedWriter contextLock(impl.identifierMutex); auto it = impl.identifiers.insert({str, char()}).first;