diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -42,6 +42,7 @@ set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner") set(MLIR_ROCM_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir ROCm runner") +set(MLIR_SPIRV_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir SPIR-V runner") set(MLIR_VULKAN_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir Vulkan runner") option(MLIR_INCLUDE_TESTS diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h --- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h +++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h @@ -38,6 +38,12 @@ using LoweringCallback = std::function( Operation *, llvm::LLVMContext &, StringRef)>; +/// Creates a pass to convert `gpu.launch_func` to a standard function call. +std::unique_ptr> +createGPULaunchFuncToStandardCallPass(); +/// Creates a pass to emulate a kernel function call in LLVM dialect. +std::unique_ptr> createEmulateKernelCallInLLVMPass(); + /// Creates a pass to convert a gpu.launch_func operation into a sequence of /// GPU runtime calls. /// diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -91,6 +91,16 @@ ]; } +def GPULaunchFuncToStandardCall : Pass<"gpu-launch-to-std-call", "ModuleOp"> { + let summary = "Convert gpu.launch_func to a standard dialect call"; + let constructor = "mlir::createGPULaunchFuncToStandardCallPass()"; +} + +def EmulateKernelCallInLLVM : Pass<"emulate-kernel-llvm-call", "ModuleOp"> { + let summary = "Emulate a kernel function call in LLVM dialect"; + let constructor = "mlir::createEmulateKernelCallInLLVMPass()"; +} + //===----------------------------------------------------------------------===// // GPUToNVVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h --- a/mlir/include/mlir/Dialect/SPIRV/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h @@ -26,6 +26,11 @@ std::unique_ptr> createDecorateSPIRVCompositeTypeLayoutPass(); +/// Creates a module pass that encodes bind attribute of each +/// `spv.globalVariable` into its symbolic name. +std::unique_ptr> +createEncodeDescriptorSetsPass(); + /// Creates an operation pass that deduces and attaches the minimal version/ /// capabilities/extensions requirements for spv.module ops. /// For each spv.module op, this pass requires a `spv.target_env` attribute on diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Passes.td --- a/mlir/include/mlir/Dialect/SPIRV/Passes.td +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.td @@ -17,6 +17,13 @@ let constructor = "mlir::spirv::createDecorateSPIRVCompositeTypeLayoutPass()"; } +def SPIRVEncodeDescriptorSets + : Pass<"spirv-encode-descriptor-sets", "spirv::ModuleOp"> { + let summary = "Encode `spv.globalVariable`'s' bind attribute into its " + "symbolic name"; + let constructor = "mlir::spirv::createEncodeDescriptorSetsPass()"; +} + def SPIRVLowerABIAttributes : Pass<"spirv-lower-abi-attrs", "spirv::ModuleOp"> { let summary = "Decorate SPIR-V composite type with layout info"; let constructor = "mlir::spirv::createLowerABIAttributesPass()"; diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h --- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -65,6 +65,9 @@ /// Creates an execution engine for the given module. /// + /// If `llvmModuleBuilder` is provided, it will be used to create LLVM module + /// from the given MLIR module. + /// /// If `transformer` is provided, it will be called on the LLVM module during /// JIT-compilation and can be used, e.g., for reporting or optimization. /// @@ -84,6 +87,9 @@ /// the llvm's global Perf notification listener. static llvm::Expected> create(ModuleOp m, + llvm::function_ref(ModuleOp, + llvm::LLVMContext &)> + llvmModuleBuilder = nullptr, llvm::function_ref transformer = {}, Optional jitCodeGenOptLevel = llvm::None, ArrayRef sharedLibPaths = {}, bool enableObjectCache = true, diff --git a/mlir/include/mlir/ExecutionEngine/JitRunner.h b/mlir/include/mlir/ExecutionEngine/JitRunner.h --- a/mlir/include/mlir/ExecutionEngine/JitRunner.h +++ b/mlir/include/mlir/ExecutionEngine/JitRunner.h @@ -19,6 +19,7 @@ #define MLIR_SUPPORT_JITRUNNER_H_ #include "llvm/ADT/STLExtras.h" +#include "llvm/IR/Module.h" namespace mlir { @@ -26,12 +27,17 @@ struct LogicalResult; // Entry point for all CPU runners. Expects the common argc/argv arguments for -// standard C++ main functions and an mlirTransformer. -// The latter is applied after parsing the input into MLIR IR and before passing -// the MLIR module to the ExecutionEngine. +// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`. +/// `mlirTransformer` is applied after parsing the input into MLIR IR and before +/// passing the MLIR module to the ExecutionEngine. +/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine. +/// It processes MLIR module and creates LLVM IR module. int JitRunnerMain( int argc, char **argv, - llvm::function_ref mlirTransformer); + llvm::function_ref mlirTransformer, + llvm::function_ref(ModuleOp, + llvm::LLVMContext &)> + llvmModuleBuilder = nullptr); } // namespace mlir diff --git a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt --- a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt @@ -15,6 +15,7 @@ endif() add_mlir_conversion_library(MLIRGPUToGPURuntimeTransforms + ConvertLaunchFuncToLLVMCalls.cpp ConvertLaunchFuncToRuntimeCalls.cpp ConvertKernelFuncToBlob.cpp diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToLLVMCalls.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToLLVMCalls.cpp @@ -0,0 +1,300 @@ +//===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements passes to convert `gpu.launch_func` op into a sequence +// of LLVM calls that emulate the host and device sides. +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; + +static constexpr const char kLLVMLaunchName[] = "__llvm_launch"; +static constexpr const char kStandardLaunchName[] = "__std_launch"; +static constexpr const unsigned memRefDescriptorNumElements = 5; +static constexpr const unsigned numElementsOffset = 3; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// Erases redundant functions from the module. +static void clearFunctionDefinitions(ModuleOp module) { + for (auto funcOp : + llvm::make_early_inc_range(module.getOps())) { + StringRef name = funcOp.getName(); + if (name == kStandardLaunchName || + name == + StringRef(llvm::formatv("_mlir_ciface_{0}", kStandardLaunchName))) + funcOp.erase(); + } +} + +/// Emits code to copy the given number of bytes from src to dst pointers. +static void copyData(Location loc, Value dst, Value src, Value size, + OpBuilder &builder) { + MLIRContext *context = builder.getContext(); + auto llvmI1Type = LLVM::LLVMType::getInt1Ty(context); + Value isVolatile = builder.create( + loc, llvmI1Type, builder.getBoolAttr(false)); + builder.create(loc, dst, src, size, isVolatile); +} + +/// Returns true if the function call is ex-`gpu.launch_func` op. +static bool isLaunchCallOp(LLVM::CallOp callOp) { + return callOp.callee() && callOp.callee().getValue() == kStandardLaunchName; +} + +/// Verifies if the module contains exactly one nested module with exactly one +/// (kernel) function. +static bool hasOneNestedModuleAndOneKernel(ModuleOp module) { + bool hasOneNestedModule = false; + auto walkResult = + module.walk([&hasOneNestedModule](ModuleOp moduleOp) -> WalkResult { + // Interrupt if more than one nested module has been found. + if (moduleOp.getParentOp() && hasOneNestedModule) + return WalkResult::interrupt(); + + // If there is a parent operation, it means we walked to a nested + // module. Verify there is only a single function in it. + if (moduleOp.getParentOp()) { + auto funcs = moduleOp.getOps(); + SmallVector funcVector(funcs.begin(), + funcs.end()); + if (funcVector.size() != 1) + return WalkResult::interrupt(); + hasOneNestedModule = true; + } + // Otherwise, advance. + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) + return false; + return hasOneNestedModule; +} + +static LogicalResult renameKernel(ModuleOp kernelModule) { + kernelModule.walk([&](LLVM::LLVMFuncOp funcOp) { + SymbolTable::setSymbolName(funcOp, kLLVMLaunchName); + }); + return success(); +} + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { +class GPULaunchFuncToStandardCall + : public GPULaunchFuncToStandardCallBase { +public: + void runOnOperation() override; +}; + +class EmulateKernelCallInLLVM + : public EmulateKernelCallInLLVMBase { +public: + void runOnOperation() override; +}; + +/// This pattern prepares the conversion of kernel launch to LLVM dialect. It: +/// - Declares a placeholder function with the same argument types and a void +/// result type. +/// - Replaces `gpu.launch_func` with a call to placeholder. +class KernelCallPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(gpu::LaunchFuncOp launchOp, + PatternRewriter &rewriter) const override { + MLIRContext *context = rewriter.getContext(); + auto module = launchOp.getParentOfType(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + // Declare a standard function that will keep kernel memref arguments. + SmallVector gpuLaunchTypes(launchOp.getOperandTypes()); + SmallVector kernelOperands(gpuLaunchTypes.begin() + + gpu::LaunchOp::kNumConfigOperands, + gpuLaunchTypes.end()); + auto kernelFunc = + rewriter.create(rewriter.getUnknownLoc(), kStandardLaunchName, + FunctionType::get(kernelOperands, {}, context)); + + // Replace `gpu.launch_func` with a standard function call. + rewriter.setInsertionPoint(launchOp); + SmallVector operands; + for (unsigned i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) + operands.push_back(launchOp.getKernelOperand(i)); + rewriter.replaceOpWithNewOp(launchOp, kernelFunc, operands); + return success(); + } +}; + +/// This pattern emulates a call to kernel in LLVM dialect. For that, we +/// copy the data to the global variable (emulating device side), call +/// the kernel as a normal void LLVM function, and copy the data back +/// (emulating host side). +class StandardCallPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallOp op, + PatternRewriter &rewriter) const override { + MLIRContext *context = rewriter.getContext(); + auto module = op.getParentOfType(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + // Declare kernel function in the main module so that it later can be linked + // with its definition from the kernel module. We know that the kernel + // function would have no arguments and the data is passed via global + // variables. + auto kernelFunc = rewriter.create( + rewriter.getUnknownLoc(), kLLVMLaunchName, + LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(context), + ArrayRef(), + /*isVarArg=*/false)); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + // Find ex-GPU module (now a nested LLVM module). Rename the kernel function + // to adhere to convention. + ModuleOp kernelModule; + auto walkResult = module.walk([&](ModuleOp nested) -> WalkResult { + if (nested.getParentOp()) { + kernelModule = nested; + renameKernel(kernelModule); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (!walkResult.wasInterrupted()) + return failure(); + + // Kernel's memref arguments are converted into MemRefDescriptors in LLVM. + // Walk over pointers to the allocated data and the number of elements to + // calculate the number of bytes to copy. + SmallVector copyTriples; + for (unsigned i = 0, e = op.getNumOperands(); i < e; + i += memRefDescriptorNumElements) { + Value src = op.getOperand(i); + + // Calculate a size of buffer in bytes. Support only integers for now for + // easier calculation. + auto elementType = + src.getType().cast().getPointerElementTy(); + auto integerType = elementType.dyn_cast(); + if (!integerType) + return failure(); + unsigned elementSizeInBytes = integerType.getBitWidth() / 8; + auto llvmI64Type = LLVM::LLVMType::getInt64Ty(context); + Value constantSize = rewriter.create( + loc, llvmI64Type, rewriter.getI64IntegerAttr(elementSizeInBytes)); + Value numElements = op.getOperand(i + numElementsOffset); + Value size = rewriter.create(loc, numElements, constantSize); + + // Look up the global variable from the kernel module. For each, create a + // global variable that will be linked with the global from the kernel + // module. We follow convention taken in EncodeDescriptorSetsPass to label + // globals as __var__{i}, where i is incremented sequentially. + std::string sym_name = + "__var__" + std::to_string(i / memRefDescriptorNumElements); + auto extrenalGlobal = kernelModule.lookupSymbol(sym_name); + rewriter.setInsertionPointToStart(module.getBody()); + auto global = rewriter.create( + loc, extrenalGlobal.type().cast(), + /*isConstant=*/false, LLVM::Linkage::Linkonce, sym_name, Attribute()); + rewriter.setInsertionPoint(op); + + // Get the address of created global. Then copy data from src pointer to + // dst pointer. + Value dst = rewriter.create(loc, global); + copyData(loc, dst, src, size, rewriter); + + // Store src and dst pointers to reuse after the kernel call. + copyTriples.push_back({dst, src, size}); + } + + // Emulate the kernel call and then copy the data back from global variables + // to buffers. + rewriter.replaceOpWithNewOp(op, kernelFunc, + ArrayRef()); + for (auto triple : copyTriples) { + copyData(loc, triple[1], triple[0], triple[2], rewriter); + } + return success(); + } +}; +} // namespace + +void GPULaunchFuncToStandardCall::runOnOperation() { + ModuleOp module = getOperation(); + auto *context = module.getContext(); + OwningRewritePatternList patterns; + patterns.insert(context); + + // Convert `gpu.launch_func` to standard function call. + ConversionTarget target(*context); + target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + if (failed(applyPartialConversion(module, target, patterns))) { + signalPassFailure(); + } + + // Erase GPU module. + for (auto gpuModule : + llvm::make_early_inc_range(getOperation().getOps())) + gpuModule.erase(); +} + +void EmulateKernelCallInLLVM::runOnOperation() { + ModuleOp module = getOperation(); + + // Check if the module structure can be supported by the current conversion. + // For now, only 2 modules (main and kernel) are supported. + if (!hasOneNestedModuleAndOneKernel(module)) + module.emitError("Module should contain exactly one nested module"); + + // Emulate kernel call. + auto *context = module.getContext(); + OwningRewritePatternList patterns; + patterns.insert(context); + ConversionTarget target(*context); + target.addLegalDialect(); + target.addDynamicallyLegalOp( + [&](LLVM::CallOp op) { return !isLaunchCallOp(op); }); + if (failed(applyPartialConversion(module, target, patterns))) { + signalPassFailure(); + } + + // Remove unused functions that were generated during the previous pass. + clearFunctionDefinitions(module); +} + +std::unique_ptr> +mlir::createGPULaunchFuncToStandardCallPass() { + return std::make_unique(); +} + +std::unique_ptr> +mlir::createEmulateKernelCallInLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSPIRVTransforms DecorateSPIRVCompositeTypeLayoutPass.cpp + EncodeDescriptorSetsPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp UpdateVCEPass.cpp diff --git a/mlir/lib/Dialect/SPIRV/Transforms/EncodeDescriptorSetsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/EncodeDescriptorSetsPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Transforms/EncodeDescriptorSetsPass.cpp @@ -0,0 +1,75 @@ +//===- EncodeDescriptorSetsPass.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass encodes set and binding of the global variable into its symbolic +// name. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SPIRV/Passes.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" + +#define BINDING_NAME \ + llvm::convertToSnakeFromCamelCase( \ + stringifyDecoration(spirv::Decoration::Binding)) +#define DESCRIPTOR_SET_NAME \ + llvm::convertToSnakeFromCamelCase( \ + stringifyDecoration(spirv::Decoration::DescriptorSet)) + +using namespace mlir; + +/// Returns true if the given global variable has both a descriptor set numer +/// and a binding number. +static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) { + IntegerAttr descriptorSet = + op.getAttrOfType(DESCRIPTOR_SET_NAME); + IntegerAttr binding = op.getAttrOfType(BINDING_NAME); + return descriptorSet && binding; +} + +namespace { + +class EncodeDescriptorSetsPass + : public SPIRVEncodeDescriptorSetsBase { +public: + void runOnOperation() override { + spirv::ModuleOp module = getOperation(); + unsigned i = 0; + + // Walk over all `spv.globalVariable` ops within the module. If the variable + // has a bind attribute, update the symbol and all symbol's uses. + module.walk([&](spirv::GlobalVariableOp op) { + if (hasDescriptorSetAndBinding(op)) { + // For now, we do not need to store bind attribute info. May need + // to revisit in the future. + // TODO: encode the data from the descriptor set and binding + // numbers in the global variable. + std::string name = llvm::formatv("__var__{0}", i++); + if (failed(SymbolTable::replaceAllSymbolUses(op, name, module))) + return signalPassFailure(); + + // Set new symbol name and remove unused attributes. + SymbolTable::setSymbolName(op, name); + op.removeAttr(DESCRIPTOR_SET_NAME); + op.removeAttr(BINDING_NAME); + } + }); + } +}; +} // namespace + +std::unique_ptr> +mlir::spirv::createEncodeDescriptorSetsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -214,7 +214,11 @@ : nullptr) {} Expected> ExecutionEngine::create( - ModuleOp m, llvm::function_ref transformer, + ModuleOp m, + llvm::function_ref(ModuleOp, + llvm::LLVMContext &)> + llvmModuleBuilder, + llvm::function_ref transformer, Optional jitCodeGenOptLevel, ArrayRef sharedLibPaths, bool enableObjectCache, bool enableGDBNotificationListener, bool enablePerfNotificationListener) { @@ -223,7 +227,8 @@ enablePerfNotificationListener); std::unique_ptr ctx(new llvm::LLVMContext); - auto llvmModule = translateModuleToLLVMIR(m, *ctx); + auto llvmModule = llvmModuleBuilder ? llvmModuleBuilder(m, *ctx) + : translateModuleToLLVMIR(m, *ctx); if (!llvmModule) return make_string_error("could not convert to LLVM IR"); // FIXME: the triple should be passed to the translation or dialect conversion diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -30,7 +30,6 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassNameParser.h" -#include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileUtilities.h" #include "llvm/Support/SourceMgr.h" @@ -132,18 +131,20 @@ } // JIT-compile the given module and run "entryPoint" with "args" as arguments. -static Error -compileAndExecute(Options &options, ModuleOp module, StringRef entryPoint, - std::function transformer, - void **args) { +static Error compileAndExecute( + Options &options, ModuleOp module, + function_ref(ModuleOp, llvm::LLVMContext &)> + llvmModuleBuilder, + StringRef entryPoint, + std::function transformer, void **args) { Optional jitCodeGenOptLevel; if (auto clOptLevel = getCommandLineOptLevel(options)) jitCodeGenOptLevel = static_cast(clOptLevel.getValue()); SmallVector libs(options.clSharedLibs.begin(), options.clSharedLibs.end()); - auto expectedEngine = mlir::ExecutionEngine::create(module, transformer, - jitCodeGenOptLevel, libs); + auto expectedEngine = mlir::ExecutionEngine::create( + module, llvmModuleBuilder, transformer, jitCodeGenOptLevel, libs); if (!expectedEngine) return expectedEngine.takeError(); @@ -164,13 +165,17 @@ } static Error compileAndExecuteVoidFunction( - Options &options, ModuleOp module, StringRef entryPoint, + Options &options, ModuleOp module, + function_ref(ModuleOp, llvm::LLVMContext &)> + llvmModuleBuilder, + StringRef entryPoint, std::function transformer) { auto mainFunction = module.lookupSymbol(entryPoint); if (!mainFunction || mainFunction.empty()) return make_string_error("entry point not found"); void *empty = nullptr; - return compileAndExecute(options, module, entryPoint, transformer, &empty); + return compileAndExecute(options, module, llvmModuleBuilder, entryPoint, + transformer, &empty); } template @@ -195,7 +200,10 @@ } template Error compileAndExecuteSingleReturnFunction( - Options &options, ModuleOp module, StringRef entryPoint, + Options &options, ModuleOp module, + function_ref(ModuleOp, llvm::LLVMContext &)> + llvmModuleBuilder, + StringRef entryPoint, std::function transformer) { auto mainFunction = module.lookupSymbol(entryPoint); if (!mainFunction || mainFunction.isExternal()) @@ -212,8 +220,8 @@ void *data; } data; data.data = &res; - if (auto error = compileAndExecute(options, module, entryPoint, transformer, - (void **)&data)) + if (auto error = compileAndExecute(options, module, llvmModuleBuilder, + entryPoint, transformer, (void **)&data)) return error; // Intentional printing of the output so we can test. @@ -222,13 +230,17 @@ return Error::success(); } -/// Entry point for all CPU runners. Expects the common argc/argv -/// arguments for standard C++ main functions and an mlirTransformer. -/// The latter is applied after parsing the input into MLIR IR and -/// before passing the MLIR module to the ExecutionEngine. +/// Entry point for all CPU runners. Expects the common argc/argv arguments for +/// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`. +/// `mlirTransformer` is applied after parsing the input into MLIR IR and before +/// passing the MLIR module to the ExecutionEngine. +/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine. +/// It processes MLIR module and creates LLVM IR module. int mlir::JitRunnerMain( int argc, char **argv, - function_ref mlirTransformer) { + function_ref mlirTransformer, + function_ref(ModuleOp, llvm::LLVMContext &)> + llvmModuleBuilder) { // Create the options struct containing the command line options for the // runner. This must come before the command line options are parsed. Options options; @@ -286,8 +298,10 @@ // Get the function used to compile and execute the module. using CompileAndExecuteFnT = - Error (*)(Options &, ModuleOp, StringRef, - std::function); + Error (*)(Options &, ModuleOp, + function_ref( + ModuleOp, llvm::LLVMContext &)>, + StringRef, std::function); auto compileAndExecuteFn = llvm::StringSwitch(options.mainFuncType.getValue()) .Case("i32", compileAndExecuteSingleReturnFunction) @@ -298,7 +312,7 @@ Error error = compileAndExecuteFn - ? compileAndExecuteFn(options, m.get(), + ? compileAndExecuteFn(options, m.get(), llvmModuleBuilder, options.mainFuncName.getValue(), transformer) : make_string_error("unsupported function type"); diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -23,6 +23,7 @@ # for the mlir cuda / rocm / vulkan runner tests. set(MLIR_CUDA_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) set(MLIR_ROCM_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) +set(MLIR_SPIRV_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) set(MLIR_VULKAN_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) configure_lit_site_cfg( @@ -80,6 +81,12 @@ ) endif() +if(MLIR_SPIRV_RUNNER_ENABLED) + list(APPEND MLIR_TEST_DEPENDS + mlir-spirv-runner + ) +endif() + if(MLIR_VULKAN_RUNNER_ENABLED) list(APPEND MLIR_TEST_DEPENDS mlir-vulkan-runner diff --git a/mlir/test/Conversion/GPUCommon/emulate-kernel-call.mlir b/mlir/test/Conversion/GPUCommon/emulate-kernel-call.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/emulate-kernel-call.mlir @@ -0,0 +1,79 @@ +// RUN: mlir-opt --emulate-kernel-llvm-call %s | FileCheck %s + +// CHECK: module +// CHECK: llvm.mlir.global linkonce @__var__0() : !llvm.struct)> +// CHECK-LABEL: @__llvm_launch +// CHECK: module +// CHECK-LABEL: @__llvm_launch + +// CHECK-LABEL: @main +// CHECK: %[[ELEMENT_SIZE:.*]] = llvm.mlir.constant(4 : i64) : !llvm.i64 +// CHECK: %[[SIZE:.*]] = llvm.mul %{{.*}}, %[[ELEMENT_SIZE]] : !llvm.i64 +// CHECK: %[[DST:.*]] = llvm.mlir.addressof @__var__0 : !llvm.ptr)>> +// CHECK: llvm.mlir.constant(false) : !llvm.i1 +// CHECK: "llvm.intr.memcpy"(%[[DST]], %{{.*}}, %[[SIZE]], %{{.*}}) : (!llvm.ptr)>>, !llvm.ptr, !llvm.i64, !llvm.i1) -> () +// CHECK: llvm.call @__llvm_launch() : () -> () +// CHECK: llvm.mlir.constant(false) : !llvm.i1 +// CHECK: "llvm.intr.memcpy"(%{{.*}}, %[[DST]], %[[SIZE]], %{{.*}}) : (!llvm.ptr, !llvm.ptr)>>, !llvm.i64, !llvm.i1) -> () + +module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} { + llvm.func @__std_launch(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.i64, %arg3: !llvm.i64, %arg4: !llvm.i64) { + llvm.return + } + llvm.func @_mlir_ciface___std_launch(!llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>) + module { + llvm.mlir.global external @__var__0() : !llvm.struct)> + llvm.func @simple() { + llvm.return + } + } + llvm.func @main() { + %0 = llvm.mlir.constant(6 : index) : !llvm.i64 + %1 = llvm.mlir.null : !llvm.ptr + %2 = llvm.mlir.constant(1 : index) : !llvm.i64 + %3 = llvm.getelementptr %1[%2] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + %4 = llvm.ptrtoint %3 : !llvm.ptr to !llvm.i64 + %5 = llvm.mul %0, %4 : !llvm.i64 + %6 = llvm.mlir.constant(1 : index) : !llvm.i64 + %7 = llvm.call @malloc(%5) : (!llvm.i64) -> !llvm.ptr + %8 = llvm.bitcast %7 : !llvm.ptr to !llvm.ptr + %9 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %10 = llvm.insertvalue %8, %9[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %11 = llvm.insertvalue %8, %10[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %12 = llvm.mlir.constant(0 : index) : !llvm.i64 + %13 = llvm.insertvalue %12, %11[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %14 = llvm.mlir.constant(1 : index) : !llvm.i64 + %15 = llvm.insertvalue %0, %13[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %16 = llvm.insertvalue %14, %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %17 = llvm.mlir.constant(4 : i32) : !llvm.i32 + %18 = llvm.bitcast %16 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %19 = llvm.extractvalue %18[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %20 = llvm.extractvalue %18[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %21 = llvm.extractvalue %18[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %22 = llvm.extractvalue %18[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %23 = llvm.extractvalue %18[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + llvm.call @fillI32Buffer(%19, %20, %21, %22, %23, %17) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i32) -> () + %24 = llvm.mlir.constant(1 : index) : !llvm.i64 + %25 = llvm.extractvalue %16[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %26 = llvm.extractvalue %16[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %27 = llvm.extractvalue %16[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %28 = llvm.extractvalue %16[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %29 = llvm.extractvalue %16[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + llvm.call @__std_launch(%25, %26, %27, %28, %29) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i64, !llvm.i64) -> () + %30 = llvm.mlir.constant(1 : index) : !llvm.i64 + %31 = llvm.alloca %30 x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> : (!llvm.i64) -> !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> + llvm.store %16, %31 : !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> + %32 = llvm.bitcast %31 : !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> to !llvm.ptr + %33 = llvm.mlir.constant(1 : i64) : !llvm.i64 + %34 = llvm.mlir.undef : !llvm.struct<(i64, ptr)> + %35 = llvm.insertvalue %33, %34[0] : !llvm.struct<(i64, ptr)> + %36 = llvm.insertvalue %32, %35[1] : !llvm.struct<(i64, ptr)> + %37 = llvm.extractvalue %36[0] : !llvm.struct<(i64, ptr)> + %38 = llvm.extractvalue %36[1] : !llvm.struct<(i64, ptr)> + llvm.call @print_memref_i32(%37, %38) : (!llvm.i64, !llvm.ptr) -> () + llvm.return + } + + llvm.func @fillI32Buffer(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.i64, %arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i32) + llvm.func @print_memref_i32(%arg0: !llvm.i64, %arg1: !llvm.ptr) +} diff --git a/mlir/test/Conversion/GPUCommon/gpu-launch-to-std-call.mlir b/mlir/test/Conversion/GPUCommon/gpu-launch-to-std-call.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/gpu-launch-to-std-call.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt --gpu-launch-to-std-call %s | FileCheck %s + +module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} { + + // CHECK: func @__std_launch(memref<6xi32>, memref<6xi32>) + + // CHECK-LABEL: @main + // CHECK: %[[BUFFER1:.*]] = alloc() : memref<6xi32> + // CHECK: %[[BUFFER2:.*]] = alloc() : memref<6xi32> + // CHECK: call @__std_launch(%[[BUFFER1]], %[[BUFFER2]]) : (memref<6xi32>, memref<6xi32>) -> () + + spv.module Logical GLSL450 requires #spv.vce { + spv.globalVariable @__var__0 : !spv.ptr [0]>, StorageBuffer> + spv.globalVariable @__var__1 : !spv.ptr [0]>, StorageBuffer> + spv.func @simple() "None" attributes {workgroup_attributions = 0 : i64} { + %0 = spv._address_of @__var__1 : !spv.ptr [0]>, StorageBuffer> + %1 = spv._address_of @__var__0 : !spv.ptr [0]>, StorageBuffer> + spv.Return + } + spv.EntryPoint "GLCompute" @simple + spv.ExecutionMode @simple "LocalSize", 1, 1, 1 + } + + gpu.module @kernels { + gpu.func @simple(%arg0: memref<6xi32>, %arg1: memref<6xi32>) kernel attributes {spv.entry_point_abi = {local_size = dense<1> : vector<3xi32>}} { + %c5_i32 = constant 5 : i32 + %c0 = constant 0 : index + %0 = load %arg0[%c0] : memref<6xi32> + %1 = addi %0, %c5_i32 : i32 + store %1, %arg1[%c0] : memref<6xi32> + gpu.return + } + } + + func @main() { + %0 = alloc() : memref<6xi32> + %1 = alloc() : memref<6xi32> + %c4_i32 = constant 4 : i32 + %c3_i32 = constant 3 : i32 + %2 = memref_cast %0 : memref<6xi32> to memref + %3 = memref_cast %1 : memref<6xi32> to memref + call @fillI32Buffer(%2, %c3_i32) : (memref, i32) -> () + call @fillI32Buffer(%3, %c4_i32) : (memref, i32) -> () + %c1 = constant 1 : index + "gpu.launch_func"(%c1, %c1, %c1, %c1, %c1, %c1, %0, %1) {kernel = @kernels::@simple} : (index, index, index, index, index, index, memref<6xi32>, memref<6xi32>) -> () + %4 = memref_cast %1 : memref<6xi32> to memref<*xi32> + call @print_memref_i32(%4) : (memref<*xi32>) -> () + return + } + func @fillI32Buffer(memref, i32) + func @print_memref_i32(memref<*xi32>) +} \ No newline at end of file diff --git a/mlir/test/Dialect/SPIRV/Transforms/descriptor-sets-encoding.mlir b/mlir/test/Dialect/SPIRV/Transforms/descriptor-sets-encoding.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Transforms/descriptor-sets-encoding.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt -spirv-encode-descriptor-sets -verify-diagnostics %s | FileCheck %s + +spv.module Logical GLSL450 { + + // CHECK: spv.module + // CHECK: spv.globalVariable [[VAR0:@.*__var__0]] : !spv.ptr + // CHECK: spv.globalVariable [[VAR1:@.*__var__1]] : !spv.ptr + // CHECK: spv.globalVariable [[VAR2:@.*]] : !spv.ptr + // CHECK: spv.func @func0 + // CHECK: spv._address_of [[VAR0]] + // CHECK: spv._address_of [[VAR1]] + // CHECK: spv.func @func1 + // CHECK: spv._address_of [[VAR1]] + // CHECK: spv._address_of [[VAR2]] + + spv.globalVariable @var0 bind(0,0) : !spv.ptr + spv.globalVariable @var1 bind(0,1) : !spv.ptr + spv.globalVariable @var2 : !spv.ptr + spv.func @func0() -> () "None" { + %ptr0 = spv._address_of @var0 : !spv.ptr + %ptr1 = spv._address_of @var1 : !spv.ptr + spv.Return + } + spv.func @func1() -> () "None" { + %ptr1 = spv._address_of @var1 : !spv.ptr + %ptr2 = spv._address_of @var2 : !spv.ptr + spv.Return + } +} \ No newline at end of file diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -74,6 +74,7 @@ ToolSubst('%linalg_test_lib_dir', config.linalg_test_lib_dir, unresolved='ignore'), ToolSubst('%mlir_runner_utils_dir', config.mlir_runner_utils_dir, unresolved='ignore'), ToolSubst('%rocm_wrapper_library_dir', config.rocm_wrapper_library_dir, unresolved='ignore'), + ToolSubst('%spirv_wrapper_library_dir', config.spirv_wrapper_library_dir, unresolved='ignore'), ToolSubst('%vulkan_wrapper_library_dir', config.vulkan_wrapper_library_dir, unresolved='ignore'), ]) diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -41,6 +41,8 @@ config.run_rocm_tests = @MLIR_ROCM_CONVERSIONS_ENABLED@ config.rocm_wrapper_library_dir = "@MLIR_ROCM_WRAPPER_LIBRARY_DIR@" config.enable_rocm_runner = @MLIR_ROCM_RUNNER_ENABLED@ +config.spirv_wrapper_library_dir = "@MLIR_SPIRV_WRAPPER_LIBRARY_DIR@" +config.enable_spirv_runner = @MLIR_SPIRV_RUNNER_ENABLED@ config.vulkan_wrapper_library_dir = "@MLIR_VULKAN_WRAPPER_LIBRARY_DIR@" config.enable_vulkan_runner = @MLIR_VULKAN_RUNNER_ENABLED@ config.enable_bindings_python = @MLIR_BINDINGS_PYTHON_ENABLED@ diff --git a/mlir/test/mlir-spirv-runner/lit.local.cfg b/mlir/test/mlir-spirv-runner/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-spirv-runner/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_spirv_runner: + config.unsupported = True \ No newline at end of file diff --git a/mlir/test/mlir-spirv-runner/simple.mlir b/mlir/test/mlir-spirv-runner/simple.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-spirv-runner/simple.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-spirv-runner %s -e main --entry-point-result=void --shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%spirv_wrapper_library_dir/libruntime-wrappers%shlibext + +// CHECK: [8, 8, 8, 8, 8, 8] +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + gpu.module @kernels { + gpu.func @double(%arg0 : memref<6xi32>, %arg1 : memref<6xi32>) + kernel attributes { spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} { + %factor = constant 2 : i32 + + %i0 = constant 0 : index + %i1 = constant 1 : index + %i2 = constant 2 : index + %i3 = constant 3 : index + %i4 = constant 4 : index + %i5 = constant 5 : index + + %x0 = load %arg0[%i0] : memref<6xi32> + %x1 = load %arg0[%i1] : memref<6xi32> + %x2 = load %arg0[%i2] : memref<6xi32> + %x3 = load %arg0[%i3] : memref<6xi32> + %x4 = load %arg0[%i4] : memref<6xi32> + %x5 = load %arg0[%i5] : memref<6xi32> + + %y0 = muli %x0, %factor : i32 + %y1 = muli %x1, %factor : i32 + %y2 = muli %x2, %factor : i32 + %y3 = muli %x3, %factor : i32 + %y4 = muli %x4, %factor : i32 + %y5 = muli %x5, %factor : i32 + + store %y0, %arg1[%i0] : memref<6xi32> + store %y1, %arg1[%i1] : memref<6xi32> + store %y2, %arg1[%i2] : memref<6xi32> + store %y3, %arg1[%i3] : memref<6xi32> + store %y4, %arg1[%i4] : memref<6xi32> + store %y5, %arg1[%i5] : memref<6xi32> + gpu.return + } + } + func @main() { + %input = alloc() : memref<6xi32> + %output = alloc() : memref<6xi32> + %four = constant 4 : i32 + %zero = constant 0 : i32 + %input_casted = memref_cast %input : memref<6xi32> to memref + %output_casted = memref_cast %output : memref<6xi32> to memref + call @fillI32Buffer(%input_casted, %four) : (memref, i32) -> () + call @fillI32Buffer(%output_casted, %zero) : (memref, i32) -> () + + %one = constant 1 : index + "gpu.launch_func"(%one, %one, %one, + %one, %one, %one, + %input, %output) { kernel = @kernels::@double } + : (index, index, index, index, index, index, memref<6xi32>, memref<6xi32>) -> () + %result = memref_cast %output : memref<6xi32> to memref<*xi32> + call @print_memref_i32(%result) : (memref<*xi32>) -> () + return + } + + func @fillI32Buffer(%arg0 : memref, %arg1 : i32) + func @print_memref_i32(%ptr : memref<*xi32>) +} diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt --- a/mlir/tools/CMakeLists.txt +++ b/mlir/tools/CMakeLists.txt @@ -5,5 +5,6 @@ add_subdirectory(mlir-reduce) add_subdirectory(mlir-rocm-runner) add_subdirectory(mlir-shlib) +add_subdirectory(mlir-spirv-runner) add_subdirectory(mlir-translate) add_subdirectory(mlir-vulkan-runner) \ No newline at end of file diff --git a/mlir/tools/mlir-spirv-runner/CMakeLists.txt b/mlir/tools/mlir-spirv-runner/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-spirv-runner/CMakeLists.txt @@ -0,0 +1,42 @@ +set(LLVM_OPTIONAL_SOURCES + mlir-spirv-runner.cpp + runtime-wrappers.cpp + ) + +if (MLIR_SPIRV_RUNNER_ENABLED) + message(STATUS "Building SPIR-V runner") + + add_llvm_library(runtime-wrappers SHARED + runtime-wrappers.cpp + ) + + target_include_directories(runtime-wrappers + PUBLIC + ) + + add_llvm_tool(mlir-spirv-runner + mlir-spirv-runner.cpp + ) + + add_dependencies(mlir-spirv-runner runtime-wrappers) + llvm_update_compile_flags(mlir-spirv-runner) + + get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + + target_link_libraries(mlir-spirv-runner PRIVATE + ${conversion_libs} + ${dialect_libs} + MLIRAnalysis + MLIREDSC + MLIRExecutionEngine + MLIRIR + MLIRJitRunner + MLIRLLVMIR + MLIRParser + MLIRTargetLLVMIR + MLIRTransforms + MLIRTranslation + MLIRSupport + ) +endif() diff --git a/mlir/tools/mlir-spirv-runner/mlir-spirv-runner.cpp b/mlir/tools/mlir-spirv-runner/mlir-spirv-runner.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-spirv-runner/mlir-spirv-runner.cpp @@ -0,0 +1,97 @@ +//===- mlir-spirv-runner.cpp - MLIR SPIR-V Execution on CPU ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Main entry point to a command line utility that executes an MLIR file on the +// CPU by translating MLIR GPU module and host part to LLVM IR before +// JIT-compiling and executing. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h" +#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Dialect/SPIRV/Passes.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/ExecutionEngine/JitRunner.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Linker/Linker.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/TargetSelect.h" + +using namespace mlir; + +/// A utility function that builds llvm::Module from two nested MLIR modules. +/// +/// module @main { +/// module @kernel { +/// // Some ops +/// } +/// // Some other ops +/// } +/// +/// Each of these two modules is translated to LLVM IR module, then they are +/// linked together and returned. +static std::unique_ptr +convertMLIRModule(ModuleOp module, llvm::LLVMContext &context) { + std::unique_ptr kernelModule; + module.walk([&](ModuleOp nested) -> WalkResult { + if (nested.getParentOp()) { + // Found kernel module. Translate it to LLVM IR, remove from the main + // module and terminate. + kernelModule = translateModuleToLLVMIR(nested, context); + nested.erase(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + std::unique_ptr mainModule = + translateModuleToLLVMIR(module, context); + llvm::Linker::linkModules(*mainModule, std::move(kernelModule)); + return mainModule; +} + +static LogicalResult runMLIRPasses(ModuleOp module) { + PassManager passManager(module.getContext()); + applyPassManagerCLOptions(passManager); + passManager.addPass(createGpuKernelOutliningPass()); + passManager.addPass(createConvertGPUToSPIRVPass()); + + OpPassManager &nestedPM = passManager.nest(); + nestedPM.addPass(spirv::createLowerABIAttributesPass()); + nestedPM.addPass(spirv::createUpdateVersionCapabilityExtensionPass()); + nestedPM.addPass(spirv::createEncodeDescriptorSetsPass()); + + passManager.addPass(createGPULaunchFuncToStandardCallPass()); + LowerToLLVMOptions options = { + /*useBarePtrCallConv=*/false, + /*emitCWrappers=*/true, + /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout}; + passManager.addPass(createLowerToLLVMPass(options)); + passManager.addPass(createConvertSPIRVToLLVMPass()); + passManager.addPass(createEmulateKernelCallInLLVMPass()); + return passManager.run(module); +} + +int main(int argc, char **argv) { + mlir::registerAllDialects(); + llvm::InitLLVM y(argc, argv); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::initializeLLVMPasses(); + + return mlir::JitRunnerMain(argc, argv, &runMLIRPasses, &convertMLIRModule); +} \ No newline at end of file diff --git a/mlir/tools/mlir-spirv-runner/runtime-wrappers.cpp b/mlir/tools/mlir-spirv-runner/runtime-wrappers.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-spirv-runner/runtime-wrappers.cpp @@ -0,0 +1,30 @@ +//===- runtime-wrappers.cpp - MLIR SPIR-V runner wrapper library ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A small library for SPIR-V runner. +// +//===----------------------------------------------------------------------===// + +#include "mlir/ExecutionEngine/RunnerUtils.h" + +// A struct corresponding to a MemRef argument after C-compatible wrapper +// emission. +template +struct MemRefDescriptor { + T *allocated; + T *aligned; + intptr_t offset; + intptr_t sizes[N]; + intptr_t strides[N]; +}; + +extern "C" void +_mlir_ciface_fillI32Buffer(MemRefDescriptor *mem_ref, + int32_t value) { + std::fill_n(mem_ref->allocated, mem_ref->sizes[0], value); +}