diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -37,6 +37,7 @@ #include "llvm/Support/StringSaver.h" #include "llvm/Support/ToolOutputFile.h" #include +#include using namespace mlir; using llvm::Error; @@ -54,7 +55,7 @@ llvm::cl::opt mainFuncType{ "entry-point-result", llvm::cl::desc("Textual description of the function type to be called"), - llvm::cl::value_desc("f32 | void"), llvm::cl::init("f32")}; + llvm::cl::value_desc("f32 | i32 | void"), llvm::cl::init("f32")}; llvm::cl::OptionCategory optFlags{"opt-like flags"}; @@ -172,7 +173,22 @@ return compileAndExecute(options, module, entryPoint, transformer, &empty); } -static Error compileAndExecuteSingleFloatReturnFunction( +template +Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); +template<> +Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { + if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32)) + return make_string_error("only single llvm.i32 function result supported"); + return Error::success(); +} +template<> +Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { + if (!mainFunction.getType().getFunctionResultType().isFloatTy()) + return make_string_error("only single llvm.f32 function result supported"); + return Error::success(); +} +template +Error compileAndExecuteSingleReturnFunction( Options &options, ModuleOp module, StringRef entryPoint, std::function transformer) { auto mainFunction = module.lookupSymbol(entryPoint); @@ -182,10 +198,10 @@ if (mainFunction.getType().getFunctionNumParams() != 0) return make_string_error("function inputs not supported"); - if (!mainFunction.getType().getFunctionResultType().isFloatTy()) - return make_string_error("only single llvm.f32 function result supported"); + if (Error error = checkCompatibleReturnType(mainFunction)) + return error; - float res; + Type res; struct { void *data; } data; @@ -196,13 +212,14 @@ // Intentional printing of the output so we can test. llvm::outs() << res << '\n'; + return Error::success(); } -/// Entry point for all CPU runners. Expects the common argc/argv arguments for -/// standard C++ main functions and an mlirTransformer. -/// The latter is applied after parsing the input into MLIR IR and before -/// passing the MLIR module to the ExecutionEngine. +// Entry point for all CPU runners. Expects the common argc/argv arguments for +// standard C++ main functions and an mlirTransformer. +// The latter is applied after parsing the input into MLIR IR and before passing +// the MLIR module to the ExecutionEngine. int mlir::JitRunnerMain( int argc, char **argv, function_ref mlirTransformer) { @@ -266,9 +283,10 @@ Error (*)(Options &, ModuleOp, StringRef, std::function); auto compileAndExecuteFn = - llvm::StringSwitch(options.mainFuncType.getValue()) - .Case("f32", compileAndExecuteSingleFloatReturnFunction) - .Case("void", compileAndExecuteVoidFunction) + llvm::StringSwitch(options.mainFuncType.getValue()) + .Case("i32", compileAndExecuteSingleReturnFunction) + .Case("f32", compileAndExecuteSingleReturnFunction) + .Case("void", compileAndExecuteVoidFunction) .Default(nullptr); Error error = diff --git a/mlir/test/mlir-cpu-runner/simple.mlir b/mlir/test/mlir-cpu-runner/simple.mlir --- a/mlir/test/mlir-cpu-runner/simple.mlir +++ b/mlir/test/mlir-cpu-runner/simple.mlir @@ -1,5 +1,6 @@ // RUN: mlir-cpu-runner %s | FileCheck %s // RUN: mlir-cpu-runner %s -e foo | FileCheck -check-prefix=NOMAIN %s +// RUN: mlir-cpu-runner %s --entry-point-result=i32 -e int_main | FileCheck -check-prefix=INTMAIN %s // RUN: mlir-cpu-runner %s -O3 | FileCheck %s // RUN: cp %s %t @@ -51,3 +52,10 @@ llvm.return %5 : !llvm.float } // NOMAIN: 1.234000e+03 + +// Check that i32 return type works +llvm.func @int_main() -> !llvm.i32 { + %0 = llvm.mlir.constant(42 : i32) : !llvm.i32 + llvm.return %0 : !llvm.i32 +} +// INTMAIN: 42