diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -36,6 +36,7 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/StringSaver.h" #include "llvm/Support/ToolOutputFile.h" +#include #include using namespace mlir; @@ -54,7 +55,7 @@ llvm::cl::opt mainFuncType{ "entry-point-result", llvm::cl::desc("Textual description of the function type to be called"), - llvm::cl::value_desc("f32 | void"), llvm::cl::init("f32")}; + llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")}; llvm::cl::OptionCategory optFlags{"opt-like flags"}; @@ -172,7 +173,28 @@ return compileAndExecute(options, module, entryPoint, transformer, &empty); } -static Error compileAndExecuteSingleFloatReturnFunction( +template +Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); +template <> +Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { + if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32)) + return make_string_error("only single llvm.i32 function result supported"); + return Error::success(); +} +template <> +Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { + if (!mainFunction.getType().getFunctionResultType().isIntegerTy(64)) + return make_string_error("only single llvm.i64 function result supported"); + return Error::success(); +} +template <> +Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { + if (!mainFunction.getType().getFunctionResultType().isFloatTy()) + return make_string_error("only single llvm.f32 function result supported"); + return Error::success(); +} +template +Error compileAndExecuteSingleReturnFunction( Options &options, ModuleOp module, StringRef entryPoint, std::function transformer) { auto mainFunction = module.lookupSymbol(entryPoint); @@ -182,10 +204,10 @@ if (mainFunction.getType().getFunctionNumParams() != 0) return make_string_error("function inputs not supported"); - if (!mainFunction.getType().getFunctionResultType().isFloatTy()) - return make_string_error("only single llvm.f32 function result supported"); + if (Error error = checkCompatibleReturnType(mainFunction)) + return error; - float res; + Type res; struct { void *data; } data; @@ -196,13 +218,14 @@ // Intentional printing of the output so we can test. llvm::outs() << res << '\n'; + return Error::success(); } -/// Entry point for all CPU runners. Expects the common argc/argv arguments for -/// standard C++ main functions and an mlirTransformer. -/// The latter is applied after parsing the input into MLIR IR and before -/// passing the MLIR module to the ExecutionEngine. +/// Entry point for all CPU runners. Expects the common argc/argv +/// arguments for standard C++ main functions and an mlirTransformer. +/// The latter is applied after parsing the input into MLIR IR and +/// before passing the MLIR module to the ExecutionEngine. int mlir::JitRunnerMain( int argc, char **argv, function_ref mlirTransformer) { @@ -267,7 +290,9 @@ std::function); auto compileAndExecuteFn = llvm::StringSwitch(options.mainFuncType.getValue()) - .Case("f32", compileAndExecuteSingleFloatReturnFunction) + .Case("i32", compileAndExecuteSingleReturnFunction) + .Case("i64", compileAndExecuteSingleReturnFunction) + .Case("f32", compileAndExecuteSingleReturnFunction) .Case("void", compileAndExecuteVoidFunction) .Default(nullptr); diff --git a/mlir/test/mlir-cpu-runner/simple.mlir b/mlir/test/mlir-cpu-runner/simple.mlir --- a/mlir/test/mlir-cpu-runner/simple.mlir +++ b/mlir/test/mlir-cpu-runner/simple.mlir @@ -1,5 +1,7 @@ // RUN: mlir-cpu-runner %s | FileCheck %s // RUN: mlir-cpu-runner %s -e foo | FileCheck -check-prefix=NOMAIN %s +// RUN: mlir-cpu-runner %s --entry-point-result=i32 -e int32_main | FileCheck -check-prefix=INT32MAIN %s +// RUN: mlir-cpu-runner %s --entry-point-result=i64 -e int64_main | FileCheck -check-prefix=INT64MAIN %s // RUN: mlir-cpu-runner %s -O3 | FileCheck %s // RUN: cp %s %t @@ -51,3 +53,17 @@ llvm.return %5 : !llvm.float } // NOMAIN: 1.234000e+03 + +// Check that i32 return type works +llvm.func @int32_main() -> !llvm.i32 { + %0 = llvm.mlir.constant(42 : i32) : !llvm.i32 + llvm.return %0 : !llvm.i32 +} +// INT32MAIN: 42 + +// Check that i64 return type works +llvm.func @int64_main() -> !llvm.i64 { + %0 = llvm.mlir.constant(42 : i64) : !llvm.i64 + llvm.return %0 : !llvm.i64 +} +// INT64MAIN: 42