diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -30,7 +30,7 @@ DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable) DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute) -DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier) +DEFINE_C_API_METHODS(MlirIdentifier, mlir::StringAttr) DEFINE_C_API_METHODS(MlirLocation, mlir::Location) DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp) DEFINE_C_API_METHODS(MlirType, mlir::Type) diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -15,6 +15,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/StorageUniquerSupport.h" +#include "mlir/IR/Types.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/Twine.h" @@ -118,7 +119,7 @@ public: /// Get the type of this attribute. - Type getType() const; + Type getType() const { return type; } /// Return the abstract descriptor for this attribute. const AbstractAttribute &getAbstractAttribute() const { @@ -131,24 +132,27 @@ /// Note: All attributes require a valid type. If no type is provided here, /// the type of the attribute will automatically default to NoneType /// upon initialization in the uniquer. - AttributeStorage(Type type); - AttributeStorage(); + AttributeStorage(Type type = nullptr) : type(type) {} /// Set the type of this attribute. - void setType(Type type); + void setType(Type newType) { type = newType; } - // Set the abstract attribute for this storage instance. This is used by the - // AttributeUniquer when initializing a newly constructed storage object. - void initialize(const AbstractAttribute &abstractAttr) { + /// Set the abstract attribute for this storage instance. This is used by the + /// AttributeUniquer when initializing a newly constructed storage object. + void initializeAbstractAttribute(const AbstractAttribute &abstractAttr) { abstractAttribute = &abstractAttr; } + /// Default initialization for attribute storage classes that require no + /// additional initialization. + void initialize(MLIRContext *context) {} + private: + /// The type of the attribute value. + Type type; + /// The abstract descriptor for this attribute. const AbstractAttribute *abstractAttribute; - - /// The opaque type of the attribute value. - const void *type; }; /// Default storage type for attributes that require no additional @@ -188,6 +192,10 @@ return ctx->getAttributeUniquer().get( [ctx](AttributeStorage *storage) { initializeAttributeStorage(storage, ctx, T::getTypeID()); + + // Execute any additional attribute storage initialization with the + // context. + static_cast(storage)->initialize(ctx); }, T::getTypeID(), std::forward(args)...); } diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -13,7 +13,10 @@ #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { -class Identifier; +class StringAttr; + +// TODO: Remove this when all usages have been replaced with StringAttr. +using Identifier = StringAttr; /// Attributes are known-constant values of operations. /// @@ -61,7 +64,7 @@ TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); } /// Return the type of this attribute. - Type getType() const; + Type getType() const { return impl->getType(); } /// Return the context this attribute belongs to. MLIRContext *getContext() const; @@ -126,7 +129,7 @@ } inline ::llvm::hash_code hash_value(Attribute arg) { - return ::llvm::hash_value(arg.impl); + return DenseMapInfo::getHashValue(arg.impl); } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -885,7 +885,35 @@ }; return iterator(llvm::seq(0, getNumElements()).begin(), mapFn); } -} // end namespace mlir. + +//===----------------------------------------------------------------------===// +// StringAttr +//===----------------------------------------------------------------------===// + +/// Define comparisons for StringAttr against nullptr and itself to avoid the +/// StringRef overloads from being chosen when not desirable. +inline bool operator==(StringAttr lhs, std::nullptr_t) { return !lhs; } +inline bool operator!=(StringAttr lhs, std::nullptr_t) { + return static_cast(lhs); +} +inline bool operator==(StringAttr lhs, StringAttr rhs) { + return (Attribute)lhs == (Attribute)rhs; +} +inline bool operator!=(StringAttr lhs, StringAttr rhs) { return !(lhs == rhs); } + +/// Allow direct comparison with StringRef. +inline bool operator==(StringAttr lhs, StringRef rhs) { + return lhs.getValue() == rhs; +} +inline bool operator!=(StringAttr lhs, StringRef rhs) { return !(lhs == rhs); } +inline bool operator==(StringRef lhs, StringAttr rhs) { + return rhs.getValue() == lhs; +} +inline bool operator!=(StringRef lhs, StringAttr rhs) { return !(lhs == rhs); } + +inline Type StringAttr::getType() const { return Attribute::getType(); } + +} // end namespace mlir //===----------------------------------------------------------------------===// // Attribute Utilities @@ -893,12 +921,30 @@ namespace llvm { +template <> +struct DenseMapInfo : public DenseMapInfo { + static mlir::StringAttr getEmptyKey() { + const void *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::StringAttr::getFromOpaquePointer(pointer); + } + static mlir::StringAttr getTombstoneKey() { + const void *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::StringAttr::getFromOpaquePointer(pointer); + } +}; +template <> +struct PointerLikeTypeTraits + : public PointerLikeTypeTraits { + static inline mlir::StringAttr getFromVoidPointer(void *p) { + return mlir::StringAttr::getFromOpaquePointer(p); + } +}; + template <> struct PointerLikeTypeTraits : public PointerLikeTypeTraits { static inline mlir::SymbolRefAttr getFromVoidPointer(void *ptr) { - return PointerLikeTypeTraits::getFromVoidPointer(ptr) - .cast(); + return mlir::SymbolRefAttr::getFromOpaquePointer(ptr); } }; diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -915,6 +915,44 @@ let extraClassDeclaration = [{ using ValueType = StringRef; + /// If the value of this string is prefixed with a dialect namespace, + /// returns the dialect corresponding to that namespace if it is loaded, + /// nullptr otherwise. For example, the string `llvm.fastmathflags` would + /// return the LLVM dialect, assuming it is loaded in the context. + Dialect *getReferencedDialect() const; + + /// Enable conversion to StringRef. + operator StringRef() const { return getValue(); } + + /// Returns the underlying string value + StringRef strref() const { return getValue(); } + + /// Convert the underling value to an std::string. + std::string str() const { return getValue().str(); } + + /// Return a pointer to the start of the string data. + const char *data() const { return getValue().data(); } + + /// Return the number of bytes in this string. + size_t size() const { return getValue().size(); } + + /// Iterate over the underlying string data. + StringRef::iterator begin() const { return getValue().begin(); } + StringRef::iterator end() const { return getValue().end(); } + + /// Compare the underlying string value to the one in `rhs`. + int compare(StringAttr rhs) const { + if (*this == rhs) + return 0; + return getValue().compare(rhs.getValue()); + } + + /// FIXME: Defined as part of transition of Identifier->StringAttr. Prefer + /// using the other `get` methods instead. + static StringAttr get(const Twine &str, MLIRContext *context) { + return get(context, str); + } + private: /// Return an empty StringAttr with NoneType type. This is a special variant /// of the `get` method that is used by the MLIRContext to cache the @@ -923,6 +961,7 @@ friend MLIRContext; public: }]; + let genStorageClass = 0; let skipDefaultBuilders = 1; } diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -20,11 +20,14 @@ class AffineExpr; class AffineMap; class FloatType; -class Identifier; class IndexType; class IntegerType; +class StringAttr; class TypeRange; +// TODO: Remove this when all usages have been replaced with StringAttr. +using Identifier = StringAttr; + //===----------------------------------------------------------------------===// // FloatType //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -24,7 +24,6 @@ namespace mlir { class DiagnosticEngine; -class Identifier; struct LogicalResult; class MLIRContext; class Operation; @@ -196,6 +195,7 @@ arguments.push_back(DiagnosticArgument(std::forward(val))); return *this; } + Diagnostic &operator<<(StringAttr val); /// Stream in a string literal. Diagnostic &operator<<(const char *val) { @@ -208,9 +208,6 @@ Diagnostic &operator<<(const Twine &val); Diagnostic &operator<<(Twine &&val); - /// Stream in an Identifier. - Diagnostic &operator<<(Identifier val); - /// Stream in an OperationName. Diagnostic &operator<<(OperationName val); diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -612,7 +612,7 @@ if (!attr.first.strref().contains('.')) return funcOp.emitOpError( "arguments may only have dialect attributes"); - if (Dialect *dialect = attr.first.getDialect()) { + if (Dialect *dialect = attr.first.getReferencedDialect()) { if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, /*argIndex=*/i, attr))) return failure(); @@ -645,7 +645,7 @@ for (auto attr : resultAttrs) { if (!attr.first.strref().contains('.')) return funcOp.emitOpError("results may only have dialect attributes"); - if (Dialect *dialect = attr.first.getDialect()) { + if (Dialect *dialect = attr.first.getReferencedDialect()) { if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, /*resultIndex=*/i, attr))) diff --git a/mlir/include/mlir/IR/Identifier.h b/mlir/include/mlir/IR/Identifier.h --- a/mlir/include/mlir/IR/Identifier.h +++ b/mlir/include/mlir/IR/Identifier.h @@ -9,151 +9,12 @@ #ifndef MLIR_IR_IDENTIFIER_H #define MLIR_IR_IDENTIFIER_H -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/DenseMapInfo.h" -#include "llvm/ADT/PointerUnion.h" -#include "llvm/ADT/StringMapEntry.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Support/PointerLikeTypeTraits.h" +#include "mlir/IR/BuiltinAttributes.h" namespace mlir { -class Dialect; -class MLIRContext; - -/// This class represents a uniqued string owned by an MLIRContext. Strings -/// represented by this type cannot contain nul characters, and may not have a -/// zero length. -/// -/// This is a POD type with pointer size, so it should be passed around by -/// value. The underlying data is owned by MLIRContext and is thus immortal for -/// almost all clients. -/// -/// An Identifier may be prefixed with a dialect namespace followed by a single -/// dot `.`. This is particularly useful when used as a key in a NamedAttribute -/// to differentiate a dependent attribute (specific to an operation) from a -/// generic attribute defined by the dialect (in general applicable to multiple -/// operations). -class Identifier { - using EntryType = - llvm::StringMapEntry>; - -public: - /// Return an identifier for the specified string. - static Identifier get(const Twine &string, MLIRContext *context); - - Identifier(const Identifier &) = default; - Identifier &operator=(const Identifier &other) = default; - - /// Return a StringRef for the string. - StringRef strref() const { return entry->first(); } - - /// Identifiers implicitly convert to StringRefs. - operator StringRef() const { return strref(); } - - /// Return an std::string. - std::string str() const { return strref().str(); } - - /// Return a null terminated C string. - const char *c_str() const { return entry->getKeyData(); } - - /// Return a pointer to the start of the string data. - const char *data() const { return entry->getKeyData(); } - - /// Return the number of bytes in this string. - unsigned size() const { return entry->getKeyLength(); } - - /// Return the dialect loaded in the context for this identifier or nullptr if - /// this identifier isn't prefixed with a loaded dialect. For example the - /// `llvm.fastmathflags` identifier would return the LLVM dialect here, - /// assuming it is loaded in the context. - Dialect *getDialect(); - - /// Return the current MLIRContext associated with this identifier. - MLIRContext *getContext(); - - const char *begin() const { return data(); } - const char *end() const { return entry->getKeyData() + size(); } - - bool operator==(Identifier other) const { return entry == other.entry; } - bool operator!=(Identifier rhs) const { return !(*this == rhs); } - - void print(raw_ostream &os) const; - void dump() const; - - const void *getAsOpaquePointer() const { - return static_cast(entry); - } - static Identifier getFromOpaquePointer(const void *entry) { - return Identifier(static_cast(entry)); - } - - /// Compare the underlying StringRef. - int compare(Identifier rhs) const { return strref().compare(rhs.strref()); } - -private: - /// This contains the bytes of the string, which is guaranteed to be nul - /// terminated. - const EntryType *entry; - explicit Identifier(const EntryType *entry) : entry(entry) {} -}; - -inline raw_ostream &operator<<(raw_ostream &os, Identifier identifier) { - identifier.print(os); - return os; -} - -// Identifier/Identifier equality comparisons are defined inline. -inline bool operator==(Identifier lhs, StringRef rhs) { - return lhs.strref() == rhs; -} -inline bool operator!=(Identifier lhs, StringRef rhs) { return !(lhs == rhs); } - -inline bool operator==(StringRef lhs, Identifier rhs) { - return rhs.strref() == lhs; -} -inline bool operator!=(StringRef lhs, Identifier rhs) { return !(lhs == rhs); } - -// Make identifiers hashable. -inline llvm::hash_code hash_value(Identifier arg) { - // Identifiers are uniqued, so we can just hash the pointer they contain. - return llvm::hash_value(arg.getAsOpaquePointer()); -} +/// NOTICE: Identifier is deprecated and usages of it should be replaced with +/// StringAttr. +using Identifier = StringAttr; } // end namespace mlir -namespace llvm { -// Identifiers hash just like pointers, there is no need to hash the bytes. -template <> -struct DenseMapInfo { - static mlir::Identifier getEmptyKey() { - auto pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::Identifier::getFromOpaquePointer(pointer); - } - static mlir::Identifier getTombstoneKey() { - auto pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::Identifier::getFromOpaquePointer(pointer); - } - static unsigned getHashValue(mlir::Identifier val) { - return mlir::hash_value(val); - } - static bool isEqual(mlir::Identifier lhs, mlir::Identifier rhs) { - return lhs == rhs; - } -}; - -/// The pointer inside of an identifier comes from a StringMap, so its alignment -/// is always at least 4 and probably 8 (on 64-bit machines). Allow LLVM to -/// steal the low bits. -template <> -struct PointerLikeTypeTraits { -public: - static inline void *getAsVoidPointer(mlir::Identifier i) { - return const_cast(i.getAsOpaquePointer()); - } - static inline mlir::Identifier getFromVoidPointer(void *p) { - return mlir::Identifier::getFromOpaquePointer(p); - } - static constexpr int NumLowBitsAvailable = 2; -}; - -} // end namespace llvm #endif diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -19,7 +19,6 @@ namespace mlir { -class Identifier; class Location; class WalkResult; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -456,7 +456,7 @@ Dialect *getDialect() const { if (const auto *abstractOp = getAbstractOperation()) return &abstractOp->dialect; - return representation.get().getDialect(); + return representation.get().getReferencedDialect(); } /// Return the operation name with dialect name stripped, if it has one. diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -164,8 +164,7 @@ /// Get an instance of the concrete type from a void pointer. static ConcreteT getFromOpaquePointer(const void *ptr) { - return ptr ? BaseT::getFromOpaquePointer(ptr).template cast() - : nullptr; + return ConcreteT((const typename BaseT::ImplType *)ptr); } protected: diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -15,8 +15,6 @@ #include "llvm/ADT/StringMap.h" namespace mlir { -class Identifier; -class Operation; /// This class allows for representing and managing the symbol table used by /// operations with the 'SymbolTable' trait. Inserting into and erasing from diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -27,12 +27,15 @@ namespace mlir { class AnalysisManager; -class Identifier; class MLIRContext; class Operation; class Pass; class PassInstrumentation; class PassInstrumentor; +class StringAttr; + +// TODO: Remove this when all usages have been replaced with StringAttr. +using Identifier = StringAttr; namespace detail { struct OpPassManagerImpl; diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -105,8 +105,13 @@ /// Copy the provided string into memory managed by our bump pointer /// allocator. StringRef copyInto(StringRef str) { - auto result = copyInto(ArrayRef(str.data(), str.size())); - return StringRef(result.data(), str.size()); + if (str.empty()) + return StringRef(); + + char *result = allocator.Allocate(str.size() + 1); + std::uninitialized_copy(str.begin(), str.end(), result); + result[str.size()] = 0; + return StringRef(result, str.size()); } /// Allocate an instance of the provided type. diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h --- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h +++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h @@ -82,7 +82,7 @@ amendOperation(Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const { if (const LLVMTranslationDialectInterface *iface = - getInterfaceFor(attribute.first.getDialect())) { + getInterfaceFor(attribute.first.getReferencedDialect())) { return iface->amendOperation(op, attribute, moduleTranslation); } return success(); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1845,7 +1845,8 @@ mlirOperationGetAttribute(operation->get(), index); return PyNamedAttribute( namedAttr.attribute, - std::string(mlirIdentifierStr(namedAttr.name).data)); + std::string(mlirIdentifierStr(namedAttr.name).data, + mlirIdentifierStr(namedAttr.name).length)); } void dunderSetItem(const std::string &name, PyAttribute attr) { @@ -2601,7 +2602,8 @@ PyPrintAccumulator printAccum; printAccum.parts.append("NamedAttribute("); printAccum.parts.append( - mlirIdentifierStr(self.namedAttr.name).data); + py::str(mlirIdentifierStr(self.namedAttr.name).data, + mlirIdentifierStr(self.namedAttr.name).length)); printAccum.parts.append("="); mlirAttributePrint(self.namedAttr.attribute, printAccum.getCallback(), diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -186,11 +186,11 @@ } MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { - return wrap(StringAttr::get(unwrap(ctx), unwrap(str))); + return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str))); } MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { - return wrap(StringAttr::get(unwrap(str), unwrap(type))); + return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type))); } MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -805,7 +805,7 @@ MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation) { - return wrap(unwrap(symbolTable)->insert(unwrap(operation))); + return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation))); } void mlirSymbolTableErase(MlirSymbolTable symbolTable, diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp --- a/mlir/lib/Dialect/DLTI/DLTI.cpp +++ b/mlir/lib/Dialect/DLTI/DLTI.cpp @@ -154,7 +154,7 @@ } else { auto id = entry.getKey().get(); if (!ids.insert(id).second) - return emitError() << "repeated layout entry key: " << id; + return emitError() << "repeated layout entry key: " << id.getValue(); } } return success(); @@ -221,7 +221,7 @@ for (const auto &kvp : newEntriesForID) { Identifier id = kvp.second.getKey().get(); - Dialect *dialect = id.getDialect(); + Dialect *dialect = id.getReferencedDialect(); if (!entriesForID.count(id)) { entriesForID[id] = kvp.second; continue; @@ -377,6 +377,6 @@ return success(); } - return op->emitError() << "attribute '" << attr.first + return op->emitError() << "attribute '" << attr.first.getValue() << "' not supported by dialect"; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -753,7 +753,7 @@ // Copy over unknown attributes. They might be load bearing for some flow. ArrayRef odsAttrs = genericOp.getAttributeNames(); for (NamedAttribute kv : genericOp->getAttrs()) { - if (!llvm::is_contained(odsAttrs, kv.first.c_str())) { + if (!llvm::is_contained(odsAttrs, kv.first.getValue())) { newOp->setAttr(kv.first, kv.second); } } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -46,10 +46,6 @@ using namespace mlir; using namespace mlir::detail; -void Identifier::print(raw_ostream &os) const { os << str(); } - -void Identifier::dump() const { print(llvm::errs()); } - void OperationName::print(raw_ostream &os) const { os << getStringRef(); } void OperationName::dump() const { print(llvm::errs()); } @@ -1339,7 +1335,7 @@ }) .Case([&](FileLineColLoc loc) { if (pretty) { - os << loc.getFilename(); + os << loc.getFilename().getValue(); } else { os << "\""; printEscapedString(loc.getFilename(), os); @@ -1693,7 +1689,7 @@ if (printerFlags.shouldElideElementsAttr(opaqueAttr)) { printElidedElementsAttr(os); } else { - os << "opaque<\"" << opaqueAttr.getDialect() << "\", \"0x" + os << "opaque<" << opaqueAttr.getDialect() << ", \"0x" << llvm::toHex(opaqueAttr.getValue()) << "\">"; } diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -319,6 +319,41 @@ ArrayRef data; }; +//===----------------------------------------------------------------------===// +// StringAttr +//===----------------------------------------------------------------------===// + +struct StringAttrStorage : public AttributeStorage { + StringAttrStorage(StringRef value, Type type) + : AttributeStorage(type), value(value), referencedDialect(nullptr) {} + + /// The hash key is a tuple of the parameter types. + using KeyTy = std::pair; + bool operator==(const KeyTy &key) const { + return value == key.first && getType() == key.second; + } + static ::llvm::hash_code hashKey(const KeyTy &key) { + return DenseMapInfo::getHashValue(key); + } + + /// Define a construction method for creating a new instance of this + /// storage. + static StringAttrStorage *construct(AttributeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + StringAttrStorage(allocator.copyInto(key.first), key.second); + } + + /// Initialize the storage given an MLIRContext. + void initialize(MLIRContext *context); + + /// The raw string value. + StringRef value; + /// If the string value contains a dialect namespace prefix (e.g. + /// dialect.blah), this is the dialect referenced. + Dialect *referencedDialect; +}; + } // namespace detail } // namespace mlir diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -12,28 +12,10 @@ using namespace mlir; using namespace mlir::detail; -//===----------------------------------------------------------------------===// -// AttributeStorage -//===----------------------------------------------------------------------===// - -AttributeStorage::AttributeStorage(Type type) - : type(type.getAsOpaquePointer()) {} -AttributeStorage::AttributeStorage() : type(nullptr) {} - -Type AttributeStorage::getType() const { - return Type::getFromOpaquePointer(type); -} -void AttributeStorage::setType(Type newType) { - type = newType.getAsOpaquePointer(); -} - //===----------------------------------------------------------------------===// // Attribute //===----------------------------------------------------------------------===// -/// Return the type of this attribute. -Type Attribute::getType() const { return impl->getType(); } - /// Return the context this attribute belongs to. MLIRContext *Attribute::getContext() const { return getDialect().getContext(); } @@ -42,13 +24,8 @@ //===----------------------------------------------------------------------===// bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) { - return strcmp(lhs.first.data(), rhs.first.data()) < 0; + return lhs.first.compare(rhs.first) < 0; } bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) { - // This is correct even when attr.first.data()[name.size()] is not a zero - // string terminator, because we only care about a less than comparison. - // This can't use memcmp, because it doesn't guarantee that it will stop - // reading both buffers if one is shorter than the other, even if there is - // a difference. - return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0; + return lhs.first.getValue().compare(rhs) < 0; } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -264,6 +264,12 @@ return Base::get(type.getContext(), twine.toStringRef(tempStr), type); } +StringRef StringAttr::getValue() const { return getImpl()->value; } + +Dialect *StringAttr::getReferencedDialect() const { + return getImpl()->referencedDialect; +} + //===----------------------------------------------------------------------===// // FloatAttr //===----------------------------------------------------------------------===// @@ -1250,7 +1256,7 @@ //===----------------------------------------------------------------------===// bool OpaqueElementsAttr::decode(ElementsAttr &result) { - Dialect *dialect = getDialect().getDialect(); + Dialect *dialect = getContext()->getLoadedDialect(getDialect()); if (!dialect) return true; auto *interface = diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -253,7 +253,7 @@ attr.first.strref())) return op.emitOpError() << "can only contain attributes with " "dialect-prefixed names, found: '" - << attr.first << "'"; + << attr.first.getValue() << "'"; } // Check that there is at most one data layout spec attribute. @@ -266,7 +266,8 @@ op.emitOpError() << "expects at most one data layout attribute"; diag.attachNote() << "'" << layoutSpecAttrName << "' is a data layout attribute"; - diag.attachNote() << "'" << na.first << "' is a data layout attribute"; + diag.attachNote() << "'" << na.first.getValue() + << "' is a data layout attribute"; } layoutSpecAttrName = na.first.strref(); layoutSpec = spec; diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -8,7 +8,6 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Attributes.h" -#include "mlir/IR/Identifier.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" @@ -109,11 +108,8 @@ return *this; } -/// Stream in an Identifier. -Diagnostic &Diagnostic::operator<<(Identifier val) { - // An identifier is stored in the context, so we don't need to worry about the - // lifetime of its data. - arguments.push_back(DiagnosticArgument(val.strref())); +Diagnostic &Diagnostic::operator<<(StringAttr val) { + arguments.push_back(DiagnosticArgument(val)); return *this; } @@ -469,7 +465,7 @@ // the constructor of SMDiagnostic that takes a location. std::string locStr; llvm::raw_string_ostream locOS(locStr); - locOS << fileLoc->getFilename() << ":" << fileLoc->getLine() << ":" + locOS << fileLoc->getFilename().getValue() << ":" << fileLoc->getLine() << ":" << fileLoc->getColumn(); llvm::SMDiagnostic diag(locOS.str(), getDiagKind(kind), message.str()); diag.print(nullptr, os); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -18,7 +18,6 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/Identifier.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" @@ -33,6 +32,7 @@ #include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/Mutex.h" #include "llvm/Support/RWMutex.h" #include "llvm/Support/ThreadPool.h" #include "llvm/Support/raw_ostream.h" @@ -227,14 +227,6 @@ /// An action manager for use within the context. DebugActionManager debugActionManager; - //===--------------------------------------------------------------------===// - // Identifier uniquing - //===--------------------------------------------------------------------===// - - // Identifier allocator and mutex for thread safety. - llvm::BumpPtrAllocator identifierAllocator; - llvm::sys::SmartRWMutex identifierMutex; - //===--------------------------------------------------------------------===// // Diagnostics //===--------------------------------------------------------------------===// @@ -289,12 +281,6 @@ /// operations. llvm::StringMap registeredOperations; - /// Identifiers are uniqued by string value and use the internal string set - /// for storage. - llvm::StringMap, - llvm::BumpPtrAllocator &> - identifiers; - /// An allocator used for AbstractAttribute and AbstractType objects. llvm::BumpPtrAllocator abstractDialectSymbolAllocator; @@ -349,10 +335,15 @@ DictionaryAttr emptyDictionaryAttr; StringAttr emptyStringAttr; + /// Map of string attributes that may reference a dialect, that are awaiting + /// that dialect to be loaded. + llvm::sys::SmartMutex dialectRefStrAttrMutex; + DenseMap> + dialectReferencingStrAttrs; + public: MLIRContextImpl(bool threadingIsEnabled) - : threadingIsEnabled(threadingIsEnabled), - identifiers(identifierAllocator) { + : threadingIsEnabled(threadingIsEnabled) { if (threadingIsEnabled) { ownedThreadPool = std::make_unique(); threadPool = ownedThreadPool.get(); @@ -541,12 +532,12 @@ // Refresh all the identifiers dialect field, this catches cases where a // dialect may be loaded after identifier prefixed with this dialect name // were already created. - llvm::SmallString<32> dialectPrefix(dialectNamespace); - dialectPrefix.push_back('.'); - for (auto &identifierEntry : impl.identifiers) - if (identifierEntry.second.is() && - identifierEntry.first().startswith(dialectPrefix)) - identifierEntry.second = dialect.get(); + auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace); + if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) { + for (StringAttrStorage *storage : stringAttrsIt->second) + storage->referencedDialect = dialect.get(); + impl.dialectReferencingStrAttrs.erase(stringAttrsIt); + } // Actually register the interfaces with delayed registration. impl.dialectsRegistry.registerDelayedInterfaces(dialect.get()); @@ -784,7 +775,8 @@ MutableArrayRef cachedAttrNames; if (!attrNames.empty()) { cachedAttrNames = MutableArrayRef( - impl.identifierAllocator.Allocate(attrNames.size()), + impl.abstractDialectSymbolAllocator.Allocate( + attrNames.size()), attrNames.size()); for (unsigned i : llvm::seq(0, attrNames.size())) new (&cachedAttrNames[i]) Identifier(Identifier::get(attrNames[i], ctx)); @@ -840,63 +832,6 @@ return it->second; } -//===----------------------------------------------------------------------===// -// Identifier uniquing -//===----------------------------------------------------------------------===// - -/// Return an identifier for the specified string. -Identifier Identifier::get(const Twine &string, MLIRContext *context) { - SmallString<32> tempStr; - StringRef str = string.toStringRef(tempStr); - - // Check invariants after seeing if we already have something in the - // identifier table - if we already had it in the table, then it already - // passed invariant checks. - assert(!str.empty() && "Cannot create an empty identifier"); - assert(!str.contains('\0') && - "Cannot create an identifier with a nul character"); - - auto getDialectOrContext = [&]() { - PointerUnion dialectOrContext = context; - auto dialectNamePair = str.split('.'); - if (!dialectNamePair.first.empty()) - if (Dialect *dialect = context->getLoadedDialect(dialectNamePair.first)) - dialectOrContext = dialect; - return dialectOrContext; - }; - - auto &impl = context->getImpl(); - if (!context->isMultithreadingEnabled()) { - auto insertedIt = impl.identifiers.insert({str, nullptr}); - if (insertedIt.second) - insertedIt.first->second = getDialectOrContext(); - return Identifier(&*insertedIt.first); - } - - // Check for an existing identifier in read-only mode. - { - llvm::sys::SmartScopedReader contextLock(impl.identifierMutex); - auto it = impl.identifiers.find(str); - if (it != impl.identifiers.end()) - return Identifier(&*it); - } - - // Acquire a writer-lock so that we can safely create the new instance. - llvm::sys::SmartScopedWriter contextLock(impl.identifierMutex); - auto it = impl.identifiers.insert({str, getDialectOrContext()}).first; - return Identifier(&*it); -} - -Dialect *Identifier::getDialect() { - return entry->second.dyn_cast(); -} - -MLIRContext *Identifier::getContext() { - if (Dialect *dialect = getDialect()) - return dialect->getContext(); - return entry->second.get(); -} - //===----------------------------------------------------------------------===// // Type uniquing //===----------------------------------------------------------------------===// @@ -995,7 +930,7 @@ void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, MLIRContext *ctx, TypeID attrID) { - storage->initialize(AbstractAttribute::lookup(attrID, ctx)); + storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx)); // If the attribute did not provide a type, then default to NoneType. if (!storage->getType()) @@ -1019,6 +954,24 @@ return context->getImpl().emptyDictionaryAttr; } +void StringAttrStorage::initialize(MLIRContext *context) { + // Check for a dialect namespace prefix, if there isn't one we don't need to + // do any additional initialization. + auto dialectNamePair = value.split('.'); + if (dialectNamePair.first.empty() || dialectNamePair.second.empty()) + return; + + // If one exists, we check to see if this dialect is loaded. If it is, we set + // the dialect now, if it isn't we record this storage for initialization + // later if the dialect ever gets loaded. + if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first))) + return; + + MLIRContextImpl &impl = context->getImpl(); + llvm::sys::SmartScopedLock lock(impl.dialectRefStrAttrMutex); + impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this); +} + /// Return an empty string. StringAttr StringAttr::get(MLIRContext *context) { return context->getImpl().emptyStringAttr; diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -73,10 +73,10 @@ void NamedAttrList::push_back(NamedAttribute newAttribute) { assert(newAttribute.second && "unexpected null attribute"); - if (isSorted()) - dictionarySorted.setInt( - attrs.empty() || - strcmp(attrs.back().first.data(), newAttribute.first.data()) < 0); + if (isSorted()) { + dictionarySorted.setInt(attrs.empty() || + attrs.back().first.compare(newAttribute.first) < 0); + } dictionarySorted.setPointer(nullptr); attrs.push_back(newAttribute); } diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -170,7 +170,7 @@ /// Verify that all of the attributes are okay. for (auto attr : op.getAttrs()) { // Check for any optional dialect specific attributes. - if (auto *dialect = attr.first.getDialect()) + if (auto *dialect = attr.first.getReferencedDialect()) if (failed(dialect->verifyOperationAttribute(&op, attr))) return failure(); } diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -431,7 +431,7 @@ for (const auto &kvp : ids) { Identifier identifier = kvp.second.getKey().get(); - Dialect *dialect = identifier.getDialect(); + Dialect *dialect = identifier.getReferencedDialect(); // Ignore attributes that belong to an unknown dialect, the dialect may // actually implement the relevant interface but we don't know about that. diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -273,7 +273,7 @@ return emitError("expected attribute name"); if (!seenKeys.insert(*nameId).second) return emitError("duplicate key '") - << *nameId << "' in dictionary attribute"; + << nameId->getValue() << "' in dictionary attribute"; consumeToken(); // Lazy load a dialect in the context if there is a possible namespace. diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1127,7 +1127,7 @@ Optional duplicate = opState.attributes.findDuplicate(); if (duplicate) return emitError(getNameLoc(), "attribute '") - << duplicate->first + << duplicate->first.getValue() << "' occurs more than once in the attribute list"; return success(); } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -822,7 +822,7 @@ auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { if (llvm::is_contained(exclude, attr.first.strref())) return success(); - os << "/* " << attr.first << " */"; + os << "/* " << attr.first.getValue() << " */"; if (failed(emitAttribute(op.getLoc(), attr.second))) return failure(); return success(); diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -221,7 +221,7 @@ if (printAttrs) { os << "\n"; for (const NamedAttribute &attr : op->getAttrs()) { - os << '\n' << attr.first << ": "; + os << '\n' << attr.first.getValue() << ": "; emitMlirAttr(os, attr.second); } } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -494,7 +494,7 @@ LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::Loops) - .setDistributionOptions(cyclicNprocsEqNiters), + .setDistributionOptions(cyclicNprocsEqNiters), LinalgTransformationFilter( Identifier::get("tensors_distribute1", context), Identifier::get("tensors_after_distribute1", context))); @@ -508,8 +508,7 @@ MLIRContext *ctx = funcOp.getContext(); SmallVector stage1Patterns; if (testMatmulToVectorPatterns1dTiling) { - fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), - stage1Patterns); + fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); } else if (testMatmulToVectorPatterns2dTiling) { stage1Patterns.emplace_back( ctx, std::make_unique>( @@ -519,8 +518,7 @@ .setInterchange({1, 2, 0}), LinalgTransformationFilter(Identifier::get("START", ctx), Identifier::get("L2", ctx)))); - fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), - stage1Patterns); + fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); } { // Canonicalization patterns diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -243,7 +243,7 @@ if (!dAttr) return; for (auto d : dAttr) - dOp.emitRemark() << d.first << " = " << d.second; + dOp.emitRemark() << d.first.getValue() << " = " << d.second; }); } diff --git a/mlir/test/lib/IR/TestPrintNesting.cpp b/mlir/test/lib/IR/TestPrintNesting.cpp --- a/mlir/test/lib/IR/TestPrintNesting.cpp +++ b/mlir/test/lib/IR/TestPrintNesting.cpp @@ -37,8 +37,8 @@ if (!op->getAttrs().empty()) { printIndent() << op->getAttrs().size() << " attributes:\n"; for (NamedAttribute attr : op->getAttrs()) - printIndent() << " - '" << attr.first << "' : '" << attr.second - << "'\n"; + printIndent() << " - '" << attr.first.getValue() << "' : '" + << attr.second << "'\n"; } // Recurse into each of the regions attached to the operation.