diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -80,6 +80,8 @@ #define MLIR_PYTHON_CAPSULE_PASS_MANAGER \ MAKE_MLIR_PYTHON_QUALNAME("passmanager.PassManager._CAPIPtr") #define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_TYPEID \ + MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr") /** Attribute on MLIR Python objects that expose their C-API pointer. * This will be a type-specific capsule created as per one of the helpers @@ -268,6 +270,25 @@ return op; } +/** Creates a capsule object encapsulating the raw C-API MlirTypeID. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the type in any way. + */ +static inline PyObject *mlirPythonTypeIDToCapsule(MlirTypeID typeID) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(typeID), + MLIR_PYTHON_CAPSULE_TYPEID, NULL); +} + +/** Extracts an MlirTypeID from a capsule as produced from + * mlirPythonTypeIDToCapsule. If the capsule is not of the right type, then + * a null type is returned (as checked via mlirTypeIDIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_TYPEID); + MlirTypeID typeID = {ptr}; + return typeID; +} + /** Creates a capsule object encapsulating the raw C-API MlirType. * The returned capsule does not extend or affect ownership of any Python * objects that reference the type in any way. diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -22,6 +22,9 @@ // Integer types. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Integer type. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void); + /// Checks whether the given type is an integer type. MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type); @@ -56,6 +59,9 @@ // Index type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Index type. +MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void); + /// Checks whether the given type is an index type. MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type); @@ -67,6 +73,9 @@ // Floating-point types. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Float8E5M2 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void); + /// Checks whether the given type is an f8E5M2 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); @@ -74,6 +83,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3FN type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void); + /// Checks whether the given type is an f8E4M3FN type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type); @@ -81,6 +93,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E5M2FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void); + /// Checks whether the given type is an f8E5M2FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type); @@ -88,6 +103,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void); + /// Checks whether the given type is an f8E4M3FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type); @@ -95,6 +113,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3B11FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void); + /// Checks whether the given type is an f8E4M3B11FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type); @@ -102,6 +123,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an BFloat16 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); @@ -109,6 +133,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx); +/// Returns the typeID of an Float16 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void); + /// Checks whether the given type is an f16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type); @@ -116,6 +143,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx); +/// Returns the typeID of an Float32 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void); + /// Checks whether the given type is an f32 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type); @@ -123,6 +153,9 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx); +/// Returns the typeID of an Float64 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void); + /// Checks whether the given type is an f64 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type); @@ -134,6 +167,9 @@ // None type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an None type. +MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void); + /// Checks whether the given type is a None type. MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type); @@ -145,6 +181,9 @@ // Complex type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Complex type. +MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void); + /// Checks whether the given type is a Complex type. MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type); @@ -159,6 +198,9 @@ // Shaped type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Shaped type. +MLIR_CAPI_EXPORTED MlirTypeID mlirShapedTypeGetTypeID(void); + /// Checks whether the given type is a Shaped type. MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type); @@ -202,6 +244,9 @@ // Vector type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Vector type. +MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void); + /// Checks whether the given type is a Vector type. MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type); @@ -226,9 +271,15 @@ /// Checks whether the given type is a Tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsATensor(MlirType type); +/// Returns the typeID of an RankedTensor type. +MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void); + /// Checks whether the given type is a ranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type); +/// Returns the typeID of an UnrankedTensor type. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void); + /// Checks whether the given type is an unranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type); @@ -264,9 +315,15 @@ // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an MemRef type. +MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void); + /// Checks whether the given type is a MemRef type. MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type); +/// Returns the typeID of an UnrankedMemRef type. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void); + /// Checks whether the given type is an UnrankedMemRef type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type); @@ -326,6 +383,9 @@ // Tuple type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Tuple type. +MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void); + /// Checks whether the given type is a tuple type. MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type); @@ -345,6 +405,9 @@ // Function type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Function type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void); + /// Checks whether the given type is a function type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type); @@ -373,6 +436,9 @@ // Opaque type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Opaque type. +MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void); + /// Checks whether the given type is an opaque type. MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type); diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -236,6 +236,27 @@ } }; +/// Casts object <-> MlirTypeID. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToTypeID(capsule.ptr()); + return !mlirTypeIDIsNull(value); + } + static handle cast(MlirTypeID v, return_value_policy, handle) { + if (v.ptr == nullptr) + return py::none(); + py::object capsule = + py::reinterpret_steal(mlirPythonTypeIDToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("TypeID") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + /// Casts object <-> MlirType. template <> struct type_caster { diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -17,6 +17,7 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -1807,6 +1808,24 @@ rawType); } +//------------------------------------------------------------------------------ +// PyTypeID. +//------------------------------------------------------------------------------ + +py::object PyTypeID::getCapsule() { + return py::reinterpret_steal(mlirPythonTypeIDToCapsule(*this)); +} + +PyTypeID PyTypeID::createFromCapsule(py::object capsule) { + MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr()); + if (mlirTypeIDIsNull(mlirTypeID)) + throw py::error_already_set(); + return PyTypeID(mlirTypeID); +} +bool PyTypeID::operator==(const PyTypeID &other) const { + return mlirTypeIDEqual(typeID, other.typeID); +} + //------------------------------------------------------------------------------ // PyValue and subclases. //------------------------------------------------------------------------------ @@ -3268,16 +3287,47 @@ return printAccum.join(); }, "Returns the assembly form of the type.") - .def("__repr__", [](PyType &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, types are an exception as they typically have compact - // assembly forms and printing them is useful. - PyPrintAccumulator printAccum; - printAccum.parts.append("Type("); - mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); + .def("__repr__", + [](PyType &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, types are an exception as they typically have compact + // assembly forms and printing them is useful. + PyPrintAccumulator printAccum; + printAccum.parts.append("Type("); + mlirTypePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID { + MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); + if (!mlirTypeIDIsNull(mlirTypeID)) + return mlirTypeID; + auto origRepr = + pybind11::repr(pybind11::cast(self)).cast(); + throw py::value_error( + (origRepr + llvm::Twine(" has no typeid.")).str()); + }); + + //---------------------------------------------------------------------------- + // Mapping of PyTypeID. + //---------------------------------------------------------------------------- + py::class_(m, "TypeID", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) + // Note, this tests whether the underlying TypeIDs are the same, + // not whether the wrapper MlirTypeIDs are the same, nor whether + // the Python objects are the same (i.e., PyTypeID is a value type). + .def("__eq__", + [](PyTypeID &self, PyTypeID &other) { return self == other; }) + .def("__eq__", + [](PyTypeID &self, const py::object &other) { return false; }) + // Note, this gives the hash value of the underlying TypeID, not the + // hash value of the Python object, nor the hash value of the + // MlirTypeID wrapper. + .def("__hash__", [](PyTypeID &self) { + return static_cast(mlirTypeIDHashValue(self)); }); //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -20,6 +20,7 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" #include "llvm/ADT/DenseMap.h" namespace mlir { @@ -826,6 +827,29 @@ MlirType type; }; +/// A TypeID provides an efficient and unique identifier for a specific C++ +/// type. This allows for a C++ type to be compared, hashed, and stored in an +/// opaque context. This class wraps around the generic MlirTypeID. +class PyTypeID { +public: + PyTypeID(MlirTypeID typeID) : typeID(typeID) {} + // Note, this tests whether the underlying TypeIDs are the same, + // not whether the wrapper MlirTypeIDs are the same, nor whether + // the PyTypeID objects are the same (i.e., PyTypeID is a value type). + bool operator==(const PyTypeID &other) const; + operator MlirTypeID() const { return typeID; } + MlirTypeID get() { return typeID; } + + /// Gets a capsule wrapping the void* within the MlirTypeID. + pybind11::object getCapsule(); + + /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. + static PyTypeID createFromCapsule(pybind11::object capsule); + +private: + MlirTypeID typeID; +}; + /// CRTP base classes for Python types that subclass Type and should be /// castable from it (i.e. via something like IntegerType(t)). /// By default, type class hierarchies are one level deep (i.e. a @@ -839,10 +863,14 @@ // const char *pyClassName using ClassTy = pybind11::class_; using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; PyConcreteType() = default; PyConcreteType(PyMlirContextRef contextRef, MlirType t) - : BaseTy(std::move(contextRef), t) {} + : BaseTy(std::move(contextRef), t) { + pybind11::implicitly_convertible(); + } PyConcreteType(PyType &orig) : PyConcreteType(orig.getContext(), castFrom(orig)) {} @@ -866,6 +894,26 @@ return DerivedTy::isaFunction(otherType); }, pybind11::arg("other")); + cls.def_property_readonly_static( + "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + if (DerivedTy::getTypeIdFunction) + return DerivedTy::getTypeIdFunction(); + throw SetPyError(PyExc_AttributeError, + DerivedTy::pyClassName + + llvm::Twine(" has no typeid.")); + }); + cls.def_property_readonly("typeid", [](PyType &self) { + return py::cast(self).attr("typeid").cast(); + }); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -32,6 +32,8 @@ class PyIntegerType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIntegerTypeGetTypeID; static constexpr const char *pyClassName = "IntegerType"; using PyConcreteType::PyConcreteType; @@ -89,6 +91,8 @@ class PyIndexType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIndexTypeGetTypeID; static constexpr const char *pyClassName = "IndexType"; using PyConcreteType::PyConcreteType; @@ -107,6 +111,8 @@ class PyFloat8E4M3FNType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNTypeGetTypeID; static constexpr const char *pyClassName = "Float8E4M3FNType"; using PyConcreteType::PyConcreteType; @@ -125,6 +131,8 @@ class PyFloat8E5M2Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2TypeGetTypeID; static constexpr const char *pyClassName = "Float8E5M2Type"; using PyConcreteType::PyConcreteType; @@ -143,6 +151,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNUZTypeGetTypeID; static constexpr const char *pyClassName = "Float8E4M3FNUZType"; using PyConcreteType::PyConcreteType; @@ -161,6 +171,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3B11FNUZTypeGetTypeID; static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; using PyConcreteType::PyConcreteType; @@ -179,6 +191,8 @@ class PyFloat8E5M2FNUZType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2FNUZTypeGetTypeID; static constexpr const char *pyClassName = "Float8E5M2FNUZType"; using PyConcreteType::PyConcreteType; @@ -197,6 +211,8 @@ class PyBF16Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirBFloat16TypeGetTypeID; static constexpr const char *pyClassName = "BF16Type"; using PyConcreteType::PyConcreteType; @@ -215,6 +231,8 @@ class PyF16Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat16TypeGetTypeID; static constexpr const char *pyClassName = "F16Type"; using PyConcreteType::PyConcreteType; @@ -233,6 +251,8 @@ class PyF32Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat32TypeGetTypeID; static constexpr const char *pyClassName = "F32Type"; using PyConcreteType::PyConcreteType; @@ -251,6 +271,8 @@ class PyF64Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat64TypeGetTypeID; static constexpr const char *pyClassName = "F64Type"; using PyConcreteType::PyConcreteType; @@ -269,6 +291,8 @@ class PyNoneType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirNoneTypeGetTypeID; static constexpr const char *pyClassName = "NoneType"; using PyConcreteType::PyConcreteType; @@ -287,6 +311,8 @@ class PyComplexType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirComplexTypeGetTypeID; static constexpr const char *pyClassName = "ComplexType"; using PyConcreteType::PyConcreteType; @@ -417,6 +443,8 @@ class PyVectorType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirVectorTypeGetTypeID; static constexpr const char *pyClassName = "VectorType"; using PyConcreteType::PyConcreteType; @@ -442,6 +470,8 @@ : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirRankedTensorTypeGetTypeID; static constexpr const char *pyClassName = "RankedTensorType"; using PyConcreteType::PyConcreteType; @@ -476,6 +506,8 @@ : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedTensorTypeGetTypeID; static constexpr const char *pyClassName = "UnrankedTensorType"; using PyConcreteType::PyConcreteType; @@ -498,6 +530,8 @@ class PyMemRefType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirMemRefTypeGetTypeID; static constexpr const char *pyClassName = "MemRefType"; using PyConcreteType::PyConcreteType; @@ -550,6 +584,8 @@ : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedMemRefTypeGetTypeID; static constexpr const char *pyClassName = "UnrankedMemRefType"; using PyConcreteType::PyConcreteType; @@ -585,6 +621,8 @@ class PyTupleType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTupleTypeGetTypeID; static constexpr const char *pyClassName = "TupleType"; using PyConcreteType::PyConcreteType; @@ -622,6 +660,8 @@ class PyFunctionType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFunctionTypeGetTypeID; static constexpr const char *pyClassName = "FunctionType"; using PyConcreteType::PyConcreteType; @@ -676,6 +716,8 @@ class PyOpaqueType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueTypeGetTypeID; static constexpr const char *pyClassName = "OpaqueType"; using PyConcreteType::PyConcreteType; diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -22,6 +22,8 @@ // Integer types. //===----------------------------------------------------------------------===// +MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); } + bool mlirTypeIsAInteger(MlirType type) { return llvm::isa(unwrap(type)); } @@ -58,6 +60,8 @@ // Index type. //===----------------------------------------------------------------------===// +MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); } + bool mlirTypeIsAIndex(MlirType type) { return llvm::isa(unwrap(type)); } @@ -70,6 +74,10 @@ // Floating-point types. //===----------------------------------------------------------------------===// +MlirTypeID mlirFloat8E5M2TypeGetTypeID() { + return wrap(Float8E5M2Type::getTypeID()); +} + bool mlirTypeIsAFloat8E5M2(MlirType type) { return unwrap(type).isFloat8E5M2(); } @@ -78,6 +86,10 @@ return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() { + return wrap(Float8E4M3FNType::getTypeID()); +} + bool mlirTypeIsAFloat8E4M3FN(MlirType type) { return unwrap(type).isFloat8E4M3FN(); } @@ -86,6 +98,10 @@ return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); } +MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() { + return wrap(Float8E5M2FNUZType::getTypeID()); +} + bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { return unwrap(type).isFloat8E5M2FNUZ(); } @@ -94,6 +110,10 @@ return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() { + return wrap(Float8E4M3FNUZType::getTypeID()); +} + bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { return unwrap(type).isFloat8E4M3FNUZ(); } @@ -102,6 +122,10 @@ return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() { + return wrap(Float8E4M3B11FNUZType::getTypeID()); +} + bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { return unwrap(type).isFloat8E4M3B11FNUZ(); } @@ -110,24 +134,34 @@ return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx))); } +MlirTypeID mlirBFloat16TypeGetTypeID() { + return wrap(BFloat16Type::getTypeID()); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { return wrap(FloatType::getBF16(unwrap(ctx))); } +MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); } + bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); } MlirType mlirF16TypeGet(MlirContext ctx) { return wrap(FloatType::getF16(unwrap(ctx))); } +MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } + bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); } MlirType mlirF32TypeGet(MlirContext ctx) { return wrap(FloatType::getF32(unwrap(ctx))); } +MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); } + bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); } MlirType mlirF64TypeGet(MlirContext ctx) { @@ -138,6 +172,8 @@ // None type. //===----------------------------------------------------------------------===// +MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); } + bool mlirTypeIsANone(MlirType type) { return llvm::isa(unwrap(type)); } @@ -150,6 +186,8 @@ // Complex type. //===----------------------------------------------------------------------===// +MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); } + bool mlirTypeIsAComplex(MlirType type) { return llvm::isa(unwrap(type)); } @@ -214,6 +252,8 @@ // Vector type. //===----------------------------------------------------------------------===// +MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); } + bool mlirTypeIsAVector(MlirType type) { return llvm::isa(unwrap(type)); } @@ -239,10 +279,18 @@ return llvm::isa(unwrap(type)); } +MlirTypeID mlirRankedTensorTypeGetTypeID() { + return wrap(RankedTensorType::getTypeID()); +} + bool mlirTypeIsARankedTensor(MlirType type) { return llvm::isa(unwrap(type)); } +MlirTypeID mlirUnrankedTensorTypeGetTypeID() { + return wrap(UnrankedTensorType::getTypeID()); +} + bool mlirTypeIsAUnrankedTensor(MlirType type) { return llvm::isa(unwrap(type)); } @@ -280,6 +328,8 @@ // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// +MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); } + bool mlirTypeIsAMemRef(MlirType type) { return llvm::isa(unwrap(type)); } @@ -337,6 +387,10 @@ return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } +MlirTypeID mlirUnrankedMemRefTypeGetTypeID() { + return wrap(UnrankedMemRefType::getTypeID()); +} + bool mlirTypeIsAUnrankedMemRef(MlirType type) { return llvm::isa(unwrap(type)); } @@ -362,6 +416,8 @@ // Tuple type. //===----------------------------------------------------------------------===// +MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); } + bool mlirTypeIsATuple(MlirType type) { return llvm::isa(unwrap(type)); } @@ -386,6 +442,10 @@ // Function type. //===----------------------------------------------------------------------===// +MlirTypeID mlirFunctionTypeGetTypeID() { + return wrap(FunctionType::getTypeID()); +} + bool mlirTypeIsAFunction(MlirType type) { return llvm::isa(unwrap(type)); } @@ -424,6 +484,8 @@ // Opaque type. //===----------------------------------------------------------------------===// +MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); } + bool mlirTypeIsAOpaque(MlirType type) { return llvm::isa(unwrap(type)); } diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue +from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType def register_python_test_dialect(context, load=True): from .._mlir_libs import _mlirPythonTest diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -278,6 +278,14 @@ # The following cast must not assert. b = test.TestType(a) + # Instance custom types should have typeids + assert isinstance(b.typeid, TypeID) + # Subclasses of ir.Type should not have a static_typeid + # CHECK: 'TestType' object has no attribute 'static_typeid' + try: + b.static_typeid + except AttributeError as e: + print(e) i8 = IntegerType.get_signless(8) try: @@ -332,6 +340,12 @@ # CHECK: False print(tt.is_null()) + # Classes of custom types that inherit from concrete types should have + # static_typeid + assert isinstance(test.TestTensorType.static_typeid, TypeID) + # And it should be equal to the in-tree concrete type + assert test.TestTensorType.static_typeid == t.type.typeid + # CHECK-LABEL: TEST: inferReturnTypeComponents @run diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -3,6 +3,7 @@ import gc from mlir.ir import * + def run(f): print("\nTEST:", f.__name__) f() @@ -76,6 +77,7 @@ # CHECK: len(s): 2 print("len(s): ", len(s)) + # CHECK-LABEL: TEST: testTypeCast @run def testTypeCast(): @@ -182,6 +184,7 @@ # CHECK: unsigned: ui64 print("unsigned:", IntegerType.get_unsigned(64)) + # CHECK-LABEL: TEST: testIndexType @run def testIndexType(): @@ -259,7 +262,8 @@ # CHECK: rank: 2 print("rank:", vector.rank) # CHECK: whether the shaped type has a static shape: True - print("whether the shaped type has a static shape:", vector.has_static_shape) + print("whether the shaped type has a static shape:", + vector.has_static_shape) # CHECK: whether the dim-th dimension is dynamic: False print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) # CHECK: dim size: 3 @@ -311,8 +315,7 @@ shape = [2, 3] loc = Location.unknown() # CHECK: ranked tensor type: tensor<2x3xf32> - print("ranked tensor type:", - RankedTensorType.get(shape, f32)) + print("ranked tensor type:", RankedTensorType.get(shape, f32)) none = NoneType.get() try: @@ -477,8 +480,7 @@ @run def testFunctionType(): with Context() as ctx: - input_types = [IntegerType.get_signless(32), - IntegerType.get_signless(16)] + input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)] result_types = [IndexType.get()] func = FunctionType.get(input_types, result_types) # CHECK: INPUTS: [Type(i32), Type(i16)] @@ -509,3 +511,91 @@ print(type(ShapedType.get_dynamic_size())) # CHECK: print(type(ShapedType.get_dynamic_stride_or_offset())) + + +# CHECK-LABEL: TEST: testTypeIDs +@run +def testTypeIDs(): + with Context(), Location.unknown(): + f32 = F32Type.get() + + types = [ + (IntegerType, IntegerType.get_signless(16)), + (IndexType, IndexType.get()), + (Float8E4M3FNType, Float8E4M3FNType.get()), + (Float8E5M2Type, Float8E5M2Type.get()), + (Float8E4M3FNUZType, Float8E4M3FNUZType.get()), + (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()), + (Float8E5M2FNUZType, Float8E5M2FNUZType.get()), + (BF16Type, BF16Type.get()), + (F16Type, F16Type.get()), + (F32Type, F32Type.get()), + (F64Type, F64Type.get()), + (NoneType, NoneType.get()), + (ComplexType, ComplexType.get(f32)), + (VectorType, VectorType.get([2, 3], f32)), + (RankedTensorType, RankedTensorType.get([2, 3], f32)), + (UnrankedTensorType, UnrankedTensorType.get(f32)), + (MemRefType, MemRefType.get([2, 3], f32)), + (UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))), + (TupleType, TupleType.get_tuple([f32])), + (FunctionType, FunctionType.get([], [])), + (OpaqueType, OpaqueType.get("tensor", "bob")), + ] + + # CHECK: IntegerType(i16) + # CHECK: IndexType(index) + # CHECK: Float8E4M3FNType(f8E4M3FN) + # CHECK: Float8E5M2Type(f8E5M2) + # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ) + # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ) + # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ) + # CHECK: BF16Type(bf16) + # CHECK: F16Type(f16) + # CHECK: F32Type(f32) + # CHECK: F64Type(f64) + # CHECK: NoneType(none) + # CHECK: ComplexType(complex) + # CHECK: VectorType(vector<2x3xf32>) + # CHECK: RankedTensorType(tensor<2x3xf32>) + # CHECK: UnrankedTensorType(tensor<*xf32>) + # CHECK: MemRefType(memref<2x3xf32>) + # CHECK: UnrankedMemRefType(memref<*xf32, 2>) + # CHECK: TupleType(tuple) + # CHECK: FunctionType(() -> ()) + # CHECK: OpaqueType(!tensor.bob) + for _, t in types: + print(repr(t)) + + # Test getTypeIdFunction agrees with + # mlirTypeGetTypeID(self) for an instance. + # CHECK: all equal + for t1, t2 in types: + tid1, tid2 = t1.static_typeid, Type(t2).typeid + assert tid1 == tid2 and hash(tid1) == hash( + tid2), f"expected hash and value equality {t1} {t2}" + else: + print("all equal") + + # Test that storing PyTypeID in python dicts + # works as expected. + typeid_dict = dict(types) + assert len(typeid_dict) + + # CHECK: all equal + for t1, t2 in typeid_dict.items(): + assert t1.static_typeid == t2.typeid and hash( + t1.static_typeid) == hash( + t2.typeid), f"expected hash and value equality {t1} {t2}" + else: + print("all equal") + + # CHECK: ShapedType has no typeid. + try: + print(ShapedType.static_typeid) + except AttributeError as e: + print(e) + + vector_type = Type.parse("vector<2x3xf32>") + # CHECK: True + print(ShapedType(vector_type).typeid == vector_type.typeid) diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "PythonTestCAPI.h" +#include "mlir-c/BuiltinTypes.h" #include "mlir/Bindings/Python/PybindAdaptors.h" namespace py = pybind11; @@ -40,6 +41,9 @@ return cls(mlirPythonTestTestTypeGet(ctx)); }, py::arg("cls"), py::arg("context") = py::none()); + mlir_type_subclass(m, "TestTensorType", mlirTypeIsARankedTensor, + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("RankedTensorType")); mlir_value_subclass(m, "TestTensorValue", mlirTypeIsAPythonTestTestTensorValue) .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });