diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -336,6 +336,9 @@ /** Parses an attribute. The attribute is owned by the context. */ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr); +/** Checks whether a type is null. */ +inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; } + /** Checks if two attributes are equal. */ int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2); diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -45,6 +45,37 @@ MlirModule module; }; +/// Wrapper around the generic MlirAttribute. +/// The lifetime of a type is bound by the PyContext that created it. +class PyAttribute { +public: + PyAttribute(MlirAttribute attr) : attr(attr) {} + bool operator==(const PyAttribute &other); + + MlirAttribute attr; +}; + +/// Represents a Python MlirNamedAttr, carrying an optional owned name. +class PyNamedAttribute { +public: + /// Constructs a PyNamedAttr that retains an owned name. This should be + /// used in any code that originates an MlirNamedAttribute from a python + /// string. + /// The lifetime of the PyNamedAttr must extend to the lifetime of the + /// passed attribute. + PyNamedAttribute(MlirAttribute attr, std::string ownedName); + + MlirNamedAttribute namedAttr; + +private: + // Since the MlirNamedAttr contains an internal pointer to the actual + // memory of the owned string, it must be heap allocated to remain valid. + // Otherwise, strings that fit within the small object optimization threshold + // will have their memory address change as the containing object is moved, + // resulting in an invalid aliased pointer. + std::unique_ptr ownedName; +}; + /// Wrapper around the generic MlirType. /// The lifetime of a type is bound by the PyContext that created it. class PyType { diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -9,6 +9,7 @@ #include "IRModules.h" #include "PybindUtils.h" +#include "mlir-c/StandardAttributes.h" #include "mlir-c/StandardTypes.h" namespace py = pybind11; @@ -76,8 +77,49 @@ } }; +struct PySinglePartStringAccumulator { + void *getUserData() { return this; } + + MlirStringCallback getCallback() { + return [](const char *part, intptr_t size, void *userData) { + PySinglePartStringAccumulator *accum = + static_cast(userData); + assert(!accum->invoked && + "PySinglePartStringAccumulator called back multiple times"); + accum->invoked = true; + accum->value = py::str(part, size); + }; + } + + py::str takeValue() { + assert(invoked && "PySinglePartStringAccumulator not called back"); + return std::move(value); + } + +private: + py::str value; + bool invoked = false; +}; + } // namespace +//------------------------------------------------------------------------------ +// PyAttribute. +//------------------------------------------------------------------------------ + +bool PyAttribute::operator==(const PyAttribute &other) { + return mlirAttributeEqual(attr, other.attr); +} + +//------------------------------------------------------------------------------ +// PyNamedAttribute. +//------------------------------------------------------------------------------ + +PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) + : ownedName(new std::string(std::move(ownedName))) { + namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr); +} + //------------------------------------------------------------------------------ // PyType. //------------------------------------------------------------------------------ @@ -86,6 +128,86 @@ return mlirTypeEqual(type, other.type); } +//------------------------------------------------------------------------------ +// Standard attribute subclasses. +//------------------------------------------------------------------------------ + +namespace { + +/// CRTP base classes for Python attributes that subclass Attribute and should +/// be castable from it (i.e. via something like StringAttr(attr)). +template +class PyConcreteAttribute : public PyAttribute { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = py::class_; + using IsAFunctionTy = int (*)(MlirAttribute); + + PyConcreteAttribute() = default; + PyConcreteAttribute(MlirAttribute attr) : PyAttribute(attr) {} + PyConcreteAttribute(PyAttribute &orig) + : PyConcreteAttribute(castFrom(orig)) {} + + static MlirAttribute castFrom(PyAttribute &orig) { + if (!T::isaFunction(orig.attr)) { + auto origRepr = py::repr(py::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, + llvm::Twine("Cannot cast attribute to ") + + T::pyClassName + " (from " + origRepr + ")"); + } + return orig.attr; + } + + static void bind(py::module &m) { + auto class_ = ClassTy(m, T::pyClassName); + class_.def(py::init(), py::keep_alive<0, 1>()); + T::bindDerived(class_); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +class PyStringAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; + static constexpr const char *pyClassName = "StringAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyMlirContext &context, std::string value) { + MlirAttribute attr = + mlirStringAttrGet(context.context, value.size(), &value[0]); + return PyStringAttribute(attr); + }, + py::keep_alive<0, 1>(), "Gets a uniqued string attribute"); + c.def_static( + "get_typed", + [](PyType &type, std::string value) { + MlirAttribute attr = + mlirStringAttrTypedGet(type.type, value.size(), &value[0]); + return PyStringAttribute(attr); + }, + py::keep_alive<0, 1>(), + "Gets a uniqued string attribute associated to a type"); + c.def_property_readonly( + "value", + [](PyStringAttribute &self) { + PySinglePartStringAccumulator accum; + mlirStringAttrGetValue(self.attr, accum.getCallback(), + accum.getUserData()); + return accum.takeValue(); + }, + "Returns the value of the string attribute"); + } +}; + +} // namespace + //------------------------------------------------------------------------------ // Standard type subclasses. //------------------------------------------------------------------------------ @@ -135,21 +257,21 @@ static void bindDerived(ClassTy &c) { c.def_static( - "signless", + "get_signless", [](PyMlirContext &context, unsigned width) { MlirType t = mlirIntegerTypeGet(context.context, width); return PyIntegerType(t); }, py::keep_alive<0, 1>(), "Create a signless integer type"); c.def_static( - "signed", + "get_signed", [](PyMlirContext &context, unsigned width) { MlirType t = mlirIntegerTypeSignedGet(context.context, width); return PyIntegerType(t); }, py::keep_alive<0, 1>(), "Create a signed integer type"); c.def_static( - "unsigned", + "get_unsigned", [](PyMlirContext &context, unsigned width) { MlirType t = mlirIntegerTypeUnsignedGet(context.context, width); return PyIntegerType(t); @@ -203,6 +325,19 @@ return PyModule(moduleRef); }, py::keep_alive<0, 1>(), kContextParseDocstring) + .def( + "parse_attr", + [](PyMlirContext &self, std::string attrSpec) { + MlirAttribute type = + mlirAttributeParseGet(self.context, attrSpec.c_str()); + if (mlirAttributeIsNull(type)) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("Unable to parse attribute: '") + + attrSpec + "'"); + } + return PyAttribute(type); + }, + py::keep_alive<0, 1>()) .def( "parse_type", [](PyMlirContext &self, std::string typeSpec) { @@ -235,6 +370,79 @@ }, kOperationStrDunderDocstring); + // Mapping of Type. + py::class_(m, "Attribute") + .def( + "named", + [](PyAttribute &self, std::string name) { + return PyNamedAttribute(self.attr, std::move(name)); + }, + py::keep_alive<0, 1>(), "Binds a name to the attribute") + .def("__eq__", + [](PyAttribute &self, py::object &other) { + try { + PyAttribute otherAttribute = other.cast(); + return self == otherAttribute; + } catch (std::exception &e) { + return false; + } + }) + .def( + "dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); }, + kDumpDocstring) + .def( + "__str__", + [](PyAttribute &self) { + PyPrintAccumulator printAccum; + mlirAttributePrint(self.attr, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + kTypeStrDunderDocstring) + .def("__repr__", [](PyAttribute &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, attribute values are generally considered useful and are + // printed. This may need to be re-evaluated if debug dumps end up + // being excessive. + PyPrintAccumulator printAccum; + printAccum.parts.append("Attribute("); + mlirAttributePrint(self.attr, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + + py::class_(m, "NamedAttribute") + .def("__repr__", + [](PyNamedAttribute &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("NamedAttribute("); + printAccum.parts.append(self.namedAttr.name); + printAccum.parts.append("="); + mlirAttributePrint(self.namedAttr.attribute, + printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "name", + [](PyNamedAttribute &self) { + return py::str(self.namedAttr.name, strlen(self.namedAttr.name)); + }, + "The name of the NamedAttribute binding") + .def_property_readonly( + "attr", + [](PyNamedAttribute &self) { + return PyAttribute(self.namedAttr.attribute); + }, + py::keep_alive<0, 1>(), + "The underlying generic attribute of the NamedAttribute binding"); + + // Standard attribute bindings. + PyStringAttribute::bind(m); + // Mapping of Type. py::class_(m, "Type") .def("__eq__", diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -0,0 +1,119 @@ +# RUN: %PYTHON %s | FileCheck %s + +import mlir + +def run(f): + print("\nTEST:", f.__name__) + f() + + +# CHECK-LABEL: TEST: testParsePrint +def testParsePrint(): + ctx = mlir.ir.Context() + t = ctx.parse_attr('"hello"') + # CHECK: "hello" + print(str(t)) + # CHECK: Attribute("hello") + print(repr(t)) + +run(testParsePrint) + + +# CHECK-LABEL: TEST: testParseError +# TODO: Hook the diagnostic manager to capture a more meaningful error +# message. +def testParseError(): + ctx = mlir.ir.Context() + try: + t = ctx.parse_attr("BAD_ATTR_DOES_NOT_EXIST") + except ValueError as e: + # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST' + print("testParseError:", e) + else: + print("Exception not produced") + +run(testParseError) + + +# CHECK-LABEL: TEST: testAttrEq +def testAttrEq(): + ctx = mlir.ir.Context() + a1 = ctx.parse_attr('"attr1"') + a2 = ctx.parse_attr('"attr2"') + a3 = ctx.parse_attr('"attr1"') + # CHECK: a1 == a1: True + print("a1 == a1:", a1 == a1) + # CHECK: a1 == a2: False + print("a1 == a2:", a1 == a2) + # CHECK: a1 == a3: True + print("a1 == a3:", a1 == a3) + # CHECK: a1 == None: False + print("a1 == None:", a1 == None) + +run(testAttrEq) + + +# CHECK-LABEL: TEST: testAttrEqDoesNotRaise +def testAttrEqDoesNotRaise(): + ctx = mlir.ir.Context() + a1 = ctx.parse_attr('"attr1"') + not_an_attr = "foo" + # CHECK: False + print(a1 == not_an_attr) + # CHECK: False + print(a1 == None) + # CHECK: True + print(a1 != None) + +run(testAttrEqDoesNotRaise) + + +# CHECK-LABEL: TEST: testStandardAttrCasts +def testStandardAttrCasts(): + ctx = mlir.ir.Context() + a1 = ctx.parse_attr('"attr1"') + astr = mlir.ir.StringAttr(a1) + aself = mlir.ir.StringAttr(astr) + # CHECK: Attribute("attr1") + print(repr(astr)) + try: + tillegal = mlir.ir.StringAttr(ctx.parse_attr("1.0")) + except ValueError as e: + # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64)) + print("ValueError:", e) + else: + print("Exception not produced") + +run(testStandardAttrCasts) + + +# CHECK-LABEL: TEST: testStringAttr +def testStringAttr(): + ctx = mlir.ir.Context() + sattr = mlir.ir.StringAttr(ctx.parse_attr('"stringattr"')) + # CHECK: sattr value: stringattr + print("sattr value:", sattr.value) + + # Test factory methods. + # CHECK: default_get: "foobar" + print("default_get:", mlir.ir.StringAttr.get(ctx, "foobar")) + # CHECK: typed_get: "12345" : i32 + print("typed_get:", mlir.ir.StringAttr.get_typed( + mlir.ir.IntegerType.get_signless(ctx, 32), "12345")) + +run(testStringAttr) + + +# CHECK-LABEL: TEST: testNamedAttr +def testNamedAttr(): + ctx = mlir.ir.Context() + a = ctx.parse_attr('"stringattr"') + named = a.named("foobar") # Note: under the small object threshold + # CHECK: attr: "stringattr" + print("attr:", named.attr) + # CHECK: name: foobar + print("name:", named.name) + # CHECK: named: NamedAttribute(foobar="stringattr") + print("named:", named) + +run(testNamedAttr) diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -117,10 +117,10 @@ print("u32 unsigned:", u32.is_unsigned) # CHECK: signless: i16 - print("signless:", mlir.ir.IntegerType.signless(ctx, 16)) + print("signless:", mlir.ir.IntegerType.get_signless(ctx, 16)) # CHECK: signed: si8 - print("signed:", mlir.ir.IntegerType.signed(ctx, 8)) + print("signed:", mlir.ir.IntegerType.get_signed(ctx, 8)) # CHECK: unsigned: ui64 - print("unsigned:", mlir.ir.IntegerType.unsigned(ctx, 64)) + print("unsigned:", mlir.ir.IntegerType.get_unsigned(ctx, 64)) run(testIntegerType)