diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h --- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -101,9 +101,57 @@ /// pointer to it. Propagates errors in case of failure. llvm::Expected lookup(StringRef name) const; + /// Invokes the function with the given name passing it the list of opaque + /// pointers to the actual arguments. + llvm::Error invokePacked(StringRef name, + MutableArrayRef args = llvm::None); + + /// Trait that defines how a given type is passed to the JIT code. This + /// default to passing the address but can be specialized. + template + struct Argument { + static void pack(SmallVectorImpl &args, T &val) { + args.push_back(&val); + } + }; + + /// Tag to wrap an output parameter when invoking a jitted function. + template + struct Result { + Result(T &result) : value(result) {} + T &value; + }; + + // Specialization for output parameter: their address is forwarded directly to + // the native code. + template + struct Argument> { + static void pack(SmallVectorImpl &args, Result &result) { + args.push_back(&result.value); + } + }; + /// Invokes the function with the given name passing it the list of arguments - /// as a list of opaque pointers. - llvm::Error invoke(StringRef name, MutableArrayRef args = llvm::None); + /// by value. Function result can be obtain through output parameter using the + /// `Result` wrapper defined above. For example: + /// func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } { ... + /// } + /// can be invoked: + /// int result = 0; + /// llvm::Error error = jit->invoke("foo", 42, + /// ExecutionEngine::Result(result)); + template + llvm::Error invoke(StringRef funcName, Args... args) { + const std::string adapterName = + std::string("_mlir_ciface_") + funcName.str(); + llvm::SmallVector argsArray; + // Pack every arguments in an array of pointers. Delegate the packing to a + // trait so that it can be overridden per argument type. + // TODO: replace with a fold expression when migrating to C++17. + int dummy[] = {0, ((void)Argument::pack(argsArray, args), 0)...}; + (void)dummy; + return invokePacked(adapterName, argsArray); + } /// Set the target triple on the module. This is implicitly done when creating /// the engine. diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -339,7 +339,8 @@ return fptr; } -Error ExecutionEngine::invoke(StringRef name, MutableArrayRef args) { +Error ExecutionEngine::invokePacked(StringRef name, + MutableArrayRef args) { auto expectedFPtr = lookup(name); if (!expectedFPtr) return expectedFPtr.takeError(); diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -7,6 +7,7 @@ add_subdirectory(Analysis) add_subdirectory(Dialect) +add_subdirectory(ExecutionEngine) add_subdirectory(IR) add_subdirectory(Pass) add_subdirectory(SDBM) diff --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_unittest(MLIRExecutionEngineTests + Invoke.cpp +) +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + +target_link_libraries(MLIRExecutionEngineTests + PRIVATE + MLIRExecutionEngine + MLIRLinalgToLLVM + ${dialect_libs} + +) diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/ExecutionEngine/Invoke.cpp @@ -0,0 +1,91 @@ +//===- NativeMemrefJit.cpp ------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/RunnerUtils.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" + +#include "gmock/gmock.h" + +using namespace mlir; + +static struct LLVMInitializer { + LLVMInitializer() { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + } +} initializer; + +/// Simple conversion pipeline for the purpose of testing sources written in +/// dialects lowering to LLVM Dialect. +static LogicalResult lowerToLLVMDialect(ModuleOp module) { + PassManager pm(module.getContext()); + pm.addNestedPass(mlir::createConvertVectorToSCFPass()); + pm.addNestedPass(mlir::createConvertLinalgToLoopsPass()); + pm.addPass(mlir::createConvertLinalgToLLVMPass()); + pm.addPass(mlir::createConvertVectorToLLVMPass()); + pm.addPass(mlir::createLowerToLLVMPass()); + return pm.run(module); +} + +TEST(MLIRExecutionEngine, AddInteger) { + std::string moduleStr = R"mlir( + func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } { + %res = std.addi %arg0, %arg0 : i32 + return %res : i32 + } + )mlir"; + MLIRContext context; + registerAllDialects(context.getDialectRegistry()); + OwningModuleRef module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + auto jitOrError = ExecutionEngine::create(*module); + ASSERT_TRUE(!!jitOrError); + std::unique_ptr jit = std::move(jitOrError.get()); + // The result of the function must be passed as output argument. + int result = 0; + llvm::Error error = + jit->invoke("foo", 42, ExecutionEngine::Result(result)); + ASSERT_TRUE(!error); + ASSERT_EQ(result, 42 + 42); +} + +TEST(MLIRExecutionEngine, SubtractFloat) { + std::string moduleStr = R"mlir( + func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } { + %res = std.subf %arg0, %arg1 : f32 + return %res : f32 + } + )mlir"; + MLIRContext context; + registerAllDialects(context.getDialectRegistry()); + OwningModuleRef module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + auto jitOrError = ExecutionEngine::create(*module); + ASSERT_TRUE(!!jitOrError); + std::unique_ptr jit = std::move(jitOrError.get()); + // The result of the function must be passed as output argument. + float result = -1; + llvm::Error error = + jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::Result(result)); + ASSERT_TRUE(!error); + ASSERT_EQ(result, 42.f); +}