diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -42,6 +42,7 @@ set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner") set(MLIR_ROCM_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir ROCm runner") +set(MLIR_SPIRV_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir SPIR-V runner") set(MLIR_VULKAN_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir Vulkan runner") option(MLIR_INCLUDE_TESTS diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h --- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -65,6 +65,9 @@ /// Creates an execution engine for the given module. /// + /// If `llvmModuleBuilder` is provided, it will be used to create LLVM module + /// from the given MLIR module. + /// /// If `transformer` is provided, it will be called on the LLVM module during /// JIT-compilation and can be used, e.g., for reporting or optimization. /// @@ -84,6 +87,9 @@ /// the llvm's global Perf notification listener. static llvm::Expected> create(ModuleOp m, + llvm::function_ref(ModuleOp, + llvm::LLVMContext &)> + llvmModuleBuilder = nullptr, llvm::function_ref transformer = {}, Optional jitCodeGenOptLevel = llvm::None, ArrayRef sharedLibPaths = {}, bool enableObjectCache = true, diff --git a/mlir/include/mlir/ExecutionEngine/JitRunner.h b/mlir/include/mlir/ExecutionEngine/JitRunner.h --- a/mlir/include/mlir/ExecutionEngine/JitRunner.h +++ b/mlir/include/mlir/ExecutionEngine/JitRunner.h @@ -19,6 +19,7 @@ #define MLIR_SUPPORT_JITRUNNER_H_ #include "llvm/ADT/STLExtras.h" +#include "llvm/IR/Module.h" namespace mlir { @@ -26,12 +27,17 @@ struct LogicalResult; // Entry point for all CPU runners. Expects the common argc/argv arguments for -// standard C++ main functions and an mlirTransformer. -// The latter is applied after parsing the input into MLIR IR and before passing -// the MLIR module to the ExecutionEngine. +// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`. +/// `mlirTransformer` is applied after parsing the input into MLIR IR and before +/// passing the MLIR module to the ExecutionEngine. +/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine. +/// It processes MLIR module and creates LLVM IR module. int JitRunnerMain( int argc, char **argv, - llvm::function_ref mlirTransformer); + llvm::function_ref mlirTransformer, + llvm::function_ref(ModuleOp, + llvm::LLVMContext &)> + llvmModuleBuilder = nullptr); } // namespace mlir diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -214,7 +214,11 @@ : nullptr) {} Expected> ExecutionEngine::create( - ModuleOp m, llvm::function_ref transformer, + ModuleOp m, + llvm::function_ref(ModuleOp, + llvm::LLVMContext &)> + llvmModuleBuilder, + llvm::function_ref transformer, Optional jitCodeGenOptLevel, ArrayRef sharedLibPaths, bool enableObjectCache, bool enableGDBNotificationListener, bool enablePerfNotificationListener) { @@ -223,7 +227,8 @@ enablePerfNotificationListener); std::unique_ptr ctx(new llvm::LLVMContext); - auto llvmModule = translateModuleToLLVMIR(m, *ctx); + auto llvmModule = llvmModuleBuilder ? llvmModuleBuilder(m, *ctx) + : translateModuleToLLVMIR(m, *ctx); if (!llvmModule) return make_string_error("could not convert to LLVM IR"); // FIXME: the triple should be passed to the translation or dialect conversion diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -30,7 +30,6 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassNameParser.h" -#include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileUtilities.h" #include "llvm/Support/SourceMgr.h" @@ -132,18 +131,20 @@ } // JIT-compile the given module and run "entryPoint" with "args" as arguments. -static Error -compileAndExecute(Options &options, ModuleOp module, StringRef entryPoint, - std::function transformer, - void **args) { +static Error compileAndExecute( + Options &options, ModuleOp module, + function_ref(ModuleOp, llvm::LLVMContext &)> + llvmModuleBuilder, + StringRef entryPoint, + std::function transformer, void **args) { Optional jitCodeGenOptLevel; if (auto clOptLevel = getCommandLineOptLevel(options)) jitCodeGenOptLevel = static_cast(clOptLevel.getValue()); SmallVector libs(options.clSharedLibs.begin(), options.clSharedLibs.end()); - auto expectedEngine = mlir::ExecutionEngine::create(module, transformer, - jitCodeGenOptLevel, libs); + auto expectedEngine = mlir::ExecutionEngine::create( + module, llvmModuleBuilder, transformer, jitCodeGenOptLevel, libs); if (!expectedEngine) return expectedEngine.takeError(); @@ -164,13 +165,17 @@ } static Error compileAndExecuteVoidFunction( - Options &options, ModuleOp module, StringRef entryPoint, + Options &options, ModuleOp module, + function_ref(ModuleOp, llvm::LLVMContext &)> + llvmModuleBuilder, + StringRef entryPoint, std::function transformer) { auto mainFunction = module.lookupSymbol(entryPoint); if (!mainFunction || mainFunction.empty()) return make_string_error("entry point not found"); void *empty = nullptr; - return compileAndExecute(options, module, entryPoint, transformer, &empty); + return compileAndExecute(options, module, llvmModuleBuilder, entryPoint, + transformer, &empty); } template @@ -195,7 +200,10 @@ } template Error compileAndExecuteSingleReturnFunction( - Options &options, ModuleOp module, StringRef entryPoint, + Options &options, ModuleOp module, + function_ref(ModuleOp, llvm::LLVMContext &)> + llvmModuleBuilder, + StringRef entryPoint, std::function transformer) { auto mainFunction = module.lookupSymbol(entryPoint); if (!mainFunction || mainFunction.isExternal()) @@ -212,8 +220,8 @@ void *data; } data; data.data = &res; - if (auto error = compileAndExecute(options, module, entryPoint, transformer, - (void **)&data)) + if (auto error = compileAndExecute(options, module, llvmModuleBuilder, + entryPoint, transformer, (void **)&data)) return error; // Intentional printing of the output so we can test. @@ -222,13 +230,17 @@ return Error::success(); } -/// Entry point for all CPU runners. Expects the common argc/argv -/// arguments for standard C++ main functions and an mlirTransformer. -/// The latter is applied after parsing the input into MLIR IR and -/// before passing the MLIR module to the ExecutionEngine. +/// Entry point for all CPU runners. Expects the common argc/argv arguments for +/// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`. +/// `mlirTransformer` is applied after parsing the input into MLIR IR and before +/// passing the MLIR module to the ExecutionEngine. +/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine. +/// It processes MLIR module and creates LLVM IR module. int mlir::JitRunnerMain( int argc, char **argv, - function_ref mlirTransformer) { + function_ref mlirTransformer, + function_ref(ModuleOp, llvm::LLVMContext &)> + llvmModuleBuilder) { // Create the options struct containing the command line options for the // runner. This must come before the command line options are parsed. Options options; @@ -286,8 +298,10 @@ // Get the function used to compile and execute the module. using CompileAndExecuteFnT = - Error (*)(Options &, ModuleOp, StringRef, - std::function); + Error (*)(Options &, ModuleOp, + function_ref( + ModuleOp, llvm::LLVMContext &)>, + StringRef, std::function); auto compileAndExecuteFn = llvm::StringSwitch(options.mainFuncType.getValue()) .Case("i32", compileAndExecuteSingleReturnFunction) @@ -298,7 +312,7 @@ Error error = compileAndExecuteFn - ? compileAndExecuteFn(options, m.get(), + ? compileAndExecuteFn(options, m.get(), llvmModuleBuilder, options.mainFuncName.getValue(), transformer) : make_string_error("unsupported function type"); diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -23,6 +23,7 @@ # for the mlir cuda / rocm / vulkan runner tests. set(MLIR_CUDA_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) set(MLIR_ROCM_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) +set(MLIR_SPIRV_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) set(MLIR_VULKAN_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) configure_lit_site_cfg( @@ -80,6 +81,12 @@ ) endif() +if(MLIR_SPIRV_RUNNER_ENABLED) + list(APPEND MLIR_TEST_DEPENDS + mlir-spirv-runner + ) +endif() + if(MLIR_VULKAN_RUNNER_ENABLED) list(APPEND MLIR_TEST_DEPENDS mlir-vulkan-runner diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -74,6 +74,7 @@ ToolSubst('%linalg_test_lib_dir', config.linalg_test_lib_dir, unresolved='ignore'), ToolSubst('%mlir_runner_utils_dir', config.mlir_runner_utils_dir, unresolved='ignore'), ToolSubst('%rocm_wrapper_library_dir', config.rocm_wrapper_library_dir, unresolved='ignore'), + ToolSubst('%spirv_wrapper_library_dir', config.spirv_wrapper_library_dir, unresolved='ignore'), ToolSubst('%vulkan_wrapper_library_dir', config.vulkan_wrapper_library_dir, unresolved='ignore'), ]) diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -41,6 +41,8 @@ config.run_rocm_tests = @MLIR_ROCM_CONVERSIONS_ENABLED@ config.rocm_wrapper_library_dir = "@MLIR_ROCM_WRAPPER_LIBRARY_DIR@" config.enable_rocm_runner = @MLIR_ROCM_RUNNER_ENABLED@ +config.spirv_wrapper_library_dir = "@MLIR_SPIRV_WRAPPER_LIBRARY_DIR@" +config.enable_spirv_runner = @MLIR_SPIRV_RUNNER_ENABLED@ config.vulkan_wrapper_library_dir = "@MLIR_VULKAN_WRAPPER_LIBRARY_DIR@" config.enable_vulkan_runner = @MLIR_VULKAN_RUNNER_ENABLED@ config.enable_bindings_python = @MLIR_BINDINGS_PYTHON_ENABLED@ diff --git a/mlir/test/mlir-spirv-runner/lit.local.cfg b/mlir/test/mlir-spirv-runner/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-spirv-runner/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_spirv_runner: + config.unsupported = True diff --git a/mlir/test/mlir-spirv-runner/simple.mlir b/mlir/test/mlir-spirv-runner/simple.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-spirv-runner/simple.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-spirv-runner %s -e main --entry-point-result=void --shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%spirv_wrapper_library_dir/libruntime-wrappers%shlibext + +// CHECK: [8, 8, 8, 8, 8, 8] +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + gpu.module @kernels { + gpu.func @double(%arg0 : memref<6xi32>, %arg1 : memref<6xi32>) + kernel attributes { spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} { + %factor = constant 2 : i32 + + %i0 = constant 0 : index + %i1 = constant 1 : index + %i2 = constant 2 : index + %i3 = constant 3 : index + %i4 = constant 4 : index + %i5 = constant 5 : index + + %x0 = load %arg0[%i0] : memref<6xi32> + %x1 = load %arg0[%i1] : memref<6xi32> + %x2 = load %arg0[%i2] : memref<6xi32> + %x3 = load %arg0[%i3] : memref<6xi32> + %x4 = load %arg0[%i4] : memref<6xi32> + %x5 = load %arg0[%i5] : memref<6xi32> + + %y0 = muli %x0, %factor : i32 + %y1 = muli %x1, %factor : i32 + %y2 = muli %x2, %factor : i32 + %y3 = muli %x3, %factor : i32 + %y4 = muli %x4, %factor : i32 + %y5 = muli %x5, %factor : i32 + + store %y0, %arg1[%i0] : memref<6xi32> + store %y1, %arg1[%i1] : memref<6xi32> + store %y2, %arg1[%i2] : memref<6xi32> + store %y3, %arg1[%i3] : memref<6xi32> + store %y4, %arg1[%i4] : memref<6xi32> + store %y5, %arg1[%i5] : memref<6xi32> + gpu.return + } + } + func @main() { + %input = alloc() : memref<6xi32> + %output = alloc() : memref<6xi32> + %four = constant 4 : i32 + %zero = constant 0 : i32 + %input_casted = memref_cast %input : memref<6xi32> to memref + %output_casted = memref_cast %output : memref<6xi32> to memref + call @fillI32Buffer(%input_casted, %four) : (memref, i32) -> () + call @fillI32Buffer(%output_casted, %zero) : (memref, i32) -> () + + %one = constant 1 : index + "gpu.launch_func"(%one, %one, %one, + %one, %one, %one, + %input, %output) { kernel = @kernels::@double } + : (index, index, index, index, index, index, memref<6xi32>, memref<6xi32>) -> () + %result = memref_cast %output : memref<6xi32> to memref<*xi32> + call @print_memref_i32(%result) : (memref<*xi32>) -> () + return + } + + func @fillI32Buffer(%arg0 : memref, %arg1 : i32) + func @print_memref_i32(%ptr : memref<*xi32>) +} diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt --- a/mlir/tools/CMakeLists.txt +++ b/mlir/tools/CMakeLists.txt @@ -5,5 +5,6 @@ add_subdirectory(mlir-reduce) add_subdirectory(mlir-rocm-runner) add_subdirectory(mlir-shlib) +add_subdirectory(mlir-spirv-runner) add_subdirectory(mlir-translate) add_subdirectory(mlir-vulkan-runner) \ No newline at end of file diff --git a/mlir/tools/mlir-spirv-runner/CMakeLists.txt b/mlir/tools/mlir-spirv-runner/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-spirv-runner/CMakeLists.txt @@ -0,0 +1,42 @@ +set(LLVM_OPTIONAL_SOURCES + mlir-spirv-runner.cpp + runtime-wrappers.cpp + ) + +if (MLIR_SPIRV_RUNNER_ENABLED) + message(STATUS "Building SPIR-V runner") + + add_llvm_library(runtime-wrappers SHARED + runtime-wrappers.cpp + ) + + target_include_directories(runtime-wrappers + PUBLIC + ) + + add_llvm_tool(mlir-spirv-runner + mlir-spirv-runner.cpp + ) + + add_dependencies(mlir-spirv-runner runtime-wrappers) + llvm_update_compile_flags(mlir-spirv-runner) + + get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + + target_link_libraries(mlir-spirv-runner PRIVATE + ${conversion_libs} + ${dialect_libs} + MLIRAnalysis + MLIREDSC + MLIRExecutionEngine + MLIRIR + MLIRJitRunner + MLIRLLVMIR + MLIRParser + MLIRTargetLLVMIR + MLIRTransforms + MLIRTranslation + MLIRSupport + ) +endif() diff --git a/mlir/tools/mlir-spirv-runner/mlir-spirv-runner.cpp b/mlir/tools/mlir-spirv-runner/mlir-spirv-runner.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-spirv-runner/mlir-spirv-runner.cpp @@ -0,0 +1,97 @@ +//===- mlir-spirv-runner.cpp - MLIR SPIR-V Execution on CPU ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Main entry point to a command line utility that executes an MLIR file on the +// CPU by translating MLIR GPU module and host part to LLVM IR before +// JIT-compiling and executing. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h" +#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Dialect/SPIRV/Passes.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/ExecutionEngine/JitRunner.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Linker/Linker.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/TargetSelect.h" + +using namespace mlir; + +/// A utility function that builds llvm::Module from two nested MLIR modules. +/// +/// module @main { +/// module @kernel { +/// // Some ops +/// } +/// // Some other ops +/// } +/// +/// Each of these two modules is translated to LLVM IR module, then they are +/// linked together and returned. +static std::unique_ptr +convertMLIRModule(ModuleOp module, llvm::LLVMContext &context) { + std::unique_ptr kernelModule; + module.walk([&](ModuleOp nested) -> WalkResult { + if (nested.getParentOp()) { + // Found kernel module. Translate it to LLVM IR, remove from the main + // module and terminate. + kernelModule = translateModuleToLLVMIR(nested, context); + nested.erase(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + std::unique_ptr mainModule = + translateModuleToLLVMIR(module, context); + llvm::Linker::linkModules(*mainModule, std::move(kernelModule)); + return mainModule; +} + +static LogicalResult runMLIRPasses(ModuleOp module) { + PassManager passManager(module.getContext()); + applyPassManagerCLOptions(passManager); + passManager.addPass(createGpuKernelOutliningPass()); + passManager.addPass(createConvertGPUToSPIRVPass()); + + OpPassManager &nestedPM = passManager.nest(); + nestedPM.addPass(spirv::createLowerABIAttributesPass()); + nestedPM.addPass(spirv::createUpdateVersionCapabilityExtensionPass()); + nestedPM.addPass(spirv::createEncodeDescriptorSetsPass()); + + passManager.addPass(createGPULaunchFuncToStandardCallPass()); + LowerToLLVMOptions options = { + /*useBarePtrCallConv=*/false, + /*emitCWrappers=*/true, + /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout}; + passManager.addPass(createLowerToLLVMPass(options)); + passManager.addPass(createConvertSPIRVToLLVMPass()); + passManager.addPass(createEmulateKernelCallInLLVMPass()); + return passManager.run(module); +} + +int main(int argc, char **argv) { + mlir::registerAllDialects(); + llvm::InitLLVM y(argc, argv); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::initializeLLVMPasses(); + + return mlir::JitRunnerMain(argc, argv, &runMLIRPasses, &convertMLIRModule); +} diff --git a/mlir/tools/mlir-spirv-runner/runtime-wrappers.cpp b/mlir/tools/mlir-spirv-runner/runtime-wrappers.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-spirv-runner/runtime-wrappers.cpp @@ -0,0 +1,30 @@ +//===- runtime-wrappers.cpp - MLIR SPIR-V runner wrapper library ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A small library for SPIR-V runner. +// +//===----------------------------------------------------------------------===// + +#include "mlir/ExecutionEngine/RunnerUtils.h" + +// A struct corresponding to a MemRef argument after C-compatible wrapper +// emission. +template +struct MemRefDescriptor { + T *allocated; + T *aligned; + intptr_t offset; + intptr_t sizes[N]; + intptr_t strides[N]; +}; + +extern "C" void +_mlir_ciface_fillI32Buffer(MemRefDescriptor *mem_ref, + int32_t value) { + std::fill_n(mem_ref->allocated, mem_ref->sizes[0], value); +}