diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -80,6 +80,8 @@ #define MLIR_PYTHON_CAPSULE_PASS_MANAGER \ MAKE_MLIR_PYTHON_QUALNAME("passmanager.PassManager._CAPIPtr") #define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_TYPEID \ + MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr") /** Attribute on MLIR Python objects that expose their C-API pointer. * This will be a type-specific capsule created as per one of the helpers @@ -268,6 +270,25 @@ return op; } +/** Creates a capsule object encapsulating the raw C-API MlirTypeID. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the type in any way. + */ +static inline PyObject *mlirPythonTypeIDToCapsule(MlirTypeID typeID) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(typeID), + MLIR_PYTHON_CAPSULE_TYPEID, NULL); +} + +/** Extracts an MlirTypeID from a capsule as produced from + * mlirPythonTypeIDToCapsule. If the capsule is not of the right type, then + * a null type is returned (as checked via mlirTypeIDIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_TYPEID); + MlirTypeID typeID = {ptr}; + return typeID; +} + /** Creates a capsule object encapsulating the raw C-API MlirType. * The returned capsule does not extend or affect ownership of any Python * objects that reference the type in any way. diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -22,6 +22,9 @@ // Integer types. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Integer type. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void); + /// Checks whether the given type is an integer type. MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type); @@ -56,6 +59,9 @@ // Index type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Index type. +MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void); + /// Checks whether the given type is an index type. MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type); @@ -67,6 +73,9 @@ // Floating-point types. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Float8E5M2 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void); + /// Checks whether the given type is an f8E5M2 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); @@ -74,6 +83,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3FN type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void); + /// Checks whether the given type is an f8E4M3FN type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type); @@ -81,6 +93,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E5M2FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void); + /// Checks whether the given type is an f8E5M2FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type); @@ -88,6 +103,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void); + /// Checks whether the given type is an f8E4M3FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type); @@ -95,6 +113,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3B11FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void); + /// Checks whether the given type is an f8E4M3B11FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type); @@ -102,6 +123,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an BFloat16 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); @@ -109,6 +133,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx); +/// Returns the typeID of an Float16 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void); + /// Checks whether the given type is an f16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type); @@ -116,6 +143,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx); +/// Returns the typeID of an Float32 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void); + /// Checks whether the given type is an f32 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type); @@ -123,6 +153,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx); +/// Returns the typeID of an Float64 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void); + /// Checks whether the given type is an f64 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type); @@ -134,6 +167,9 @@ // None type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an None type. +MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void); + /// Checks whether the given type is a None type. MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type); @@ -145,6 +181,9 @@ // Complex type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Complex type. +MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void); + /// Checks whether the given type is a Complex type. MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type); @@ -159,6 +198,9 @@ // Shaped type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Shaped type. +MLIR_CAPI_EXPORTED MlirTypeID mlirShapedTypeGetTypeID(void); + /// Checks whether the given type is a Shaped type. MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type); @@ -202,6 +244,9 @@ // Vector type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Vector type. +MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void); + /// Checks whether the given type is a Vector type. MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type); @@ -226,9 +271,15 @@ /// Checks whether the given type is a Tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsATensor(MlirType type); +/// Returns the typeID of an RankedTensor type. +MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void); + /// Checks whether the given type is a ranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type); +/// Returns the typeID of an UnrankedTensor type. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void); + /// Checks whether the given type is an unranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type); @@ -264,9 +315,15 @@ // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an MemRef type. +MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void); + /// Checks whether the given type is a MemRef type. MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type); +/// Returns the typeID of an UnrankedMemRef type. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void); + /// Checks whether the given type is an UnrankedMemRef type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type); @@ -326,6 +383,9 @@ // Tuple type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Tuple type. +MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void); + /// Checks whether the given type is a tuple type. MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type); @@ -345,6 +405,9 @@ // Function type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Function type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void); + /// Checks whether the given type is a function type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type); @@ -373,6 +436,9 @@ // Opaque type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Opaque type. +MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void); + /// Checks whether the given type is an opaque type. MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1807,6 +1807,24 @@ rawType); } +//------------------------------------------------------------------------------ +// PyTypeID. +//------------------------------------------------------------------------------ + +py::object PyTypeID::getCapsule() { + return py::reinterpret_steal(mlirPythonTypeIDToCapsule(*this)); +} + +PyTypeID PyTypeID::createFromCapsule(py::object capsule) { + MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr()); + if (mlirTypeIDIsNull(mlirTypeID)) + throw py::error_already_set(); + return PyTypeID(mlirTypeID); +} +bool PyTypeID::operator==(const PyTypeID &other) const { + return mlirTypeIDEqual(typeID, other.typeID); +} + //------------------------------------------------------------------------------ // PyValue and subclases. //------------------------------------------------------------------------------ @@ -3268,16 +3286,44 @@ return printAccum.join(); }, "Returns the assembly form of the type.") - .def("__repr__", [](PyType &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, types are an exception as they typically have compact - // assembly forms and printing them is useful. - PyPrintAccumulator printAccum; - printAccum.parts.append("Type("); - mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); + .def("__repr__", + [](PyType &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, types are an exception as they typically have compact + // assembly forms and printing them is useful. + PyPrintAccumulator printAccum; + printAccum.parts.append("Type("); + mlirTypePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def("get_typeid", [](PyType &self) -> std::optional { + MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); + if (!mlirTypeIDIsNull(mlirTypeID)) + return PyTypeID(mlirTypeID); + return std::nullopt; + }); + + //---------------------------------------------------------------------------- + // Mapping of PyTypeID. + //---------------------------------------------------------------------------- + py::class_(m, "TypeID", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) + // Note, this tests whether the underlying TypeIDs are the same, + // not whether the wrapper MlirTypeIDs are the same, nor whether + // the Python objects are the same (i.e., PyTypeID is a value type). + .def("__eq__", + [](PyTypeID &self, PyTypeID &other) { return self == other; }) + .def("__eq__", + [](PyTypeID &self, const py::object &other) { return false; }) + // Note, this gives the hash value of the underlying TypeID, not the + // hash value of the Python object, nor the hash value of the + // MlirTypeID wrapper. + .def("__hash__", [](PyTypeID &self) { + return static_cast(mlirTypeIDHashValue(self)); }); //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -826,6 +826,29 @@ MlirType type; }; +/// A TypeID provides an efficient and unique identifier for a specific C++ +/// type. This allows for a C++ type to be compared, hashed, and stored in an +/// opaque context. This class wraps around the generic MlirTypeID. +class PyTypeID { +public: + PyTypeID(MlirTypeID typeID) : typeID(typeID) {} + // Note, this tests whether the underlying TypeIDs are the same, + // not whether the wrapper MlirTypeIDs are the same, nor whether + // the PyTypeID objects are the same (i.e., PyTypeID is a value type). + bool operator==(const PyTypeID &other) const; + operator MlirTypeID() const { return typeID; } + MlirTypeID get() { return typeID; } + + /// Gets a capsule wrapping the void* within the MlirTypeID. + pybind11::object getCapsule(); + + /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. + static PyTypeID createFromCapsule(pybind11::object capsule); + +private: + MlirTypeID typeID; +}; + /// CRTP base classes for Python types that subclass Type and should be /// castable from it (i.e. via something like IntegerType(t)). /// By default, type class hierarchies are one level deep (i.e. a @@ -839,6 +862,8 @@ // const char *pyClassName using ClassTy = pybind11::class_; using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; PyConcreteType() = default; PyConcreteType(PyMlirContextRef contextRef, MlirType t) @@ -866,6 +891,20 @@ return DerivedTy::isaFunction(otherType); }, pybind11::arg("other")); + cls.def_static("get_typeid", []() -> std::optional { + if (DerivedTy::getTypeIdFunction) + return DerivedTy::getTypeIdFunction(); + return std::nullopt; + }); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -32,6 +32,8 @@ class PyIntegerType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIntegerTypeGetTypeID; static constexpr const char *pyClassName = "IntegerType"; using PyConcreteType::PyConcreteType; @@ -89,6 +91,8 @@ class PyIndexType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIndexTypeGetTypeID; static constexpr const char *pyClassName = "IndexType"; using PyConcreteType::PyConcreteType; @@ -107,6 +111,8 @@ class PyFloat8E4M3FNType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNTypeGetTypeID; static constexpr const char *pyClassName = "Float8E4M3FNType"; using PyConcreteType::PyConcreteType; @@ -125,6 +131,8 @@ class PyFloat8E5M2Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2TypeGetTypeID; static constexpr const char *pyClassName = "Float8E5M2Type"; using PyConcreteType::PyConcreteType; @@ -143,6 +151,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNUZTypeGetTypeID; static constexpr const char *pyClassName = "Float8E4M3FNUZType"; using PyConcreteType::PyConcreteType; @@ -161,6 +171,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3B11FNUZTypeGetTypeID; static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; using PyConcreteType::PyConcreteType; @@ -179,6 +191,8 @@ class PyFloat8E5M2FNUZType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2FNUZTypeGetTypeID; static constexpr const char *pyClassName = "Float8E5M2FNUZType"; using PyConcreteType::PyConcreteType; @@ -197,6 +211,8 @@ class PyBF16Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirBFloat16TypeGetTypeID; static constexpr const char *pyClassName = "BF16Type"; using PyConcreteType::PyConcreteType; @@ -215,6 +231,8 @@ class PyF16Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat16TypeGetTypeID; static constexpr const char *pyClassName = "F16Type"; using PyConcreteType::PyConcreteType; @@ -233,6 +251,8 @@ class PyF32Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat32TypeGetTypeID; static constexpr const char *pyClassName = "F32Type"; using PyConcreteType::PyConcreteType; @@ -251,6 +271,8 @@ class PyF64Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat64TypeGetTypeID; static constexpr const char *pyClassName = "F64Type"; using PyConcreteType::PyConcreteType; @@ -269,6 +291,8 @@ class PyNoneType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirNoneTypeGetTypeID; static constexpr const char *pyClassName = "NoneType"; using PyConcreteType::PyConcreteType; @@ -287,6 +311,8 @@ class PyComplexType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirComplexTypeGetTypeID; static constexpr const char *pyClassName = "ComplexType"; using PyConcreteType::PyConcreteType; @@ -417,6 +443,8 @@ class PyVectorType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirVectorTypeGetTypeID; static constexpr const char *pyClassName = "VectorType"; using PyConcreteType::PyConcreteType; @@ -442,6 +470,8 @@ : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirRankedTensorTypeGetTypeID; static constexpr const char *pyClassName = "RankedTensorType"; using PyConcreteType::PyConcreteType; @@ -476,6 +506,8 @@ : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedTensorTypeGetTypeID; static constexpr const char *pyClassName = "UnrankedTensorType"; using PyConcreteType::PyConcreteType; @@ -498,6 +530,8 @@ class PyMemRefType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirMemRefTypeGetTypeID; static constexpr const char *pyClassName = "MemRefType"; using PyConcreteType::PyConcreteType; @@ -550,6 +584,8 @@ : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedMemRefTypeGetTypeID; static constexpr const char *pyClassName = "UnrankedMemRefType"; using PyConcreteType::PyConcreteType; @@ -585,6 +621,8 @@ class PyTupleType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTupleTypeGetTypeID; static constexpr const char *pyClassName = "TupleType"; using PyConcreteType::PyConcreteType; @@ -622,6 +660,8 @@ class PyFunctionType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFunctionTypeGetTypeID; static constexpr const char *pyClassName = "FunctionType"; using PyConcreteType::PyConcreteType; @@ -676,6 +716,8 @@ class PyOpaqueType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueTypeGetTypeID; static constexpr const char *pyClassName = "OpaqueType"; using PyConcreteType::PyConcreteType; diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -22,6 +22,8 @@ // Integer types. //===----------------------------------------------------------------------===// +MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); } + bool mlirTypeIsAInteger(MlirType type) { return llvm::isa(unwrap(type)); } @@ -58,6 +60,8 @@ // Index type. //===----------------------------------------------------------------------===// +MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); } + bool mlirTypeIsAIndex(MlirType type) { return llvm::isa(unwrap(type)); } @@ -70,6 +74,10 @@ // Floating-point types. //===----------------------------------------------------------------------===// +MlirTypeID mlirFloat8E5M2TypeGetTypeID() { + return wrap(Float8E5M2Type::getTypeID()); +} + bool mlirTypeIsAFloat8E5M2(MlirType type) { return unwrap(type).isFloat8E5M2(); } @@ -78,6 +86,10 @@ return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() { + return wrap(Float8E4M3FNType::getTypeID()); +} + bool mlirTypeIsAFloat8E4M3FN(MlirType type) { return unwrap(type).isFloat8E4M3FN(); } @@ -86,6 +98,10 @@ return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); } +MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() { + return wrap(Float8E5M2FNUZType::getTypeID()); +} + bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { return unwrap(type).isFloat8E5M2FNUZ(); } @@ -94,6 +110,10 @@ return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() { + return wrap(Float8E4M3FNUZType::getTypeID()); +} + bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { return unwrap(type).isFloat8E4M3FNUZ(); } @@ -102,6 +122,10 @@ return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() { + return wrap(Float8E4M3B11FNUZType::getTypeID()); +} + bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { return unwrap(type).isFloat8E4M3B11FNUZ(); } @@ -110,24 +134,34 @@ return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx))); } +MlirTypeID mlirBFloat16TypeGetTypeID() { + return wrap(BFloat16Type::getTypeID()); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { return wrap(FloatType::getBF16(unwrap(ctx))); } +MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); } + bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); } MlirType mlirF16TypeGet(MlirContext ctx) { return wrap(FloatType::getF16(unwrap(ctx))); } +MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } + bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); } MlirType mlirF32TypeGet(MlirContext ctx) { return wrap(FloatType::getF32(unwrap(ctx))); } +MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); } + bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); } MlirType mlirF64TypeGet(MlirContext ctx) { @@ -138,6 +172,8 @@ // None type. //===----------------------------------------------------------------------===// +MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); } + bool mlirTypeIsANone(MlirType type) { return llvm::isa(unwrap(type)); } @@ -150,6 +186,8 @@ // Complex type. //===----------------------------------------------------------------------===// +MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); } + bool mlirTypeIsAComplex(MlirType type) { return llvm::isa(unwrap(type)); } @@ -214,6 +252,8 @@ // Vector type. //===----------------------------------------------------------------------===// +MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); } + bool mlirTypeIsAVector(MlirType type) { return llvm::isa(unwrap(type)); } @@ -239,10 +279,18 @@ return llvm::isa(unwrap(type)); } +MlirTypeID mlirRankedTensorTypeGetTypeID() { + return wrap(RankedTensorType::getTypeID()); +} + bool mlirTypeIsARankedTensor(MlirType type) { return llvm::isa(unwrap(type)); } +MlirTypeID mlirUnrankedTensorTypeGetTypeID() { + return wrap(UnrankedTensorType::getTypeID()); +} + bool mlirTypeIsAUnrankedTensor(MlirType type) { return llvm::isa(unwrap(type)); } @@ -280,6 +328,8 @@ // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// +MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); } + bool mlirTypeIsAMemRef(MlirType type) { return llvm::isa(unwrap(type)); } @@ -337,6 +387,10 @@ return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } +MlirTypeID mlirUnrankedMemRefTypeGetTypeID() { + return wrap(UnrankedMemRefType::getTypeID()); +} + bool mlirTypeIsAUnrankedMemRef(MlirType type) { return llvm::isa(unwrap(type)); } @@ -362,6 +416,8 @@ // Tuple type. //===----------------------------------------------------------------------===// +MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); } + bool mlirTypeIsATuple(MlirType type) { return llvm::isa(unwrap(type)); } @@ -386,6 +442,10 @@ // Function type. //===----------------------------------------------------------------------===// +MlirTypeID mlirFunctionTypeGetTypeID() { + return wrap(FunctionType::getTypeID()); +} + bool mlirTypeIsAFunction(MlirType type) { return llvm::isa(unwrap(type)); } @@ -424,6 +484,8 @@ // Opaque type. //===----------------------------------------------------------------------===// +MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); } + bool mlirTypeIsAOpaque(MlirType type) { return llvm::isa(unwrap(type)); } diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -3,6 +3,7 @@ import gc from mlir.ir import * + def run(f): print("\nTEST:", f.__name__) f() @@ -76,6 +77,7 @@ # CHECK: len(s): 2 print("len(s): ", len(s)) + # CHECK-LABEL: TEST: testTypeCast @run def testTypeCast(): @@ -182,6 +184,7 @@ # CHECK: unsigned: ui64 print("unsigned:", IntegerType.get_unsigned(64)) + # CHECK-LABEL: TEST: testIndexType @run def testIndexType(): @@ -259,7 +262,8 @@ # CHECK: rank: 2 print("rank:", vector.rank) # CHECK: whether the shaped type has a static shape: True - print("whether the shaped type has a static shape:", vector.has_static_shape) + print("whether the shaped type has a static shape:", + vector.has_static_shape) # CHECK: whether the dim-th dimension is dynamic: False print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) # CHECK: dim size: 3 @@ -311,8 +315,7 @@ shape = [2, 3] loc = Location.unknown() # CHECK: ranked tensor type: tensor<2x3xf32> - print("ranked tensor type:", - RankedTensorType.get(shape, f32)) + print("ranked tensor type:", RankedTensorType.get(shape, f32)) none = NoneType.get() try: @@ -477,8 +480,7 @@ @run def testFunctionType(): with Context() as ctx: - input_types = [IntegerType.get_signless(32), - IntegerType.get_signless(16)] + input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)] result_types = [IndexType.get()] func = FunctionType.get(input_types, result_types) # CHECK: INPUTS: [Type(i32), Type(i16)] @@ -509,3 +511,116 @@ print(type(ShapedType.get_dynamic_size())) # CHECK: print(type(ShapedType.get_dynamic_stride_or_offset())) + + +# CHECK-LABEL: TEST: testTypeIDs +@run +def testTypeIDs(): + with Context(), Location.unknown(): + f32 = F32Type.get() + + types_classes = [ + IntegerType, + IndexType, + Float8E4M3FNType, + Float8E5M2Type, + Float8E4M3FNUZType, + Float8E4M3B11FNUZType, + Float8E5M2FNUZType, + BF16Type, + F16Type, + F32Type, + F64Type, + NoneType, + ComplexType, + VectorType, + RankedTensorType, + UnrankedTensorType, + MemRefType, + UnrankedMemRefType, + TupleType, + FunctionType, + OpaqueType, + ] + + type_instances = [ + IntegerType.get_signless(16), + IndexType.get(), + Float8E4M3FNType.get(), + Float8E5M2Type.get(), + Float8E4M3FNUZType.get(), + Float8E4M3B11FNUZType.get(), + Float8E5M2FNUZType.get(), + BF16Type.get(), + F16Type.get(), + F32Type.get(), + F64Type.get(), + NoneType.get(), + ComplexType.get(f32), + VectorType.get([2, 3], f32), + RankedTensorType.get([2, 3], f32), + UnrankedTensorType.get(f32), + MemRefType.get([2, 3], f32), + UnrankedMemRefType.get(f32, Attribute.parse("2")), + TupleType.get_tuple([ + f32, + ]), + FunctionType.get([], []), + OpaqueType.get("tensor", "bob") + ] + # A reminder for anyone adding more cases here. + assert len(types_classes) == len( + type_instances), "mismatch len type_classes and type_instances" + + # CHECK: IntegerType(i16) + # CHECK: IndexType(index) + # CHECK: Float8E4M3FNType(f8E4M3FN) + # CHECK: Float8E5M2Type(f8E5M2) + # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ) + # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ) + # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ) + # CHECK: BF16Type(bf16) + # CHECK: F16Type(f16) + # CHECK: F32Type(f32) + # CHECK: F64Type(f64) + # CHECK: NoneType(none) + # CHECK: ComplexType(complex) + # CHECK: VectorType(vector<2x3xf32>) + # CHECK: RankedTensorType(tensor<2x3xf32>) + # CHECK: UnrankedTensorType(tensor<*xf32>) + # CHECK: MemRefType(memref<2x3xf32>) + # CHECK: UnrankedMemRefType(memref<*xf32, 2>) + # CHECK: TupleType(tuple) + # CHECK: FunctionType(() -> ()) + # CHECK: OpaqueType(!tensor.bob) + for t in type_instances: + print(repr(t)) + + # Test getTypeIdFunction agrees with + # mlirTypeGetTypeID(self) for an instance. + # CHECK: all equal + for t1, t2 in zip(types_classes, type_instances): + tid1, tid2 = t1.get_typeid(), Type(t2).get_typeid() + assert tid1 == tid2 and hash(tid1) == hash( + tid2), f"expected hash and value equality {t1} {t2}" + else: + print("all equal") + + # Test that storing PyTypeID in python dicts + # works as expected. + typeid_dict = dict(zip(types_classes, type_instances)) + assert len(typeid_dict) + + # CHECK: all equal + for t1, t2 in typeid_dict.items(): + assert t1.get_typeid() == t2.get_typeid() and hash( + t1.get_typeid()) == hash( + t2.get_typeid()), f"expected hash and value equality {t1} {t2}" + else: + print("all equal") + + # CHECK: None + print(ShapedType.get_typeid()) + + # CHECK: None + print(ShapedType(Type.parse("vector<2x3xf32>")).get_typeid())