diff --git a/mlir/include/mlir/IR/Identifier.h b/mlir/include/mlir/IR/Identifier.h --- a/mlir/include/mlir/IR/Identifier.h +++ b/mlir/include/mlir/IR/Identifier.h @@ -11,7 +11,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMapInfo.h" -#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringMapEntry.h" #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { @@ -25,6 +25,8 @@ /// value. The underlying data is owned by MLIRContext and is thus immortal for /// almost all clients. class Identifier { + using EntryType = llvm::StringMapEntry; + public: /// Return an identifier for the specified string. static Identifier get(StringRef str, MLIRContext *context); @@ -32,7 +34,7 @@ Identifier &operator=(const Identifier &other) = default; /// Return a StringRef for the string. - StringRef strref() const { return StringRef(pointer, size()); } + StringRef strref() const { return entry->first(); } /// Identifiers implicitly convert to StringRefs. operator StringRef() const { return strref(); } @@ -41,39 +43,38 @@ std::string str() const { return strref().str(); } /// Return a null terminated C string. - const char *c_str() const { return pointer; } + const char *c_str() const { return entry->getKeyData(); } /// Return a pointer to the start of the string data. - const char *data() const { return pointer; } + const char *data() const { return entry->getKeyData(); } /// Return the number of bytes in this string. - unsigned size() const { return ::strlen(pointer); } + unsigned size() const { return entry->getKeyLength(); } /// Return true if this identifier is the specified string. - bool is(StringRef string) const { - // Note: this can't use memcmp, because memcmp doesn't guarantee that it - // will stop reading both buffers if one is shorter than the other. - return strncmp(pointer, string.data(), string.size()) == 0 && - pointer[string.size()] == '\0'; - } + bool is(StringRef string) const { return strref() == string; } - const char *begin() const { return pointer; } - const char *end() const { return pointer + size(); } + const char *begin() const { return data(); } + const char *end() const { return entry->getKeyData() + size(); } + + bool operator==(Identifier other) const { return entry == other.entry; } + bool operator!=(Identifier rhs) const { return !(*this == rhs); } void print(raw_ostream &os) const; void dump() const; const void *getAsOpaquePointer() const { - return static_cast(pointer); + return static_cast(entry); } - static Identifier getFromOpaquePointer(const void *pointer) { - return Identifier((const char *)pointer); + static Identifier getFromOpaquePointer(const void *entry) { + return Identifier(static_cast(entry)); } private: - /// These are the bytes of the string, which is a nul terminated string. - const char *pointer; - explicit Identifier(const char *pointer) : pointer(pointer) {} + /// This contains the bytes of the string, which is guaranteed to be nul + /// terminated. + const EntryType *entry; + explicit Identifier(const EntryType *entry) : entry(entry) {} }; inline raw_ostream &operator<<(raw_ostream &os, Identifier identifier) { @@ -81,14 +82,7 @@ return os; } -inline bool operator==(Identifier lhs, Identifier rhs) { - return lhs.data() == rhs.data(); -} - -inline bool operator!=(Identifier lhs, Identifier rhs) { - return lhs.data() != rhs.data(); -} - +// Identifier/Identifier equality comparisons are defined inline. inline bool operator==(Identifier lhs, StringRef rhs) { return lhs.is(rhs); } inline bool operator!=(Identifier lhs, StringRef rhs) { return !lhs.is(rhs); } inline bool operator==(StringRef lhs, Identifier rhs) { return rhs.is(lhs); } @@ -97,7 +91,7 @@ // Make identifiers hashable. inline llvm::hash_code hash_value(Identifier arg) { // Identifiers are uniqued, so we can just hash the pointer they contain. - return llvm::hash_value(static_cast(arg.data())); + return llvm::hash_value(arg.getAsOpaquePointer()); } } // end namespace mlir @@ -114,7 +108,7 @@ return mlir::Identifier::getFromOpaquePointer(pointer); } static unsigned getHashValue(mlir::Identifier val) { - return DenseMapInfo::getHashValue(val.data()); + return mlir::hash_value(val); } static bool isEqual(mlir::Identifier lhs, mlir::Identifier rhs) { return lhs == rhs; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -533,7 +533,7 @@ llvm::sys::SmartScopedReader contextLock(impl.identifierMutex); auto it = impl.identifiers.find(str); if (it != impl.identifiers.end()) - return Identifier(it->getKeyData()); + return Identifier(&*it); } // Check invariants after seeing if we already have something in the @@ -546,7 +546,7 @@ // Acquire a writer-lock so that we can safely create the new instance. llvm::sys::SmartScopedWriter contextLock(impl.identifierMutex); auto it = impl.identifiers.insert(str).first; - return Identifier(it->getKeyData()); + return Identifier(&*it); } //===----------------------------------------------------------------------===//