diff --git a/mlir/include/mlir/IR/Identifier.h b/mlir/include/mlir/IR/Identifier.h --- a/mlir/include/mlir/IR/Identifier.h +++ b/mlir/include/mlir/IR/Identifier.h @@ -11,10 +11,12 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/StringMapEntry.h" #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { +class Dialect; class MLIRContext; /// This class represents a uniqued string owned by an MLIRContext. Strings @@ -25,7 +27,8 @@ /// value. The underlying data is owned by MLIRContext and is thus immortal for /// almost all clients. class Identifier { - using EntryType = llvm::StringMapEntry; + using EntryType = + llvm::StringMapEntry>; public: /// Return an identifier for the specified string. @@ -51,6 +54,14 @@ /// Return the number of bytes in this string. unsigned size() const { return entry->getKeyLength(); } + /// Return the dialect registered/loaded in the context for this + /// identifier or nullptr if this identifier isn't prefixed with a loaded + /// dialect. + Dialect *getDialect(); + + /// Return the current MLIRContext associated with this identifier. + MLIRContext *getContext(); + const char *begin() const { return data(); } const char *end() const { return entry->getKeyData() + size(); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -27,6 +27,7 @@ #include "mlir/Support/ThreadLocalCache.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/Twine.h" @@ -264,9 +265,12 @@ /// Identifiers are uniqued by string value and use the internal string set /// for storage. - llvm::StringSet identifiers; + llvm::StringMap, + llvm::BumpPtrAllocator &> + identifiers; /// A thread local cache of identifiers to reduce lock contention. - ThreadLocalCache *>> + ThreadLocalCache> *>> localIdentifierCache; /// An allocator used for AbstractAttribute and AbstractType objects. @@ -481,6 +485,14 @@ #endif dialect = ctor(); assert(dialect && "dialect ctor failed"); + + // Refresh all the identifiers dialect field, this catches cases where a + // dialect may be loaded after identifier prefixed with this dialect name + // were already created. + for (auto &identifierEntry : impl.identifiers) + if (identifierEntry.first().startswith(dialectNamespace)) + identifierEntry.second = dialect.get(); + return dialect.get(); } @@ -697,9 +709,15 @@ assert(str.find('\0') == StringRef::npos && "Cannot create an identifier with a nul character"); + PointerUnion dialectOrContext = context; + auto dialectNamePair = str.split('.'); + if (!dialectNamePair.first.empty()) + if (Dialect *dialect = context->getLoadedDialect(dialectNamePair.first)) + dialectOrContext = dialect; + auto &impl = context->getImpl(); if (!context->isMultithreadingEnabled()) - return Identifier(&*impl.identifiers.insert(str).first); + return Identifier(&*impl.identifiers.insert({str, dialectOrContext}).first); // Check for an existing instance in the local cache. auto *&localEntry = (*impl.localIdentifierCache)[str]; @@ -718,11 +736,21 @@ // Acquire a writer-lock so that we can safely create the new instance. llvm::sys::SmartScopedWriter contextLock(impl.identifierMutex); - auto it = impl.identifiers.insert(str).first; + auto it = impl.identifiers.insert({str, dialectOrContext}).first; localEntry = &*it; return Identifier(localEntry); } +Dialect *Identifier::getDialect() { + return entry->second.dyn_cast(); +} + +MLIRContext *Identifier::getContext() { + if (Dialect *dialect = getDialect()) + return dialect->getContext(); + return entry->second.get(); +} + //===----------------------------------------------------------------------===// // Type uniquing //===----------------------------------------------------------------------===//