diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -74,6 +74,7 @@ IRAffine.cpp IRAttributes.cpp IRCore.cpp + IRModule.cpp IRTypes.cpp PybindUtils.cpp Pass.cpp diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp copy from mlir/lib/Bindings/Python/MainModule.cpp copy to mlir/lib/Bindings/Python/IRModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -1,4 +1,4 @@ -//===- MainModule.cpp - Main pybind module --------------------------------===// +//===- MainModule.cpp - IR pybind module ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,15 +6,11 @@ // //===----------------------------------------------------------------------===// -#include - +#include "IRModule.h" +#include "Globals.h" #include "PybindUtils.h" -#include "DialectLinalg.h" -#include "ExecutionEngine.h" -#include "Globals.h" -#include "IRModule.h" -#include "Pass.h" +#include namespace py = pybind11; using namespace mlir; @@ -148,87 +144,3 @@ loadedDialectModulesCache.clear(); rawOpViewClassMapCache.clear(); } - -// ----------------------------------------------------------------------------- -// Module initialization. -// ----------------------------------------------------------------------------- - -PYBIND11_MODULE(_mlir, m) { - m.doc() = "MLIR Python Native Extension"; - - py::class_(m, "_Globals") - .def_property("dialect_search_modules", - &PyGlobals::getDialectSearchPrefixes, - &PyGlobals::setDialectSearchPrefixes) - .def("append_dialect_search_prefix", - [](PyGlobals &self, std::string moduleName) { - self.getDialectSearchPrefixes().push_back(std::move(moduleName)); - self.clearImportCache(); - }) - .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, - "Testing hook for directly registering a dialect") - .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "Testing hook for directly registering an operation"); - - // Aside from making the globals accessible to python, having python manage - // it is necessary to make sure it is destroyed (and releases its python - // resources) properly. - m.attr("globals") = - py::cast(new PyGlobals, py::return_value_policy::take_ownership); - - // Registration decorators. - m.def( - "register_dialect", - [](py::object pyClass) { - std::string dialectNamespace = - pyClass.attr("DIALECT_NAMESPACE").cast(); - PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); - return pyClass; - }, - "Class decorator for registering a custom Dialect wrapper"); - m.def( - "register_operation", - [](py::object dialectClass) -> py::cpp_function { - return py::cpp_function( - [dialectClass](py::object opClass) -> py::object { - std::string operationName = - opClass.attr("OPERATION_NAME").cast(); - auto rawSubclass = PyOpView::createRawSubclass(opClass); - PyGlobals::get().registerOperationImpl(operationName, opClass, - rawSubclass); - - // Dict-stuff the new opClass by name onto the dialect class. - py::object opClassName = opClass.attr("__name__"); - dialectClass.attr(opClassName) = opClass; - - // Now create a special "Raw" subclass that passes through - // construction to the OpView parent (bypasses the intermediate - // child's __init__). - opClass.attr("_Raw") = rawSubclass; - return opClass; - }); - }, - "Class decorator for registering a custom Operation wrapper"); - - // Define and populate IR submodule. - auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); - populateIRCore(irModule); - populateIRAffine(irModule); - populateIRAttributes(irModule); - populateIRTypes(irModule); - - // Define and populate PassManager submodule. - auto passModule = - m.def_submodule("passmanager", "MLIR Pass Management Bindings"); - populatePassManagerSubmodule(passModule); - - // Define and populate ExecutionEngine submodule. - auto executionEngineModule = - m.def_submodule("execution_engine", "MLIR JIT Execution Engine"); - populateExecutionEngineSubmodule(executionEngineModule); - - // Define and populate Linalg submodule. - auto dialectsModule = m.def_submodule("dialects"); - auto linalgModule = dialectsModule.def_submodule("linalg"); - populateDialectLinalgSubmodule(linalgModule); -} diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -20,135 +20,6 @@ using namespace mlir; using namespace mlir::python; -// ----------------------------------------------------------------------------- -// PyGlobals -// ----------------------------------------------------------------------------- - -PyGlobals *PyGlobals::instance = nullptr; - -PyGlobals::PyGlobals() { - assert(!instance && "PyGlobals already constructed"); - instance = this; -} - -PyGlobals::~PyGlobals() { instance = nullptr; } - -void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { - py::gil_scoped_acquire(); - if (loadedDialectModulesCache.contains(dialectNamespace)) - return; - // Since re-entrancy is possible, make a copy of the search prefixes. - std::vector localSearchPrefixes = dialectSearchPrefixes; - py::object loaded; - for (std::string moduleName : localSearchPrefixes) { - moduleName.push_back('.'); - moduleName.append(dialectNamespace.data(), dialectNamespace.size()); - - try { - py::gil_scoped_release(); - loaded = py::module::import(moduleName.c_str()); - } catch (py::error_already_set &e) { - if (e.matches(PyExc_ModuleNotFoundError)) { - continue; - } else { - throw; - } - } - break; - } - - // Note: Iterator cannot be shared from prior to loading, since re-entrancy - // may have occurred, which may do anything. - loadedDialectModulesCache.insert(dialectNamespace); -} - -void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, - py::object pyClass) { - py::gil_scoped_acquire(); - py::object &found = dialectClassMap[dialectNamespace]; - if (found) { - throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + - dialectNamespace + - "' is already registered."); - } - found = std::move(pyClass); -} - -void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass, - py::object rawOpViewClass) { - py::gil_scoped_acquire(); - py::object &found = operationClassMap[operationName]; - if (found) { - throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + - operationName + - "' is already registered."); - } - found = std::move(pyClass); - rawOpViewClassMap[operationName] = std::move(rawOpViewClass); -} - -llvm::Optional -PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { - py::gil_scoped_acquire(); - loadDialectModule(dialectNamespace); - // Fast match against the class map first (common case). - const auto foundIt = dialectClassMap.find(dialectNamespace); - if (foundIt != dialectClassMap.end()) { - if (foundIt->second.is_none()) - return llvm::None; - assert(foundIt->second && "py::object is defined"); - return foundIt->second; - } - - // Not found and loading did not yield a registration. Negative cache. - dialectClassMap[dialectNamespace] = py::none(); - return llvm::None; -} - -llvm::Optional -PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { - { - py::gil_scoped_acquire(); - auto foundIt = rawOpViewClassMapCache.find(operationName); - if (foundIt != rawOpViewClassMapCache.end()) { - if (foundIt->second.is_none()) - return llvm::None; - assert(foundIt->second && "py::object is defined"); - return foundIt->second; - } - } - - // Not found. Load the dialect namespace. - auto split = operationName.split('.'); - llvm::StringRef dialectNamespace = split.first; - loadDialectModule(dialectNamespace); - - // Attempt to find from the canonical map and cache. - { - py::gil_scoped_acquire(); - auto foundIt = rawOpViewClassMap.find(operationName); - if (foundIt != rawOpViewClassMap.end()) { - if (foundIt->second.is_none()) - return llvm::None; - assert(foundIt->second && "py::object is defined"); - // Positive cache. - rawOpViewClassMapCache[operationName] = foundIt->second; - return foundIt->second; - } else { - // Negative cache. - rawOpViewClassMap[operationName] = py::none(); - return llvm::None; - } - } -} - -void PyGlobals::clearImportCache() { - py::gil_scoped_acquire(); - loadedDialectModulesCache.clear(); - rawOpViewClassMapCache.clear(); -} - // ----------------------------------------------------------------------------- // Module initialization. // -----------------------------------------------------------------------------