diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1922,6 +1922,28 @@ } }; +class PyTypeAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; + static constexpr const char *pyClassName = "TypeAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirTypeAttrGet(value.get()); + return PyTypeAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets a uniqued Type attribute"); + c.def_property_readonly("value", [](PyTypeAttribute &self) { + return PyType(self.getContext()->getRef(), + mlirTypeAttrGetValue(self.get())); + }); + } +}; + /// Unit Attribute subclass. Unit attributes don't have values. class PyUnitAttribute : public PyConcreteAttribute { public: @@ -3073,6 +3095,7 @@ PyDenseElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); PyDenseFPElementsAttribute::bind(m); + PyTypeAttribute::bind(m); PyUnitAttribute::bind(m); //---------------------------------------------------------------------------- diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -255,3 +255,17 @@ run(testDenseFPAttr) + + +# CHECK-LABEL: TEST: testTypeAttr +def testTypeAttr(): + with Context(): + raw = Attribute.parse("vector<4xf32>") + # CHECK: attr: vector<4xf32> + print("attr:", raw) + type_attr = TypeAttr(raw) + # CHECK: f32 + print(ShapedType(type_attr.value).element_type) + + +run(testTypeAttr)