diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -80,6 +80,8 @@ #define MLIR_PYTHON_CAPSULE_PASS_MANAGER \ MAKE_MLIR_PYTHON_QUALNAME("passmanager.PassManager._CAPIPtr") #define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_TYPEID \ + MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr") /** Attribute on MLIR Python objects that expose their C-API pointer. * This will be a type-specific capsule created as per one of the helpers @@ -268,6 +270,25 @@ return op; } +/** Creates a capsule object encapsulating the raw C-API MlirTypeID. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the type in any way. + */ +static inline PyObject *mlirPythonTypeIDToCapsule(MlirTypeID typeID) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(typeID), + MLIR_PYTHON_CAPSULE_TYPEID, NULL); +} + +/** Extracts an MlirTypeID from a capsule as produced from + * mlirPythonTypeIDToCapsule. If the capsule is not of the right type, then + * a null type is returned (as checked via mlirTypeIDIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_TYPEID); + MlirTypeID typeID = {ptr}; + return typeID; +} + /** Creates a capsule object encapsulating the raw C-API MlirType. * The returned capsule does not extend or affect ownership of any Python * objects that reference the type in any way. diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -18,10 +18,40 @@ extern "C" { #endif +//#define FORALL_CONCRETETYPES(_) \ +// _(BFloat16Type) \ +// _(ComplexType) \ +// _(Float8E4M3B11FNUZType) \ +// _(Float8E4M3FNType) \ +// _(Float8E4M3FNUZType) \ +// _(Float8E5M2Type) \ +// _(Float8E5M2FNUZType) \ +// _(Float16Type) \ +// _(Float32Type) \ +// _(Float64Type) \ +// _(FunctionType) \ +// _(IndexType) \ +// _(IntegerType) \ +// _(MemRefType) \ +// _(NoneType) \ +// _(OpaqueType) \ +// _(RankedTensorType) \ +// _(TupleType) \ +// _(UnrankedMemRefType) \ +// _(UnrankedTensorType) \ +// _(VectorType) +// +//#define DECLARETYPEID(TYPE) MlirTypeID mlir##TYPE##GetTypeID(); +//FORALL_CONCRETETYPES(DECLARETYPEID) +//#undef DECLARETYPEID + //===----------------------------------------------------------------------===// // Integer types. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Integer type. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(); + /// Checks whether the given type is an integer type. MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type); @@ -56,6 +86,9 @@ // Index type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Index type. +MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(); + /// Checks whether the given type is an index type. MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type); @@ -67,6 +100,9 @@ // Floating-point types. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Float8E5M2 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(); + /// Checks whether the given type is an f8E5M2 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); @@ -74,6 +110,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3FN type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(); + /// Checks whether the given type is an f8E4M3FN type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type); @@ -81,6 +120,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E5M2FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(); + /// Checks whether the given type is an f8E5M2FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type); @@ -88,6 +130,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(); + /// Checks whether the given type is an f8E4M3FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type); @@ -95,6 +140,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3B11FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(); + /// Checks whether the given type is an f8E4M3B11FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type); @@ -102,6 +150,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an BFloat16 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); @@ -109,6 +160,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx); +/// Returns the typeID of an Float16 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(); + /// Checks whether the given type is an f16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type); @@ -116,6 +170,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx); +/// Returns the typeID of an Float32 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(); + /// Checks whether the given type is an f32 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type); @@ -123,6 +180,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx); +/// Returns the typeID of an Float64 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(); + /// Checks whether the given type is an f64 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type); @@ -134,6 +194,9 @@ // None type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an None type. +MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(); + /// Checks whether the given type is a None type. MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type); @@ -145,6 +208,9 @@ // Complex type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Complex type. +MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(); + /// Checks whether the given type is a Complex type. MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type); @@ -159,6 +225,9 @@ // Shaped type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Shaped type. +MLIR_CAPI_EXPORTED MlirTypeID mlirShapedTypeGetTypeID(); + /// Checks whether the given type is a Shaped type. MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type); @@ -202,6 +271,9 @@ // Vector type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Vector type. +MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(); + /// Checks whether the given type is a Vector type. MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type); @@ -226,9 +298,15 @@ /// Checks whether the given type is a Tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsATensor(MlirType type); +/// Returns the typeID of an RankedTensor type. +MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(); + /// Checks whether the given type is a ranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type); +/// Returns the typeID of an UnrankedTensor type. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(); + /// Checks whether the given type is an unranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type); @@ -264,9 +342,15 @@ // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an MemRef type. +MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(); + /// Checks whether the given type is a MemRef type. MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type); +/// Returns the typeID of an UnrankedMemRef type. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(); + /// Checks whether the given type is an UnrankedMemRef type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type); @@ -326,6 +410,9 @@ // Tuple type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Tuple type. +MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(); + /// Checks whether the given type is a tuple type. MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type); @@ -345,6 +432,9 @@ // Function type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Function type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(); + /// Checks whether the given type is a function type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type); @@ -373,6 +463,9 @@ // Opaque type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Opaque type. +MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(); + /// Checks whether the given type is an opaque type. MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type); diff --git a/mlir/include/mlir/Support/TypeID.h b/mlir/include/mlir/Support/TypeID.h --- a/mlir/include/mlir/Support/TypeID.h +++ b/mlir/include/mlir/Support/TypeID.h @@ -147,9 +147,30 @@ /// Enable hashing TypeID. inline ::llvm::hash_code hash_value(TypeID id) { - return DenseMapInfo::getHashValue(id.storage); + return DenseMapInfo::getHashValue((uintptr_t)id.storage); + // uint64_t key = (uint64_t)id.storage; + // key ^= (key >> 33); + // key *= 0xff51afd7ed558ccd; + // key ^= (key >> 33); + // key *= 0xc4ceb9fe1a85ec53; + // key ^= (key >> 33); + // return (unsigned)key; } +/// #define LARGEST_PRIME8 251u +// #define LARGEST_PRIME15 32749u +// #define LARGEST_PRIME16 65521u +// #define LARGEST_PRIME31 2147483647u +// #define LARGEST_PRIME32 4294967291u +// #define LARGEST_PRIME63 9223372036854775783u +// #define LARGEST_PRIME64 18446744073709551557u +// +// static unsigned getHashValue(const T *PtrVal) { +// return (unsigned) ((uintptr_t)(void *)PtrVal) % LARGEST_PRIME32; +//// return (unsigned((uintptr_t)PtrVal) >> 4) ^ +//// (unsigned((uintptr_t)PtrVal) >> 9); +// } + //===----------------------------------------------------------------------===// // TypeIDResolver //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1807,6 +1807,24 @@ rawType); } +//------------------------------------------------------------------------------ +// PyTypeID. +//------------------------------------------------------------------------------ + +py::object PyTypeID::getCapsule() { + return py::reinterpret_steal(mlirPythonTypeIDToCapsule(*this)); +} + +PyTypeID PyTypeID::createFromCapsule(py::object capsule) { + MlirTypeID rawRegistry = mlirPythonCapsuleToTypeID(capsule.ptr()); + if (mlirTypeIDIsNull(rawRegistry)) + throw py::error_already_set(); + return PyTypeID(rawRegistry); +} +bool PyTypeID::operator==(const PyTypeID &other) const { + return mlirTypeIDEqual(typeID, other.typeID); +} + //------------------------------------------------------------------------------ // PyValue and subclases. //------------------------------------------------------------------------------ @@ -3280,6 +3298,26 @@ return printAccum.join(); }); + //---------------------------------------------------------------------------- + // Mapping of PyTypeID. + //---------------------------------------------------------------------------- + py::class_(m, "TypeID", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) + // Note, this tests whether the underlying TypeIDs are the same, + // not whether the wrapper MlirTypeIDs are the same, nor whether + // the Python object are the same. + .def("__eq__", + [](PyTypeID &self, PyTypeID &other) { return self == other; }) + .def("__eq__", + [](PyTypeID &self, const py::object &other) { return false; }) + // Note, this gives the hash value of the underlying TypeID, not the + // hash value of the Python object, nor the hash value of the + // MlirTypeID wrapper. + .def("__hash__", [](PyTypeID &self) { + return static_cast(mlirTypeIDHashValue(self)); + }); + //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -826,6 +826,28 @@ MlirType type; }; +/// A TypeID provides an efficient and unique identifier for a specific C++ +/// type. This allows for a C++ type to be compared, hashed, and stored in an +/// opaque context. This class wraps around the generic MlirTypeID. +class PyTypeID { +public: + PyTypeID(MlirTypeID typeID) : typeID(typeID) {} + // Note, this tests whether the underlying TypeIDs are the same, + // not whether the wrapper MlirTypeIDs are the same. + bool operator==(const PyTypeID &other) const; + operator MlirTypeID() const { return typeID; } + MlirTypeID get() { return typeID; } + + /// Gets a capsule wrapping the void* within the MlirTypeID. + pybind11::object getCapsule(); + + /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. + static PyTypeID createFromCapsule(pybind11::object capsule); + +private: + MlirTypeID typeID; +}; + /// CRTP base classes for Python types that subclass Type and should be /// castable from it (i.e. via something like IntegerType(t)). /// By default, type class hierarchies are one level deep (i.e. a @@ -839,6 +861,8 @@ // const char *pyClassName using ClassTy = pybind11::class_; using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; PyConcreteType() = default; PyConcreteType(PyMlirContextRef contextRef, MlirType t) @@ -866,6 +890,20 @@ return DerivedTy::isaFunction(otherType); }, pybind11::arg("other")); + cls.def_static("get_typeid", []() -> std::optional { + if (DerivedTy::getTypeIdFunction) + return DerivedTy::getTypeIdFunction(); + return std::nullopt; + }); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -32,6 +32,8 @@ class PyIntegerType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIntegerTypeGetTypeID; static constexpr const char *pyClassName = "IntegerType"; using PyConcreteType::PyConcreteType; @@ -89,6 +91,8 @@ class PyIndexType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIndexTypeGetTypeID; static constexpr const char *pyClassName = "IndexType"; using PyConcreteType::PyConcreteType; @@ -107,6 +111,8 @@ class PyFloat8E4M3FNType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNTypeGetTypeID; static constexpr const char *pyClassName = "Float8E4M3FNType"; using PyConcreteType::PyConcreteType; @@ -125,6 +131,8 @@ class PyFloat8E5M2Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2TypeGetTypeID; static constexpr const char *pyClassName = "Float8E5M2Type"; using PyConcreteType::PyConcreteType; @@ -143,6 +151,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNUZTypeGetTypeID; static constexpr const char *pyClassName = "Float8E4M3FNUZType"; using PyConcreteType::PyConcreteType; @@ -161,6 +171,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3B11FNUZTypeGetTypeID; static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; using PyConcreteType::PyConcreteType; @@ -179,6 +191,8 @@ class PyFloat8E5M2FNUZType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2FNUZTypeGetTypeID; static constexpr const char *pyClassName = "Float8E5M2FNUZType"; using PyConcreteType::PyConcreteType; @@ -197,6 +211,8 @@ class PyBF16Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirBFloat16TypeGetTypeID; static constexpr const char *pyClassName = "BF16Type"; using PyConcreteType::PyConcreteType; @@ -215,6 +231,8 @@ class PyF16Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat16TypeGetTypeID; static constexpr const char *pyClassName = "F16Type"; using PyConcreteType::PyConcreteType; @@ -233,6 +251,8 @@ class PyF32Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat32TypeGetTypeID; static constexpr const char *pyClassName = "F32Type"; using PyConcreteType::PyConcreteType; @@ -251,6 +271,8 @@ class PyF64Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat64TypeGetTypeID; static constexpr const char *pyClassName = "F64Type"; using PyConcreteType::PyConcreteType; @@ -269,6 +291,8 @@ class PyNoneType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirNoneTypeGetTypeID; static constexpr const char *pyClassName = "NoneType"; using PyConcreteType::PyConcreteType; @@ -287,6 +311,8 @@ class PyComplexType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirComplexTypeGetTypeID; static constexpr const char *pyClassName = "ComplexType"; using PyConcreteType::PyConcreteType; @@ -417,6 +443,8 @@ class PyVectorType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirVectorTypeGetTypeID; static constexpr const char *pyClassName = "VectorType"; using PyConcreteType::PyConcreteType; @@ -442,6 +470,8 @@ : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirRankedTensorTypeGetTypeID; static constexpr const char *pyClassName = "RankedTensorType"; using PyConcreteType::PyConcreteType; @@ -476,6 +506,8 @@ : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedTensorTypeGetTypeID; static constexpr const char *pyClassName = "UnrankedTensorType"; using PyConcreteType::PyConcreteType; @@ -498,6 +530,8 @@ class PyMemRefType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirMemRefTypeGetTypeID; static constexpr const char *pyClassName = "MemRefType"; using PyConcreteType::PyConcreteType; @@ -550,6 +584,8 @@ : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedMemRefTypeGetTypeID; static constexpr const char *pyClassName = "UnrankedMemRefType"; using PyConcreteType::PyConcreteType; @@ -585,6 +621,8 @@ class PyTupleType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTupleTypeGetTypeID; static constexpr const char *pyClassName = "TupleType"; using PyConcreteType::PyConcreteType; @@ -622,6 +660,8 @@ class PyFunctionType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFunctionTypeGetTypeID; static constexpr const char *pyClassName = "FunctionType"; using PyConcreteType::PyConcreteType; @@ -676,6 +716,8 @@ class PyOpaqueType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueTypeGetTypeID; static constexpr const char *pyClassName = "OpaqueType"; using PyConcreteType::PyConcreteType; diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -18,15 +18,23 @@ using namespace mlir; +//#define DEFINETYPEID(TYPE) \ +// MlirTypeID mlir##TYPE##GetTypeID() { return wrap(TYPE::getTypeID()); } +//FORALL_CONCRETETYPES(DEFINETYPEID) +//#undef DEFINETYPEID + //===----------------------------------------------------------------------===// // Integer types. //===----------------------------------------------------------------------===// +MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); } + bool mlirTypeIsAInteger(MlirType type) { return llvm::isa(unwrap(type)); } MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) { + return wrap(IntegerType::get(unwrap(ctx), bitwidth)); } @@ -58,6 +66,8 @@ // Index type. //===----------------------------------------------------------------------===// +MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); } + bool mlirTypeIsAIndex(MlirType type) { return llvm::isa(unwrap(type)); } @@ -70,6 +80,10 @@ // Floating-point types. //===----------------------------------------------------------------------===// +MlirTypeID mlirFloat8E5M2TypeGetTypeID() { + return wrap(Float8E5M2Type::getTypeID()); +} + bool mlirTypeIsAFloat8E5M2(MlirType type) { return unwrap(type).isFloat8E5M2(); } @@ -78,6 +92,10 @@ return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() { + return wrap(Float8E4M3FNType::getTypeID()); +} + bool mlirTypeIsAFloat8E4M3FN(MlirType type) { return unwrap(type).isFloat8E4M3FN(); } @@ -86,6 +104,10 @@ return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); } +MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() { + return wrap(Float8E5M2FNUZType::getTypeID()); +} + bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { return unwrap(type).isFloat8E5M2FNUZ(); } @@ -94,6 +116,10 @@ return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() { + return wrap(Float8E4M3FNUZType::getTypeID()); +} + bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { return unwrap(type).isFloat8E4M3FNUZ(); } @@ -102,6 +128,10 @@ return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() { + return wrap(Float8E4M3B11FNUZType::getTypeID()); +} + bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { return unwrap(type).isFloat8E4M3B11FNUZ(); } @@ -110,24 +140,34 @@ return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx))); } +MlirTypeID mlirBFloat16TypeGetTypeID() { + return wrap(BFloat16Type::getTypeID()); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { return wrap(FloatType::getBF16(unwrap(ctx))); } +MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); } + bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); } MlirType mlirF16TypeGet(MlirContext ctx) { return wrap(FloatType::getF16(unwrap(ctx))); } +MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } + bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); } MlirType mlirF32TypeGet(MlirContext ctx) { return wrap(FloatType::getF32(unwrap(ctx))); } +MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); } + bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); } MlirType mlirF64TypeGet(MlirContext ctx) { @@ -138,6 +178,8 @@ // None type. //===----------------------------------------------------------------------===// +MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); } + bool mlirTypeIsANone(MlirType type) { return llvm::isa(unwrap(type)); } @@ -150,6 +192,8 @@ // Complex type. //===----------------------------------------------------------------------===// +MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); } + bool mlirTypeIsAComplex(MlirType type) { return llvm::isa(unwrap(type)); } @@ -214,6 +258,8 @@ // Vector type. //===----------------------------------------------------------------------===// +MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); } + bool mlirTypeIsAVector(MlirType type) { return llvm::isa(unwrap(type)); } @@ -239,10 +285,18 @@ return llvm::isa(unwrap(type)); } +MlirTypeID mlirRankedTensorTypeGetTypeID() { + return wrap(RankedTensorType::getTypeID()); +} + bool mlirTypeIsARankedTensor(MlirType type) { return llvm::isa(unwrap(type)); } +MlirTypeID mlirUnrankedTensorTypeGetTypeID() { + return wrap(UnrankedTensorType::getTypeID()); +} + bool mlirTypeIsAUnrankedTensor(MlirType type) { return llvm::isa(unwrap(type)); } @@ -280,6 +334,8 @@ // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// +MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); } + bool mlirTypeIsAMemRef(MlirType type) { return llvm::isa(unwrap(type)); } @@ -337,6 +393,10 @@ return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } +MlirTypeID mlirUnrankedMemRefTypeGetTypeID() { + return wrap(UnrankedMemRefType::getTypeID()); +} + bool mlirTypeIsAUnrankedMemRef(MlirType type) { return llvm::isa(unwrap(type)); } @@ -362,6 +422,8 @@ // Tuple type. //===----------------------------------------------------------------------===// +MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); } + bool mlirTypeIsATuple(MlirType type) { return llvm::isa(unwrap(type)); } @@ -386,6 +448,10 @@ // Function type. //===----------------------------------------------------------------------===// +MlirTypeID mlirFunctionTypeGetTypeID() { + return wrap(FunctionType::getTypeID()); +} + bool mlirTypeIsAFunction(MlirType type) { return llvm::isa(unwrap(type)); } @@ -424,6 +490,8 @@ // Opaque type. //===----------------------------------------------------------------------===// +MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); } + bool mlirTypeIsAOpaque(MlirType type) { return llvm::isa(unwrap(type)); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -2085,6 +2085,70 @@ mlirOperationDestroy(constZero); mlirOperationDestroy(unregisteredOp); +#define FORALL_CONCRETETYPES(_) \ + _(BFloat16Type) \ + _(ComplexType) \ + _(Float8E4M3B11FNUZType) \ + _(Float8E4M3FNType) \ + _(Float8E4M3FNUZType) \ + _(Float8E5M2Type) \ + _(Float8E5M2FNUZType) \ + _(Float16Type) \ + _(Float32Type) \ + _(Float64Type) \ + _(FunctionType) \ + _(IndexType) \ + _(IntegerType) \ + _(MemRefType) \ + _(NoneType) \ + _(OpaqueType) \ + _(RankedTensorType) \ + _(TupleType) \ + _(UnrankedMemRefType) \ + _(UnrankedTensorType) \ + _(VectorType) + +#define DECLARETYPEID(TYPE) MlirTypeID TYPE##ID = mlir##TYPE##GetTypeID(); + FORALL_CONCRETETYPES(DECLARETYPEID) +#undef DECLARETYPEID + + typedef struct { + MlirTypeID typeID; + char typeName[100]; + } typeIDTuple; + + typeIDTuple typeIDTuples[] = { +#define DECLARETYPEID(TYPE) {TYPE##ID, #TYPE}, + FORALL_CONCRETETYPES(DECLARETYPEID) +#undef DECLARETYPEID + }; + + bool equalError = false; + bool hashError = false; + for (int i = 0; i < sizeof(typeIDTuples) / sizeof(typeIDTuples[0]); ++i) { + for (int j = i + 1; j < sizeof(typeIDTuples) / sizeof(typeIDTuples[0]); + ++j) { + typeIDTuple type1 = typeIDTuples[i]; + typeIDTuple type2 = typeIDTuples[j]; + if (mlirTypeIDEqual(type1.typeID, type2.typeID)) { + fprintf(stderr, "ERROR: Expected type id %s to be unequal type id %s\n", + type1.typeName, type2.typeName); + equalError = true; + } + if (mlirTypeIDHashValue(type1.typeID) == + mlirTypeIDHashValue(type2.typeID)) { + fprintf(stderr, + "ERROR: Expected hash of type id %s to be unequal hash type id %s\n", + type1.typeName, type2.typeName); + hashError = true; + } + } + } + if (equalError) + return 10; + if (hashError) + return 11; + return 0; } diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -509,3 +509,98 @@ print(type(ShapedType.get_dynamic_size())) # CHECK: print(type(ShapedType.get_dynamic_stride_or_offset())) + + +# CHECK-LABEL: TEST: testTypeIDs +@run +def testTypeIDs(): + with Context(), Location.unknown(): + f32 = F32Type.get() + + types1 = [ + IntegerType, + IndexType, + Float8E4M3FNType, + Float8E5M2Type, + Float8E4M3FNUZType, + Float8E4M3B11FNUZType, + Float8E5M2FNUZType, + BF16Type, + F16Type, + F32Type, + F64Type, + NoneType, + ComplexType, + ShapedType, + VectorType, + RankedTensorType, + UnrankedTensorType, + MemRefType, + UnrankedMemRefType, + TupleType, + FunctionType, + OpaqueType, + ] + + # CHECK: IntegerType + # CHECK: IndexType + # CHECK: Float8E4M3FNType + # CHECK: Float8E5M2Type + # CHECK: Float8E4M3FNUZType + # CHECK: Float8E4M3B11FNUZType + # CHECK: Float8E5M2FNUZType + # CHECK: BF16Type + # CHECK: F16Type + # CHECK: F32Type + # CHECK: F64Type + # CHECK: NoneType + # CHECK: ComplexType + # CHECK: ShapedType + # CHECK: VectorType + # CHECK: RankedTensorType + # CHECK: UnrankedTensorType + # CHECK: MemRefType + # CHECK: UnrankedMemRefType + # CHECK: TupleType + # CHECK: FunctionType + # CHECK: OpaqueType + for t in types1: + print(t.__name__) + + types2 = [ + IntegerType.get_signless(16), + IndexType.get(), + Float8E4M3FNType.get(), + Float8E5M2Type.get(), + Float8E4M3FNUZType.get(), + Float8E4M3B11FNUZType.get(), + Float8E5M2FNUZType.get(), + BF16Type.get(), + F16Type.get(), + F32Type.get(), + F64Type.get(), + NoneType.get(), + ComplexType.get(f32), + ShapedType(Type.parse("vector<2x3xf32>")), + VectorType.get([2, 3], f32), + RankedTensorType.get([2, 3], f32), + UnrankedTensorType.get(f32), + MemRefType.get([2, 3], f32), + UnrankedMemRefType.get(f32, Attribute.parse("2")), + TupleType.get_tuple([f32, ]), + FunctionType.get([], []), + OpaqueType.get("tensor", "bob") + ] + + # CHECK: all equal + for t1, t2 in zip(types1, types2): + assert t1.get_typeid() == t2.get_typeid() and hash(t1.get_typeid()) == hash(t2.get_typeid()), f"expected hash and value equality {t1} {t2}" + else: + print("all equal") + + for i, t1 in enumerate(types1): + for j, t2 in enumerate(types2[i + 1:]): + assert (t1.get_typeid() != t2.get_typeid() and hash(t1.get_typeid()) != hash(t2.get_typeid()), + f"expected hash and value inequality {t1} {t2}") + else: + print("all unequal") diff --git a/mlir/unittests/IR/TypeTest.cpp b/mlir/unittests/IR/TypeTest.cpp --- a/mlir/unittests/IR/TypeTest.cpp +++ b/mlir/unittests/IR/TypeTest.cpp @@ -64,3 +64,70 @@ EXPECT_EQ(8u, cast(intTy).getWidth()); } + +TEST(Type, TypeIDComparison) { + MLIRContext ctx; + ctx.loadDialect(); + +#define FORALL_CONCRETETYPES(_) \ + _(BFloat16Type) \ + _(ComplexType) \ + _(Float8E4M3B11FNUZType) \ + _(Float8E4M3FNType) \ + _(Float8E4M3FNUZType) \ + _(Float8E5M2Type) \ + _(Float8E5M2FNUZType) \ + _(Float16Type) \ + _(Float32Type) \ + _(Float64Type) \ + _(FunctionType) \ + _(IndexType) \ + _(IntegerType) \ + _(MemRefType) \ + _(NoneType) \ + _(OpaqueType) \ + _(RankedTensorType) \ + _(TupleType) \ + _(UnrankedMemRefType) \ + _(UnrankedTensorType) \ + _(VectorType) + +#define DECLARETYPEID(TYPE) TypeID TYPE##ID = TypeID::get(); + FORALL_CONCRETETYPES(DECLARETYPEID) +#undef DECLARETYPEID + + typedef struct { + TypeID typeID; + char typeName[100]; + } typeIDTuple; + + typeIDTuple typeIDTuples[] = { +#define DECLARETYPEID(TYPE) {TYPE##ID, #TYPE}, + FORALL_CONCRETETYPES(DECLARETYPEID) +#undef DECLARETYPEID + }; + + bool equalError = false; + bool hashError = false; + for (int i = 0; i < sizeof(typeIDTuples) / sizeof(typeIDTuples[0]); ++i) { + for (int j = i + 1; j < sizeof(typeIDTuples) / sizeof(typeIDTuples[0]); + ++j) { + typeIDTuple type1 = typeIDTuples[i]; + typeIDTuple type2 = typeIDTuples[j]; + if (type1.typeID == type2.typeID) { + fprintf(stderr, "ERROR: Expected type id %s to be unequal type id %s\n", + type1.typeName, type2.typeName); + equalError = true; + } + if (hash_value(type1.typeID) == hash_value(type2.typeID)) { + fprintf(stderr, + "ERROR: Expected hash of type id %s to be unequal hash type id " + "%s\n", + type1.typeName, type2.typeName); + hashError = true; + } + } + } + EXPECT_FALSE(equalError); + EXPECT_FALSE(hashError); +}