diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -25,6 +25,7 @@ #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" +#include "mlir-c/ExecutionEngine.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Pass.h" @@ -33,6 +34,7 @@ #define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr" #define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr" #define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr" +#define MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE "mlir.jit.Jit._CAPIPtr" #define MLIR_PYTHON_CAPSULE_INTEGER_SET "mlir.ir.IntegerSet._CAPIPtr" #define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr" #define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr" @@ -261,6 +263,27 @@ return integerSet; } +/** Creates a capsule object encapsulating the raw C-API MlirExecutionEngine. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the set in any way. */ +static inline PyObject * +mlirPythonExecutionEngineToCapsule(MlirExecutionEngine jit) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(jit), + MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE, NULL); +} + +/** Extracts an MlirExecutionEngine from a capsule as produced from + * mlirPythonIntegerSetToCapsule. If the capsule is not of the right type, then + * a null set is returned (as checked via mlirExecutionEngineIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirExecutionEngine +mlirPythonCapsuleToExecutionEngine(PyObject *capsule) { + void *ptr = + PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE); + MlirExecutionEngine jit = {ptr}; + return jit; +} + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -56,6 +56,11 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirExecutionEngineInvokePacked( MlirExecutionEngine jit, MlirStringRef name, void **arguments); +/// Lookup a native function in the execution engine by name, returns nullptr +/// if the name can't be looked-up. +MLIR_CAPI_EXPORTED void *mlirExecutionEngineLookup(MlirExecutionEngine jit, + MlirStringRef name); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -8,11 +8,12 @@ set(PY_SRC_FILES mlir/__init__.py mlir/_dlloader.py - mlir/ir.py + mlir/conversions/__init__.py mlir/dialects/__init__.py mlir/dialects/_linalg.py mlir/dialects/_builtin.py mlir/ir.py + mlir/jit.py mlir/passmanager.py mlir/transforms/__init__.py ) @@ -74,6 +75,7 @@ IRModules.cpp PybindUtils.cpp Pass.cpp + Jit.cpp ) add_dependencies(MLIRBindingsPythonExtension MLIRCoreBindingsPythonExtension) @@ -114,3 +116,4 @@ endif() add_subdirectory(Transforms) +add_subdirectory(Conversions) diff --git a/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt b/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt @@ -0,0 +1,10 @@ +################################################################################ +# Build python extension +################################################################################ + +add_mlir_python_extension(MLIRConversionsBindingsPythonExtension _mlirConversions + INSTALL_DIR + python + SOURCES + Conversions.cpp +) \ No newline at end of file diff --git a/mlir/lib/Bindings/Python/Conversions/Conversions.cpp b/mlir/lib/Bindings/Python/Conversions/Conversions.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/Conversions/Conversions.cpp @@ -0,0 +1,24 @@ +//===- Conversions.cpp - Pybind module for the Conversionss library -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Conversion.h" + +#include + +namespace py = pybind11; + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +PYBIND11_MODULE(_mlirConversions, m) { + m.doc() = "MLIR Conversions library"; + + // Register all the passes in the Conversions library on load. + mlirRegisterConversionPasses(); +} diff --git a/mlir/lib/Bindings/Python/Jit.h b/mlir/lib/Bindings/Python/Jit.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/Jit.h @@ -0,0 +1,22 @@ +//===- IRModules.h - IR Submodules of pybind module -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_JIT_H +#define MLIR_BINDINGS_PYTHON_JIT_H + +#include "PybindUtils.h" + +namespace mlir { +namespace python { + +void populateJitSubmodule(pybind11::module &m); + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_JIT_H diff --git a/mlir/lib/Bindings/Python/Jit.cpp b/mlir/lib/Bindings/Python/Jit.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/Jit.cpp @@ -0,0 +1,79 @@ +//===- Jit.cpp - Python MLIR JIT Bindings ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Jit.h" + +#include "IRModules.h" +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/ExecutionEngine.h" + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +namespace { + +/// Owning Wrapper around a PassManager. +class PyJit { +public: + PyJit(MlirExecutionEngine jit) : jit(jit) {} + PyJit(PyJit &&other) : jit(other.jit) { other.jit.ptr = nullptr; } + ~PyJit() { + if (!mlirExecutionEngineIsNull(jit)) + mlirExecutionEngineDestroy(jit); + } + MlirExecutionEngine get() { return jit; } + + void release() { jit.ptr = nullptr; } + pybind11::object getCapsule() { + return py::reinterpret_steal( + mlirPythonExecutionEngineToCapsule(get())); + } + + static pybind11::object createFromCapsule(pybind11::object capsule) { + MlirExecutionEngine rawPm = + mlirPythonCapsuleToExecutionEngine(capsule.ptr()); + if (mlirExecutionEngineIsNull(rawPm)) + throw py::error_already_set(); + return py::cast(PyJit(rawPm), py::return_value_policy::move); + } + +private: + MlirExecutionEngine jit; +}; + +} // anonymous namespace + +/// Create the `mlir.jit` here. +void mlir::python::populateJitSubmodule(py::module &m) { + //---------------------------------------------------------------------------- + // Mapping of the top-level PassManager + //---------------------------------------------------------------------------- + py::class_(m, "Jit") + .def(py::init<>([](PyModule &module) { + MlirExecutionEngine jit = mlirExecutionEngineCreate(module.get()); + if (mlirExecutionEngineIsNull(jit)) + throw SetPyError(PyExc_RuntimeError, + "Failure while creating the JIT."); + return new PyJit(jit); + }), + "Create a new jit instance for the given Module. The module must be " + "containing only dialect that can be translated to LLVM.") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyJit::getCapsule) + .def("_testing_release", &PyJit::release, + "Releases (leaks) the backing jit (for testing purpose)") + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyJit::createFromCapsule) + .def( + "raw_lookup", + [](PyJit &jit, const std::string &func) { + auto *res = mlirExecutionEngineLookup( + jit.get(), mlirStringRefCreate(func.c_str(), func.size())); + return (int64_t)res; + }, + "Lookup function `func` in the jit."); +} diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -12,6 +12,7 @@ #include "Globals.h" #include "IRModules.h" +#include "Jit.h" #include "Pass.h" namespace py = pybind11; @@ -216,4 +217,8 @@ auto passModule = m.def_submodule("passmanager", "MLIR Pass Management Bindings"); populatePassManagerSubmodule(passModule); + + // Define and populate JIT submodule. + auto jitModule = m.def_submodule("jit", "MLIR JIT Execution Engine"); + populateJitSubmodule(jitModule); } diff --git a/mlir/lib/Bindings/Python/mlir/__init__.py b/mlir/lib/Bindings/Python/mlir/__init__.py --- a/mlir/lib/Bindings/Python/mlir/__init__.py +++ b/mlir/lib/Bindings/Python/mlir/__init__.py @@ -10,6 +10,7 @@ __all__ = [ "ir", + "jit", "passmanager", ] @@ -61,7 +62,7 @@ # Import sub-modules. Since these may import from here, this must come after # any exported definitions. -from . import ir, passmanager +from . import ir, jit, passmanager # Add our 'dialects' parent module to the search path for implementations. _cext.globals.append_dialect_search_prefix("mlir.dialects") diff --git a/mlir/lib/Bindings/Python/mlir/conversions/__init__.py b/mlir/lib/Bindings/Python/mlir/conversions/__init__.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/conversions/__init__.py @@ -0,0 +1,8 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Expose the corresponding C-Extension module with a well-known name at this +# level. +from .. import _load_extension +_cextConversions = _load_extension("_mlirConversions") diff --git a/mlir/lib/Bindings/Python/mlir/jit.py b/mlir/lib/Bindings/Python/mlir/jit.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/jit.py @@ -0,0 +1,33 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Simply a wrapper around the extension module of the same name. +from . import _cext +import ctypes + +class Jit(_cext.jit.Jit): + + def lookup(self, name): + '''Lookup a function emitted with the `llvm.emit_c_interface` + attribute and returns a ctype callable. + Raise a RuntimeError if the function isn't found. + ''' + func = self.raw_lookup("_mlir_ciface_" + name) + if not func: + raise RuntimeError("Unknown function " + name) + prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + return prototype(func) + + def invoke(self, name, *ctypes_args): + '''Invoke a function with the list of ctypes arguments. + All arguments must be pointers. + Raise a RuntimeError if the function isn't found. + ''' + func = self.lookup(name) + packed_args = (ctypes.c_void_p * len(ctypes_args))() + for argNum in range(len(ctypes_args)): + packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) + func(packed_args) + + diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -10,6 +10,7 @@ #include "mlir/CAPI/ExecutionEngine.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" +#include "mlir/Target/LLVMIR.h" #include "llvm/Support/TargetSelect.h" using namespace mlir; @@ -22,6 +23,7 @@ }(); (void)init_once; + mlir::registerLLVMDialectTranslation(*unwrap(op)->getContext()); auto jitOrError = ExecutionEngine::create(unwrap(op)); if (!jitOrError) { consumeError(jitOrError.takeError()); @@ -44,3 +46,11 @@ return wrap(failure()); return wrap(success()); } + +extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit, + MlirStringRef name) { + auto expectedFPtr = unwrap(jit)->lookup(unwrap(name)); + if (!expectedFPtr) + return nullptr; + return reinterpret_cast(*expectedFPtr); +} diff --git a/mlir/test/Bindings/Python/jit.py b/mlir/test/Bindings/Python/jit.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/jit.py @@ -0,0 +1,82 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import gc, sys +from mlir.ir import * +from mlir.passmanager import * +from mlir.jit import * + +# Log everything to stderr and flush so that we have a unified stream to match +# errors/info emitted by MLIR to stderr. +def log(*args): + print(*args, file=sys.stderr) + sys.stderr.flush() + +def run(f): + log("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + +# Verify capsule interop. +# CHECK-LABEL: TEST: testCapsule +def testCapsule(): + with Context(): + module = Module.parse(r""" +llvm.func @none() { + llvm.return +} + """) + jit = Jit(module) + jit_capsule = jit._CAPIPtr + # CHECK: mlir.jit.Jit._CAPIPtr + log(repr(jit_capsule)) + jit._testing_release() + jit1 = Jit._CAPICreate(jit_capsule) + # CHECK: _mlir.jit.Jit + log(repr(jit1)) +run(testCapsule) + + +def lowerToLLVM(module): + import mlir.conversions + pm = PassManager.parse("convert-std-to-llvm") + pm.run(module) + return module + +# Test simple JIT execution +# CHECK-LABEL: TEST: testInvokeVoid +def testInvokeVoid(): + with Context(): + module = Module.parse(r""" +func @void() attributes { llvm.emit_c_interface } { + return +} + """) + jit = Jit(lowerToLLVM(module)) + # Nothing to check other than no exception thrown here. + jit.invoke("void") +run(testInvokeVoid) + + +# Test argument passing and result with a simple float addition. +# CHECK-LABEL: TEST: testInvokeFloatAdd +def testInvokeFloatAdd(): + with Context(): + module = Module.parse(r""" +func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } { + %add = std.addf %arg0, %arg1 : f32 + return %add : f32 +} + """) + jit = Jit(lowerToLLVM(module)) + # Prepare arguments: two input floats and one result. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + arg0 = c_float_p(42.) + arg1 = c_float_p(2.) + res = c_float_p(-1.) + jit.invoke("add", arg0, arg1, res) + # CHECK: 42.0 + 2.0 = 44.0 + print("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0])) + +run(testInvokeFloatAdd) diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -116,6 +116,7 @@ MLIRBindingsPythonExtension MLIRBindingsPythonTestOps MLIRTransformsBindingsPythonExtension + MLIRConversionsBindingsPythonExtension ) endif()