diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -116,6 +116,9 @@ /** Parses a module from the string and transfers ownership to the caller. */ MlirModule mlirModuleCreateParse(MlirContext context, const char *module); +/** Checks whether a module is null. */ +inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; } + /** Takes a module owned by the caller and deletes it. */ void mlirModuleDestroy(MlirModule module); @@ -312,6 +315,9 @@ /** Parses a type. The type is owned by the context. */ MlirType mlirTypeParseGet(MlirContext context, const char *type); +/** Checks whether a type is null. */ +inline int mlirTypeIsNull(MlirType type) { return !type.ptr; } + /** Checks if two types are equal. */ int mlirTypeEqual(MlirType t1, MlirType t2); diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -17,38 +17,44 @@ namespace python { class PyMlirContext; -class PyMlirModule; +class PyModule; /// Wrapper around MlirContext. class PyMlirContext { public: PyMlirContext() { context = mlirContextCreate(); } ~PyMlirContext() { mlirContextDestroy(context); } - /// Parses the module from asm. - PyMlirModule parse(const std::string &module); MlirContext context; }; /// Wrapper around MlirModule. -class PyMlirModule { +class PyModule { public: - PyMlirModule(MlirModule module) : module(module) {} - PyMlirModule(PyMlirModule &) = delete; - PyMlirModule(PyMlirModule &&other) { + PyModule(MlirModule module) : module(module) {} + PyModule(PyModule &) = delete; + PyModule(PyModule &&other) { module = other.module; other.module.ptr = nullptr; } - ~PyMlirModule() { + ~PyModule() { if (module.ptr) mlirModuleDestroy(module); } - /// Dumps the module. - void dump(); MlirModule module; }; +/// Wrapper around the generic MlirType. +/// The lifetime of a type is bound by the PyContext that created it. +class PyType { +public: + PyType(MlirType type) : type(type) {} + bool operator==(const PyType &other); + + MlirType type; +}; + void populateIRSubmodule(pybind11::module &m); } // namespace python diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -9,7 +9,10 @@ #include "IRModules.h" #include "PybindUtils.h" +#include "mlir-c/StandardTypes.h" + namespace py = pybind11; +using namespace mlir; using namespace mlir::python; //------------------------------------------------------------------------------ @@ -20,6 +23,15 @@ R"(Parses a module's assembly format from a string. Returns a new MlirModule or raises a ValueError if the parsing fails. + +See also: https://mlir.llvm.org/docs/LangRef/ +)"; + +static const char kContextParseType[] = R"(Parses the assembly form of a type. + +Returns a Type object or raises a ValueError if the type cannot be parsed. + +See also: https://mlir.llvm.org/docs/LangRef/#type-system )"; static const char kOperationStrDunderDocstring[] = @@ -30,6 +42,9 @@ behavior. )"; +static const char kTypeStrDunderDocstring[] = + R"(Prints the assembly form of the type.)"; + static const char kDumpDocstring[] = R"(Dumps a debug representation of the object to stderr.)"; @@ -64,39 +79,154 @@ } // namespace //------------------------------------------------------------------------------ -// Context Wrapper Class. +// PyType. //------------------------------------------------------------------------------ -PyMlirModule PyMlirContext::parse(const std::string &module) { - auto moduleRef = mlirModuleCreateParse(context, module.c_str()); - if (!moduleRef.ptr) { - throw SetPyError(PyExc_ValueError, - "Unable to parse module assembly (see diagnostics)"); - } - return PyMlirModule(moduleRef); +bool PyType::operator==(const PyType &other) { + return mlirTypeEqual(type, other.type); } //------------------------------------------------------------------------------ -// Module Wrapper Class. +// Standard type subclasses. //------------------------------------------------------------------------------ -void PyMlirModule::dump() { mlirOperationDump(mlirModuleGetOperation(module)); } +namespace { + +/// CRTP base classes for Python types that subclass Type and should be +/// castable from it (i.e. via something like IntegerType(t)). +template +class PyConcreteType : public PyType { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = py::class_; + using IsAFunctionTy = int (*)(MlirType); + + PyConcreteType() = default; + PyConcreteType(MlirType t) : PyType(t) {} + PyConcreteType(PyType &orig) : PyType(castFrom(orig)) {} + + static MlirType castFrom(PyType &orig) { + if (!T::isaFunction(orig.type)) { + auto origRepr = py::repr(py::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + + T::pyClassName + " (from " + + origRepr + ")"); + } + return orig.type; + } + + static void bind(py::module &m) { + auto class_ = ClassTy(m, T::pyClassName); + class_.def(py::init(), py::keep_alive<0, 1>()); + T::bindDerived(class_); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +class PyIntegerType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; + static constexpr const char *pyClassName = "IntegerType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "signless", + [](PyMlirContext &context, unsigned width) { + MlirType t = mlirIntegerTypeGet(context.context, width); + return PyIntegerType(t); + }, + py::keep_alive<0, 1>(), "Create a signless integer type"); + c.def_static( + "signed", + [](PyMlirContext &context, unsigned width) { + MlirType t = mlirIntegerTypeSignedGet(context.context, width); + return PyIntegerType(t); + }, + py::keep_alive<0, 1>(), "Create a signed integer type"); + c.def_static( + "unsigned", + [](PyMlirContext &context, unsigned width) { + MlirType t = mlirIntegerTypeUnsignedGet(context.context, width); + return PyIntegerType(t); + }, + py::keep_alive<0, 1>(), "Create an unsigned integer type"); + c.def_property_readonly( + "width", + [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); }, + "Returns the width of the integer type"); + c.def_property_readonly( + "is_signless", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsSignless(self.type); + }, + "Returns whether this is a signless integer"); + c.def_property_readonly( + "is_signed", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsSigned(self.type); + }, + "Returns whether this is a signed integer"); + c.def_property_readonly( + "is_unsigned", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsUnsigned(self.type); + }, + "Returns whether this is an unsigned integer"); + } +}; + +} // namespace //------------------------------------------------------------------------------ // Populates the pybind11 IR submodule. //------------------------------------------------------------------------------ void mlir::python::populateIRSubmodule(py::module &m) { - py::class_(m, "MlirContext") + // Mapping of MlirContext + py::class_(m, "Context") .def(py::init<>()) - .def("parse", &PyMlirContext::parse, py::keep_alive<0, 1>(), - kContextParseDocstring); + .def( + "parse_module", + [](PyMlirContext &self, const std::string module) { + auto moduleRef = + mlirModuleCreateParse(self.context, module.c_str()); + if (mlirModuleIsNull(moduleRef)) { + throw SetPyError( + PyExc_ValueError, + "Unable to parse module assembly (see diagnostics)"); + } + return PyModule(moduleRef); + }, + py::keep_alive<0, 1>(), kContextParseDocstring) + .def( + "parse_type", + [](PyMlirContext &self, std::string typeSpec) { + MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str()); + if (mlirTypeIsNull(type)) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("Unable to parse type: '") + + typeSpec + "'"); + } + return PyType(type); + }, + py::keep_alive<0, 1>(), kContextParseType); - py::class_(m, "MlirModule") - .def("dump", &PyMlirModule::dump, kDumpDocstring) + // Mapping of Module + py::class_(m, "Module") + .def( + "dump", + [](PyModule &self) { + mlirOperationDump(mlirModuleGetOperation(self.module)); + }, + kDumpDocstring) .def( "__str__", - [](PyMlirModule &self) { + [](PyModule &self) { auto operation = mlirModuleGetOperation(self.module); PyPrintAccumulator printAccum; mlirOperationPrint(operation, printAccum.getCallback(), @@ -104,4 +234,42 @@ return printAccum.join(); }, kOperationStrDunderDocstring); + + // Mapping of Type. + py::class_(m, "Type") + .def("__eq__", + [](PyType &self, py::object &other) { + try { + PyType otherType = other.cast(); + return self == otherType; + } catch (std::exception &e) { + return false; + } + }) + .def( + "dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring) + .def( + "__str__", + [](PyType &self) { + PyPrintAccumulator printAccum; + mlirTypePrint(self.type, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + kTypeStrDunderDocstring) + .def("__repr__", [](PyType &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, types are an exception as they typically have compact + // assembly forms and printing them is useful. + PyPrintAccumulator printAccum; + printAccum.parts.append("Type("); + mlirTypePrint(self.type, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + + // Standard type bindings. + PyIntegerType::bind(m); } diff --git a/mlir/test/Bindings/Python/ir_module_test.py b/mlir/test/Bindings/Python/ir_module.py rename from mlir/test/Bindings/Python/ir_module_test.py rename to mlir/test/Bindings/Python/ir_module.py --- a/mlir/test/Bindings/Python/ir_module_test.py +++ b/mlir/test/Bindings/Python/ir_module.py @@ -3,15 +3,15 @@ import mlir def run(f): - print("TEST:", f.__name__) + print("\nTEST:", f.__name__) f() # Verify successful parse. # CHECK-LABEL: TEST: testParseSuccess # CHECK: module @successfulParse def testParseSuccess(): - ctx = mlir.ir.MlirContext() - module = ctx.parse(r"""module @successfulParse {}""") + ctx = mlir.ir.Context() + module = ctx.parse_module(r"""module @successfulParse {}""") module.dump() # Just outputs to stderr. Verifies that it functions. print(str(module)) @@ -22,9 +22,9 @@ # CHECK-LABEL: TEST: testParseError # CHECK: testParseError: Unable to parse module assembly (see diagnostics) def testParseError(): - ctx = mlir.ir.MlirContext() + ctx = mlir.ir.Context() try: - module = ctx.parse(r"""}SYNTAX ERROR{""") + module = ctx.parse_module(r"""}SYNTAX ERROR{""") except ValueError as e: print("testParseError:", e) else: @@ -40,8 +40,8 @@ # CHECK: func @roundtripUnicode() # CHECK: foo = "\F0\9F\98\8A" def testRoundtripUnicode(): - ctx = mlir.ir.MlirContext() - module = ctx.parse(r""" + ctx = mlir.ir.Context() + module = ctx.parse_module(r""" func @roundtripUnicode() attributes { foo = "😊" } """) print(str(module)) diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/ir_types.py @@ -0,0 +1,126 @@ +# RUN: %PYTHON %s | FileCheck %s + +import mlir + +def run(f): + print("\nTEST:", f.__name__) + f() + + +# CHECK-LABEL: TEST: testParsePrint +def testParsePrint(): + ctx = mlir.ir.Context() + t = ctx.parse_type("i32") + # CHECK: i32 + print(str(t)) + # CHECK: Type(i32) + print(repr(t)) + +run(testParsePrint) + + +# CHECK-LABEL: TEST: testParseError +# TODO: Hook the diagnostic manager to capture a more meaningful error +# message. +def testParseError(): + ctx = mlir.ir.Context() + try: + t = ctx.parse_type("BAD_TYPE_DOES_NOT_EXIST") + except ValueError as e: + # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST' + print("testParseError:", e) + else: + print("Exception not produced") + +run(testParseError) + + +# CHECK-LABEL: TEST: testTypeEq +def testTypeEq(): + ctx = mlir.ir.Context() + t1 = ctx.parse_type("i32") + t2 = ctx.parse_type("f32") + t3 = ctx.parse_type("i32") + # CHECK: t1 == t1: True + print("t1 == t1:", t1 == t1) + # CHECK: t1 == t2: False + print("t1 == t2:", t1 == t2) + # CHECK: t1 == t3: True + print("t1 == t3:", t1 == t3) + # CHECK: t1 == None: False + print("t1 == None:", t1 == None) + +run(testTypeEq) + + +# CHECK-LABEL: TEST: testTypeEqDoesNotRaise +def testTypeEqDoesNotRaise(): + ctx = mlir.ir.Context() + t1 = ctx.parse_type("i32") + not_a_type = "foo" + # CHECK: False + print(t1 == not_a_type) + # CHECK: False + print(t1 == None) + # CHECK: True + print(t1 != None) + +run(testTypeEqDoesNotRaise) + + +# CHECK-LABEL: TEST: testStandardTypeCasts +def testStandardTypeCasts(): + ctx = mlir.ir.Context() + t1 = ctx.parse_type("i32") + tint = mlir.ir.IntegerType(t1) + tself = mlir.ir.IntegerType(tint) + # CHECK: Type(i32) + print(repr(tint)) + try: + tillegal = mlir.ir.IntegerType(ctx.parse_type("f32")) + except ValueError as e: + # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32)) + print("ValueError:", e) + else: + print("Exception not produced") + +run(testStandardTypeCasts) + + +# CHECK-LABEL: TEST: testIntegerType +def testIntegerType(): + ctx = mlir.ir.Context() + i32 = mlir.ir.IntegerType(ctx.parse_type("i32")) + # CHECK: i32 width: 32 + print("i32 width:", i32.width) + # CHECK: i32 signless: True + print("i32 signless:", i32.is_signless) + # CHECK: i32 signed: False + print("i32 signed:", i32.is_signed) + # CHECK: i32 unsigned: False + print("i32 unsigned:", i32.is_unsigned) + + s32 = mlir.ir.IntegerType(ctx.parse_type("si32")) + # CHECK: s32 signless: False + print("s32 signless:", s32.is_signless) + # CHECK: s32 signed: True + print("s32 signed:", s32.is_signed) + # CHECK: s32 unsigned: False + print("s32 unsigned:", s32.is_unsigned) + + u32 = mlir.ir.IntegerType(ctx.parse_type("ui32")) + # CHECK: u32 signless: False + print("u32 signless:", u32.is_signless) + # CHECK: u32 signed: False + print("u32 signed:", u32.is_signed) + # CHECK: u32 unsigned: True + print("u32 unsigned:", u32.is_unsigned) + + # CHECK: signless: i16 + print("signless:", mlir.ir.IntegerType.signless(ctx, 16)) + # CHECK: signed: si8 + print("signed:", mlir.ir.IntegerType.signed(ctx, 8)) + # CHECK: unsigned: ui64 + print("unsigned:", mlir.ir.IntegerType.unsigned(ctx, 64)) + +run(testIntegerType)