diff --git a/mlir/include/mlir-c/StandardAttributes.h b/mlir/include/mlir-c/StandardAttributes.h --- a/mlir/include/mlir-c/StandardAttributes.h +++ b/mlir/include/mlir-c/StandardAttributes.h @@ -93,6 +93,11 @@ MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, double value); +/** Same as "mlirFloatAttrDoubleGet" but returns a nullptr wrapping + * MlirAttribute on illegal arguments, emitting appropriate diagnostics. */ +MlirAttribute mlirFloatAttrDoubleGetChecked(MlirType type, double value, + MlirLocation loc); + /** Returns the value stored in the given floating point attribute, interpreting * the value as double. */ double mlirFloatAttrGetValueDouble(MlirAttribute attr); diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -742,6 +742,88 @@ static void bindDerived(ClassTy &m) {} }; +/// Float Point Attribute subclass - FloatAttr. +class PyFloatAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; + static constexpr const char *pyClassName = "FloatAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_typed", + // TODO: Make the location optional and create a default location. + [](PyType &type, double value, PyLocation &loc) { + MlirAttribute attr = + mlirFloatAttrDoubleGetChecked(type.type, value, loc.loc); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirAttributeIsNull(attr)) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(type)).cast() + + "' and expected floating point type."); + } + return PyFloatAttribute(type.getContext(), attr); + }, + py::arg("type"), py::arg("value"), py::arg("loc"), + "Gets an uniqued float point attribute associated to a type"); + c.def_property_readonly( + "value", + [](PyFloatAttribute &self) { + return mlirFloatAttrGetValueDouble(self.attr); + }, + "Returns the value of the float point attribute"); + } +}; + +/// Integer Attribute subclass - IntegerAttr. +class PyIntegerAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; + static constexpr const char *pyClassName = "IntegerAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_typed", + [](PyType &type, int64_t value) { + MlirAttribute attr = mlirIntegerAttrGet(type.type, value); + return PyIntegerAttribute(type.getContext(), attr); + }, + py::arg("type"), py::arg("value"), + "Gets an uniqued integer attribute associated to a type"); + c.def_property_readonly( + "value", + [](PyIntegerAttribute &self) { + return mlirIntegerAttrGetValueInt(self.attr); + }, + "Returns the value of the integer attribute"); + } +}; + +/// Bool Attribute subclass - BoolAttr. +class PyBoolAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; + static constexpr const char *pyClassName = "BoolAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyMlirContext &context, int value) { + MlirAttribute attr = mlirBoolAttrGet(context.get(), value); + return PyBoolAttribute(context.getRef(), attr); + }, + py::arg("context"), py::arg("value"), "Gets an uniqued bool attribute"); + c.def_property_readonly( + "value", + [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self.attr); }, + "Returns the value of the bool attribute"); + } +}; + class PyStringAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; @@ -1630,6 +1712,9 @@ "The underlying generic attribute of the NamedAttribute binding"); // Standard attribute bindings. + PyFloatAttribute::bind(m); + PyIntegerAttribute::bind(m); + PyBoolAttribute::bind(m); PyStringAttribute::bind(m); // Mapping of Type. diff --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/StandardAttributes.cpp --- a/mlir/lib/CAPI/IR/StandardAttributes.cpp +++ b/mlir/lib/CAPI/IR/StandardAttributes.cpp @@ -102,6 +102,11 @@ return wrap(FloatAttr::get(unwrap(type), value)); } +MlirAttribute mlirFloatAttrDoubleGetChecked(MlirType type, double value, + MlirLocation loc) { + return wrap(FloatAttr::getChecked(unwrap(type), value, unwrap(loc))); +} + double mlirFloatAttrGetValueDouble(MlirAttribute attr) { return unwrap(attr).cast().getValueAsDouble(); } diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -92,6 +92,59 @@ run(testStandardAttrCasts) +# CHECK-LABEL: TEST: testFloatAttr +def testFloatAttr(): + ctx = mlir.ir.Context() + fattr = mlir.ir.FloatAttr(ctx.parse_attr("42.0 : f32")) + # CHECK: fattr value: 42.0 + print("fattr value:", fattr.value) + + # Test factory methods. + loc = ctx.get_unknown_location() + # CHECK: typed_get: 4.200000e+01 : f32 + print("typed_get:", mlir.ir.FloatAttr.get_typed( + mlir.ir.F32Type(ctx), 42.0, loc)) + try: + fattr_invalid = mlir.ir.FloatAttr.get_typed( + mlir.ir.IntegerType.get_signless(ctx, 32), 42, loc) + except ValueError as e: + # CHECK: invalid 'Type(i32)' and expected floating point type. + print(e) + else: + print("Exception not produced") + +run(testFloatAttr) + + +# CHECK-LABEL: TEST: testIntegerAttr +def testIntegerAttr(): + ctx = mlir.ir.Context() + iattr = mlir.ir.IntegerAttr(ctx.parse_attr("42")) + # CHECK: iattr value: 42 + print("iattr value:", iattr.value) + + # Test factory methods. + # CHECK: typed_get: 42 : i32 + print("typed_get:", mlir.ir.IntegerAttr.get_typed( + mlir.ir.IntegerType.get_signless(ctx, 32), 42)) + +run(testIntegerAttr) + + +# CHECK-LABEL: TEST: testBoolAttr +def testBoolAttr(): + ctx = mlir.ir.Context() + battr = mlir.ir.BoolAttr(ctx.parse_attr("true")) + # CHECK: iattr value: 1 + print("iattr value:", battr.value) + + # Test factory methods. + # CHECK: default_get: true + print("default_get:", mlir.ir.BoolAttr.get(ctx, 1)) + +run(testBoolAttr) + + # CHECK-LABEL: TEST: testStringAttr def testStringAttr(): ctx = mlir.ir.Context()