diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1763,6 +1763,23 @@ static void bindDerived(ClassTy &m) {} }; +class PyAffineMapAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; + static constexpr const char *pyClassName = "AffineMapAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyAffineMap &affineMap) { + MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); + return PyAffineMapAttribute(affineMap.getContext(), attr); + }, + py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); + } +}; + class PyArrayAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; @@ -3994,17 +4011,18 @@ "The underlying generic attribute of the NamedAttribute binding"); // Builtin attribute bindings. - PyFloatAttribute::bind(m); + PyAffineMapAttribute::bind(m); PyArrayAttribute::bind(m); PyArrayAttribute::PyArrayAttributeIterator::bind(m); - PyIntegerAttribute::bind(m); PyBoolAttribute::bind(m); - PyFlatSymbolRefAttribute::bind(m); - PyStringAttribute::bind(m); PyDenseElementsAttribute::bind(m); - PyDenseIntElementsAttribute::bind(m); PyDenseFPElementsAttribute::bind(m); + PyDenseIntElementsAttribute::bind(m); PyDictAttribute::bind(m); + PyFlatSymbolRefAttribute::bind(m); + PyFloatAttribute::bind(m); + PyIntegerAttribute::bind(m); + PyStringAttribute::bind(m); PyTypeAttribute::bind(m); PyUnitAttribute::bind(m); diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -107,6 +107,24 @@ run(testStandardAttrCasts) +# CHECK-LABEL: TEST: testAffineMapAttr +def testAffineMapAttr(): + with Context() as ctx: + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + map0 = AffineMap.get(2, 3, []) + + # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()> + attr_built = AffineMapAttr.get(map0) + print(str(attr_built)) + + attr_parsed = Attribute.parse(str(attr_built)) + assert attr_built == attr_parsed + +run(testAffineMapAttr) + + # CHECK-LABEL: TEST: testFloatAttr def testFloatAttr(): with Context(), Location.unknown():