diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -31,12 +31,21 @@ } MlirExecutionEngine get() { return executionEngine; } - void release() { executionEngine.ptr = nullptr; } + void release() { + executionEngine.ptr = nullptr; + referencedObjects.clear(); + } pybind11::object getCapsule() { return py::reinterpret_steal( mlirPythonExecutionEngineToCapsule(get())); } + // Add an object to the list of referenced objects whose lifetime must exceed + // those of the ExecutionEngine. + void addReferencedObject(pybind11::object obj) { + referencedObjects.push_back(obj); + } + static pybind11::object createFromCapsule(pybind11::object capsule) { MlirExecutionEngine rawPm = mlirPythonCapsuleToExecutionEngine(capsule.ptr()); @@ -47,6 +56,10 @@ private: MlirExecutionEngine executionEngine; + // We support Python ctypes closures as callbacks. Keep a list of the objects + // so that they don't get garbage collected. (The ExecutionEngine itself + // just holds raw pointers with no lifetime semantics). + std::vector referencedObjects; }; } // anonymous namespace @@ -96,13 +109,17 @@ .def( "raw_register_runtime", [](PyExecutionEngine &executionEngine, const std::string &name, - uintptr_t sym) { + py::object callbackObj) { + executionEngine.addReferencedObject(callbackObj); + uintptr_t rawSym = + py::cast(py::getattr(callbackObj, "value")); mlirExecutionEngineRegisterSymbol( executionEngine.get(), mlirStringRefCreate(name.c_str(), name.size()), - reinterpret_cast(sym)); + reinterpret_cast(rawSym)); }, - "Lookup function `func` in the ExecutionEngine.") + py::arg("name"), py::arg("callback"), + "Register `callback` as the runtime symbol `name`.") .def( "dump_to_object_file", [](PyExecutionEngine &executionEngine, const std::string &fileName) { diff --git a/mlir/python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py --- a/mlir/python/mlir/execution_engine.py +++ b/mlir/python/mlir/execution_engine.py @@ -39,5 +39,5 @@ under the provided `name`. The `ctypes_callback` must be a `CFuncType` that outlives the execution engine. """ - callback = ctypes.cast(ctypes_callback, ctypes.c_void_p).value + callback = ctypes.cast(ctypes_callback, ctypes.c_void_p) self.raw_register_runtime("_mlir_ciface_" + name, callback)