diff --git a/mlir/lib/Support/JitRunner.cpp b/mlir/lib/Support/JitRunner.cpp --- a/mlir/lib/Support/JitRunner.cpp +++ b/mlir/lib/Support/JitRunner.cpp @@ -22,7 +22,6 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/InitAllDialects.h" #include "mlir/Parser.h" #include "mlir/Support/FileUtilities.h" @@ -34,10 +33,8 @@ #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileUtilities.h" -#include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/StringSaver.h" -#include "llvm/Support/TargetSelect.h" #include "llvm/Support/ToolOutputFile.h" #include @@ -108,12 +105,6 @@ return OwningModuleRef(parseSourceFile(sourceMgr, context)); } -// Initialize the relevant subsystems of LLVM. -static void initializeLLVM() { - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); -} - static inline Error make_string_error(const Twine &message) { return llvm::make_error(message.str(), llvm::inconvertibleErrorCode()); @@ -210,12 +201,6 @@ int mlir::JitRunnerMain( int argc, char **argv, function_ref mlirTransformer) { - registerAllDialects(); - llvm::InitLLVM y(argc, argv); - - initializeLLVM(); - mlir::initializeLLVMPasses(); - llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); Optional optLevel = getCommandLineOptLevel(); diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp --- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp +++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp @@ -12,8 +12,18 @@ // //===----------------------------------------------------------------------===// +#include "mlir/InitAllDialects.h" #include "mlir/Support/JitRunner.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/TargetSelect.h" +#include "mlir/ExecutionEngine/OptUtils.h" int main(int argc, char **argv) { + mlir::registerAllDialects(); + llvm::InitLLVM y(argc, argv); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::initializeLLVMPasses(); + return mlir::JitRunnerMain(argc, argv, nullptr); } diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -22,13 +22,17 @@ #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" +#include "mlir/InitAllDialects.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/JitRunner.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/TargetSelect.h" #include "cuda.h" @@ -118,5 +122,10 @@ int main(int argc, char **argv) { registerPassManagerCLOptions(); + mlir::registerAllDialects(); + llvm::InitLLVM y(argc, argv); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::initializeLLVMPasses(); return mlir::JitRunnerMain(argc, argv, &runMLIRPasses); } diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -19,9 +19,13 @@ #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/SPIRV/Passes.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/InitAllDialects.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/JitRunner.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/TargetSelect.h" using namespace mlir; @@ -42,5 +46,12 @@ int main(int argc, char **argv) { llvm::llvm_shutdown_obj x; registerPassManagerCLOptions(); + + mlir::registerAllDialects(); + llvm::InitLLVM y(argc, argv); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::initializeLLVMPasses(); + return mlir::JitRunnerMain(argc, argv, &runMLIRPasses); }