diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.h b/flang/include/flang/Optimizer/Dialect/FIRAttr.h --- a/flang/include/flang/Optimizer/Dialect/FIRAttr.h +++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.h @@ -48,7 +48,6 @@ mlir::Type getType() const; - static constexpr bool kindof(unsigned kind) { return kind == getId(); } static constexpr unsigned getId() { return AttributeKind::FIR_EXACTTYPE; } }; @@ -64,7 +63,6 @@ mlir::Type getType() const; - static constexpr bool kindof(unsigned kind) { return kind == getId(); } static constexpr unsigned getId() { return AttributeKind::FIR_SUBCLASS; } }; @@ -82,7 +80,6 @@ static constexpr llvm::StringRef getAttrName() { return "interval"; } static ClosedIntervalAttr get(mlir::MLIRContext *ctxt); - static constexpr bool kindof(unsigned kind) { return kind == getId(); } static constexpr unsigned getId() { return AttributeKind::FIR_CLOSEDCLOSED_INTERVAL; } @@ -100,7 +97,6 @@ static constexpr llvm::StringRef getAttrName() { return "upper"; } static UpperBoundAttr get(mlir::MLIRContext *ctxt); - static constexpr bool kindof(unsigned kind) { return kind == getId(); } static constexpr unsigned getId() { return AttributeKind::FIR_OPENCLOSED_INTERVAL; } @@ -118,7 +114,6 @@ static constexpr llvm::StringRef getAttrName() { return "lower"; } static LowerBoundAttr get(mlir::MLIRContext *ctxt); - static constexpr bool kindof(unsigned kind) { return kind == getId(); } static constexpr unsigned getId() { return AttributeKind::FIR_CLOSEDOPEN_INTERVAL; } @@ -136,7 +131,6 @@ static constexpr llvm::StringRef getAttrName() { return "point"; } static PointIntervalAttr get(mlir::MLIRContext *ctxt); - static constexpr bool kindof(unsigned kind) { return kind == getId(); } static constexpr unsigned getId() { return AttributeKind::FIR_POINT; } }; @@ -157,7 +151,6 @@ int getFKind() const; llvm::APFloat getValue() const; - static constexpr bool kindof(unsigned kind) { return kind == getId(); } static constexpr unsigned getId() { return AttributeKind::FIR_REAL_ATTR; } }; diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -114,7 +114,6 @@ /// Boilerplate mixin template template struct IntrinsicTypeMixin { - static constexpr bool kindof(unsigned kind) { return kind == getId(); } static constexpr unsigned getId() { return Id; } }; @@ -194,7 +193,6 @@ public: using Base::Base; static BoxType get(mlir::Type eleTy, mlir::AffineMapAttr map = {}); - static bool kindof(unsigned kind) { return kind == TypeKind::FIR_BOX; } mlir::Type getEleTy() const; mlir::AffineMapAttr getLayoutMap() const; @@ -211,7 +209,6 @@ public: using Base::Base; static BoxCharType get(mlir::MLIRContext *ctxt, KindTy kind); - static bool kindof(unsigned kind) { return kind == TypeKind::FIR_BOXCHAR; } CharacterType getEleTy() const; }; @@ -223,7 +220,6 @@ public: using Base::Base; static BoxProcType get(mlir::Type eleTy); - static bool kindof(unsigned kind) { return kind == TypeKind::FIR_BOXPROC; } mlir::Type getEleTy() const; static mlir::LogicalResult verifyConstructionInvariants(mlir::Location, @@ -239,7 +235,6 @@ public: using Base::Base; static DimsType get(mlir::MLIRContext *ctx, unsigned rank); - static bool kindof(unsigned kind) { return kind == TypeKind::FIR_DIMS; } /// returns -1 if the rank is unknown unsigned getRank() const; @@ -253,7 +248,6 @@ public: using Base::Base; static FieldType get(mlir::MLIRContext *ctxt); - static bool kindof(unsigned kind) { return kind == TypeKind::FIR_FIELD; } }; /// The type of a heap pointer. Fortran entities with the ALLOCATABLE attribute @@ -265,7 +259,6 @@ public: using Base::Base; static HeapType get(mlir::Type elementType); - static bool kindof(unsigned kind) { return kind == TypeKind::FIR_HEAP; } mlir::Type getEleTy() const; @@ -281,7 +274,6 @@ public: using Base::Base; static LenType get(mlir::MLIRContext *ctxt); - static bool kindof(unsigned kind) { return kind == TypeKind::FIR_LEN; } }; /// The type of entities with the POINTER attribute. These pointers are @@ -292,7 +284,6 @@ public: using Base::Base; static PointerType get(mlir::Type elementType); - static bool kindof(unsigned kind) { return kind == TypeKind::FIR_POINTER; } mlir::Type getEleTy() const; @@ -307,7 +298,6 @@ public: using Base::Base; static ReferenceType get(mlir::Type elementType); - static bool kindof(unsigned kind) { return kind == TypeKind::FIR_REFERENCE; } mlir::Type getEleTy() const; @@ -361,8 +351,6 @@ /// The value `-1` represents an unknown extent for a dimension static constexpr Extent getUnknownExtent() { return -1; } - static bool kindof(unsigned kind) { return kind == TypeKind::FIR_SEQUENCE; } - static mlir::LogicalResult verifyConstructionInvariants(mlir::Location loc, const Shape &shape, mlir::Type eleTy, mlir::AffineMapAttr map); @@ -379,9 +367,6 @@ public: using Base::Base; static TypeDescType get(mlir::Type ofType); - static constexpr bool kindof(unsigned kind) { - return kind == TypeKind::FIR_TYPEDESC; - } mlir::Type getOfTy() const; static mlir::LogicalResult verifyConstructionInvariants(mlir::Location, @@ -415,7 +400,6 @@ static RecordType get(mlir::MLIRContext *ctxt, llvm::StringRef name); void finalize(llvm::ArrayRef lenPList, llvm::ArrayRef typeList); - static constexpr bool kindof(unsigned kind) { return kind == getId(); } static constexpr unsigned getId() { return TypeKind::FIR_DERIVED; } detail::RecordTypeStorage const *uniqueKey() const; diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md --- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md +++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md @@ -89,10 +89,6 @@ /// Inherit some necessary constructors from 'TypeBase'. using Base::Base; - /// This static method is used to support type inquiry through isa, cast, - /// and dyn_cast. - static bool kindof(unsigned kind) { return kind == MyTypes::Simple; } - /// This method is used to get an instance of the 'SimpleType'. Given that /// this is a parameterless type, it just needs to take the context for /// uniquing purposes. @@ -193,10 +189,6 @@ /// Inherit some necessary constructors from 'TypeBase'. using Base::Base; - /// This static method is used to support type inquiry through isa, cast, - /// and dyn_cast. - static bool kindof(unsigned kind) { return kind == MyTypes::Complex; } - /// This method is used to get an instance of the 'ComplexType'. This method /// asserts that all of the construction invariants were satisfied. To /// gracefully handle failed construction, getChecked should be used instead. @@ -327,10 +319,6 @@ /// Inherit parent constructors. using Base::Base; - /// This static method is used to support type inquiry through isa, cast, - /// and dyn_cast. - static bool kindof(unsigned kind) { return kind == MyTypes::Recursive; } - /// Creates an instance of the Recursive type. This only takes the type name /// and returns the type with uninitialized body. static RecursiveType get(MLIRContext *ctx, StringRef name) { diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md --- a/mlir/docs/Tutorials/Toy/Ch-7.md +++ b/mlir/docs/Tutorials/Toy/Ch-7.md @@ -184,10 +184,6 @@ /// Inherit some necessary constructors from 'TypeBase'. using Base::Base; - /// This static method is used to support type inquiry through isa, cast, - /// and dyn_cast. - static bool kindof(unsigned kind) { return kind == ToyTypes::Struct; } - /// Create an instance of a `StructType` with the given element types. There /// *must* be at least one element type. static StructType get(llvm::ArrayRef elementTypes) { diff --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h --- a/mlir/examples/toy/Ch7/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h @@ -81,10 +81,6 @@ /// Inherit some necessary constructors from 'TypeBase'. using Base::Base; - /// This static method is used to support type inquiry through isa, cast, - /// and dyn_cast. - static bool kindof(unsigned kind) { return kind == ToyTypes::Struct; } - /// Create an instance of a `StructType` with the given element types. There /// *must* be atleast one element type. static StructType get(llvm::ArrayRef elementTypes); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -61,7 +61,7 @@ /// Similarly to other MLIR types, LLVM dialect types are owned by the MLIR /// context, have an immutable identifier (for most types except identified /// structs, the entire type is the identifier) and are thread-safe. -class LLVMType : public Type::TypeBase { +class LLVMType : public Type { public: enum Kind { // Keep non-parametric types contiguous in the enum. @@ -92,7 +92,7 @@ }; /// Inherit base constructors. - using Base::Base; + using Type::Type; /// Support for PointerLikeTypeTraits. using Type::getAsOpaquePointer; @@ -101,8 +101,9 @@ } /// Support for isa/cast. - static bool kindof(unsigned kind) { - return FIRST_NEW_LLVM_TYPE <= kind && kind <= LAST_NEW_LLVM_TYPE; + static bool classof(Type type) { + return type.getKind() >= FIRST_NEW_LLVM_TYPE && + type.getKind() <= LAST_NEW_LLVM_TYPE; } LLVMDialect &getDialect(); @@ -256,7 +257,6 @@ class ClassName : public Type::TypeBase { \ public: \ using Base::Base; \ - static bool kindof(unsigned kind) { return kind == Kind; } \ static ClassName get(MLIRContext *context) { \ return Base::get(context, Kind); \ } \ @@ -290,9 +290,6 @@ /// Inherit base constructors. using Base::Base; - /// Support for isa/cast. - static bool kindof(unsigned kind) { return kind == LLVMType::ArrayType; } - /// Gets or creates an instance of LLVM dialect array type containing /// `numElements` of `elementType`, in the same context as `elementType`. static LLVMArrayType get(LLVMType elementType, unsigned numElements); @@ -318,9 +315,6 @@ /// Inherit base constructors. using Base::Base; - /// Support for isa/cast. - static bool kindof(unsigned kind) { return kind == LLVMType::FunctionType; } - /// Gets or creates an instance of LLVM dialect function in the same context /// as the `result` type. static LLVMFunctionType get(LLVMType result, ArrayRef arguments, @@ -354,9 +348,6 @@ /// Inherit base constructor. using Base::Base; - /// Support for isa/cast. - static bool kindof(unsigned kind) { return kind == LLVMType::IntegerType; } - /// Gets or creates an instance of the integer of the specified `bitwidth` in /// the given context. static LLVMIntegerType get(MLIRContext *ctx, unsigned bitwidth); @@ -378,9 +369,6 @@ /// Inherit base constructors. using Base::Base; - /// Support for isa/cast. - static bool kindof(unsigned kind) { return kind == LLVMType::PointerType; } - /// Gets or creates an instance of LLVM dialect pointer type pointing to an /// object of `pointee` type in the given address space. The pointer type is /// created in the same context as `pointee`. @@ -427,9 +415,6 @@ /// Inherit base construtors. using Base::Base; - /// Support for isa/cast. - static bool kindof(unsigned kind) { return kind == LLVMType::StructType; } - /// Gets or creates an identified struct with the given name in the provided /// context. Note that unlike llvm::StructType::create, this function will /// _NOT_ rename a struct in case a struct with the same name already exists @@ -485,17 +470,13 @@ /// LLVM dialect vector type, represents a sequence of elements that can be /// processed as one, typically in SIMD context. This is a base class for fixed /// and scalable vectors. -class LLVMVectorType : public Type::TypeBase { +class LLVMVectorType : public LLVMType { public: /// Inherit base constructor. - using Base::Base; + using LLVMType::LLVMType; - /// Support for isa/cast. - static bool kindof(unsigned kind) { - return kind == LLVMType::FixedVectorType || - kind == LLVMType::ScalableVectorType; - } + /// Support type casting functionality. + static bool classof(Type type); /// Returns the element type of the vector. LLVMType getElementType(); @@ -517,11 +498,6 @@ /// Inherit base constructor. using Base::Base; - /// Support for isa/cast. - static bool kindof(unsigned kind) { - return kind == LLVMType::FixedVectorType; - } - /// Gets or creates a fixed vector type containing `numElements` of /// `elementType` in the same context as `elementType`. static LLVMFixedVectorType get(LLVMType elementType, unsigned numElements); @@ -544,11 +520,6 @@ /// Inherit base constructor. using Base::Base; - /// Support for isa/cast. - static bool kindof(unsigned kind) { - return kind == LLVMType::ScalableVectorType; - } - /// Gets or creates a scalable vector type containing a non-zero multiple of /// `minNumElements` of `elementType` in the same context as `elementType`. static LLVMScalableVectorType get(LLVMType elementType, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -41,8 +41,6 @@ /// Custom, uniq'ed construction in the MLIRContext. return Base::get(context, LinalgTypes::Range); } - /// Used to implement llvm-style cast. - static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; } }; } // namespace linalg diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -211,9 +211,6 @@ public: using Base::Base; - /// Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { return kind == QuantizationTypes::Any; } - /// Gets an instance of the type with all parameters specified but not /// checked. static AnyQuantizedType get(unsigned flags, Type storageType, @@ -292,11 +289,6 @@ int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax); - /// Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { - return kind == QuantizationTypes::UniformQuantized; - } - /// Gets the scale term. The scale designates the difference between the real /// values corresponding to consecutive quantized values differing by 1. double getScale() const; @@ -357,11 +349,6 @@ int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax); - /// Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { - return kind == QuantizationTypes::UniformQuantizedPerAxis; - } - /// Gets the quantization scales. The scales designate the difference between /// the real values corresponding to consecutive quantized values differing /// by 1. The ith scale corresponds to the ith slice in the diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h @@ -68,10 +68,6 @@ /// Returns `spirv::StorageClass`. Optional getStorageClass(); - static bool kindof(unsigned kind) { - return kind == AttrKind::InterfaceVarABI; - } - static LogicalResult verifyConstructionInvariants(Location loc, IntegerAttr descriptorSet, IntegerAttr binding, @@ -123,8 +119,6 @@ /// Returns the capabilities as an integer array attribute. ArrayAttr getCapabilitiesAttr(); - static bool kindof(unsigned kind) { return kind == AttrKind::VerCapExt; } - static LogicalResult verifyConstructionInvariants(Location loc, IntegerAttr version, ArrayAttr capabilities, @@ -165,8 +159,6 @@ /// Returns the target resource limits. ResourceLimitsAttr getResourceLimits(); - static bool kindof(unsigned kind) { return kind == AttrKind::TargetEnv; } - static LogicalResult verifyConstructionInvariants(Location loc, VerCapExtAttr triple, DictionaryAttr limits); diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -170,8 +170,6 @@ public: using Base::Base; - static bool kindof(unsigned kind) { return kind == TypeKind::Array; } - static ArrayType get(Type elementType, unsigned elementCount); /// Returns an array type with the given stride in bytes. @@ -202,8 +200,6 @@ public: using Base::Base; - static bool kindof(unsigned kind) { return kind == TypeKind::Image; } - static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth = ImageDepthInfo::DepthUnknown, @@ -243,8 +239,6 @@ public: using Base::Base; - static bool kindof(unsigned kind) { return kind == TypeKind::Pointer; } - static PointerType get(Type pointeeType, StorageClass storageClass); Type getPointeeType() const; @@ -264,8 +258,6 @@ public: using Base::Base; - static bool kindof(unsigned kind) { return kind == TypeKind::RuntimeArray; } - static RuntimeArrayType get(Type elementType); /// Returns a runtime array type with the given stride in bytes. @@ -318,8 +310,6 @@ } }; - static bool kindof(unsigned kind) { return kind == TypeKind::Struct; } - /// Construct a StructType with at least one member. static StructType get(ArrayRef memberTypes, ArrayRef offsetInfo = {}, @@ -385,10 +375,6 @@ public: using Base::Base; - static bool kindof(unsigned kind) { - return kind == TypeKind::CooperativeMatrix; - } - static CooperativeMatrixNVType get(Type elementType, spirv::Scope scope, unsigned rows, unsigned columns); Type getElementType() const; @@ -412,8 +398,6 @@ public: using Base::Base; - static bool kindof(unsigned kind) { return kind == TypeKind::Matrix; } - static MatrixType get(Type columnType, uint32_t columnCount); static MatrixType getChecked(Type columnType, uint32_t columnCount, diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -49,11 +49,6 @@ static ComponentType get(MLIRContext *context) { return Base::get(context, ShapeTypes::Kind::Component); } - - /// Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { - return kind == ShapeTypes::Kind::Component; - } }; /// The element type of the shaped type. @@ -64,11 +59,6 @@ static ElementType get(MLIRContext *context) { return Base::get(context, ShapeTypes::Kind::Element); } - - /// Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { - return kind == ShapeTypes::Kind::Element; - } }; /// The shape descriptor type represents rank and dimension sizes. @@ -79,9 +69,6 @@ static ShapeType get(MLIRContext *context) { return Base::get(context, ShapeTypes::Kind::Shape); } - - /// Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { return kind == ShapeTypes::Kind::Shape; } }; /// The type of a single dimension. @@ -92,9 +79,6 @@ static SizeType get(MLIRContext *context) { return Base::get(context, ShapeTypes::Kind::Size); } - - /// Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { return kind == ShapeTypes::Kind::Size; } }; /// The ValueShape represents a (potentially unknown) runtime value and shape. @@ -106,11 +90,6 @@ static ValueShapeType get(MLIRContext *context) { return Base::get(context, ShapeTypes::Kind::ValueShape); } - - /// Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { - return kind == ShapeTypes::Kind::ValueShape; - } }; /// The Witness represents a runtime constraint, to be used as shape related @@ -122,11 +101,6 @@ static WitnessType get(MLIRContext *context) { return Base::get(context, ShapeTypes::Kind::Witness); } - - /// Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { - return kind == ShapeTypes::Kind::Witness; - } }; #define GET_OP_CLASSES diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -36,7 +36,7 @@ /// This method is used by Dialect objects when they register the list of /// attributes they contain. template static AbstractAttribute get(Dialect &dialect) { - return AbstractAttribute(dialect, T::getInterfaceMap()); + return AbstractAttribute(dialect, T::getInterfaceMap(), T::getTypeID()); } /// Return the dialect this attribute was registered to. @@ -49,15 +49,23 @@ return interfaceMap.lookup(); } + /// Return the unique identifier representing the concrete attribute class. + TypeID getTypeID() const { return typeID; } + private: - AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap) - : dialect(dialect), interfaceMap(std::move(interfaceMap)) {} + AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap, + TypeID typeID) + : dialect(dialect), interfaceMap(std::move(interfaceMap)), + typeID(typeID) {} /// This is the dialect that this attribute was registered to. Dialect &dialect; /// This is a collection of the interfaces registered to this attribute. detail::InterfaceMap interfaceMap; + + /// The unique identifier of the derived Attribute class. + TypeID typeID; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -97,6 +97,10 @@ /// Return the classification for this attribute. unsigned getKind() const { return impl->getKind(); } + /// Return a unique identifier for the concrete attribute type. This is used + /// to support dynamic type casting. + TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); } + /// Return the type of this attribute. Type getType() const; @@ -231,11 +235,6 @@ static AffineMapAttr get(AffineMap value); AffineMap getValue() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::AffineMap; - } }; //===----------------------------------------------------------------------===// @@ -262,11 +261,6 @@ size_t size() const { return getValue().size(); } bool empty() const { return size() == 0; } - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::Array; - } - private: /// Class for underlying value iterator support. template @@ -357,11 +351,6 @@ /// Requires: uniquely named attributes. static bool sortInPlace(SmallVectorImpl &array); - /// Methods for supporting type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::Dictionary; - } - private: /// Return empty dictionary. static DictionaryAttr getEmpty(MLIRContext *context); @@ -394,11 +383,6 @@ double getValueAsDouble() const; static double getValueAsDouble(APFloat val); - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::Float; - } - /// Verify the construction invariants for a double value. static LogicalResult verifyConstructionInvariants(Location loc, Type type, double value); @@ -432,11 +416,6 @@ /// an unsigned integer. uint64_t getUInt() const; - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::Integer; - } - static LogicalResult verifyConstructionInvariants(Location loc, Type type, int64_t value); static LogicalResult verifyConstructionInvariants(Location loc, Type type, @@ -480,11 +459,6 @@ static IntegerSetAttr get(IntegerSet value); IntegerSet getValue() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::IntegerSet; - } }; //===----------------------------------------------------------------------===// @@ -520,10 +494,6 @@ Identifier dialect, StringRef attrData, Type type); - - static bool kindof(unsigned kind) { - return kind == StandardAttributes::Opaque; - } }; //===----------------------------------------------------------------------===// @@ -543,11 +513,6 @@ static StringAttr get(StringRef bytes, Type type); StringRef getValue() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::String; - } }; //===----------------------------------------------------------------------===// @@ -584,11 +549,6 @@ /// Returns the set of nested references representing the path to the symbol /// nested under the root reference. ArrayRef getNestedReferences() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::SymbolRef; - } }; /// A symbol reference with a reference path containing a single element. This @@ -630,9 +590,6 @@ static TypeAttr get(Type value); Type getValue() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { return kind == StandardAttributes::Type; } }; //===----------------------------------------------------------------------===// @@ -647,8 +604,6 @@ using Base::Base; static UnitAttr get(MLIRContext *context); - - static bool kindof(unsigned kind) { return kind == StandardAttributes::Unit; } }; //===----------------------------------------------------------------------===// @@ -1229,11 +1184,6 @@ public: using Base::Base; - /// Method for support type inquiry through isa, cast and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::DenseStringElements; - } - /// Overload of the raw 'get' method that asserts that the given type is of /// integer or floating-point type. This method is used to verify type /// invariants that the templatized 'get' method cannot. @@ -1252,11 +1202,6 @@ public: using Base::Base; - /// Method for support type inquiry through isa, cast and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::DenseIntOrFPElements; - } - protected: friend DenseElementsAttr; @@ -1394,11 +1339,6 @@ /// Returns dialect associated with this opaque constant. Dialect *getDialect() const; - - /// Method for support type inquiry through isa, cast and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::OpaqueElements; - } }; /// An attribute that represents a reference to a sparse vector or tensor @@ -1460,11 +1400,6 @@ /// expected to refer to a valid element. Attribute getValue(ArrayRef index) const; - /// Method for support type inquiry through isa, cast and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::SparseElements; - } - private: /// Get a zero APFloat for the given sparse attribute. APFloat getZeroAPFloat() const; diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -120,11 +120,6 @@ /// The caller's location. Location getCaller() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::CallSiteLocation; - } }; /// Represents a location derived from a file/line/column location. The column @@ -146,11 +141,6 @@ unsigned getLine() const; unsigned getColumn() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::FileLineColLocation; - } }; /// Represents a value composed of multiple source constructs, with an optional @@ -174,11 +164,6 @@ /// Returns the optional metadata attached to this fused location. Given that /// it is optional, the return value may be a null node. Attribute getMetadata() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::FusedLocation; - } }; /// Represents an identity name attached to a child location. @@ -199,11 +184,6 @@ /// Return the child location. Location getChildLoc() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::NameLocation; - } }; /// Represents an unknown location. This is always a singleton for a given @@ -215,11 +195,6 @@ /// Get an instance of the UnknownLoc. static Location get(MLIRContext *context); - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::UnknownLocation; - } }; /// Represents a location that is external to MLIR. Contains a pointer to some @@ -283,11 +258,6 @@ /// Returns a fallback location. Location getFallbackLocation() const; - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == StandardAttributes::OpaqueLocation; - } - private: static Location get(uintptr_t underlyingLocation, TypeID typeID, Location fallbackLocation); diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -92,8 +92,6 @@ Type elementType); Type getElementType(); - - static bool kindof(unsigned kind) { return kind == StandardTypes::Complex; } }; //===----------------------------------------------------------------------===// @@ -109,9 +107,6 @@ /// Get an instance of the IndexType. static IndexType get(MLIRContext *context); - /// Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { return kind == StandardTypes::Index; } - /// Storage bit width used for IndexType by internal compiler data structures. static constexpr unsigned kInternalStorageBitWidth = 64; }; @@ -177,9 +172,6 @@ /// Return true if this is an unsigned integer type. bool isUnsigned() const { return getSignedness() == Unsigned; } - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; } - /// Integer representation maximal bitwidth. static constexpr unsigned kMaxWidth = 4096; }; @@ -208,12 +200,6 @@ return get(StandardTypes::F64, ctx); } - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE && - kind <= StandardTypes::LAST_FLOATING_POINT_TYPE; - } - /// Return the bitwidth of this float type. unsigned getWidth(); @@ -233,8 +219,6 @@ /// Get an instance of the NoneType. static NoneType get(MLIRContext *context); - - static bool kindof(unsigned kind) { return kind == StandardTypes::None; } }; //===----------------------------------------------------------------------===// @@ -361,9 +345,6 @@ } ArrayRef getShape() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; } }; //===----------------------------------------------------------------------===// @@ -422,10 +403,6 @@ Type elementType); ArrayRef getShape() const; - - static bool kindof(unsigned kind) { - return kind == StandardTypes::RankedTensor; - } }; //===----------------------------------------------------------------------===// @@ -454,10 +431,6 @@ Type elementType); ArrayRef getShape() const { return llvm::None; } - - static bool kindof(unsigned kind) { - return kind == StandardTypes::UnrankedTensor; - } }; //===----------------------------------------------------------------------===// @@ -568,8 +541,6 @@ return ShapedType::kDynamicStrideOrOffset; } - static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; } - private: /// Get or create a new MemRefType defined by the arguments. If the resulting /// type would be ill-formed, return nullptr. If the location is provided, @@ -611,9 +582,6 @@ /// Returns the memory space in which data referred to by this memref resides. unsigned getMemorySpace() const; - static bool kindof(unsigned kind) { - return kind == StandardTypes::UnrankedMemRef; - } }; //===----------------------------------------------------------------------===// @@ -659,8 +627,6 @@ assert(index < size() && "invalid index for tuple type"); return getTypes()[index]; } - - static bool kindof(unsigned kind) { return kind == StandardTypes::Tuple; } }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -68,12 +68,12 @@ /// Return a unique identifier for the concrete type. static TypeID getTypeID() { return TypeID::get(); } - /// Provide a default implementation of 'classof' that invokes a 'kindof' - /// method on the concrete type. + /// Provide an implementation of 'classof' that compares the type id of the + /// provided value with that of the concerete type. template static bool classof(T val) { static_assert(std::is_convertible::value, "casting from a non-convertible type"); - return ConcreteT::kindof(val.getKind()); + return val.getTypeID() == getTypeID(); } /// Returns an interface map for the interfaces registered to this storage @@ -107,8 +107,7 @@ /// Mutate the current storage instance. This will not change the unique key. /// The arguments are forwarded to 'ConcreteT::mutate'. - template - LogicalResult mutate(Args &&...args) { + template LogicalResult mutate(Args &&...args) { return UniquerT::template mutate(this->getContext(), getImpl(), std::forward(args)...); } diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -35,7 +35,7 @@ /// This method is used by Dialect objects when they register the list of /// types they contain. template static AbstractType get(Dialect &dialect) { - return AbstractType(dialect, T::getInterfaceMap()); + return AbstractType(dialect, T::getInterfaceMap(), T::getTypeID()); } /// Return the dialect this type was registered to. @@ -48,15 +48,23 @@ return interfaceMap.lookup(); } + /// Return the unique identifier representing the concrete type class. + TypeID getTypeID() const { return typeID; } + private: - AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap) - : dialect(dialect), interfaceMap(std::move(interfaceMap)) {} + AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap, + TypeID typeID) + : dialect(dialect), interfaceMap(std::move(interfaceMap)), + typeID(typeID) {} /// This is the dialect that this type was registered to. Dialect &dialect; /// This is a collection of the interfaces registered to this type. detail::InterfaceMap interfaceMap; + + /// The unique identifier of the derived Type class. + TypeID typeID; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -44,11 +44,6 @@ /// /// Derived type classes are expected to implement several required /// implementation hooks: -/// * Required: -/// - static bool kindof(unsigned kind); -/// * Returns if the provided type kind corresponds to an instance of the -/// current type. Used for isa/dyn_cast casting functionality. -/// /// * Optional: /// - static LogicalResult verifyConstructionInvariants(Location loc, /// Args... args) @@ -137,6 +132,10 @@ // Support type casting Type to itself. static bool classof(Type) { return true; } + /// Return a unique identifier for the concrete type. This is used to support + /// dynamic type casting. + TypeID getTypeID() { return impl->getAbstractType().getTypeID(); } + /// Return the classification for this type. unsigned getKind() const; @@ -262,7 +261,6 @@ // Input types. unsigned getNumInputs() const; - Type getInput(unsigned i) const { return getInputs()[i]; } ArrayRef getInputs() const; @@ -270,9 +268,6 @@ unsigned getNumResults() const; Type getResult(unsigned i) const { return getResults()[i]; } ArrayRef getResults() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { return kind == Kind::Function; } }; //===----------------------------------------------------------------------===// @@ -307,8 +302,6 @@ static LogicalResult verifyConstructionInvariants(Location loc, Identifier dialect, StringRef typeData); - - static bool kindof(unsigned kind) { return kind == Kind::Opaque; } }; // Make Type hashable. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -322,6 +322,11 @@ //===----------------------------------------------------------------------===// // Vector types. +/// Support type casting functionality. +bool LLVMVectorType::classof(Type type) { + return type.isa(); +} + LLVMType LLVMVectorType::getElementType() { // Both derived classes share the implementation type. return static_cast(impl)->elementType; @@ -331,7 +336,7 @@ // Both derived classes share the implementation type. return llvm::ElementCount( static_cast(impl)->numElements, - this->isa()); + isa()); } LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType, diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -666,7 +666,6 @@ StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } FloatType FloatType::get(StandardTypes::Kind kind, MLIRContext *context) { - assert(kindof(kind) && "Not a FP kind."); switch (kind) { case StandardTypes::BF16: return context->getImpl().bf16Ty; diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -26,10 +26,6 @@ TestTypeInterface::Trait> { using Base::Base; - static bool kindof(unsigned kind) { - return kind == Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE; - } - static TestType get(MLIRContext *context) { return Base::get(context, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE); } @@ -76,10 +72,6 @@ public: using Base::Base; - static bool kindof(unsigned kind) { - return kind == Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1; - } - static TestRecursiveType create(MLIRContext *ctx, StringRef name) { return Base::get(ctx, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1, name);