diff --git a/mlir/include/mlir/IR/Identifier.h b/mlir/include/mlir/IR/Identifier.h --- a/mlir/include/mlir/IR/Identifier.h +++ b/mlir/include/mlir/IR/Identifier.h @@ -11,10 +11,12 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/StringMapEntry.h" #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { +class Dialect; class MLIRContext; /// This class represents a uniqued string owned by an MLIRContext. Strings @@ -22,10 +24,17 @@ /// zero length. /// /// This is a POD type with pointer size, so it should be passed around by -/// value. The underlying data is owned by MLIRContext and is thus immortal for +/// value. The underlying data is owned by MLIRContext and is thus immortal for /// almost all clients. +/// +/// An Identifier may be prefixed with a dialect namespace followed by a single +/// dot `.`. This is particularly useful when used as a key in a NamedAttribute +/// to differentiate a dependent attribute (specific to an operation) from a +/// generic attribute defined by the dialect (in general applicable to multiple +/// operations). class Identifier { - using EntryType = llvm::StringMapEntry; + using EntryType = + llvm::StringMapEntry>; public: /// Return an identifier for the specified string. @@ -51,6 +60,15 @@ /// Return the number of bytes in this string. unsigned size() const { return entry->getKeyLength(); } + /// Return the dialect loaded in the context for this identifier or nullptr if + /// this identifier isn't prefixed with a loaded dialect. For example the + /// `llvm.fastmathflags` identifier would return the LLVM dialect here, + /// assuming it is loaded in the context. + Dialect *getDialect(); + + /// Return the current MLIRContext associated with this identifier. + MLIRContext *getContext(); + const char *begin() const { return data(); } const char *end() const { return entry->getKeyData() + size(); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -264,9 +264,12 @@ /// Identifiers are uniqued by string value and use the internal string set /// for storage. - llvm::StringSet identifiers; + llvm::StringMap, + llvm::BumpPtrAllocator &> + identifiers; /// A thread local cache of identifiers to reduce lock contention. - ThreadLocalCache *>> + ThreadLocalCache> *>> localIdentifierCache; /// An allocator used for AbstractAttribute and AbstractType objects. @@ -481,6 +484,14 @@ #endif dialect = ctor(); assert(dialect && "dialect ctor failed"); + + // Refresh all the identifiers dialect field, this catches cases where a + // dialect may be loaded after identifier prefixed with this dialect name + // were already created. + for (auto &identifierEntry : impl.identifiers) + if (identifierEntry.first().startswith(dialectNamespace)) + identifierEntry.second = dialect.get(); + return dialect.get(); } @@ -707,9 +718,15 @@ assert(str.find('\0') == StringRef::npos && "Cannot create an identifier with a nul character"); + PointerUnion dialectOrContext = context; + auto dialectNamePair = str.split('.'); + if (!dialectNamePair.first.empty()) + if (Dialect *dialect = context->getLoadedDialect(dialectNamePair.first)) + dialectOrContext = dialect; + auto &impl = context->getImpl(); if (!context->isMultithreadingEnabled()) - return Identifier(&*impl.identifiers.insert(str).first); + return Identifier(&*impl.identifiers.insert({str, dialectOrContext}).first); // Check for an existing instance in the local cache. auto *&localEntry = (*impl.localIdentifierCache)[str]; @@ -728,11 +745,21 @@ // Acquire a writer-lock so that we can safely create the new instance. llvm::sys::SmartScopedWriter contextLock(impl.identifierMutex); - auto it = impl.identifiers.insert(str).first; + auto it = impl.identifiers.insert({str, dialectOrContext}).first; localEntry = &*it; return Identifier(localEntry); } +Dialect *Identifier::getDialect() { + return entry->second.dyn_cast(); +} + +MLIRContext *Identifier::getContext() { + if (Dialect *dialect = getDialect()) + return dialect->getContext(); + return entry->second.get(); +} + //===----------------------------------------------------------------------===// // Type uniquing //===----------------------------------------------------------------------===//