diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -62,6 +62,11 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirExecutionEngineInvokePacked( MlirExecutionEngine jit, MlirStringRef name, void **arguments); +/// Lookup the wrapper of the native function in the execution engine with the +/// given name, returns nullptr if the function can't be looked-up. +MLIR_CAPI_EXPORTED void * +mlirExecutionEngineLookupPacked(MlirExecutionEngine jit, MlirStringRef name); + /// Lookup a native function in the execution engine by name, returns nullptr /// if the name can't be looked-up. MLIR_CAPI_EXPORTED void *mlirExecutionEngineLookup(MlirExecutionEngine jit, diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h --- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -97,9 +97,14 @@ bool enableGDBNotificationListener = true, bool enablePerfNotificationListener = true); - /// Looks up a packed-argument function with the given name and returns a - /// pointer to it. Propagates errors in case of failure. - llvm::Expected lookup(StringRef name) const; + /// Looks up a packed-argument function wrapping the function with the given + /// name and returns a pointer to it. Propagates errors in case of failure. + llvm::Expected lookupPacked(StringRef name) const; + + /// Looks up the original function with the given name and returns a + /// pointer to it. This is not necesarily a packed function. Propagates + /// errors in case of failure. + llvm::Expected lookup(StringRef name) const; /// Invokes the function with the given name passing it the list of opaque /// pointers to the actual arguments. diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -100,7 +100,7 @@ .def( "raw_lookup", [](PyExecutionEngine &executionEngine, const std::string &func) { - auto *res = mlirExecutionEngineLookup( + auto *res = mlirExecutionEngineLookupPacked( executionEngine.get(), mlirStringRefCreate(func.c_str(), func.size())); return reinterpret_cast(res); diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -75,6 +75,14 @@ return wrap(success()); } +extern "C" void *mlirExecutionEngineLookupPacked(MlirExecutionEngine jit, + MlirStringRef name) { + auto expectedFPtr = unwrap(jit)->lookupPacked(unwrap(name)); + if (!expectedFPtr) + return nullptr; + return reinterpret_cast(*expectedFPtr); +} + extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit, MlirStringRef name) { auto expectedFPtr = unwrap(jit)->lookup(unwrap(name)); diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -328,8 +328,16 @@ return std::move(engine); } -Expected ExecutionEngine::lookup(StringRef name) const { - auto expectedSymbol = jit->lookup(makePackedFunctionName(name)); +Expected +ExecutionEngine::lookupPacked(StringRef name) const { + auto result = lookup(makePackedFunctionName(name)); + if (!result) + return result.takeError(); + return reinterpret_cast(result.get()); +} + +Expected ExecutionEngine::lookup(StringRef name) const { + auto expectedSymbol = jit->lookup(name); // JIT lookup may return an Error referring to strings stored internally by // the JIT. If the Error outlives the ExecutionEngine, it would want have a @@ -346,7 +354,7 @@ } auto rawFPtr = expectedSymbol->getAddress(); - auto fptr = reinterpret_cast(rawFPtr); + auto fptr = reinterpret_cast(rawFPtr); if (!fptr) return make_string_error("looked up function is null"); return fptr; @@ -354,7 +362,7 @@ Error ExecutionEngine::invokePacked(StringRef name, MutableArrayRef args) { - auto expectedFPtr = lookup(name); + auto expectedFPtr = lookupPacked(name); if (!expectedFPtr) return expectedFPtr.takeError(); auto fptr = *expectedFPtr; diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -216,7 +216,7 @@ auto engine = std::move(*expectedEngine); engine->registerSymbols(runtimeSymbolMap); - auto expectedFPtr = engine->lookup(entryPoint); + auto expectedFPtr = engine->lookupPacked(entryPoint); if (!expectedFPtr) return expectedFPtr.takeError();