diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -18,10 +18,14 @@ extern "C" { #endif +#define TypeID(TYPE) MlirTypeID mlir##TYPE##TypeGetTypeID(); + //===----------------------------------------------------------------------===// // Integer types. //===----------------------------------------------------------------------===// +TypeID(Integer); + /// Checks whether the given type is an integer type. MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type); @@ -56,6 +60,8 @@ // Index type. //===----------------------------------------------------------------------===// +TypeID(Index); + /// Checks whether the given type is an index type. MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type); @@ -67,6 +73,8 @@ // Floating-point types. //===----------------------------------------------------------------------===// +TypeID(Float8E5M2); + /// Checks whether the given type is an f8E5M2 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); @@ -74,6 +82,8 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); +TypeID(Float8E4M3FN); + /// Checks whether the given type is an f8E4M3FN type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type); @@ -81,6 +91,8 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); +TypeID(Float8E5M2FNUZ); + /// Checks whether the given type is an f8E5M2FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type); @@ -88,6 +100,8 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx); +TypeID(Float8E4M3FNUZ); + /// Checks whether the given type is an f8E4M3FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type); @@ -95,6 +109,8 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx); +TypeID(Float8E4M3B11FNUZ); + /// Checks whether the given type is an f8E4M3B11FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type); @@ -102,6 +118,8 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx); +TypeID(BFloat16); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); @@ -109,6 +127,8 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx); +TypeID(Float16); + /// Checks whether the given type is an f16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type); @@ -116,6 +136,8 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx); +TypeID(Float32); + /// Checks whether the given type is an f32 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type); @@ -123,6 +145,8 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx); +TypeID(Float64); + /// Checks whether the given type is an f64 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type); @@ -134,6 +158,8 @@ // None type. //===----------------------------------------------------------------------===// +TypeID(None); + /// Checks whether the given type is a None type. MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type); @@ -145,6 +171,8 @@ // Complex type. //===----------------------------------------------------------------------===// +TypeID(Complex); + /// Checks whether the given type is a Complex type. MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type); @@ -202,6 +230,8 @@ // Vector type. //===----------------------------------------------------------------------===// +TypeID(Vector); + /// Checks whether the given type is a Vector type. MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type); @@ -226,9 +256,13 @@ /// Checks whether the given type is a Tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsATensor(MlirType type); +TypeID(RankedTensor); + /// Checks whether the given type is a ranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type); +TypeID(UnrankedTensor); + /// Checks whether the given type is an unranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type); @@ -264,9 +298,13 @@ // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// +TypeID(MemRef); + /// Checks whether the given type is a MemRef type. MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type); +TypeID(UnrankedMemRef); + /// Checks whether the given type is an UnrankedMemRef type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type); @@ -326,6 +364,8 @@ // Tuple type. //===----------------------------------------------------------------------===// +TypeID(Tuple); + /// Checks whether the given type is a tuple type. MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type); @@ -345,6 +385,8 @@ // Function type. //===----------------------------------------------------------------------===// +TypeID(Function); + /// Checks whether the given type is a function type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type); @@ -373,6 +415,8 @@ // Opaque type. //===----------------------------------------------------------------------===// +TypeID(Opaque); + /// Checks whether the given type is an opaque type. MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type); @@ -396,4 +440,6 @@ } #endif +#undef TypeID + #endif // MLIR_C_BUILTINTYPES_H diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -9,12 +9,13 @@ #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H #define MLIR_BINDINGS_PYTHON_GLOBALS_H +#include #include #include -#include #include "PybindUtils.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" @@ -64,6 +65,8 @@ void registerAttributeBuilder(const std::string &attributeKind, pybind11::function pyFunc); + void registerTypeBuilder(MlirTypeID mlirTypeId, pybind11::function pyFunc); + /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. @@ -80,6 +83,9 @@ std::optional lookupAttributeBuilder(const std::string &attributeKind); + std::optional + lookupTypeBuilder(const MlirTypeID &mlirTypeId); + /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. std::optional @@ -101,6 +107,12 @@ llvm::StringMap operationClassMap; /// Map of attribute ODS name to custom builder. llvm::StringMap attributeBuilderMap; + struct MlirTypeIDCompare { + bool operator()(MlirTypeID a, MlirTypeID b) const { + return mlirTypeIDEqual(a, b); + } + }; + std::map typeBuilderMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "IRModule.h" @@ -1028,8 +1028,7 @@ py::arg("value"), py::arg("context") = py::none(), "Gets a uniqued Type attribute"); c.def_property_readonly("value", [](PyTypeAttribute &self) { - return PyType(self.getContext()->getRef(), - mlirTypeAttrGetValue(self.get())); + return mlirTypeAttrGetValue(self.get()); }); } }; diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2085,13 +2085,12 @@ /// Returns the list of types of the values held by container. template -static std::vector getValueTypes(Container &container, - PyMlirContextRef &context) { - std::vector result; +static std::vector getValueTypes(Container &container, + PyMlirContextRef &context) { + std::vector result; result.reserve(container.size()); for (int i = 0, e = container.size(); i < e; ++i) { - result.push_back( - PyType(context, mlirValueGetType(container.getElement(i).get()))); + result.push_back(py::cast(mlirValueGetType(container.getElement(i).get()))); } return result; } @@ -3145,11 +3144,8 @@ "context", [](PyAttribute &self) { return self.getContext().getObject(); }, "Context that owns the Attribute") - .def_property_readonly("type", - [](PyAttribute &self) { - return PyType(self.getContext()->getRef(), - mlirAttributeGetType(self)); - }) + .def_property_readonly( + "type", [](PyAttribute &self) { return mlirAttributeGetType(self); }) .def( "get_named", [](PyAttribute &self, std::string name) { @@ -3244,7 +3240,7 @@ mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); if (mlirTypeIsNull(type)) throw MLIRError("Unable to parse type", errors.take()); - return PyType(context->getRef(), type); + return type; }, py::arg("asm"), py::arg("context") = py::none(), kContextParseTypeDocstring) @@ -3353,12 +3349,8 @@ return printAccum.join(); }, py::arg("use_local_scope") = false, kGetNameAsOperand) - .def_property_readonly("type", - [](PyValue &self) { - return PyType( - self.getParentOperation()->getContext(), - mlirValueGetType(self.get())); - }) + .def_property_readonly( + "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -321,11 +321,7 @@ py::module_local()) .def_property_readonly( "element_type", - [](PyShapedTypeComponents &self) { - return PyType(PyMlirContext::forContext( - mlirTypeGetContext(self.elementType)), - self.elementType); - }, + [](PyShapedTypeComponents &self) { return self.elementType; }, "Returns the element type of the shaped type components.") .def_static( "get", diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -13,10 +13,12 @@ #include #include +#include "Globals.h" #include "PybindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" +#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" @@ -839,6 +841,7 @@ // const char *pyClassName using ClassTy = pybind11::class_; using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); PyConcreteType() = default; PyConcreteType(PyMlirContextRef contextRef, MlirType t) @@ -867,6 +870,18 @@ }, pybind11::arg("other")); DerivedTy::bindDerived(cls); + + if (DerivedTy::getTypeIdFunction) { + PyGlobals::get().registerTypeBuilder( + DerivedTy::getTypeIdFunction(), + pybind11::cpp_function([](MlirType mlirType) -> pybind11::handle { + PyMlirContextRef context = + PyMlirContext::forContext(mlirTypeGetContext(mlirType)); + if (DerivedTy::isaFunction(mlirType)) + return pybind11::cast(DerivedTy(context, mlirType)).release(); + return pybind11::cast(PyType(context, mlirType)).release(); + })); + } } /// Implemented by derived classes to add methods to the Python subclass. @@ -960,9 +975,8 @@ return DerivedTy::isaFunction(otherAttr); }, pybind11::arg("other")); - cls.def_property_readonly("type", [](PyAttribute &attr) { - return PyType(attr.getContext(), mlirAttributeGetType(attr)); - }); + cls.def_property_readonly( + "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); DerivedTy::bindDerived(cls); } @@ -1141,12 +1155,46 @@ namespace pybind11 { namespace detail { +using namespace mlir::python; +namespace py = pybind11; + +static py::object mlirApiObjectToCapsule(py::handle apiObject) { + if (PyCapsule_CheckExact(apiObject.ptr())) + return py::reinterpret_borrow(apiObject); + if (!py::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) { + auto repr = py::repr(apiObject).cast(); + throw py::type_error( + (llvm::Twine("Expected an MLIR object (got ") + repr + ").").str()); + } + return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR); +} + template <> -struct type_caster - : MlirDefaultingCaster {}; +struct type_caster + : MlirDefaultingCaster {}; template <> -struct type_caster - : MlirDefaultingCaster {}; +struct type_caster + : MlirDefaultingCaster {}; + +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirType, _("MlirType")); + bool load(handle src, bool) { + pybind11::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToType(capsule.ptr()); + return !mlirTypeIsNull(value); + } + static handle cast(MlirType t, return_value_policy, handle) { + MlirTypeID mlirTypeId = mlirTypeGetTypeID(t); + std::optional typeBuilder = + PyGlobals::get().lookupTypeBuilder(mlirTypeId); + if (typeBuilder) + return typeBuilder->cpp_function()(t); + + PyMlirContextRef context = PyMlirContext::forContext(mlirTypeGetContext(t)); + return pybind11::cast(PyType(context, t)).release(); + } +}; } // namespace detail } // namespace pybind11 diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -10,8 +10,8 @@ #include "Globals.h" #include "PybindUtils.h" -#include #include +#include #include "mlir-c/Bindings/Python/Interop.h" @@ -72,6 +72,16 @@ found = std::move(pyFunc); } +void PyGlobals::registerTypeBuilder(MlirTypeID mlirTypeId, + pybind11::function pyFunc) { + py::object &found = typeBuilderMap[mlirTypeId]; + if (found) { + throw std::runtime_error( + (llvm::Twine("Type builder for '") + "' is already registered").str()); + } + found = std::move(pyFunc); +} + void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { py::object &found = dialectClassMap[dialectNamespace]; diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -33,6 +33,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; static constexpr const char *pyClassName = "IntegerType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirIntegerTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -90,6 +92,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; static constexpr const char *pyClassName = "IndexType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirIndexTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -108,6 +112,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; static constexpr const char *pyClassName = "Float8E4M3FNType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirFloat8E4M3FNTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -126,6 +132,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; static constexpr const char *pyClassName = "Float8E5M2Type"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirFloat8E5M2TypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -144,6 +152,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; static constexpr const char *pyClassName = "Float8E4M3FNUZType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirFloat8E4M3FNUZTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -162,6 +172,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirFloat8E4M3B11FNUZTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -180,6 +192,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; static constexpr const char *pyClassName = "Float8E5M2FNUZType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirFloat8E5M2FNUZTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -198,6 +212,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; static constexpr const char *pyClassName = "BF16Type"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirBFloat16TypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -216,6 +232,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; static constexpr const char *pyClassName = "F16Type"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirFloat16TypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -234,6 +252,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; static constexpr const char *pyClassName = "F32Type"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirFloat32TypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -252,6 +272,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; static constexpr const char *pyClassName = "F64Type"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirFloat64TypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -270,6 +292,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; static constexpr const char *pyClassName = "NoneType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirNoneTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -288,6 +312,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; static constexpr const char *pyClassName = "ComplexType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirComplexTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -308,9 +334,8 @@ "Create a complex type"); c.def_property_readonly( "element_type", - [](PyComplexType &self) -> PyType { - MlirType t = mlirComplexTypeGetElementType(self); - return PyType(self.getContext(), t); + [](PyComplexType &self) { + return mlirComplexTypeGetElementType(self); }, "Returns element type."); } @@ -320,15 +345,13 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; static constexpr const char *pyClassName = "ShapedType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { c.def_property_readonly( "element_type", - [](PyShapedType &self) { - MlirType t = mlirShapedTypeGetElementType(self); - return PyType(self.getContext(), t); - }, + [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, "Returns the element type of the shaped type."); c.def_property_readonly( "has_rank", @@ -418,6 +441,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; static constexpr const char *pyClassName = "VectorType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirVectorTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -443,6 +468,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; static constexpr const char *pyClassName = "RankedTensorType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirRankedTensorTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -477,6 +504,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; static constexpr const char *pyClassName = "UnrankedTensorType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirUnrankedTensorTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -499,6 +528,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; static constexpr const char *pyClassName = "MemRefType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirMemRefTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -551,6 +582,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; static constexpr const char *pyClassName = "UnrankedMemRefType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirUnrankedMemRefTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -586,6 +619,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; static constexpr const char *pyClassName = "TupleType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirTupleTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -604,9 +639,8 @@ "Create a tuple type"); c.def( "get_type", - [](PyTupleType &self, intptr_t pos) -> PyType { - MlirType t = mlirTupleTypeGetType(self, pos); - return PyType(self.getContext(), t); + [](PyTupleType &self, intptr_t pos) { + return mlirTupleTypeGetType(self, pos); }, py::arg("pos"), "Returns the pos-th type in the tuple type."); c.def_property_readonly( @@ -623,6 +657,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; static constexpr const char *pyClassName = "FunctionType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirFunctionTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { @@ -647,7 +683,7 @@ py::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; ++i) { - types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); + types.append(mlirFunctionTypeGetInput(t, i)); } return types; }, @@ -659,8 +695,7 @@ py::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; ++i) { - types.append( - PyType(contextRef, mlirFunctionTypeGetResult(self, i))); + types.append(mlirFunctionTypeGetResult(self, i)); } return types; }, @@ -677,6 +712,8 @@ public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; static constexpr const char *pyClassName = "OpaqueType"; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + &mlirOpaqueTypeGetTypeID; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -18,10 +18,15 @@ using namespace mlir; +#define TypeID(TYPE) \ + MlirTypeID mlir##TYPE##TypeGetTypeID() { return wrap(TYPE##Type::getTypeID()); } + //===----------------------------------------------------------------------===// // Integer types. //===----------------------------------------------------------------------===// +TypeID(Integer); + bool mlirTypeIsAInteger(MlirType type) { return llvm::isa(unwrap(type)); } @@ -58,6 +63,8 @@ // Index type. //===----------------------------------------------------------------------===// +TypeID(Index); + bool mlirTypeIsAIndex(MlirType type) { return llvm::isa(unwrap(type)); } @@ -70,6 +77,8 @@ // Floating-point types. //===----------------------------------------------------------------------===// +TypeID(Float8E5M2); + bool mlirTypeIsAFloat8E5M2(MlirType type) { return unwrap(type).isFloat8E5M2(); } @@ -78,6 +87,8 @@ return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); } +TypeID(Float8E4M3FN); + bool mlirTypeIsAFloat8E4M3FN(MlirType type) { return unwrap(type).isFloat8E4M3FN(); } @@ -86,6 +97,8 @@ return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); } +TypeID(Float8E5M2FNUZ); + bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { return unwrap(type).isFloat8E5M2FNUZ(); } @@ -94,6 +107,8 @@ return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx))); } +TypeID(Float8E4M3FNUZ); + bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { return unwrap(type).isFloat8E4M3FNUZ(); } @@ -102,6 +117,8 @@ return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx))); } +TypeID(Float8E4M3B11FNUZ); + bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { return unwrap(type).isFloat8E4M3B11FNUZ(); } @@ -110,24 +127,32 @@ return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx))); } +TypeID(BFloat16); + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { return wrap(FloatType::getBF16(unwrap(ctx))); } +TypeID(Float16); + bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); } MlirType mlirF16TypeGet(MlirContext ctx) { return wrap(FloatType::getF16(unwrap(ctx))); } +TypeID(Float32); + bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); } MlirType mlirF32TypeGet(MlirContext ctx) { return wrap(FloatType::getF32(unwrap(ctx))); } +TypeID(Float64); + bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); } MlirType mlirF64TypeGet(MlirContext ctx) { @@ -138,6 +163,8 @@ // None type. //===----------------------------------------------------------------------===// +TypeID(None); + bool mlirTypeIsANone(MlirType type) { return llvm::isa(unwrap(type)); } @@ -150,6 +177,8 @@ // Complex type. //===----------------------------------------------------------------------===// +TypeID(Complex); + bool mlirTypeIsAComplex(MlirType type) { return llvm::isa(unwrap(type)); } @@ -214,6 +243,8 @@ // Vector type. //===----------------------------------------------------------------------===// +TypeID(Vector); + bool mlirTypeIsAVector(MlirType type) { return llvm::isa(unwrap(type)); } @@ -239,10 +270,14 @@ return llvm::isa(unwrap(type)); } +TypeID(RankedTensor); + bool mlirTypeIsARankedTensor(MlirType type) { return llvm::isa(unwrap(type)); } +TypeID(UnrankedTensor); + bool mlirTypeIsAUnrankedTensor(MlirType type) { return llvm::isa(unwrap(type)); } @@ -280,6 +315,8 @@ // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// +TypeID(MemRef); + bool mlirTypeIsAMemRef(MlirType type) { return llvm::isa(unwrap(type)); } @@ -337,6 +374,8 @@ return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } +TypeID(UnrankedMemRef); + bool mlirTypeIsAUnrankedMemRef(MlirType type) { return llvm::isa(unwrap(type)); } @@ -362,6 +401,8 @@ // Tuple type. //===----------------------------------------------------------------------===// +TypeID(Tuple); + bool mlirTypeIsATuple(MlirType type) { return llvm::isa(unwrap(type)); } @@ -386,6 +427,8 @@ // Function type. //===----------------------------------------------------------------------===// +TypeID(Function); + bool mlirTypeIsAFunction(MlirType type) { return llvm::isa(unwrap(type)); } @@ -424,6 +467,8 @@ // Opaque type. //===----------------------------------------------------------------------===// +TypeID(Opaque); + bool mlirTypeIsAOpaque(MlirType type) { return llvm::isa(unwrap(type)); } diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -557,3 +557,57 @@ print(f"rank: {len(attr.strides)}") # CHECK: strides are dynamic: [True, True, True] print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}") + + +# CHECK-LABEL: TEST: testConcreteTypesRoundTrip +@run +def testConcreteTypesRoundTrip(): + with Context(), Location.unknown(): + def print_item(attr): + print(repr(attr.type)) + + # CHECK: F32Type(f32) + print_item(Attribute.parse("42.0 : f32")) + # CHECK: F32Type(f32) + print_item(FloatAttr.get_f32(42.0)) + # CHECK: F64Type(f64) + print_item(FloatAttr.get_f64(42.0)) + # CHECK: IntegerType(i32) + print_item(IntegerAttr.get(IntegerType.get_signless(32), 42)) + # CHECK: IntegerType(i64) + print_item(IntegerAttr.get(IntegerType.get_signless(64), 42)) + + def print_container_item(attr_asm): + attr = DenseElementsAttr(Attribute.parse(attr_asm)) + print(repr(attr.type)) + print(repr(attr.type.element_type)) + + # CHECK: RankedTensorType(tensor) + # CHECK: IntegerType(i16) + print_container_item("dense<123> : tensor") + # CHECK: RankedTensorType(tensor) + # CHECK: IntegerType(i32) + print_container_item("dense<123> : tensor") + # CHECK: RankedTensorType(tensor) + # CHECK: IntegerType(i64) + print_container_item("dense<123> : tensor") + + # CHECK: RankedTensorType(tensor) + # CHECK: F16Type(f16) + print_container_item("dense<1.0> : tensor") + # CHECK: RankedTensorType(tensor) + # CHECK: F32Type(f32) + print_container_item("dense<1.0> : tensor") + # CHECK: RankedTensorType(tensor) + # CHECK: F64Type(f64) + print_container_item("dense<1.0> : tensor") + + raw = Attribute.parse("vector<4xf32>") + # CHECK: attr: vector<4xf32> + print("attr:", raw) + type_attr = TypeAttr(raw) + + # CHECK: VectorType(vector<4xf32>) + print(repr(type_attr.value)) + # CHECK: F32Type(f32) + print(repr(type_attr.value.element_type)) diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -2,6 +2,7 @@ import gc from mlir.ir import * +from mlir.dialects import arith, tensor, func, memref def run(f): print("\nTEST:", f.__name__) @@ -380,15 +381,15 @@ f32 = F32Type.get() shape = [2, 3] loc = Location.unknown() - memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) + memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) # CHECK: memref type: memref<2x3xf32, 2> - print("memref type:", memref) + print("memref type:", memref_f32) # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)> - print("memref layout:", memref.layout) + print("memref layout:", memref_f32.layout) # CHECK: memref affine map: (d0, d1) -> (d0, d1) - print("memref affine map:", memref.affine_map) + print("memref affine map:", memref_f32.affine_map) # CHECK: memory space: 2 - print("memory space:", memref.memory_space) + print("memory space:", memref_f32.memory_space) layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0])) memref_layout = MemRefType.get(shape, f32, layout=layout) @@ -411,7 +412,7 @@ else: print("Exception not produced") - assert memref.shape == shape + assert memref_f32.shape == shape # CHECK-LABEL: TEST: testUnrankedMemRefType @@ -481,9 +482,9 @@ IntegerType.get_signless(16)] result_types = [IndexType.get()] func = FunctionType.get(input_types, result_types) - # CHECK: INPUTS: [Type(i32), Type(i16)] + # CHECK: INPUTS: [IntegerType(i32), IntegerType(i16)] print("INPUTS:", func.inputs) - # CHECK: RESULTS: [Type(index)] + # CHECK: RESULTS: [IndexType(index)] print("RESULTS:", func.results) @@ -509,3 +510,122 @@ print(type(ShapedType.get_dynamic_size())) # CHECK: print(type(ShapedType.get_dynamic_stride_or_offset())) + + +# CHECK-LABEL: TEST: testConcreteTypesRoundTrip +@run +def testConcreteTypesRoundTrip(): + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + def print_item(typ, v): + cst = arith.ConstantOp(typ, v).result + print(type(cst.type).__name__) + print(repr(cst.type)) + + # CHECK: F16Type + # CHECK: F16Type(f16) + print_item(F16Type.get(), 0.0) + # CHECK: F32Type + # CHECK: F32Type(f32) + print_item(F32Type.get(), 0.0) + # CHECK: F64Type + # CHECK: F64Type(f64) + print_item(F64Type.get(), 0.0) + # CHECK: Float8E4M3B11FNUZType + # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ) + print_item(Float8E4M3B11FNUZType.get(), 0.0) + # CHECK: Float8E4M3FNType + # CHECK: Float8E4M3FNType(f8E4M3FN) + print_item(Float8E4M3FNType.get(), 0.0) + # CHECK: Float8E4M3FNUZType + # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ) + print_item(Float8E4M3FNUZType.get(), 0.0) + # CHECK: Float8E5M2Type + # CHECK: Float8E5M2Type(f8E5M2) + print_item(Float8E5M2Type.get(), 0.0) + # CHECK: Float8E5M2FNUZType + # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ) + print_item(Float8E5M2FNUZType.get(), 0.0) + # CHECK: BF16Type + # CHECK: BF16Type(bf16) + print_item(BF16Type.get(), 0.0) + # CHECK: IndexType + # CHECK: IndexType(index) + print_item(IndexType.get(), 0) + # CHECK: IntegerType + # CHECK: IntegerType(i32) + print_item(IntegerType.get_signless(32), 0) + + f32 = F32Type.get() + ranked_tensor = tensor.EmptyOp([10, 10], f32).result + # CHECK: RankedTensorType + print(type(ranked_tensor.type).__name__) + # CHECK: RankedTensorType(tensor<10x10xf32>) + print(repr(ranked_tensor.type)) + + cf32 = ComplexType.get(f32) + # CHECK: ComplexType + print(type(cf32).__name__) + # CHECK: ComplexType(complex) + print(repr(cf32)) + + ranked_tensor = tensor.EmptyOp([10, 10], f32).result + # CHECK: RankedTensorType + print(type(ranked_tensor.type).__name__) + # CHECK: RankedTensorType(tensor<10x10xf32>) + print(repr(ranked_tensor.type)) + + vector = VectorType.get([10, 10], f32) + tuple_type = TupleType.get_tuple([f32, vector]) + # CHECK: TupleType + print(type(tuple_type).__name__) + # CHECK: TupleType(tuple>) + print(repr(tuple_type)) + # CHECK: F32Type(f32) + print(repr(tuple_type.get_type(0))) + # CHECK: VectorType(vector<10x10xf32>) + print(repr(tuple_type.get_type(1))) + + index_type = IndexType.get() + @func.FuncOp.from_py_func() + def default_builder(): + c0 = arith.ConstantOp(f32, 0.0) + unranked_tensor_type = UnrankedTensorType.get(f32) + unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result + # CHECK: UnrankedTensorType + print(type(unranked_tensor.type).__name__) + # CHECK: UnrankedTensorType(tensor<*xf32>) + print(repr(unranked_tensor.type)) + + c10 = arith.ConstantOp(index_type, 10) + memref_f32_t = MemRefType.get([10, 10], f32) + memref_f32 = memref.AllocOp(memref_f32_t, [c10, c10], []).result + # CHECK: MemRefType + print(type(memref_f32.type).__name__) + # CHECK: MemRefType(memref<10x10xf32>) + print(repr(memref_f32.type)) + + unranked_memref_t = UnrankedMemRefType.get(f32, Attribute.parse("2")) + memref_f32 = memref.AllocOp(unranked_memref_t, [c10, c10], []).result + # CHECK: UnrankedMemRefType + print(type(memref_f32.type).__name__) + # CHECK: UnrankedMemRefType(memref<*xf32, 2>) + print(repr(memref_f32.type)) + + tuple_type = Operation.parse(f'"test.make_tuple"() : () -> tuple').result + # CHECK: TupleType + print(type(tuple_type.type).__name__) + # CHECK: TupleType(tuple) + print(repr(tuple_type.type)) + + return c0, c10 + + func_op = default_builder.func_op + # CHECK: FunctionType + print(type(func_op.type).__name__) + # CHECK: FunctionType(() -> (f32, index)) + print(repr(func_op.type)) + # CHECK: [] + print(func_op.type.inputs) + # CHECK: [F32Type(f32), IndexType(index)] + print(func_op.type.results)