diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h --- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h +++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h @@ -38,6 +38,16 @@ using LoweringCallback = std::function( Operation *, llvm::LLVMContext &, StringRef)>; +/// Creates a pass to emulate `gpu.launch_func` call in LLVM dialect and lower +/// the host module code to LLVM. +/// +/// This transformation creates a sequence of global variables that are later +/// linked to the varables in the kernel module, and a series of copies to/from +/// them to emulate the memory transfer from the host or to the device sides. It +/// also converts the remaining Standard dialect into LLVM dialect, emitting C +/// wrappers. +std::unique_ptr> createLowerHostCodeToLLVMPass(); + /// Creates a pass to convert a gpu.launch_func operation into a sequence of /// GPU runtime calls. /// diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -97,6 +97,12 @@ ]; } +def LowerHostCodeToLLVM : Pass<"lower-host-to-llvm", "ModuleOp"> { + let summary = "Lowers the host module code to LLVM dialect and converts " + "`gpu.launch_func` op into `llvm.call` op"; + let constructor = "mlir::createLowerHostCodeToLLVMPass()"; +} + //===----------------------------------------------------------------------===// // GPUToNVVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h --- a/mlir/include/mlir/Dialect/SPIRV/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h @@ -26,6 +26,11 @@ std::unique_ptr> createDecorateSPIRVCompositeTypeLayoutPass(); +/// Creates a module pass that encodes bind attribute of each +/// `spv.globalVariable` into its symbolic name. +std::unique_ptr> +createEncodeDescriptorSetsPass(); + /// Creates an operation pass that deduces and attaches the minimal version/ /// capabilities/extensions requirements for spv.module ops. /// For each spv.module op, this pass requires a `spv.target_env` attribute on diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Passes.td --- a/mlir/include/mlir/Dialect/SPIRV/Passes.td +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.td @@ -17,6 +17,13 @@ let constructor = "mlir::spirv::createDecorateSPIRVCompositeTypeLayoutPass()"; } +def SPIRVEncodeDescriptorSets + : Pass<"spirv-encode-descriptor-sets", "spirv::ModuleOp"> { + let summary = "Encode `spv.globalVariable`'s' bind attribute into its " + "symbolic name"; + let constructor = "mlir::spirv::createEncodeDescriptorSetsPass()"; +} + def SPIRVLowerABIAttributes : Pass<"spirv-lower-abi-attrs", "spirv::ModuleOp"> { let summary = "Decorate SPIR-V composite type with layout info"; let constructor = "mlir::spirv::createLowerABIAttributesPass()"; diff --git a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt --- a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt @@ -15,6 +15,7 @@ endif() add_mlir_conversion_library(MLIRGPUToGPURuntimeTransforms + ConvertLaunchFuncToLLVMCalls.cpp ConvertLaunchFuncToRuntimeCalls.cpp ConvertKernelFuncToBlob.cpp diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToLLVMCalls.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToLLVMCalls.cpp @@ -0,0 +1,218 @@ +//===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements passes to convert `gpu.launch_func` op into a sequence +// of LLVM calls that emulate the host and device sides. +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// Copies the given number of bytes from src to dst pointers. +static void copy(Location loc, Value dst, Value src, Value size, + OpBuilder &builder) { + MLIRContext *context = builder.getContext(); + auto llvmI1Type = LLVM::LLVMType::getInt1Ty(context); + Value isVolatile = builder.create( + loc, llvmI1Type, builder.getBoolAttr(false)); + builder.create(loc, dst, src, size, isVolatile); +} + +/// Verifies if the module contains exactly one nested module with exactly one +/// (kernel) function. +static bool hasOneNestedModuleAndOneKernel(ModuleOp module) { + bool hasOneNestedModule = false; + auto walkResult = + module.walk([&hasOneNestedModule](ModuleOp moduleOp) -> WalkResult { + // Interrupt if more than one nested module has been found. + if (moduleOp.getParentOp() && hasOneNestedModule) + return WalkResult::interrupt(); + + // If there is a parent operation, it means we walked to a nested + // module. Verify there is only a single function in it. + if (moduleOp.getParentOp()) { + auto funcs = moduleOp.getOps(); + SmallVector funcVector(funcs.begin(), + funcs.end()); + if (funcVector.size() != 1) + return WalkResult::interrupt(); + hasOneNestedModule = true; + } + // Otherwise, advance. + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) + return false; + return hasOneNestedModule; +} + +/// Walk the main module and look for the kernel module. This function is likely +/// to be changed. +static LogicalResult findKernelModule(ModuleOp mainModule, + ModuleOp &kernelModule) { + auto walkResult = mainModule.walk([&](ModuleOp nested) -> WalkResult { + if (nested.getParentOp()) { + kernelModule = nested; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (!walkResult.wasInterrupted()) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { + +/// This pattern emulates a call to the kernel in LLVM dialect. For that, we +/// copy the data to the global variable (emulating device side), call the +/// kernel as a normal void LLVM function, and copy the data back (emulating the +/// host side). +class GPULaunchLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + gpu::LaunchFuncOp launchOp = cast(op); + MLIRContext *context = rewriter.getContext(); + auto module = launchOp.getParentOfType(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + // Declare kernel function in the main module so that it later can be linked + // with its definition from the kernel module. We know that the kernel + // function would have no arguments and the data is passed via global + // variables. + auto kernelFunc = rewriter.create( + rewriter.getUnknownLoc(), launchOp.getKernelName(), + LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(context), + ArrayRef(), + /*isVarArg=*/false)); + rewriter.setInsertionPoint(launchOp); + Location loc = launchOp.getLoc(); + SmallVector, 4> copyInfo; + + // Find the kernel module. + // TODO: support this properly. + ModuleOp kernelModule; + if (failed(findKernelModule(module, kernelModule))) + return failure(); + + // Traverse kernel operands that were converted to MemRefDescriptors. For + // each operand, create a global variable and copy data from operand to it. + auto numKernelOperands = launchOp.getNumKernelOperands(); + auto kernelOperands = operands.take_back(numKernelOperands); + for (auto operand : llvm::enumerate(kernelOperands)) { + // Check if the kernel's opernad is a ranked memref. Calculate its size + // and get the pointer to the allocated buffer. + auto memRefType = launchOp.getKernelOperand(operand.index()) + .getType() + .dyn_cast(); + if (!memRefType) + return failure(); + + SmallVector sizes; + getMemRefDescriptorSizes(loc, memRefType, operand.value(), rewriter, + sizes); + Value size = getCumulativeSizeInBytes(loc, memRefType.getElementType(), + sizes, rewriter); + MemRefDescriptor descriptor(operand.value()); + Value src = descriptor.allocatedPtr(rewriter, loc); + + // Construct the global variable's name that adheres to the convention of + // "__var__{i}" where i is the index of the kernel argument. Then look up + // the global in the kernel module and construct another one of the same + // type in the host module. + rewriter.setInsertionPointToStart(module.getBody()); + std::string name = "__var__" + std::to_string(operand.index()); + auto kernelGlobal = kernelModule.lookupSymbol(name); + auto hostGlobal = rewriter.create( + loc, kernelGlobal.type().cast(), + /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute()); + rewriter.setInsertionPoint(launchOp); + + // Copy the data from src operand pointer to dst global variable. Save + // src, dst and size triple so that we can copy data back after emulating + // the kernel call. + Value dst = rewriter.create(loc, hostGlobal); + copy(loc, dst, src, size, rewriter); + copyInfo.push_back(std::make_tuple(src, dst, size)); + } + // Create a call to the kernel and copy the data back. + rewriter.replaceOpWithNewOp(op, kernelFunc, + ArrayRef()); + for (auto triple : copyInfo) + copy(loc, std::get<0>(triple), std::get<1>(triple), std::get<2>(triple), + rewriter); + return success(); + } +}; + +class LowerHostCodeToLLVM + : public LowerHostCodeToLLVMBase { +public: + void runOnOperation() override { + ModuleOp module = getOperation(); + + // Erase the GPU module and check if the main module satisfies the + // requirements for `mlir-spirv-cpu-runner`: + // - a main module with exactly one nested module representing the kernel + // - have exactly one kernel function in the kernel module + for (auto gpuModule : + llvm::make_early_inc_range(module.getOps())) + gpuModule.erase(); + if (!hasOneNestedModuleAndOneKernel(module)) + module.emitError("Module should contain exactly one nested kernel module" + "with exactly one kernel function"); + + // Specify options to lower Standard to LLVM and pull in conversion + // patterns. + LowerToLLVMOptions options = { + /*useBarePtrCallConv=*/false, + /*emitCWrappers=*/true, + /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout}; + auto *context = module.getContext(); + OwningRewritePatternList patterns; + LLVMTypeConverter typeConverter(context, options); + populateStdToLLVMConversionPatterns(typeConverter, patterns); + patterns.insert(typeConverter); + + ConversionTarget target(*context); + target.addLegalDialect(); + if (failed(applyPartialConversion(module, target, patterns))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::createLowerHostCodeToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSPIRVTransforms DecorateSPIRVCompositeTypeLayoutPass.cpp + EncodeDescriptorSetsPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp UpdateVCEPass.cpp diff --git a/mlir/lib/Dialect/SPIRV/Transforms/EncodeDescriptorSetsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/EncodeDescriptorSetsPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Transforms/EncodeDescriptorSetsPass.cpp @@ -0,0 +1,75 @@ +//===- EncodeDescriptorSetsPass.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass encodes set and binding of the global variable into its symbolic +// name. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SPIRV/Passes.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" + +#define BINDING_NAME \ + llvm::convertToSnakeFromCamelCase( \ + stringifyDecoration(spirv::Decoration::Binding)) +#define DESCRIPTOR_SET_NAME \ + llvm::convertToSnakeFromCamelCase( \ + stringifyDecoration(spirv::Decoration::DescriptorSet)) + +using namespace mlir; + +/// Returns true if the given global variable has both a descriptor set numer +/// and a binding number. +static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) { + IntegerAttr descriptorSet = + op.getAttrOfType(DESCRIPTOR_SET_NAME); + IntegerAttr binding = op.getAttrOfType(BINDING_NAME); + return descriptorSet && binding; +} + +namespace { + +class EncodeDescriptorSetsPass + : public SPIRVEncodeDescriptorSetsBase { +public: + void runOnOperation() override { + spirv::ModuleOp module = getOperation(); + unsigned i = 0; + + // Walk over all `spv.globalVariable` ops within the module. If the variable + // has a bind attribute, update the symbol and all symbol's uses. + module.walk([&](spirv::GlobalVariableOp op) { + if (hasDescriptorSetAndBinding(op)) { + // For now, we do not need to store bind attribute info. May need + // to revisit in the future. + // TODO: encode the data from the descriptor set and binding + // numbers in the global variable. + std::string name = llvm::formatv("__var__{0}", i++); + if (failed(SymbolTable::replaceAllSymbolUses(op, name, module))) + return signalPassFailure(); + + // Set new symbol name and remove unused attributes. + SymbolTable::setSymbolName(op, name); + op.removeAttr(DESCRIPTOR_SET_NAME); + op.removeAttr(BINDING_NAME); + } + }); + } +}; +} // namespace + +std::unique_ptr> +mlir::spirv::createEncodeDescriptorSetsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/GPUCommon/lower-host-to-llvm-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-host-to-llvm-calls.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/lower-host-to-llvm-calls.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt --lower-host-to-llvm %s | FileCheck %s + +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env<#spv.vce, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} { + + // CHECK: llvm.mlir.global linkonce @[[VAR_NAME:.*]]() : !llvm.struct<(array<6 x i32>)> + // CHECK: llvm.func @[[KERNEL:.*]]() + + // CHECK: module + // CHECK-NEXT: llvm.mlir.global external @[[VAR_NAME]]() : !llvm.struct<(array<6 x i32>)> + // CHECK-NEXT: llvm.func @[[KERNEL]]() + // CHECK-NEXT: llvm.return + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: llvm.func @main() + + // CHECK: %[[DST:.*]] = llvm.mlir.addressof @[[VAR_NAME]] : !llvm.ptr)>> + // CHECK-NEXT: llvm.mlir.constant(false) : !llvm.i1 + // CHECK-NEXT: "llvm.intr.memcpy"(%[[DST]], %[[SRC:.*]], %[[SIZE:.*]], %{{.*}}) : (!llvm.ptr)>>, !llvm.ptr, !llvm.i64, !llvm.i1) -> () + // CHECK-NEXT: llvm.call @[[KERNEL]]() + // CHECK-NEXT: llvm.mlir.constant(false) : !llvm.i1 + // CHECK-NEXT: "llvm.intr.memcpy"(%[[SRC]], %[[DST]], %[[SIZE]], %{{.*}}) : (!llvm.ptr, !llvm.ptr)>>, !llvm.i64, !llvm.i1) -> () + + module { + llvm.mlir.global external @__var__0() : !llvm.struct<(array<6 x i32>)> + llvm.func @kernel() { + llvm.return + } + } + gpu.module @kernels { + gpu.func @kernel(%arg0: memref<6xi32>) kernel attributes {spv.entry_point_abi = {local_size = dense<1> : vector<3xi32>}} { + gpu.return + } + } + func @main() { + %buffer = alloc() : memref<6xi32> + %four = constant 4 : i32 + %buffer_casted = memref_cast %buffer : memref<6xi32> to memref + call @fillI32Buffer(%buffer_casted, %four) : (memref, i32) -> () + %one = constant 1 : index + "gpu.launch_func"(%one, %one, %one, %one, %one, %one, + %buffer) {kernel = @kernels::@kernel} : (index, index, index, index, index, index, memref<6xi32>) -> () + return + } + func @fillI32Buffer(memref, i32) +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/descriptor-sets-encoding.mlir b/mlir/test/Dialect/SPIRV/Transforms/descriptor-sets-encoding.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Transforms/descriptor-sets-encoding.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt -spirv-encode-descriptor-sets -verify-diagnostics %s | FileCheck %s + +spv.module Logical GLSL450 { + + // CHECK: spv.module + // CHECK: spv.globalVariable [[VAR0:@.*__var__0]] : !spv.ptr + // CHECK: spv.globalVariable [[VAR1:@.*__var__1]] : !spv.ptr + // CHECK: spv.globalVariable [[VAR2:@.*]] : !spv.ptr + // CHECK: spv.func @func0 + // CHECK: spv._address_of [[VAR0]] + // CHECK: spv._address_of [[VAR1]] + // CHECK: spv.func @func1 + // CHECK: spv._address_of [[VAR1]] + // CHECK: spv._address_of [[VAR2]] + + spv.globalVariable @var0 bind(0,0) : !spv.ptr + spv.globalVariable @var1 bind(0,1) : !spv.ptr + spv.globalVariable @var2 : !spv.ptr + spv.func @func0() -> () "None" { + %ptr0 = spv._address_of @var0 : !spv.ptr + %ptr1 = spv._address_of @var1 : !spv.ptr + spv.Return + } + spv.func @func1() -> () "None" { + %ptr1 = spv._address_of @var1 : !spv.ptr + %ptr2 = spv._address_of @var2 : !spv.ptr + spv.Return + } +}