diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.h b/mlir/lib/Bindings/Python/ExecutionEngine.h deleted file mode 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngine.h +++ /dev/null @@ -1,22 +0,0 @@ -//===- ExecutionEngine.h - ExecutionEngine submodule of pybind module -----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H -#define MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H - -#include "PybindUtils.h" - -namespace mlir { -namespace python { - -void populateExecutionEngineSubmodule(pybind11::module &m); - -} // namespace python -} // namespace mlir - -#endif // MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp rename from mlir/lib/Bindings/Python/ExecutionEngine.cpp rename to mlir/lib/Bindings/Python/ExecutionEngineModule.cpp --- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -1,4 +1,4 @@ -//===- ExecutionEngine.cpp - Python MLIR ExecutionEngine Bindings ---------===// +//===- ExecutionEngineModule.cpp - Python module for execution engine -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,9 @@ // //===----------------------------------------------------------------------===// -#include "ExecutionEngine.h" - -#include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/ExecutionEngine.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" namespace py = pybind11; using namespace mlir; @@ -54,18 +52,20 @@ } // anonymous namespace /// Create the `mlir.execution_engine` module here. -void mlir::python::populateExecutionEngineSubmodule(py::module &m) { +PYBIND11_MODULE(_mlirExecutionEngine, m) { + m.doc() = "MLIR Execution Engine"; + //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- py::class_(m, "ExecutionEngine") - .def(py::init<>([](PyModule &module, int optLevel, + .def(py::init<>([](MlirModule module, int optLevel, const std::vector &sharedLibPaths) { llvm::SmallVector libPaths; for (const std::string &path : sharedLibPaths) libPaths.push_back({path.c_str(), path.length()}); MlirExecutionEngine executionEngine = mlirExecutionEngineCreate( - module.get(), optLevel, libPaths.size(), libPaths.data()); + module, optLevel, libPaths.size(), libPaths.data()); if (mlirExecutionEngineIsNull(executionEngine)) throw std::runtime_error( "Failure while creating the ExecutionEngine."); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -11,7 +11,6 @@ #include "PybindUtils.h" #include "Dialects.h" -#include "ExecutionEngine.h" #include "Globals.h" #include "IRModule.h" #include "Pass.h" @@ -93,11 +92,6 @@ m.def_submodule("passmanager", "MLIR Pass Management Bindings"); populatePassManagerSubmodule(passModule); - // Define and populate ExecutionEngine submodule. - auto executionEngineModule = - m.def_submodule("execution_engine", "MLIR JIT Execution Engine"); - populateExecutionEngineSubmodule(executionEngineModule); - // Define and populate dialect submodules. auto dialectsModule = m.def_submodule("dialects"); auto linalgModule = dialectsModule.def_submodule("linalg"); diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -169,7 +169,6 @@ ${PYTHON_SOURCE_DIR}/IRTypes.cpp ${PYTHON_SOURCE_DIR}/PybindUtils.cpp ${PYTHON_SOURCE_DIR}/Pass.cpp - ${PYTHON_SOURCE_DIR}/ExecutionEngine.cpp # TODO: Break this out. PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -181,9 +180,6 @@ MLIRCAPILinalg # TODO: Remove when above is removed. MLIRCAPISparseTensor # TODO: Remove when above is removed. MLIRCAPIStandard - - # Execution engine (remove once disaggregated). - MLIRCEXECUTIONENGINE ) declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration @@ -219,6 +215,17 @@ MLIRCAPIConversion ) +declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine + MODULE_NAME _mlirExecutionEngine + ADD_TO_PARENT MLIRPythonSources.ExecutionEngine + SOURCES + ${PYTHON_SOURCE_DIR}/ExecutionEngineModule.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCEXECUTIONENGINE +) + declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses MODULE_NAME _mlirGPUPasses ADD_TO_PARENT MLIRPythonSources.Dialects.gpu diff --git a/mlir/python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py --- a/mlir/python/mlir/execution_engine.py +++ b/mlir/python/mlir/execution_engine.py @@ -3,10 +3,15 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Simply a wrapper around the extension module of the same name. -from ._cext_loader import _cext +from ._cext_loader import load_extension +_execution_engine = load_extension("_mlirExecutionEngine") import ctypes -class ExecutionEngine(_cext.execution_engine.ExecutionEngine): +__all__ = [ + "ExecutionEngine", +] + +class ExecutionEngine(_execution_engine.ExecutionEngine): def lookup(self, name): """Lookup a function emitted with the `llvm.emit_c_interface` diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -33,7 +33,7 @@ log(repr(execution_engine_capsule)) execution_engine._testing_release() execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule) - # CHECK: _mlir.execution_engine.ExecutionEngine + # CHECK: _mlirExecutionEngine.ExecutionEngine log(repr(execution_engine1)) run(testCapsule) @@ -68,7 +68,7 @@ func @void() attributes { llvm.emit_c_interface } { return } - """) + """) execution_engine = ExecutionEngine(lowerToLLVM(module)) # Nothing to check other than no exception thrown here. execution_engine.invoke("void") @@ -157,7 +157,7 @@ execution_engine = ExecutionEngine(lowerToLLVM(module)) execution_engine.register_runtime("some_callback_into_python", callback) inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) - # CHECK: Inside callback: + # CHECK: Inside callback: # CHECK{LITERAL}: [[1. 2.] # CHECK{LITERAL}: [3. 4.]] execution_engine.invoke( @@ -168,7 +168,7 @@ strided_arr = np.lib.stride_tricks.as_strided( inp_arr_1, strides=(4, 0), shape=(3, 4) ) - # CHECK: Inside callback: + # CHECK: Inside callback: # CHECK{LITERAL}: [[5. 5. 5. 5.] # CHECK{LITERAL}: [6. 6. 6. 6.] # CHECK{LITERAL}: [7. 7. 7. 7.]]