diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -21,6 +21,47 @@ class MLIRContext; class Type; +//===----------------------------------------------------------------------===// +// AbstractAttribute +//===----------------------------------------------------------------------===// + +/// This class contains all of the static information common to all instances of +/// a registered Attribute. +class AbstractAttribute { +public: + /// Look up the specified abstract attribute in the MLIRContext and return a + /// reference to it. + static const AbstractAttribute &lookup(TypeID typeID, MLIRContext *context); + + /// This method is used by Dialect objects when they register the list of + /// attributes they contain. + template + static AbstractAttribute get(Dialect &dialect) { + return AbstractAttribute(dialect, T::getInterfaceMap()); + } + + /// Return the dialect this attribute was registered to. + Dialect &getDialect() const { return const_cast(dialect); } + + /// Returns an instance of the concept object for the given interface if it + /// was registered to this attribute, null otherwise. This should not be used + /// directly. + template + typename T::Concept *getInterface() const { + return interfaceMap.lookup(); + } + +private: + AbstractAttribute(Dialect &dialect, detail::InterfaceMap interfaceMap) + : dialect(dialect), interfaceMap(interfaceMap) {} + + /// This is the dialect that this attribute was registered to. + Dialect &dialect; + + /// This is a collection of the interfaces registered to this attribute. + detail::InterfaceMap interfaceMap; +}; + //===----------------------------------------------------------------------===// // AttributeStorage //===----------------------------------------------------------------------===// @@ -39,10 +80,10 @@ /// Get the type of this attribute. Type getType() const; - /// Get the dialect of this attribute. - Dialect &getDialect() const { - assert(dialect && "Malformed attribute storage object."); - return const_cast(*dialect); + /// Return the abstract descriptor for this attribute. + const AbstractAttribute &getAbstractAttribute() const { + assert(abstractAttribute && "Malformed type storage object."); + return *abstractAttribute; } protected: @@ -56,13 +97,15 @@ /// Set the type of this attribute. void setType(Type type); - // Set the dialect for this storage instance. This is used by the + // Set the abstract attribute for this storage instance. This is used by the // AttributeUniquer when initializing a newly constructed storage object. - void initializeDialect(Dialect &newDialect) { dialect = &newDialect; } + void initialize(const AbstractAttribute &abstractAttr) { + abstractAttribute = &abstractAttr; + } private: - /// The dialect for this attribute. - Dialect *dialect; + /// The abstract descriptor for this attribute. + const AbstractAttribute *abstractAttribute; /// The opaque type of the attribute value. const void *type; diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -64,9 +64,10 @@ /// Utility class for implementing attributes. template + typename StorageType = AttributeStorage, + template class... Traits> using AttrBase = detail::StorageUserBase; + detail::AttributeUniquer, Traits...>; using ImplType = AttributeStorage; using ValueType = void; @@ -117,6 +118,11 @@ friend ::llvm::hash_code hash_value(Attribute arg); + /// Return the abstract descriptor for this attribute. + const AbstractAttribute &getAbstractAttribute() const { + return impl->getAbstractAttribute(); + } + protected: ImplType *impl; }; @@ -126,6 +132,46 @@ return os; } +//===----------------------------------------------------------------------===// +// AttributeTraitBase +//===----------------------------------------------------------------------===// + +namespace AttributeTrait { +/// This class represents the base of an attribute trait. +template class TraitType> +using TraitBase = detail::StorageUserTraitBase; +} // namespace AttributeTrait + +//===----------------------------------------------------------------------===// +// AttributeInterface +//===----------------------------------------------------------------------===// + +/// This class represents the base of an attribute interface. See the definition +/// of `detail::Interface` for requirements on the `Traits` type. +template +class AttributeInterface + : public detail::Interface { +public: + using Base = AttributeInterface; + using InterfaceBase = detail::Interface; + using InterfaceBase::InterfaceBase; + +private: + /// Returns the impl interface instance for the given type. + static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) { + return attr.getAbstractAttribute().getInterface(); + } + + /// Allow access to 'getInterfaceFor'. + friend InterfaceBase; +}; + +//===----------------------------------------------------------------------===// +// StandardAttributes +//===----------------------------------------------------------------------===// + namespace StandardAttributes { enum Kind { AffineMap = Attribute::FIRST_STANDARD_ATTR, diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -190,13 +190,19 @@ /// This method is used by derived classes to add their types to the set. template void addTypes() { - (void)std::initializer_list{0, (addSymbol(Args::getTypeID()), 0)...}; + (void)std::initializer_list{ + 0, (addType(Args::getTypeID(), AbstractType::get(*this)), 0)...}; } + void addType(TypeID typeID, AbstractType &&typeInfo); /// This method is used by derived classes to add their attributes to the set. template void addAttributes() { - (void)std::initializer_list{0, (addSymbol(Args::getTypeID()), 0)...}; + (void)std::initializer_list{ + 0, + (addAttribute(Args::getTypeID(), AbstractAttribute::get(*this)), + 0)...}; } + void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo); /// Enable support for unregistered operations. void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; } @@ -214,9 +220,6 @@ } private: - // Register a symbol(e.g. type) with its given unique class identifier. - void addSymbol(TypeID typeID); - Dialect(const Dialect &) = delete; void operator=(Dialect &) = delete; diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -13,6 +13,7 @@ #ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H #define MLIR_IR_STORAGEUNIQUERSUPPORT_H +#include "mlir/Support/InterfaceSupport.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StorageUniquer.h" #include "mlir/Support/TypeID.h" @@ -27,17 +28,41 @@ /// avoid the need to include Location.h. const AttributeStorage *generateUnknownStorageLocation(MLIRContext *ctx); +//===----------------------------------------------------------------------===// +// StorageUserTraitBase +//===----------------------------------------------------------------------===// + +/// Helper class for implementing traits for storage classes. Clients are not +/// expected to interact with this directly, so its members are all protected. +template class TraitType> +class StorageUserTraitBase { +protected: + /// Return the derived instance. + ConcreteType getInstance() const { + // We have to cast up to the trait type, then to the concrete type because + // the concrete type will multiply derive from the (content free) TraitBase + // class, and we need to be able to disambiguate the path for the C++ + // compiler. + auto *trait = static_cast *>(this); + return *static_cast(trait); + } +}; + +//===----------------------------------------------------------------------===// +// StorageUserBase +//===----------------------------------------------------------------------===// + /// Utility class for implementing users of storage classes uniqued by a /// StorageUniquer. Clients are not expected to interact with this class /// directly. template -class StorageUserBase : public BaseT { + typename UniquerT, template class... Traits> +class StorageUserBase : public BaseT, public Traits... { public: using BaseT::BaseT; /// Utility declarations for the concrete attribute class. - using Base = StorageUserBase; + using Base = StorageUserBase; using ImplType = StorageT; /// Return a unique identifier for the concrete type. @@ -51,6 +76,12 @@ return ConcreteT::kindof(val.getKind()); } + /// Returns an interface map for the interfaces registered to this storage + /// user. This should not be used directly. + static detail::InterfaceMap getInterfaceMap() { + return detail::InterfaceMap::template get...>(); + } + protected: /// Get or create a new ConcreteT instance within the ctx. This /// function is guaranteed to return a non null object and will assert if diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -20,6 +20,47 @@ class Dialect; class MLIRContext; +//===----------------------------------------------------------------------===// +// AbstractType +//===----------------------------------------------------------------------===// + +/// This class contains all of the static information common to all instances of +/// a registered Type. +class AbstractType { +public: + /// Look up the specified abstract type in the MLIRContext and return a + /// reference to it. + static const AbstractType &lookup(TypeID typeID, MLIRContext *context); + + /// This method is used by Dialect objects when they register the list of + /// types they contain. + template + static AbstractType get(Dialect &dialect) { + return AbstractType(dialect, T::getInterfaceMap()); + } + + /// Return the dialect this type was registered to. + Dialect &getDialect() const { return const_cast(dialect); } + + /// Returns an instance of the concept object for the given interface if it + /// was registered to this type, null otherwise. This should not be used + /// directly. + template + typename T::Concept *getInterface() const { + return interfaceMap.lookup(); + } + +private: + AbstractType(Dialect &dialect, detail::InterfaceMap interfaceMap) + : dialect(dialect), interfaceMap(interfaceMap) {} + + /// This is the dialect that this type was registered to. + Dialect &dialect; + + /// This is a collection of the interfaces registered to this type. + detail::InterfaceMap interfaceMap; +}; + //===----------------------------------------------------------------------===// // TypeStorage //===----------------------------------------------------------------------===// @@ -35,17 +76,16 @@ protected: /// This constructor is used by derived classes as part of the TypeUniquer. - /// When using this constructor, the initializeDialect function must be - /// invoked afterwards for the storage to be valid. TypeStorage(unsigned subclassData = 0) - : dialect(nullptr), subclassData(subclassData) {} + : abstractType(nullptr), subclassData(subclassData) {} public: - /// Get the dialect that this type is registered to. - Dialect &getDialect() { - assert(dialect && "Malformed type storage object."); - return *dialect; + /// Return the abstract type descriptor for this type. + const AbstractType &getAbstractType() { + assert(abstractType && "Malformed type storage object."); + return *abstractType; } + /// Get the subclass data. unsigned getSubclassData() const { return subclassData; } @@ -53,12 +93,14 @@ void setSubclassData(unsigned val) { subclassData = val; } private: - // Set the dialect for this storage instance. This is used by the TypeUniquer - // when initializing a newly constructed type storage object. - void initializeDialect(Dialect &newDialect) { dialect = &newDialect; } + /// Set the abstract type for this storage instance. This is used by the + /// TypeUniquer when initializing a newly constructed type storage object. + void initialize(const AbstractType &abstractTy) { + abstractType = &abstractTy; + } - /// The dialect for this type. - Dialect *dialect; + /// The abstract description for this type. + const AbstractType *abstractType; /// Space for subclasses to store data. unsigned subclassData; @@ -72,36 +114,26 @@ // TypeStorageAllocator //===----------------------------------------------------------------------===// -// This is a utility allocator used to allocate memory for instances of derived -// Types. +/// This is a utility allocator used to allocate memory for instances of derived +/// Types. using TypeStorageAllocator = StorageUniquer::StorageAllocator; //===----------------------------------------------------------------------===// // TypeUniquer //===----------------------------------------------------------------------===// namespace detail { -// A utility class to get, or create, unique instances of types within an -// MLIRContext. This class manages all creation and uniquing of types. -class TypeUniquer { -public: +/// A utility class to get, or create, unique instances of types within an +/// MLIRContext. This class manages all creation and uniquing of types. +struct TypeUniquer { /// Get an uniqued instance of a type T. template static T get(MLIRContext *ctx, unsigned kind, Args &&... args) { return ctx->getTypeUniquer().get( [&](TypeStorage *storage) { - storage->initializeDialect(lookupDialectForType(ctx)); + storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); }, kind, std::forward(args)...); } - -private: - /// Get the dialect that the type 'T' was registered with. - template static Dialect &lookupDialectForType(MLIRContext *ctx) { - return lookupDialectForType(ctx, T::getTypeID()); - } - - /// Get the dialect that registered the type with the provided typeid. - static Dialect &lookupDialectForType(MLIRContext *ctx, TypeID typeID); }; } // namespace detail diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -101,9 +101,10 @@ /// Utility class for implementing types. template + typename StorageType = DefaultTypeStorage, + template class... Traits> using TypeBase = detail::StorageUserBase; + detail::TypeUniquer, Traits...>; using ImplType = TypeStorage; @@ -194,6 +195,9 @@ return Type(reinterpret_cast(const_cast(pointer))); } + /// Return the abstract type descriptor for this type. + const AbstractType &getAbstractType() { return impl->getAbstractType(); } + protected: ImplType *impl; }; @@ -203,6 +207,45 @@ return os; } +//===----------------------------------------------------------------------===// +// TypeTraitBase +//===----------------------------------------------------------------------===// + +namespace TypeTrait { +/// This class represents the base of a type trait. +template class TraitType> +using TraitBase = detail::StorageUserTraitBase; +} // namespace TypeTrait + +//===----------------------------------------------------------------------===// +// TypeInterface +//===----------------------------------------------------------------------===// + +/// This class represents the base of a type interface. See the definition of +/// `detail::Interface` for requirements on the `Traits` type. +template +class TypeInterface : public detail::Interface { +public: + using Base = TypeInterface; + using InterfaceBase = + detail::Interface; + using InterfaceBase::InterfaceBase; + +private: + /// Returns the impl interface instance for the given type. + static typename InterfaceBase::Concept *getInterfaceFor(Type type) { + return type.getAbstractType().getInterface(); + } + + /// Allow access to 'getInterfaceFor'. + friend InterfaceBase; +}; + +//===----------------------------------------------------------------------===// +// FunctionType +//===----------------------------------------------------------------------===// + /// Function types map from a list of inputs to a list of results. class FunctionType : public Type::TypeBase { @@ -230,6 +273,10 @@ static bool kindof(unsigned kind) { return kind == Kind::Function; } }; +//===----------------------------------------------------------------------===// +// OpaqueType +//===----------------------------------------------------------------------===// + /// Opaque types represent types of non-registered dialects. These are types /// represented in their raw string form, and can only usefully be tested for /// type equality. diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -47,7 +47,9 @@ MLIRContext *Attribute::getContext() const { return getType().getContext(); } /// Get the dialect this attribute is registered to. -Dialect &Attribute::getDialect() const { return impl->getDialect(); } +Dialect &Attribute::getDialect() const { + return impl->getAbstractAttribute().getDialect(); +} //===----------------------------------------------------------------------===// // AffineMapAttr diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -282,13 +282,12 @@ /// operations. llvm::StringMap registeredOperations; - /// This is a mapping from type id to Dialect for registered attributes and - /// types. - DenseMap registeredDialectSymbols; - /// These are identifiers uniqued into this MLIRContext. llvm::StringSet identifiers; + /// An allocator used for AbstractAttribute and AbstractType objects. + llvm::BumpPtrAllocator abstractDialectSymbolAllocator; + //===--------------------------------------------------------------------===// // Affine uniquing //===--------------------------------------------------------------------===// @@ -311,6 +310,8 @@ //===--------------------------------------------------------------------===// // Type uniquing //===--------------------------------------------------------------------===// + + DenseMap registeredTypes; StorageUniquer typeUniquer; /// Cached Type Instances. @@ -322,6 +323,8 @@ //===--------------------------------------------------------------------===// // Attribute uniquing //===--------------------------------------------------------------------===// + + DenseMap registeredAttributes; StorageUniquer attributeUniquer; /// Cached Attribute Instances. @@ -569,18 +572,45 @@ } } -/// Register a dialect-specific symbol(e.g. type) with the current context. -void Dialect::addSymbol(TypeID typeID) { +void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { auto &impl = context->getImpl(); // Lock access to the context registry. ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled); - if (!impl.registeredDialectSymbols.insert({typeID, this}).second) { - llvm::errs() << "error: dialect symbol already registered.\n"; + auto *newInfo = + new (impl.abstractDialectSymbolAllocator.Allocate()) + AbstractType(std::move(typeInfo)); + if (!impl.registeredTypes.insert({typeID, newInfo}).second) { + llvm::errs() << "error: dialect type already registered.\n"; abort(); } } +void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { + auto &impl = context->getImpl(); + + // Lock access to the context registry. + ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled); + auto *newInfo = + new (impl.abstractDialectSymbolAllocator.Allocate()) + AbstractAttribute(std::move(attrInfo)); + if (!impl.registeredAttributes.insert({typeID, newInfo}).second) { + llvm::errs() << "error: dialect attribute already registered.\n"; + abort(); + } +} + +/// Get the dialect that registered the attribute with the provided typeid. +const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID, + MLIRContext *context) { + auto &impl = context->getImpl(); + auto it = impl.registeredAttributes.find(typeID); + if (it == impl.registeredAttributes.end()) + llvm::report_fatal_error("Trying to create an Attribute that was not " + "registered in this MLIRContext."); + return *it->second; +} + /// Look up the specified operation in the operation set and return a pointer /// to it if present. Otherwise, return a null pointer. const AbstractOperation *AbstractOperation::lookup(StringRef opName, @@ -595,6 +625,16 @@ return nullptr; } +/// Get the dialect that registered the type with the provided typeid. +const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { + auto &impl = context->getImpl(); + auto it = impl.registeredTypes.find(typeID); + if (it == impl.registeredTypes.end()) + llvm::report_fatal_error( + "Trying to create a Type that was not registered in this MLIRContext."); + return *it->second; +} + //===----------------------------------------------------------------------===// // Identifier uniquing //===----------------------------------------------------------------------===// @@ -628,24 +668,10 @@ // Type uniquing //===----------------------------------------------------------------------===// -static Dialect &lookupDialectForSymbol(MLIRContext *ctx, TypeID typeID) { - auto &impl = ctx->getImpl(); - auto it = impl.registeredDialectSymbols.find(typeID); - if (it == impl.registeredDialectSymbols.end()) - llvm::report_fatal_error( - "Trying to create a type that was not registered in this MLIRContext."); - return *it->second; -} - /// Returns the storage uniquer used for constructing type storage instances. /// This should not be used directly. StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } -/// Get the dialect that registered the type with the provided typeid. -Dialect &TypeUniquer::lookupDialectForType(MLIRContext *ctx, TypeID typeID) { - return lookupDialectForSymbol(ctx, typeID); -} - FloatType FloatType::get(StandardTypes::Kind kind, MLIRContext *context) { assert(kindof(kind) && "Not a FP kind."); switch (kind) { @@ -738,7 +764,7 @@ void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, MLIRContext *ctx, TypeID attrID) { - storage->initializeDialect(lookupDialectForSymbol(ctx, attrID)); + storage->initialize(AbstractAttribute::lookup(attrID, ctx)); // If the attribute did not provide a type, then default to NoneType. if (!storage->getType()) diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -21,8 +21,9 @@ unsigned Type::getKind() const { return impl->getKind(); } -/// Get the dialect this type is registered to. -Dialect &Type::getDialect() const { return impl->getDialect(); } +Dialect &Type::getDialect() const { + return impl->getAbstractType().getDialect(); +} MLIRContext *Type::getContext() const { return getDialect().getContext(); }