diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1643,6 +1643,33 @@ } }; +class PyFlatSymbolRefAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; + static constexpr const char *pyClassName = "FlatSymbolRefAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); + return PyFlatSymbolRefAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets a uniqued FlatSymbolRef attribute"); + c.def_property_readonly( + "value", + [](PyFlatSymbolRefAttribute &self) { + MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the value of the FlatSymbolRef attribute"); + } +}; + class PyStringAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; @@ -3229,6 +3256,7 @@ PyArrayAttribute::PyArrayAttributeIterator::bind(m); PyIntegerAttribute::bind(m); PyBoolAttribute::bind(m); + PyFlatSymbolRefAttribute::bind(m); PyStringAttribute::bind(m); PyDenseElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -165,6 +165,20 @@ run(testBoolAttr) +# CHECK-LABEL: TEST: testFlatSymbolRefAttr +def testFlatSymbolRefAttr(): + with Context() as ctx: + sattr = FlatSymbolRefAttr(Attribute.parse('@symbol')) + # CHECK: symattr value: symbol + print("symattr value:", sattr.value) + + # Test factory methods. + # CHECK: default_get: @foobar + print("default_get:", FlatSymbolRefAttr.get("foobar")) + +run(testFlatSymbolRefAttr) + + # CHECK-LABEL: TEST: testStringAttr def testStringAttr(): with Context() as ctx: