diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -61,6 +61,12 @@ MLIR_CAPI_EXPORTED void *mlirExecutionEngineLookup(MlirExecutionEngine jit, MlirStringRef name); +/// Register a symbol with the jit: this symbol will be accessible to the jitted +/// code. +MLIR_CAPI_EXPORTED void +mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, MlirStringRef name, + void *sym); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngine.cpp --- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngine.cpp @@ -83,5 +83,14 @@ mlirStringRefCreate(func.c_str(), func.size())); return (int64_t)res; }, + "Lookup function `func` in the ExecutionEngine.") + .def( + "raw_register_runtime", + [](PyExecutionEngine &executionEngine, const std::string &name, + int64_t sym) { + mlirExecutionEngineRegisterSymbol( + executionEngine.get(), + mlirStringRefCreate(name.c_str(), name.size()), (void *)sym); + }, "Lookup function `func` in the ExecutionEngine."); } diff --git a/mlir/lib/Bindings/Python/mlir/execution_engine.py b/mlir/lib/Bindings/Python/mlir/execution_engine.py --- a/mlir/lib/Bindings/Python/mlir/execution_engine.py +++ b/mlir/lib/Bindings/Python/mlir/execution_engine.py @@ -29,3 +29,11 @@ for argNum in range(len(ctypes_args)): packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) func(packed_args) + + def register_runtime(self, name, ctypes_callback): + """Register a runtime function available to the jitted code + under the provided `name`. The `ctypes_callback` must be a + `CFuncType` that outlives the execution engine. + """ + callback = ctypes.cast(ctypes_callback, ctypes.c_void_p).value + self.raw_register_runtime("_mlir_ciface_" + name, callback) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -11,6 +11,7 @@ #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "llvm/ExecutionEngine/Orc/Mangling.h" #include "llvm/Support/TargetSelect.h" using namespace mlir; @@ -54,3 +55,14 @@ return nullptr; return reinterpret_cast(*expectedFPtr); } + +extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, + MlirStringRef name, + void *sym) { + unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) { + llvm::orc::SymbolMap symbolMap; + symbolMap[interner(unwrap(name))] = + llvm::JITEvaluatedSymbol::fromPointer(sym); + return symbolMap; + }); +} diff --git a/mlir/test/Bindings/Python/execution_engine.py b/mlir/test/Bindings/Python/execution_engine.py --- a/mlir/test/Bindings/Python/execution_engine.py +++ b/mlir/test/Bindings/Python/execution_engine.py @@ -97,3 +97,37 @@ log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0])) run(testInvokeFloatAdd) + + +# Test callback +# CHECK-LABEL: TEST: testBasicCallback +def testBasicCallback(): + # Define a callback function that takes a float and an integer and returns a float. + @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int) + def callback(a, b): + return a/2 + b/2 + + with Context(): + # The module just forwards to a runtime function known as "some_callback_into_python". + module = Module.parse(r""" +func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } { + %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32) + return %resf : f32 +} +func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface } + """) + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.register_runtime("some_callback_into_python", callback) + + # Prepare arguments: two input floats and one result. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + c_int_p = ctypes.c_int * 1 + arg0 = c_float_p(42.) + arg1 = c_int_p(2) + res = c_float_p(-1.) + execution_engine.invoke("add", arg0, arg1, res) + # CHECK: 42.0 + 2 = 44.0 + log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2)) + +run(testBasicCallback)