diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -4,6 +4,8 @@ set(PY_SRC_FILES mlir/__init__.py + mlir/extensions/__init__.py + mlir/extensions/std_dialect.py ) add_custom_target(MLIRBindingsPythonSources ALL diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -187,6 +187,45 @@ PyMlirContextRef contextRef; }; +/// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in +/// order to differentiate it from the `Dialect` base class which is extended by +/// plugins which extend dialect functionality through extension python code. +/// This should be seen as the "low-level" object and `Dialect` as the +/// high-level, user facing object. +class PyDialectDescriptor : public BaseContextObject { +public: + PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) + : BaseContextObject(std::move(contextRef)), dialect(dialect) {} + + MlirDialect get() { return dialect; } + +private: + MlirDialect dialect; +}; + +/// User-level object for accessing dialects with dotted syntax such as: +/// ctx.dialect.std +class PyDialects : public BaseContextObject { +public: + PyDialects(PyMlirContextRef contextRef) + : BaseContextObject(std::move(contextRef)) {} + + MlirDialect getDialectForKey(const std::string &key, bool attrError); +}; + +/// User-level dialect object. For dialects that have a registered extension, +/// this will be the base class of the extension dialect type. For un-extended, +/// objects of this type will be returned directly. +class PyDialect { +public: + PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} + + pybind11::object getDescriptor() { return descriptor; } + +private: + pybind11::object descriptor; +}; + /// Wrapper around an MlirLocation. class PyLocation : public BaseContextObject { public: @@ -412,7 +451,7 @@ MlirValue value; }; -void populateIRSubmodule(pybind11::module &m); +void populateIRSubmodule(pybind11::module m); } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -646,6 +646,17 @@ return PyOperation::createDetached(getRef(), operation).releaseObject(); } +MlirDialect PyDialects::getDialectForKey(const std::string &key, + bool attrError) { + MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), + {key.data(), key.size()}); + if (mlirDialectIsNull(dialect)) { + throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, + llvm::Twine("Dialect '") + key + "' not found"); + } + return dialect; +} + //------------------------------------------------------------------------------ // PyModule //------------------------------------------------------------------------------ @@ -1913,8 +1924,75 @@ // Populates the pybind11 IR submodule. //------------------------------------------------------------------------------ -void mlir::python::populateIRSubmodule(py::module &m) { +void mlir::python::populateIRSubmodule(py::module m) { + //---------------------------------------------------------------------------- + // Module local extension management. + // This is complicated because it is python state carried at the module + // level, effectively as statics. In order to ensure proper cleanup and + // sharing (in the case of defining multiple modules), this is managed + // via python constructs, making the code somewhat more complicated. + //---------------------------------------------------------------------------- + + // The list of module names to search for python extensions to standard + // classes. Each element in the list is the name of a parent module, + // containing sub-modules with well-defined names relating to various + // extensions. Modules are search in order for matches. + m.attr("extension_search_list") = py::list(); + + // Extensions modules loaded against this module, keyed by extension name. + // When an extension is not in the dict, it is to be referenced by searching + // for submodules with the given name in the extension_search_list. Once + // an extension has been loaded, it will be remembered, even if the search + // list changes. + m.attr("_extensions") = py::dict(); + + // Loads an extension module by extension name. This is done by searching + // the extension_search_list and constructing a derived module as: + // . + // Returns None on failure to find. + auto loadExtension = [m](py::str extensionName) -> py::module { + py::dict extensions = m.attr("_extensions"); + py::module extModule = extensions.attr("get")(extensionName); + if (!extModule.is_none()) + return extModule; + + auto searchList = m.attr("extension_search_list"); + for (auto it : searchList) { + std::string parentModule = py::str(it); + parentModule.append("."); + parentModule.append(py::str(extensionName)); + try { + extModule = py::module::import(parentModule.c_str()); + } catch (py::error_already_set &e) { + if (e.matches(PyExc_ImportError)) { + continue; + } + throw; + } + extensions[extensionName] = extModule; + return extModule; + } + + return py::none(); + }; + + auto createDialectExtension = [=](const std::string &dialectName, + py::object descriptor) -> py::object { + std::string extensionName = dialectName; + extensionName.append("_dialect"); + py::module extModule = loadExtension(extensionName); + if (extModule.is_none()) { + // Return a non extension base version. + return py::cast(PyDialect{std::move(descriptor)}); + } else { + // Instantiate the "Dialect" class in the extension. + return extModule.attr("Dialect")(std::move(descriptor)); + } + }; + + //---------------------------------------------------------------------------- // Mapping of MlirContext + //---------------------------------------------------------------------------- py::class_(m, "Context") .def(py::init<>(&PyMlirContext::createNewContextForInit)) .def_static("_get_live_count", &PyMlirContext::getLiveCount) @@ -1928,6 +2006,25 @@ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) + .def_property_readonly( + "dialect", + [](PyMlirContext &self) { return PyDialects(self.getRef()); }, + "Gets a container for accessing dialects by name") + .def_property_readonly( + "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, + "Alias for 'dialect'") + .def( + "get_dialect_descriptor", + [=](PyMlirContext &self, std::string &name) { + MlirDialect dialect = mlirContextGetOrLoadDialect( + self.get(), {name.data(), name.size()}); + if (mlirDialectIsNull(dialect)) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("Dialect '") + name + "' not found"); + } + return PyDialectDescriptor(self.getRef(), dialect); + }, + "Gets or loads a dialect by name, returning its descriptor object") .def_property( "allow_unregistered_dialects", [](PyMlirContext &self) -> bool { @@ -2009,6 +2106,62 @@ kContextGetFileLocationDocstring, py::arg("filename"), py::arg("line"), py::arg("col")); + //---------------------------------------------------------------------------- + // Mapping of PyDialectDescriptor + //---------------------------------------------------------------------------- + py::class_(m, "DialectDescriptor") + .def_property_readonly("namespace", + [](PyDialectDescriptor &self) { + MlirStringRef ns = + mlirDialectGetNamespace(self.get()); + return py::str(ns.data, ns.length); + }) + .def("__repr__", [](PyDialectDescriptor &self) { + MlirStringRef ns = mlirDialectGetNamespace(self.get()); + std::string repr(""); + return repr; + }); + + //---------------------------------------------------------------------------- + // Mapping of PyDialects + //---------------------------------------------------------------------------- + py::class_(m, "Dialects") + .def("__getitem__", + [=](PyDialects &self, std::string keyName) { + MlirDialect dialect = + self.getDialectForKey(keyName, /*attrError=*/false); + py::object descriptor = + py::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createDialectExtension(keyName, std::move(descriptor)); + }) + .def("__getattr__", [=](PyDialects &self, std::string attrName) { + MlirDialect dialect = + self.getDialectForKey(attrName, /*attrError=*/true); + py::object descriptor = + py::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createDialectExtension(attrName, std::move(descriptor)); + }); + + //---------------------------------------------------------------------------- + // Mapping of PyDialect + //---------------------------------------------------------------------------- + py::class_(m, "Dialect") + .def(py::init(), "descriptor") + .def_property_readonly( + "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) + .def("__repr__", [](py::object self) { + auto clazz = self.attr("__class__"); + return py::str(""); + }); + + //---------------------------------------------------------------------------- + // Mapping of Location + //---------------------------------------------------------------------------- py::class_(m, "Location") .def_property_readonly( "context", @@ -2021,7 +2174,9 @@ return printAccum.join(); }); + //---------------------------------------------------------------------------- // Mapping of Module + //---------------------------------------------------------------------------- py::class_(m, "Module") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) @@ -2055,7 +2210,9 @@ }, kOperationStrDunderDocstring); + //---------------------------------------------------------------------------- // Mapping of Operation. + //---------------------------------------------------------------------------- py::class_(m, "Operation") .def_property_readonly( "context", @@ -2098,7 +2255,9 @@ py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, kOperationGetAsmDocstring); + //---------------------------------------------------------------------------- // Mapping of PyRegion. + //---------------------------------------------------------------------------- py::class_(m, "Region") .def_property_readonly( "blocks", @@ -2123,7 +2282,9 @@ } }); + //---------------------------------------------------------------------------- // Mapping of PyBlock. + //---------------------------------------------------------------------------- py::class_(m, "Block") .def_property_readonly( "arguments", @@ -2167,7 +2328,9 @@ }, "Returns the assembly form of the block."); + //---------------------------------------------------------------------------- // Mapping of PyAttribute. + //---------------------------------------------------------------------------- py::class_(m, "Attribute") .def_property_readonly( "context", @@ -2219,6 +2382,9 @@ return printAccum.join(); }); + //---------------------------------------------------------------------------- + // Mapping of PyNamedAttribute + //---------------------------------------------------------------------------- py::class_(m, "NamedAttribute") .def("__repr__", [](PyNamedAttribute &self) { @@ -2257,7 +2423,9 @@ PyStringAttribute::bind(m); PyDenseElementsAttribute::bind(m); + //---------------------------------------------------------------------------- // Mapping of PyType. + //---------------------------------------------------------------------------- py::class_(m, "Type") .def_property_readonly( "context", [](PyType &self) { return self.getContext().getObject(); }, @@ -2313,7 +2481,9 @@ PyTupleType::bind(m); PyFunctionType::bind(m); + //---------------------------------------------------------------------------- // Mapping of Value. + //---------------------------------------------------------------------------- py::class_(m, "Value") .def_property_readonly( "context", diff --git a/mlir/lib/Bindings/Python/mlir/__init__.py b/mlir/lib/Bindings/Python/mlir/__init__.py --- a/mlir/lib/Bindings/Python/mlir/__init__.py +++ b/mlir/lib/Bindings/Python/mlir/__init__.py @@ -9,3 +9,6 @@ # scenarios. from _mlir import * + +# Add our dialect package to the module search list. +ir.extension_search_list.append("mlir.extensions") diff --git a/mlir/lib/Bindings/Python/mlir/extensions/__init__.py b/mlir/lib/Bindings/Python/mlir/extensions/__init__.py new file mode 100644 diff --git a/mlir/lib/Bindings/Python/mlir/extensions/std_dialect.py b/mlir/lib/Bindings/Python/mlir/extensions/std_dialect.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/extensions/std_dialect.py @@ -0,0 +1,21 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from mlir import ir + +class Dialect(ir.Dialect): + pass + + +# Example custom op: +# - The class decorator does some meta-programming to clean up its API and +# sets it as an attr on the static Dialect. +# - Can be accessed with "ctx.dialect.std.AddFOp" +# @Dialect.register_op +# class AddFOp(ir.Op): +# @staticmethod +# def create(ctx, loc, lhs, rhs): +# return ctx.create_operation( +# "std.addf", loc, operands=[lhs, rhs], results=[lhs.type]) + diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/dialects.py @@ -0,0 +1,62 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +import mlir + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert mlir.ir.Context._get_live_count() == 0 + + +# CHECK-LABEL: TEST: testDialectDescriptor +def testDialectDescriptor(): + ctx = mlir.ir.Context() + d = ctx.get_dialect_descriptor("std") + # CHECK: + print(d) + # CHECK: std + print(d.namespace) + try: + _ = ctx.get_dialect_descriptor("not_existing") + except ValueError: + pass + else: + assert False, "Expected exception" + +run(testDialectDescriptor) + + +# CHECK-LABEL: TEST: testUserDialects +def testUserDialects(): + ctx = mlir.ir.Context() + # Access using attribute. + d = ctx.dialect.std + # CHECK: + print(d) + try: + _ = ctx.dialect.not_existing + except AttributeError: + pass + else: + assert False, "Expected exception" + + # Access using index. + d = ctx.dialect["std"] + # CHECK: + print(d) + try: + _ = ctx.dialect["not_existing"] + except IndexError: + pass + else: + assert False, "Expected exception" + + # Using the 'd' alias. + d = ctx.d["std"] + # CHECK: + print(d) + + +run(testUserDialects)