diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h --- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h +++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h @@ -38,6 +38,16 @@ using LoweringCallback = std::function( Operation *, llvm::LLVMContext &, StringRef)>; +/// Creates a pass to emulate `gpu.launch_func` call in LLVM dialect and lower +/// the host module code to LLVM. +/// +/// This transformation creates a sequence of global variables that are later +/// linked to the varables in the kernel module, and a series of copies to/from +/// them to emulate the memory transfer from the host or to the device sides. It +/// also converts the remaining Standard dialect into LLVM dialect, emitting C +/// wrappers. +std::unique_ptr> createLowerHostCodeToLLVMPass(); + /// Creates a pass to convert a gpu.launch_func operation into a sequence of /// GPU runtime calls. /// diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -97,6 +97,12 @@ ]; } +def LowerHostCodeToLLVM : Pass<"lower-host-to-llvm", "ModuleOp"> { + let summary = "Lowers the host module code and `gpu.launch_func` to LLVM"; + let constructor = "mlir::createLowerHostCodeToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + //===----------------------------------------------------------------------===// // GPUToNVVM //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt --- a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt @@ -15,6 +15,7 @@ endif() add_mlir_conversion_library(MLIRGPUToGPURuntimeTransforms + ConvertLaunchFuncToLLVMCalls.cpp ConvertLaunchFuncToRuntimeCalls.cpp ConvertKernelFuncToBlob.cpp diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToLLVMCalls.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToLLVMCalls.cpp @@ -0,0 +1,307 @@ +//===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements passes to convert `gpu.launch_func` op into a sequence +// of LLVM calls that emulate the host and device sides. +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; + +static constexpr const char kSPIRVModule[] = "__spv__"; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// Returns the string name of the `DescriptorSet` decoration. +static std::string descriptorSetName() { + return llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::DescriptorSet)); +} + +/// Returns the string name of the `Binding` decoration. +static std::string bindingName() { + return llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::Binding)); +} + +/// Calculates the index of the kernel's operand that is represented by the +/// given global variable with the `bind` attribute. We assume that the index of +/// each kernel's operand is mapped to (descriptorSet, binding) by the map: +/// i -> (0, i) +/// which is implemented under `LowerABIAttributesPass`. +static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) { + IntegerAttr binding = op.getAttrOfType(bindingName()); + return binding.getInt(); +} + +/// Copies the given number of bytes from src to dst pointers. +static void copy(Location loc, Value dst, Value src, Value size, + OpBuilder &builder) { + MLIRContext *context = builder.getContext(); + auto llvmI1Type = LLVM::LLVMType::getInt1Ty(context); + Value isVolatile = builder.create( + loc, llvmI1Type, builder.getBoolAttr(false)); + builder.create(loc, dst, src, size, isVolatile); +} + +/// Encodes the binding and descriptor set numbers into a new symbolic name. +/// The name is specified by +/// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b} +/// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and +/// binding numbers. +static std::string +createGlobalVariableWithBindName(spirv::GlobalVariableOp op, + StringRef kernelModuleName) { + IntegerAttr descriptorSet = + op.getAttrOfType(descriptorSetName()); + IntegerAttr binding = op.getAttrOfType(bindingName()); + return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}", + kernelModuleName.str(), op.sym_name().str(), + std::to_string(descriptorSet.getInt()), + std::to_string(binding.getInt())); +} + +/// Returns true if the given global variable has both a descriptor set number +/// and a binding number. +static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) { + IntegerAttr descriptorSet = + op.getAttrOfType(descriptorSetName()); + IntegerAttr binding = op.getAttrOfType(bindingName()); + return descriptorSet && binding; +} + +/// Fills `globalVariableMap` with SPIR-V global variables that represent kernel +/// arguments from the given SPIR-V module. We assume that the module contains a +/// single entry point function. Hence, all `spv.globalVariable`s with a bind +/// attribute are kernel arguments. +static LogicalResult getKernelGlobalVariables( + spirv::ModuleOp module, + DenseMap &globalVariableMap) { + auto entryPoints = module.getOps(); + if (!llvm::hasSingleElement(entryPoints)) { + return module.emitError( + "The module must contain exactly one entry point function"); + } + auto globalVariables = module.getOps(); + for (auto globalOp : globalVariables) { + if (hasDescriptorSetAndBinding(globalOp)) + globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp; + } + return success(); +} + +/// Encodes the SPIR-V module's symbolic name into the name of the entry point +/// function. +static LogicalResult encodeKernelName(spirv::ModuleOp module) { + StringRef spvModuleName = module.sym_name().getValue(); + // We already know that the module contains exactly one entry point function + // based on `getKernelGlobalVariables()` call. Update this function's name + // to: + // {spv_module_name}_{function_name} + auto entryPoint = *module.getOps().begin(); + StringRef funcName = entryPoint.fn(); + auto funcOp = module.lookupSymbol(funcName); + std::string newFuncName = spvModuleName.str() + "_" + funcName.str(); + if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module))) + return failure(); + SymbolTable::setSymbolName(funcOp, newFuncName); + return success(); +} + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { + +/// Structure to group information about the variables being copied. +struct CopyInfo { + Value dst; + Value src; + Value size; +}; + +/// This pattern emulates a call to the kernel in LLVM dialect. For that, we +/// copy the data to the global variable (emulating device side), call the +/// kernel as a normal void LLVM function, and copy the data back (emulating the +/// host side). +class GPULaunchLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + gpu::LaunchFuncOp launchOp = cast(op); + MLIRContext *context = rewriter.getContext(); + auto module = launchOp.getParentOfType(); + + // Get the SPIR-V module that represents the gpu kernel module. The module + // is named: + // __spv__{kernel_module_name} + // based on GPU to SPIR-V conversion. + StringRef kernelModuleName = launchOp.getKernelModuleName(); + std::string spvModuleName = kSPIRVModule + kernelModuleName.str(); + auto spvModule = module.lookupSymbol(spvModuleName); + if (!spvModule) { + return launchOp.emitOpError("SPIR-V kernel module '") + << spvModuleName << "' is not found"; + } + + // Declare kernel function in the main module so that it later can be linked + // with its definition from the kernel module. We know that the kernel + // function would have no arguments and the data is passed via global + // variables. The name of the kernel will be + // {spv_module_name}_{kernel_function_name} + // to avoid symbolic name conflicts. + StringRef kernelFuncName = launchOp.getKernelName(); + std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str(); + auto kernelFunc = module.lookupSymbol(newKernelFuncName); + if (!kernelFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + kernelFunc = rewriter.create( + rewriter.getUnknownLoc(), newKernelFuncName, + LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(context), + ArrayRef(), + /*isVarArg=*/false)); + rewriter.setInsertionPoint(launchOp); + } + + // Get all global variables associated with the kernel operands. + DenseMap globalVariableMap; + if (failed(getKernelGlobalVariables(spvModule, globalVariableMap))) + return failure(); + + // Traverse kernel operands that were converted to MemRefDescriptors. For + // each operand, create a global variable and copy data from operand to it. + Location loc = launchOp.getLoc(); + SmallVector copyInfo; + auto numKernelOperands = launchOp.getNumKernelOperands(); + auto kernelOperands = operands.take_back(numKernelOperands); + for (auto operand : llvm::enumerate(kernelOperands)) { + // Check if the kernel's opernad is a ranked memref. + auto memRefType = launchOp.getKernelOperand(operand.index()) + .getType() + .dyn_cast(); + if (!memRefType) + return failure(); + + // Calculate the size of the memref and get the pointer to the allocated + // buffer. + SmallVector sizes; + getMemRefDescriptorSizes(loc, memRefType, operand.value(), rewriter, + sizes); + Value size = getCumulativeSizeInBytes(loc, memRefType.getElementType(), + sizes, rewriter); + MemRefDescriptor descriptor(operand.value()); + Value src = descriptor.allocatedPtr(rewriter, loc); + + // Get the global variable in the SPIR-V module that is associated with + // the kernel operand. Construct its new name and create a corresponding + // LLVM dialect global variable. + spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; + auto pointeeType = + spirvGlobal.type().cast().getPointeeType(); + auto dstGlobalType = typeConverter.convertType(pointeeType); + if (!dstGlobalType) + return failure(); + StringRef name = + createGlobalVariableWithBindName(spirvGlobal, spvModuleName); + // Check if this variable has already been created. + auto dstGlobal = module.lookupSymbol(name); + if (!dstGlobal) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + dstGlobal = rewriter.create( + loc, dstGlobalType.cast(), + /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute()); + rewriter.setInsertionPoint(launchOp); + } + + // Copy the data from src operand pointer to dst global variable. Save + // src, dst and size so that we can copy data back after emulating the + // kernel call. + Value dst = rewriter.create(loc, dstGlobal); + copy(loc, dst, src, size, rewriter); + + CopyInfo info; + info.dst = dst; + info.src = src; + info.size = size; + copyInfo.push_back(info); + } + // Create a call to the kernel and copy the data back. + rewriter.replaceOpWithNewOp(op, kernelFunc, + ArrayRef()); + for (CopyInfo info : copyInfo) + copy(loc, info.src, info.dst, info.size, rewriter); + return success(); + } +}; + +class LowerHostCodeToLLVM + : public LowerHostCodeToLLVMBase { +public: + void runOnOperation() override { + ModuleOp module = getOperation(); + + // Erase the GPU module. + for (auto gpuModule : + llvm::make_early_inc_range(module.getOps())) + gpuModule.erase(); + + // Specify options to lower Standard to LLVM and pull in the conversion + // patterns. + LowerToLLVMOptions options = { + /*useBarePtrCallConv=*/false, + /*emitCWrappers=*/true, + /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout}; + auto *context = module.getContext(); + OwningRewritePatternList patterns; + LLVMTypeConverter typeConverter(context, options); + populateStdToLLVMConversionPatterns(typeConverter, patterns); + patterns.insert(typeConverter); + + // Pull in SPIR-V type conversion patterns to convert SPIR-V global + // variable's type to LLVM dialect type. + populateSPIRVToLLVMTypeConversion(typeConverter); + + ConversionTarget target(*context); + target.addLegalDialect(); + if (failed(applyPartialConversion(module, target, patterns))) + signalPassFailure(); + + // Finally, modify the kernel function in SPIR-V modules to avoid symbolic + // conflicts. + for (auto spvModule : module.getOps()) + encodeKernelName(spvModule); + } +}; +} // namespace + +std::unique_ptr> +mlir::createLowerHostCodeToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/GPUCommon/lower-host-to-llvm-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-host-to-llvm-calls.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/lower-host-to-llvm-calls.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt --lower-host-to-llvm %s | FileCheck %s + +module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} { + + // CHECK: llvm.mlir.global linkonce @__spv__foo_bar_arg_0_descriptor_set0_binding0() : !llvm.struct<(array<6 x i32>)> + // CHECK: llvm.func @__spv__foo_bar() + + // CHECK: spv.module @__spv__foo + // CHECK: spv.globalVariable @bar_arg_0 bind(0, 0) : !spv.ptr [0]>, StorageBuffer> + // CHECK: spv.func @__spv__foo_bar + + // CHECK: spv.EntryPoint "GLCompute" @__spv__foo_bar + // CHECK: spv.ExecutionMode @__spv__foo_bar "LocalSize", 1, 1, 1 + + // CHECK-LABEL: @main + // CHECK: %[[SRC:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %[[DEST:.*]] = llvm.mlir.addressof @__spv__foo_bar_arg_0_descriptor_set0_binding0 : !llvm.ptr)>> + // CHECK-NEXT: llvm.mlir.constant(false) : !llvm.i1 + // CHECK-NEXT: "llvm.intr.memcpy"(%[[DEST]], %[[SRC]], %[[SIZE:.*]], %{{.*}}) : (!llvm.ptr)>>, !llvm.ptr, !llvm.i64, !llvm.i1) -> () + // CHECK-NEXT: llvm.call @__spv__foo_bar() : () -> () + // CHECK-NEXT: llvm.mlir.constant(false) : !llvm.i1 + // CHECK-NEXT: "llvm.intr.memcpy"(%[[SRC]], %[[DEST]], %[[SIZE]], %{{.*}}) : (!llvm.ptr, !llvm.ptr)>>, !llvm.i64, !llvm.i1) -> () + + spv.module @__spv__foo Logical GLSL450 requires #spv.vce { + spv.globalVariable @bar_arg_0 bind(0, 0) : !spv.ptr [0]>, StorageBuffer> + spv.func @bar() "None" attributes {workgroup_attributions = 0 : i64} { + %0 = spv._address_of @bar_arg_0 : !spv.ptr [0]>, StorageBuffer> + spv.Return + } + spv.EntryPoint "GLCompute" @bar + spv.ExecutionMode @bar "LocalSize", 1, 1, 1 + } + + gpu.module @foo { + gpu.func @bar(%arg0: memref<6xi32>) kernel attributes {spv.entry_point_abi = {local_size = dense<1> : vector<3xi32>}} { + gpu.return + } + } + + func @main() { + %buffer = alloc() : memref<6xi32> + %one = constant 1 : index + "gpu.launch_func"(%one, %one, %one, + %one, %one, %one, + %buffer) {kernel = @foo::@bar} : (index, index, index, index, index, index, memref<6xi32>) -> () + return + } +}