diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h --- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h +++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h @@ -38,6 +38,12 @@ using LoweringCallback = std::function( Operation *, llvm::LLVMContext &, StringRef)>; +/// Creates a pass to convert `gpu.launch_func` to a standard function call. +std::unique_ptr> +createGPULaunchFuncToStandardCallPass(); +/// Creates a pass to emulate a kernel function call in LLVM dialect. +std::unique_ptr> createEmulateKernelCallInLLVMPass(); + /// Creates a pass to convert a gpu.launch_func operation into a sequence of /// GPU runtime calls. /// diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -91,6 +91,16 @@ ]; } +def GPULaunchFuncToStandardCall : Pass<"gpu-launch-to-std-call", "ModuleOp"> { + let summary = "Convert gpu.launch_func to a standard dialect call"; + let constructor = "mlir::createGPULaunchFuncToStandardCallPass()"; +} + +def EmulateKernelCallInLLVM : Pass<"emulate-kernel-llvm-call", "ModuleOp"> { + let summary = "Emulate a kernel function call in LLVM dialect"; + let constructor = "mlir::createEmulateKernelCallInLLVMPass()"; +} + //===----------------------------------------------------------------------===// // GPUToNVVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h --- a/mlir/include/mlir/Dialect/SPIRV/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h @@ -26,6 +26,11 @@ std::unique_ptr> createDecorateSPIRVCompositeTypeLayoutPass(); +/// Creates a module pass that encodes bind attribute of each +/// `spv.globalVariable` into its symbolic name. +std::unique_ptr> +createEncodeDescriptorSetsPass(); + /// Creates an operation pass that deduces and attaches the minimal version/ /// capabilities/extensions requirements for spv.module ops. /// For each spv.module op, this pass requires a `spv.target_env` attribute on diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Passes.td --- a/mlir/include/mlir/Dialect/SPIRV/Passes.td +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.td @@ -17,6 +17,13 @@ let constructor = "mlir::spirv::createDecorateSPIRVCompositeTypeLayoutPass()"; } +def SPIRVEncodeDescriptorSets + : Pass<"spirv-encode-descriptor-sets", "spirv::ModuleOp"> { + let summary = "Encode `spv.globalVariable`'s' bind attribute into its " + "symbolic name"; + let constructor = "mlir::spirv::createEncodeDescriptorSetsPass()"; +} + def SPIRVLowerABIAttributes : Pass<"spirv-lower-abi-attrs", "spirv::ModuleOp"> { let summary = "Decorate SPIR-V composite type with layout info"; let constructor = "mlir::spirv::createLowerABIAttributesPass()"; diff --git a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt --- a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt @@ -15,6 +15,7 @@ endif() add_mlir_conversion_library(MLIRGPUToGPURuntimeTransforms + ConvertLaunchFuncToLLVMCalls.cpp ConvertLaunchFuncToRuntimeCalls.cpp ConvertKernelFuncToBlob.cpp diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToLLVMCalls.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToLLVMCalls.cpp @@ -0,0 +1,300 @@ +//===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements passes to convert `gpu.launch_func` op into a sequence +// of LLVM calls that emulate the host and device sides. +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; + +static constexpr const char kLLVMLaunchName[] = "__llvm_launch"; +static constexpr const char kStandardLaunchName[] = "__std_launch"; +static constexpr const unsigned memRefDescriptorNumElements = 5; +static constexpr const unsigned numElementsOffset = 3; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// Erases redundant functions from the module. +static void clearFunctionDefinitions(ModuleOp module) { + for (auto funcOp : + llvm::make_early_inc_range(module.getOps())) { + StringRef name = funcOp.getName(); + if (name == kStandardLaunchName || + name == + StringRef(llvm::formatv("_mlir_ciface_{0}", kStandardLaunchName))) + funcOp.erase(); + } +} + +/// Emits code to copy the given number of bytes from src to dst pointers. +static void copyData(Location loc, Value dst, Value src, Value size, + OpBuilder &builder) { + MLIRContext *context = builder.getContext(); + auto llvmI1Type = LLVM::LLVMType::getInt1Ty(context); + Value isVolatile = builder.create( + loc, llvmI1Type, builder.getBoolAttr(false)); + builder.create(loc, dst, src, size, isVolatile); +} + +/// Returns true if the function call is ex-`gpu.launch_func` op. +static bool isLaunchCallOp(LLVM::CallOp callOp) { + return callOp.callee() && callOp.callee().getValue() == kStandardLaunchName; +} + +/// Verifies if the module contains exactly one nested module with exactly one +/// (kernel) function. +static bool hasOneNestedModuleAndOneKernel(ModuleOp module) { + bool hasOneNestedModule = false; + auto walkResult = + module.walk([&hasOneNestedModule](ModuleOp moduleOp) -> WalkResult { + // Interrupt if more than one nested module has been found. + if (moduleOp.getParentOp() && hasOneNestedModule) + return WalkResult::interrupt(); + + // If there is a parent operation, it means we walked to a nested + // module. Verify there is only a single function in it. + if (moduleOp.getParentOp()) { + auto funcs = moduleOp.getOps(); + SmallVector funcVector(funcs.begin(), + funcs.end()); + if (funcVector.size() != 1) + return WalkResult::interrupt(); + hasOneNestedModule = true; + } + // Otherwise, advance. + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) + return false; + return hasOneNestedModule; +} + +static LogicalResult renameKernel(ModuleOp kernelModule) { + kernelModule.walk([&](LLVM::LLVMFuncOp funcOp) { + SymbolTable::setSymbolName(funcOp, kLLVMLaunchName); + }); + return success(); +} + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { +class GPULaunchFuncToStandardCall + : public GPULaunchFuncToStandardCallBase { +public: + void runOnOperation() override; +}; + +class EmulateKernelCallInLLVM + : public EmulateKernelCallInLLVMBase { +public: + void runOnOperation() override; +}; + +/// This pattern prepares the conversion of kernel launch to LLVM dialect. It: +/// - Declares a placeholder function with the same argument types and a void +/// result type. +/// - Replaces `gpu.launch_func` with a call to placeholder. +class KernelCallPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(gpu::LaunchFuncOp launchOp, + PatternRewriter &rewriter) const override { + MLIRContext *context = rewriter.getContext(); + auto module = launchOp.getParentOfType(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + // Declare a standard function that will keep kernel memref arguments. + SmallVector gpuLaunchTypes(launchOp.getOperandTypes()); + SmallVector kernelOperands(gpuLaunchTypes.begin() + + gpu::LaunchOp::kNumConfigOperands, + gpuLaunchTypes.end()); + auto kernelFunc = + rewriter.create(rewriter.getUnknownLoc(), kStandardLaunchName, + FunctionType::get(kernelOperands, {}, context)); + + // Replace `gpu.launch_func` with a standard function call. + rewriter.setInsertionPoint(launchOp); + SmallVector operands; + for (unsigned i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) + operands.push_back(launchOp.getKernelOperand(i)); + rewriter.replaceOpWithNewOp(launchOp, kernelFunc, operands); + return success(); + } +}; + +/// This pattern emulates a call to kernel in LLVM dialect. For that, we +/// copy the data to the global variable (emulating device side), call +/// the kernel as a normal void LLVM function, and copy the data back +/// (emulating host side). +class StandardCallPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallOp op, + PatternRewriter &rewriter) const override { + MLIRContext *context = rewriter.getContext(); + auto module = op.getParentOfType(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + // Declare kernel function in the main module so that it later can be linked + // with its definition from the kernel module. We know that the kernel + // function would have no arguments and the data is passed via global + // variables. + auto kernelFunc = rewriter.create( + rewriter.getUnknownLoc(), kLLVMLaunchName, + LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(context), + ArrayRef(), + /*isVarArg=*/false)); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + // Find ex-GPU module (now a nested LLVM module). Rename the kernel function + // to adhere to convention. + ModuleOp kernelModule; + auto walkResult = module.walk([&](ModuleOp nested) -> WalkResult { + if (nested.getParentOp()) { + kernelModule = nested; + renameKernel(kernelModule); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (!walkResult.wasInterrupted()) + return failure(); + + // Kernel's memref arguments are converted into MemRefDescriptors in LLVM. + // Walk over pointers to the allocated data and the number of elements to + // calculate the number of bytes to copy. + SmallVector copyTriples; + for (unsigned i = 0, e = op.getNumOperands(); i < e; + i += memRefDescriptorNumElements) { + Value src = op.getOperand(i); + + // Calculate a size of buffer in bytes. Support only integers for now for + // easier calculation. + auto elementType = + src.getType().cast().getPointerElementTy(); + auto integerType = elementType.dyn_cast(); + if (!integerType) + return failure(); + unsigned elementSizeInBytes = integerType.getBitWidth() / 8; + auto llvmI64Type = LLVM::LLVMType::getInt64Ty(context); + Value constantSize = rewriter.create( + loc, llvmI64Type, rewriter.getI64IntegerAttr(elementSizeInBytes)); + Value numElements = op.getOperand(i + numElementsOffset); + Value size = rewriter.create(loc, numElements, constantSize); + + // Look up the global variable from the kernel module. For each, create a + // global variable that will be linked with the global from the kernel + // module. We follow convention taken in EncodeDescriptorSetsPass to label + // globals as __var__{i}, where i is incremented sequentially. + std::string symName = + "__var__" + std::to_string(i / memRefDescriptorNumElements); + auto extrenalGlobal = kernelModule.lookupSymbol(symName); + rewriter.setInsertionPointToStart(module.getBody()); + auto global = rewriter.create( + loc, extrenalGlobal.type().cast(), + /*isConstant=*/false, LLVM::Linkage::Linkonce, symName, Attribute()); + rewriter.setInsertionPoint(op); + + // Get the address of created global. Then copy data from src pointer to + // dst pointer. + Value dst = rewriter.create(loc, global); + copyData(loc, dst, src, size, rewriter); + + // Store src and dst pointers to reuse after the kernel call. + copyTriples.push_back({dst, src, size}); + } + + // Emulate the kernel call and then copy the data back from global variables + // to buffers. + rewriter.replaceOpWithNewOp(op, kernelFunc, + ArrayRef()); + for (auto triple : copyTriples) { + copyData(loc, triple[1], triple[0], triple[2], rewriter); + } + return success(); + } +}; +} // namespace + +void GPULaunchFuncToStandardCall::runOnOperation() { + ModuleOp module = getOperation(); + auto *context = module.getContext(); + OwningRewritePatternList patterns; + patterns.insert(context); + + // Convert `gpu.launch_func` to standard function call. + ConversionTarget target(*context); + target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + if (failed(applyPartialConversion(module, target, patterns))) { + signalPassFailure(); + } + + // Erase GPU module. + for (auto gpuModule : + llvm::make_early_inc_range(getOperation().getOps())) + gpuModule.erase(); +} + +void EmulateKernelCallInLLVM::runOnOperation() { + ModuleOp module = getOperation(); + + // Check if the module structure can be supported by the current conversion. + // For now, only 2 modules (main and kernel) are supported. + if (!hasOneNestedModuleAndOneKernel(module)) + module.emitError("Module should contain exactly one nested module"); + + // Emulate kernel call. + auto *context = module.getContext(); + OwningRewritePatternList patterns; + patterns.insert(context); + ConversionTarget target(*context); + target.addLegalDialect(); + target.addDynamicallyLegalOp( + [&](LLVM::CallOp op) { return !isLaunchCallOp(op); }); + if (failed(applyPartialConversion(module, target, patterns))) { + signalPassFailure(); + } + + // Remove unused functions that were generated during the previous pass. + clearFunctionDefinitions(module); +} + +std::unique_ptr> +mlir::createGPULaunchFuncToStandardCallPass() { + return std::make_unique(); +} + +std::unique_ptr> +mlir::createEmulateKernelCallInLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSPIRVTransforms DecorateSPIRVCompositeTypeLayoutPass.cpp + EncodeDescriptorSetsPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp UpdateVCEPass.cpp diff --git a/mlir/lib/Dialect/SPIRV/Transforms/EncodeDescriptorSetsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/EncodeDescriptorSetsPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Transforms/EncodeDescriptorSetsPass.cpp @@ -0,0 +1,75 @@ +//===- EncodeDescriptorSetsPass.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass encodes set and binding of the global variable into its symbolic +// name. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SPIRV/Passes.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" + +#define BINDING_NAME \ + llvm::convertToSnakeFromCamelCase( \ + stringifyDecoration(spirv::Decoration::Binding)) +#define DESCRIPTOR_SET_NAME \ + llvm::convertToSnakeFromCamelCase( \ + stringifyDecoration(spirv::Decoration::DescriptorSet)) + +using namespace mlir; + +/// Returns true if the given global variable has both a descriptor set numer +/// and a binding number. +static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) { + IntegerAttr descriptorSet = + op.getAttrOfType(DESCRIPTOR_SET_NAME); + IntegerAttr binding = op.getAttrOfType(BINDING_NAME); + return descriptorSet && binding; +} + +namespace { + +class EncodeDescriptorSetsPass + : public SPIRVEncodeDescriptorSetsBase { +public: + void runOnOperation() override { + spirv::ModuleOp module = getOperation(); + unsigned i = 0; + + // Walk over all `spv.globalVariable` ops within the module. If the variable + // has a bind attribute, update the symbol and all symbol's uses. + module.walk([&](spirv::GlobalVariableOp op) { + if (hasDescriptorSetAndBinding(op)) { + // For now, we do not need to store bind attribute info. May need + // to revisit in the future. + // TODO: encode the data from the descriptor set and binding + // numbers in the global variable. + std::string name = llvm::formatv("__var__{0}", i++); + if (failed(SymbolTable::replaceAllSymbolUses(op, name, module))) + return signalPassFailure(); + + // Set new symbol name and remove unused attributes. + SymbolTable::setSymbolName(op, name); + op.removeAttr(DESCRIPTOR_SET_NAME); + op.removeAttr(BINDING_NAME); + } + }); + } +}; +} // namespace + +std::unique_ptr> +mlir::spirv::createEncodeDescriptorSetsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/GPUCommon/emulate-kernel-call.mlir b/mlir/test/Conversion/GPUCommon/emulate-kernel-call.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/emulate-kernel-call.mlir @@ -0,0 +1,79 @@ +// RUN: mlir-opt --emulate-kernel-llvm-call %s | FileCheck %s + +// CHECK: module +// CHECK: llvm.mlir.global linkonce @__var__0() : !llvm.struct)> +// CHECK-LABEL: @__llvm_launch +// CHECK: module +// CHECK-LABEL: @__llvm_launch + +// CHECK-LABEL: @main +// CHECK: %[[ELEMENT_SIZE:.*]] = llvm.mlir.constant(4 : i64) : !llvm.i64 +// CHECK: %[[SIZE:.*]] = llvm.mul %{{.*}}, %[[ELEMENT_SIZE]] : !llvm.i64 +// CHECK: %[[DST:.*]] = llvm.mlir.addressof @__var__0 : !llvm.ptr)>> +// CHECK: llvm.mlir.constant(false) : !llvm.i1 +// CHECK: "llvm.intr.memcpy"(%[[DST]], %{{.*}}, %[[SIZE]], %{{.*}}) : (!llvm.ptr)>>, !llvm.ptr, !llvm.i64, !llvm.i1) -> () +// CHECK: llvm.call @__llvm_launch() : () -> () +// CHECK: llvm.mlir.constant(false) : !llvm.i1 +// CHECK: "llvm.intr.memcpy"(%{{.*}}, %[[DST]], %[[SIZE]], %{{.*}}) : (!llvm.ptr, !llvm.ptr)>>, !llvm.i64, !llvm.i1) -> () + +module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} { + llvm.func @__std_launch(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.i64, %arg3: !llvm.i64, %arg4: !llvm.i64) { + llvm.return + } + llvm.func @_mlir_ciface___std_launch(!llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>) + module { + llvm.mlir.global external @__var__0() : !llvm.struct)> + llvm.func @simple() { + llvm.return + } + } + llvm.func @main() { + %0 = llvm.mlir.constant(6 : index) : !llvm.i64 + %1 = llvm.mlir.null : !llvm.ptr + %2 = llvm.mlir.constant(1 : index) : !llvm.i64 + %3 = llvm.getelementptr %1[%2] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + %4 = llvm.ptrtoint %3 : !llvm.ptr to !llvm.i64 + %5 = llvm.mul %0, %4 : !llvm.i64 + %6 = llvm.mlir.constant(1 : index) : !llvm.i64 + %7 = llvm.call @malloc(%5) : (!llvm.i64) -> !llvm.ptr + %8 = llvm.bitcast %7 : !llvm.ptr to !llvm.ptr + %9 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %10 = llvm.insertvalue %8, %9[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %11 = llvm.insertvalue %8, %10[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %12 = llvm.mlir.constant(0 : index) : !llvm.i64 + %13 = llvm.insertvalue %12, %11[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %14 = llvm.mlir.constant(1 : index) : !llvm.i64 + %15 = llvm.insertvalue %0, %13[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %16 = llvm.insertvalue %14, %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %17 = llvm.mlir.constant(4 : i32) : !llvm.i32 + %18 = llvm.bitcast %16 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %19 = llvm.extractvalue %18[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %20 = llvm.extractvalue %18[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %21 = llvm.extractvalue %18[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %22 = llvm.extractvalue %18[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %23 = llvm.extractvalue %18[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + llvm.call @fillI32Buffer(%19, %20, %21, %22, %23, %17) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i32) -> () + %24 = llvm.mlir.constant(1 : index) : !llvm.i64 + %25 = llvm.extractvalue %16[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %26 = llvm.extractvalue %16[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %27 = llvm.extractvalue %16[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %28 = llvm.extractvalue %16[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %29 = llvm.extractvalue %16[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + llvm.call @__std_launch(%25, %26, %27, %28, %29) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i64, !llvm.i64) -> () + %30 = llvm.mlir.constant(1 : index) : !llvm.i64 + %31 = llvm.alloca %30 x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> : (!llvm.i64) -> !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> + llvm.store %16, %31 : !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> + %32 = llvm.bitcast %31 : !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> to !llvm.ptr + %33 = llvm.mlir.constant(1 : i64) : !llvm.i64 + %34 = llvm.mlir.undef : !llvm.struct<(i64, ptr)> + %35 = llvm.insertvalue %33, %34[0] : !llvm.struct<(i64, ptr)> + %36 = llvm.insertvalue %32, %35[1] : !llvm.struct<(i64, ptr)> + %37 = llvm.extractvalue %36[0] : !llvm.struct<(i64, ptr)> + %38 = llvm.extractvalue %36[1] : !llvm.struct<(i64, ptr)> + llvm.call @print_memref_i32(%37, %38) : (!llvm.i64, !llvm.ptr) -> () + llvm.return + } + + llvm.func @fillI32Buffer(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.i64, %arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i32) + llvm.func @print_memref_i32(%arg0: !llvm.i64, %arg1: !llvm.ptr) +} diff --git a/mlir/test/Conversion/GPUCommon/gpu-launch-to-std-call.mlir b/mlir/test/Conversion/GPUCommon/gpu-launch-to-std-call.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/gpu-launch-to-std-call.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt --gpu-launch-to-std-call %s | FileCheck %s + +module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} { + + // CHECK: func @__std_launch(memref<6xi32>, memref<6xi32>) + + // CHECK-LABEL: @main + // CHECK: %[[BUFFER1:.*]] = alloc() : memref<6xi32> + // CHECK: %[[BUFFER2:.*]] = alloc() : memref<6xi32> + // CHECK: call @__std_launch(%[[BUFFER1]], %[[BUFFER2]]) : (memref<6xi32>, memref<6xi32>) -> () + + spv.module Logical GLSL450 requires #spv.vce { + spv.globalVariable @__var__0 : !spv.ptr [0]>, StorageBuffer> + spv.globalVariable @__var__1 : !spv.ptr [0]>, StorageBuffer> + spv.func @simple() "None" attributes {workgroup_attributions = 0 : i64} { + %0 = spv._address_of @__var__1 : !spv.ptr [0]>, StorageBuffer> + %1 = spv._address_of @__var__0 : !spv.ptr [0]>, StorageBuffer> + spv.Return + } + spv.EntryPoint "GLCompute" @simple + spv.ExecutionMode @simple "LocalSize", 1, 1, 1 + } + + gpu.module @kernels { + gpu.func @simple(%arg0: memref<6xi32>, %arg1: memref<6xi32>) kernel attributes {spv.entry_point_abi = {local_size = dense<1> : vector<3xi32>}} { + %c5_i32 = constant 5 : i32 + %c0 = constant 0 : index + %0 = load %arg0[%c0] : memref<6xi32> + %1 = addi %0, %c5_i32 : i32 + store %1, %arg1[%c0] : memref<6xi32> + gpu.return + } + } + + func @main() { + %0 = alloc() : memref<6xi32> + %1 = alloc() : memref<6xi32> + %c4_i32 = constant 4 : i32 + %c3_i32 = constant 3 : i32 + %2 = memref_cast %0 : memref<6xi32> to memref + %3 = memref_cast %1 : memref<6xi32> to memref + call @fillI32Buffer(%2, %c3_i32) : (memref, i32) -> () + call @fillI32Buffer(%3, %c4_i32) : (memref, i32) -> () + %c1 = constant 1 : index + "gpu.launch_func"(%c1, %c1, %c1, %c1, %c1, %c1, %0, %1) {kernel = @kernels::@simple} : (index, index, index, index, index, index, memref<6xi32>, memref<6xi32>) -> () + %4 = memref_cast %1 : memref<6xi32> to memref<*xi32> + call @print_memref_i32(%4) : (memref<*xi32>) -> () + return + } + func @fillI32Buffer(memref, i32) + func @print_memref_i32(memref<*xi32>) +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/descriptor-sets-encoding.mlir b/mlir/test/Dialect/SPIRV/Transforms/descriptor-sets-encoding.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Transforms/descriptor-sets-encoding.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt -spirv-encode-descriptor-sets -verify-diagnostics %s | FileCheck %s + +spv.module Logical GLSL450 { + + // CHECK: spv.module + // CHECK: spv.globalVariable [[VAR0:@.*__var__0]] : !spv.ptr + // CHECK: spv.globalVariable [[VAR1:@.*__var__1]] : !spv.ptr + // CHECK: spv.globalVariable [[VAR2:@.*]] : !spv.ptr + // CHECK: spv.func @func0 + // CHECK: spv._address_of [[VAR0]] + // CHECK: spv._address_of [[VAR1]] + // CHECK: spv.func @func1 + // CHECK: spv._address_of [[VAR1]] + // CHECK: spv._address_of [[VAR2]] + + spv.globalVariable @var0 bind(0,0) : !spv.ptr + spv.globalVariable @var1 bind(0,1) : !spv.ptr + spv.globalVariable @var2 : !spv.ptr + spv.func @func0() -> () "None" { + %ptr0 = spv._address_of @var0 : !spv.ptr + %ptr1 = spv._address_of @var1 : !spv.ptr + spv.Return + } + spv.func @func1() -> () "None" { + %ptr1 = spv._address_of @var1 : !spv.ptr + %ptr2 = spv._address_of @var2 : !spv.ptr + spv.Return + } +}