diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -4,6 +4,9 @@ set(PY_SRC_FILES mlir/__init__.py + mlir/ir.py + mlir/dialects/__init__.py + mlir/dialects/std.py ) add_custom_target(MLIRBindingsPythonSources ALL diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/Globals.h @@ -0,0 +1,93 @@ +//===- Globals.h - MLIR Python extension globals --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H +#define MLIR_BINDINGS_PYTHON_GLOBALS_H + +#include +#include + +#include "PybindUtils.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace python { + +/// Globals that are always accessible once the extension has been initialized. +class PyGlobals { +public: + PyGlobals(); + ~PyGlobals(); + + /// Most code should get the globals via this static accessor. + static PyGlobals &get() { + assert(instance && "PyGlobals is null"); + return *instance; + } + + /// Get and set the list of parent modules to search for dialect + /// implementation classes. + std::vector &getDialectSearchPrefixes() { + return dialectSearchPrefixes; + } + void setDialectSearchPrefixes(std::vector newValues) { + dialectSearchPrefixes.swap(newValues); + } + + /// Loads or gets an already loaded module containing dialect implementation + /// classes. Raises an error if there is a problem loading. Returns None + /// if not found. Note that loading a dialect module can re-enter anywhere + /// in the API and any handles to global state should not be relied on + /// across this call. + pybind11::object getOrLoadDialectModule(const std::string &dialectNamespace); + + /// Decorator for registering a custom Dialect class. The class object must + /// have a DIALECT_NAMESPACE attribute. + pybind11::object registerDialectDecorator(pybind11::object pyClass); + + /// Adds a concrete implementation dialect class. + /// Raises an exception of the mapping already exists. + /// This is intended to be called by implementation code. + void registerDialectImpl(const std::string &dialectNamespace, + pybind11::object pyClass); + + /// Adds a concrete implementation operation class. + /// Raises an exception of the mapping already exists. + /// This is intended to be called by implementation code. + void registerOperationImpl(const std::string &operationName, + pybind11::object pyClass, + pybind11::object rawClass); + + /// Looks up a registered dialect class by namespace. Note that this may + /// trigger loading of an the defining module and can arbitrarily re-enter. + llvm::Optional + lookupDialectClass(const std::string &dialectNamespace); + +private: + static PyGlobals *instance; + /// Module name prefixes to search under for dialect implementation modules. + std::vector dialectSearchPrefixes; + /// Map of dialect namespace to parent module holding both the custom Dialect + /// and Op wrappers. + llvm::StringMap dialectParentModuleMap; + /// Map of dialect namespace to external dialect class object. + llvm::StringMap dialectClassMap; + /// Map of full operation name to external operation class object. + llvm::StringMap operationClassMap; + /// Map of operation name to custom subclass that directly initializes + /// the OpView base class (bypassing the user class constructor). + llvm::StringMap rawOperationClassMap; +}; + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_GLOBALS_H diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -132,6 +132,7 @@ /// Creates an operation. See corresponding python docstring. pybind11::object createOperation(std::string name, PyLocation location, + llvm::Optional> operands, llvm::Optional> results, llvm::Optional attributes, llvm::Optional> successors, @@ -187,6 +188,45 @@ PyMlirContextRef contextRef; }; +/// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in +/// order to differentiate it from the `Dialect` base class which is extended by +/// plugins which extend dialect functionality through extension python code. +/// This should be seen as the "low-level" object and `Dialect` as the +/// high-level, user facing object. +class PyDialectDescriptor : public BaseContextObject { +public: + PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) + : BaseContextObject(std::move(contextRef)), dialect(dialect) {} + + MlirDialect get() { return dialect; } + +private: + MlirDialect dialect; +}; + +/// User-level object for accessing dialects with dotted syntax such as: +/// ctx.dialect.std +class PyDialects : public BaseContextObject { +public: + PyDialects(PyMlirContextRef contextRef) + : BaseContextObject(std::move(contextRef)) {} + + MlirDialect getDialectForKey(const std::string &key, bool attrError); +}; + +/// User-level dialect object. For dialects that have a registered extension, +/// this will be the base class of the extension dialect type. For un-extended, +/// objects of this type will be returned directly. +class PyDialect { +public: + PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} + + pybind11::object getDescriptor() { return descriptor; } + +private: + pybind11::object descriptor; +}; + /// Wrapper around an MlirLocation. class PyLocation : public BaseContextObject { public: @@ -305,6 +345,24 @@ bool valid = true; }; +/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for +/// providing more instance-specific accessors and serve as the base class for +/// custom ODS-style operation classes. Since this class is subclass on the +/// python side, it must present an __init__ method that operates in pure +/// python types. +class PyOpView { +public: + PyOpView(pybind11::object operation); + + static pybind11::object createRawSubclass(pybind11::object userClass); + + pybind11::object getOperationObject() { return operationObject; } + +private: + pybind11::object operationObject; // Holds the reference. + PyOperation *operation; // For efficient, cast-free access from C++ +}; + /// Wrapper around an MlirRegion. /// Regions are managed completely by their containing operation. Unlike the /// C++ API, the python API does not support detached regions. @@ -412,7 +470,7 @@ MlirValue value; }; -void populateIRSubmodule(pybind11::module &m); +void populateIRSubmodule(pybind11::module m); } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "IRModules.h" + +#include "Globals.h" #include "PybindUtils.h" #include "mlir-c/Bindings/Python/Interop.h" @@ -209,19 +211,27 @@ } // namespace //------------------------------------------------------------------------------ -// Type-checking utilities. +// Utilities. //------------------------------------------------------------------------------ -namespace { - /// Checks whether the given type is an integer or float type. -int mlirTypeIsAIntegerOrFloat(MlirType type) { +static int mlirTypeIsAIntegerOrFloat(MlirType type) { return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); } -} // namespace +static py::object +createCustomDialectWrapper(const std::string &dialectNamespace, + py::object dialectDescriptor) { + auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); + if (!dialectClass) { + // Use the base class. + return py::cast(PyDialect(std::move(dialectDescriptor))); + } + // Create the custom implementation. + return (*dialectClass)(std::move(dialectDescriptor)); +} //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -567,9 +577,11 @@ py::object PyMlirContext::createOperation( std::string name, PyLocation location, + llvm::Optional> operands, llvm::Optional> results, llvm::Optional attributes, llvm::Optional> successors, int regions) { + llvm::SmallVector mlirOperands; llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; llvm::SmallVector, 4> mlirAttributes; @@ -578,6 +590,16 @@ if (regions < 0) throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); + // Unpack/validate operands. + if (operands) { + mlirOperands.reserve(operands->size()); + for (PyValue *operand : *operands) { + if (!operand) + throw SetPyError(PyExc_ValueError, "operand value cannot be None"); + mlirOperands.push_back(operand->get()); + } + } + // Unpack/validate results. if (results) { mlirResults.reserve(results->size()); @@ -614,6 +636,9 @@ // Apply unpacked/validated to the operation state. Beyond this // point, exceptions cannot be thrown or else the state will leak. MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc); + if (!mlirOperands.empty()) + mlirOperationStateAddOperands(&state, mlirOperands.size(), + mlirOperands.data()); if (!mlirResults.empty()) mlirOperationStateAddResults(&state, mlirResults.size(), mlirResults.data()); @@ -646,6 +671,24 @@ return PyOperation::createDetached(getRef(), operation).releaseObject(); } +//------------------------------------------------------------------------------ +// PyDialect, PyDialectDescriptor, PyDialects +//------------------------------------------------------------------------------ + +MlirDialect PyDialects::getDialectForKey(const std::string &key, + bool attrError) { + // If the "std" dialect was asked for, substitute the empty namespace :( + static const std::string emptyKey; + const std::string *canonKey = key == "std" ? &emptyKey : &key; + MlirDialect dialect = mlirContextGetOrLoadDialect( + getContext()->get(), {canonKey->data(), canonKey->size()}); + if (mlirDialectIsNull(dialect)) { + throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, + llvm::Twine("Dialect '") + key + "' not found"); + } + return dialect; +} + //------------------------------------------------------------------------------ // PyModule //------------------------------------------------------------------------------ @@ -815,6 +858,45 @@ return fileObject.attr("getvalue")(); } +PyOpView::PyOpView(py::object operation) + : operationObject(std::move(operation)), + operation(py::cast(this->operationObject)) {} + +py::object PyOpView::createRawSubclass(py::object userClass) { + // This is... a little gross. The typical pattern is to have a pure python + // class that extends OpView like: + // class AddFOp(_cext.ir.OpView): + // def __init__(self, loc, lhs, rhs): + // operation = loc.context.create_operation( + // "addf", lhs, rhs, results=[lhs.type]) + // super().__init__(operation) + // + // I.e. The goal of the user facing type is to provide a nice constructor + // that has complete freedom for the op under construction. This is at odds + // with our other desire to sometimes create this object by just passing an + // operation (to initialize the base class). We could do *arg and **kwargs + // munging to try to make it work, but instead, we synthesize a new class + // on the fly which extends this user class (AddFOp in this example) and + // *give it* the base class's __init__ method, thus bypassing the + // intermediate subclass's __init__ method entirely. While slightly, + // underhanded, this is safe/legal because the type hierarchy has not changed + // (we just added a new leaf) and we aren't mucking around with __new__. + // Typically, this new class will be stored on the original as "_Raw" and will + // be used for casts and other things that need a variant of the class that + // is initialized purely from an operation. + py::object parentMetaclass = + py::reinterpret_borrow((PyObject *)&PyType_Type); + py::dict attributes; + // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from + // now. + // auto opViewType = py::type::of(); + auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); + attributes["__init__"] = opViewType.attr("__init__"); + py::str origName = userClass.attr("__name__"); + py::str newName = py::str("_") + origName; + return parentMetaclass(newName, py::make_tuple(userClass), attributes); +} + //------------------------------------------------------------------------------ // PyAttribute. //------------------------------------------------------------------------------ @@ -966,6 +1048,41 @@ MlirBlock block; }; +/// A list of operation results. Internally, these are stored as consecutive +/// elements, random access is cheap. The result list is associated with the +/// operation whose results these are, and extends the lifetime of this +/// operation. +class PyOpOperandList { +public: + PyOpOperandList(PyOperationRef operation) : operation(operation) {} + + /// Returns the length of the result list. + intptr_t dunderLen() { + operation->checkValid(); + return mlirOperationGetNumOperands(operation->get()); + } + + /// Returns `index`-th element in the result list. + PyOpResult dunderGetItem(intptr_t index) { + if (index < 0 || index >= dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds region"); + } + PyValue value(operation, mlirOperationGetOperand(operation->get(), index)); + return PyOpResult(value); + } + + /// Defines a Python class in the bindings. + static void bind(py::module &m) { + py::class_(m, "OpOperandList") + .def("__len__", &PyOpOperandList::dunderLen) + .def("__getitem__", &PyOpOperandList::dunderGetItem); + } + +private: + PyOperationRef operation; +}; + /// A list of operation results. Internally, these are stored as consecutive /// elements, random access is cheap. The result list is associated with the /// operation whose results these are, and extends the lifetime of this @@ -1913,8 +2030,10 @@ // Populates the pybind11 IR submodule. //------------------------------------------------------------------------------ -void mlir::python::populateIRSubmodule(py::module &m) { +void mlir::python::populateIRSubmodule(py::module m) { + //---------------------------------------------------------------------------- // Mapping of MlirContext + //---------------------------------------------------------------------------- py::class_(m, "Context") .def(py::init<>(&PyMlirContext::createNewContextForInit)) .def_static("_get_live_count", &PyMlirContext::getLiveCount) @@ -1928,6 +2047,25 @@ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) + .def_property_readonly( + "dialect", + [](PyMlirContext &self) { return PyDialects(self.getRef()); }, + "Gets a container for accessing dialects by name") + .def_property_readonly( + "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, + "Alias for 'dialect'") + .def( + "get_dialect_descriptor", + [=](PyMlirContext &self, std::string &name) { + MlirDialect dialect = mlirContextGetOrLoadDialect( + self.get(), {name.data(), name.size()}); + if (mlirDialectIsNull(dialect)) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("Dialect '") + name + "' not found"); + } + return PyDialectDescriptor(self.getRef(), dialect); + }, + "Gets or loads a dialect by name, returning its descriptor object") .def_property( "allow_unregistered_dialects", [](PyMlirContext &self) -> bool { @@ -1937,8 +2075,8 @@ mlirContextSetAllowUnregisteredDialects(self.get(), value); }) .def("create_operation", &PyMlirContext::createOperation, py::arg("name"), - py::arg("location"), py::arg("results") = py::none(), - py::arg("attributes") = py::none(), + py::arg("location"), py::arg("operands") = py::none(), + py::arg("results") = py::none(), py::arg("attributes") = py::none(), py::arg("successors") = py::none(), py::arg("regions") = 0, kContextCreateOperationDocstring) .def( @@ -2009,6 +2147,62 @@ kContextGetFileLocationDocstring, py::arg("filename"), py::arg("line"), py::arg("col")); + //---------------------------------------------------------------------------- + // Mapping of PyDialectDescriptor + //---------------------------------------------------------------------------- + py::class_(m, "DialectDescriptor") + .def_property_readonly("namespace", + [](PyDialectDescriptor &self) { + MlirStringRef ns = + mlirDialectGetNamespace(self.get()); + return py::str(ns.data, ns.length); + }) + .def("__repr__", [](PyDialectDescriptor &self) { + MlirStringRef ns = mlirDialectGetNamespace(self.get()); + std::string repr(""); + return repr; + }); + + //---------------------------------------------------------------------------- + // Mapping of PyDialects + //---------------------------------------------------------------------------- + py::class_(m, "Dialects") + .def("__getitem__", + [=](PyDialects &self, std::string keyName) { + MlirDialect dialect = + self.getDialectForKey(keyName, /*attrError=*/false); + py::object descriptor = + py::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createCustomDialectWrapper(keyName, std::move(descriptor)); + }) + .def("__getattr__", [=](PyDialects &self, std::string attrName) { + MlirDialect dialect = + self.getDialectForKey(attrName, /*attrError=*/true); + py::object descriptor = + py::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createCustomDialectWrapper(attrName, std::move(descriptor)); + }); + + //---------------------------------------------------------------------------- + // Mapping of PyDialect + //---------------------------------------------------------------------------- + py::class_(m, "Dialect") + .def(py::init(), "descriptor") + .def_property_readonly( + "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) + .def("__repr__", [](py::object self) { + auto clazz = self.attr("__class__"); + return py::str(""); + }); + + //---------------------------------------------------------------------------- + // Mapping of Location + //---------------------------------------------------------------------------- py::class_(m, "Location") .def_property_readonly( "context", @@ -2021,7 +2215,9 @@ return printAccum.join(); }); + //---------------------------------------------------------------------------- // Mapping of Module + //---------------------------------------------------------------------------- py::class_(m, "Module") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) @@ -2055,12 +2251,17 @@ }, kOperationStrDunderDocstring); + //---------------------------------------------------------------------------- // Mapping of Operation. + //---------------------------------------------------------------------------- py::class_(m, "Operation") .def_property_readonly( "context", [](PyOperation &self) { return self.getContext().getObject(); }, "Context that owns the Operation") + .def_property_readonly( + "operands", + [](PyOperation &self) { return PyOpOperandList(self.getRef()); }) .def_property_readonly( "regions", [](PyOperation &self) { return PyRegionList(self.getRef()); }) @@ -2098,7 +2299,15 @@ py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, kOperationGetAsmDocstring); + py::class_(m, "OpView") + .def(py::init()) + .def_property_readonly("operation", &PyOpView::getOperationObject) + .def("__str__", + [](PyOpView &self) { return py::str(self.getOperationObject()); }); + + //---------------------------------------------------------------------------- // Mapping of PyRegion. + //---------------------------------------------------------------------------- py::class_(m, "Region") .def_property_readonly( "blocks", @@ -2123,7 +2332,9 @@ } }); + //---------------------------------------------------------------------------- // Mapping of PyBlock. + //---------------------------------------------------------------------------- py::class_(m, "Block") .def_property_readonly( "arguments", @@ -2167,7 +2378,9 @@ }, "Returns the assembly form of the block."); + //---------------------------------------------------------------------------- // Mapping of PyAttribute. + //---------------------------------------------------------------------------- py::class_(m, "Attribute") .def_property_readonly( "context", @@ -2219,6 +2432,9 @@ return printAccum.join(); }); + //---------------------------------------------------------------------------- + // Mapping of PyNamedAttribute + //---------------------------------------------------------------------------- py::class_(m, "NamedAttribute") .def("__repr__", [](PyNamedAttribute &self) { @@ -2257,7 +2473,9 @@ PyStringAttribute::bind(m); PyDenseElementsAttribute::bind(m); + //---------------------------------------------------------------------------- // Mapping of PyType. + //---------------------------------------------------------------------------- py::class_(m, "Type") .def_property_readonly( "context", [](PyType &self) { return self.getContext().getObject(); }, @@ -2313,7 +2531,9 @@ PyTupleType::bind(m); PyFunctionType::bind(m); + //---------------------------------------------------------------------------- // Mapping of Value. + //---------------------------------------------------------------------------- py::class_(m, "Value") .def_property_readonly( "context", @@ -2346,6 +2566,7 @@ PyBlockList::bind(m); PyOperationIterator::bind(m); PyOperationList::bind(m); + PyOpOperandList::bind(m); PyOpResultList::bind(m); PyRegionIterator::bind(m); PyRegionList::bind(m); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -8,17 +8,179 @@ #include -#include +#include "PybindUtils.h" +#include "Globals.h" #include "IRModules.h" namespace py = pybind11; using namespace mlir; using namespace mlir::python; +// ----------------------------------------------------------------------------- +// PyGlobals +// ----------------------------------------------------------------------------- + +PyGlobals *PyGlobals::instance = nullptr; + +PyGlobals::PyGlobals() { + assert(!instance && "PyGlobals already constructed"); + instance = this; +} + +PyGlobals::~PyGlobals() { instance = nullptr; } + +pybind11::object +PyGlobals::getOrLoadDialectModule(const std::string &dialectNamespace) { + { + auto foundIt = dialectParentModuleMap.find(dialectNamespace); + if (foundIt != dialectParentModuleMap.end()) + return foundIt->second; + } + + // Since re-entrancy is possible, make a copy of the search prefixes. + std::vector localSearchPrefixes = dialectSearchPrefixes; + py::object loaded; + for (std::string moduleName : localSearchPrefixes) { + moduleName.push_back('.'); + moduleName.append(dialectNamespace); + + try { + loaded = py::module::import(moduleName.c_str()); + } catch (py::error_already_set &e) { + if (e.matches(PyExc_ModuleNotFoundError)) { + continue; + } else { + throw; + } + } + break; + } + + // Cache it if found. Cannot use prior iterator because importing the module + // may have re-entered and could have done anything. + if (loaded) { + py::object &found = dialectParentModuleMap[dialectNamespace]; + if (found) { + throw SetPyError( + PyExc_RuntimeError, + llvm::Twine("Dialect module re-entrantly registered for '") + + dialectNamespace); + } + found = loaded; + return loaded; + } else { + return py::none(); + } +} + +void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, + py::object pyClass) { + py::object &found = dialectClassMap[dialectNamespace]; + if (found) { + throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + + dialectNamespace + + "' is already registered."); + } + found = std::move(pyClass); +} + +void PyGlobals::registerOperationImpl(const std::string &operationName, + py::object pyClass, py::object rawClass) { + py::object &found = operationClassMap[operationName]; + if (found) { + throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + + operationName + + "' is already registered."); + } + found = std::move(pyClass); + rawOperationClassMap[operationName] = std::move(rawClass); +} + +llvm::Optional +PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { + for (int i = 0; i < 2; ++i) { + // Fast match against the class map first (common case). + const auto foundIt = dialectClassMap.find(dialectNamespace); + if (foundIt != dialectClassMap.end()) { + if (foundIt->second.is_none()) + return llvm::None; + assert(foundIt->second && "py::object is defined"); + return foundIt->second; + } + + // Fallback to a load of the dialect module. + if (i == 0) + getOrLoadDialectModule(dialectNamespace); + } + + // Not found and loading did not yield a registration. Negative cache. + dialectClassMap[dialectNamespace] = py::none(); + return llvm::None; +} + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + PYBIND11_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; + py::class_(m, "_Globals") + .def_property("dialect_search_modules", + &PyGlobals::getDialectSearchPrefixes, + &PyGlobals::setDialectSearchPrefixes) + .def("append_dialect_search_prefix", + [](PyGlobals &self, std::string moduleName) { + self.getDialectSearchPrefixes().push_back(std::move(moduleName)); + }) + .def("get_or_load_dialect_module", &PyGlobals::getOrLoadDialectModule, + "Gets the implementation module for the dialect (if one exists)") + .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, + "Testing hook for directly registering a dialect") + .def("_register_operation_impl", &PyGlobals::registerOperationImpl, + "Testing hook for directly registering an operation"); + + // Aside from making the globals accessible to python, having python manage + // it is necessary to make sure it is destroyed (and releases its python + // resources) properly. + m.attr("globals") = + py::cast(new PyGlobals, py::return_value_policy::take_ownership); + + // Registration decorators. + m.def( + "register_dialect", + [](py::object pyClass) { + std::string dialectNamespace = + pyClass.attr("DIALECT_NAMESPACE").cast(); + PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); + return pyClass; + }, + "Class decorator for registering a custom Dialect wrapper"); + m.def( + "register_operation", + [](py::object dialectClass) -> py::cpp_function { + return py::cpp_function( + [dialectClass](py::object opClass) -> py::object { + std::string operationName = + opClass.attr("OPERATION_NAME").cast(); + auto rawSubclass = PyOpView::createRawSubclass(opClass); + PyGlobals::get().registerOperationImpl(operationName, opClass, + rawSubclass); + + // Dict-stuff the new opClass by name onto the dialect class. + py::object opClassName = opClass.attr("__name__"); + dialectClass.attr(opClassName) = opClass; + + // Now create a special "Raw" subclass that passes through + // construction to the OpView parent (bypasses the intermediate + // child's __init__). + opClass.attr("_Raw") = rawSubclass; + return opClass; + }); + }, + "Class decorator for registering a custom Operation wrapper"); + // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); populateIRSubmodule(irModule); diff --git a/mlir/lib/Bindings/Python/mlir/__init__.py b/mlir/lib/Bindings/Python/mlir/__init__.py --- a/mlir/lib/Bindings/Python/mlir/__init__.py +++ b/mlir/lib/Bindings/Python/mlir/__init__.py @@ -8,4 +8,37 @@ # and arbitrate any one-time initialization needed in various shared-library # scenarios. -from _mlir import * +__all__ = [ + "ir", +] + +# Expose the corresponding C-Extension module with a well-known name at this +# top-level module. This allows relative imports like the following to +# function: +# from .. import _cext +# This reduces coupling, allowing embedding of the python sources into another +# project that can just vary based on this top-level loader module. +import _mlir as _cext + +def _reexport_cext(cext_module_name, target_module_name): + """Re-exports a named sub-module of the C-Extension into another module. + + Typically: + from . import _reexport_cext + _reexport_cext("ir", __name__) + del _reexport_cext + """ + import sys + target_module = sys.modules[target_module_name] + source_module = getattr(_cext, cext_module_name) + for attr_name in dir(source_module): + if not attr_name.startswith("__"): + setattr(target_module, attr_name, getattr(source_module, attr_name)) + + +# Import sub-modules. Since these may import from here, this must come after +# any exported definitions. +from . import ir + +# Add our 'dialects' parent module to the search path for implementations. +_cext.globals.append_dialect_search_prefix("mlir.dialects") diff --git a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Re-export the parent _cext so that every level of the API can get it locally. +from .. import _cext diff --git a/mlir/lib/Bindings/Python/mlir/dialects/std.py b/mlir/lib/Bindings/Python/mlir/dialects/std.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/std.py @@ -0,0 +1,31 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from . import _cext + +@_cext.register_dialect +class _Dialect(_cext.ir.Dialect): + # Special case: 'std' namespace aliases to the empty namespace. + DIALECT_NAMESPACE = "std" + pass + +@_cext.register_operation(_Dialect) +class AddFOp(_cext.ir.OpView): + OPERATION_NAME = "std.addf" + + def __init__(self, loc, lhs, rhs): + super().__init__(loc.context.create_operation( + "std.addf", loc, operands=[lhs, rhs], results=[lhs.type])) + + @property + def lhs(self): + return self.operation.operands[0] + + @property + def rhs(self): + return self.operation.operands[1] + + @property + def result(self): + return self.operation.results[0] diff --git a/mlir/lib/Bindings/Python/mlir/ir.py b/mlir/lib/Bindings/Python/mlir/ir.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/ir.py @@ -0,0 +1,8 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Simply a wrapper around the extension module of the same name. +from . import _reexport_cext +_reexport_cext("ir", __name__) +del _reexport_cext diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/dialects.py @@ -0,0 +1,62 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +import mlir + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert mlir.ir.Context._get_live_count() == 0 + + +# CHECK-LABEL: TEST: testDialectDescriptor +def testDialectDescriptor(): + ctx = mlir.ir.Context() + d = ctx.get_dialect_descriptor("std") + # CHECK: + print(d) + # CHECK: std + print(d.namespace) + try: + _ = ctx.get_dialect_descriptor("not_existing") + except ValueError: + pass + else: + assert False, "Expected exception" + +run(testDialectDescriptor) + + +# CHECK-LABEL: TEST: testUserDialects +def testUserDialects(): + ctx = mlir.ir.Context() + # Access using attribute. + d = ctx.dialect.std + # CHECK: + print(d) + try: + _ = ctx.dialect.not_existing + except AttributeError: + pass + else: + assert False, "Expected exception" + + # Access using index. + d = ctx.dialect["std"] + # CHECK: + print(d) + try: + _ = ctx.dialect["not_existing"] + except IndexError: + pass + else: + assert False, "Expected exception" + + # Using the 'd' alias. + d = ctx.d["std"] + # CHECK: + print(d) + + +run(testUserDialects)