diff --git a/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h b/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h --- a/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h +++ b/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h @@ -23,14 +23,12 @@ class ModuleOp; template class OperationPass; +class Pass; -#define GEN_PASS_DECL_CONVERTVULKANLAUNCHFUNCTOVULKANCALLS +#define GEN_PASS_DECL_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS #define GEN_PASS_DECL_CONVERTGPULAUNCHFUNCTOVULKANLAUNCHFUNC #include "mlir/Conversion/Passes.h.inc" -std::unique_ptr> -createConvertVulkanLaunchFuncToVulkanCallsPass(); - std::unique_ptr> createConvertGpuLaunchFuncToVulkanLaunchFuncPass(); diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -492,14 +492,20 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } -def ConvertVulkanLaunchFuncToVulkanCalls +def ConvertVulkanLaunchFuncToVulkanCallsPass : Pass<"launch-func-to-vulkan", "ModuleOp"> { let summary = "Convert vulkanLaunch external call to Vulkan runtime external " "calls"; let description = [{ This pass is only intended for the mlir-vulkan-runner. }]; - let constructor = "mlir::createConvertVulkanLaunchFuncToVulkanCallsPass()"; + + let options = [ + Option<"useOpaquePointers", "use-opaque-pointers", "bool", + /*default=*/"false", "Generate LLVM IR using opaque pointers " + "instead of typed pointers"> + ]; + let dependentDialects = ["LLVM::LLVMDialect"]; } diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -35,6 +35,7 @@ static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; +static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types"; static constexpr const char *kVulkanLaunch = "vulkanLaunch"; namespace { @@ -189,6 +190,18 @@ vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName, launchOp.getKernelName()); + // Add MemRef element types before they're lost when lowering to LLVM. + SmallVector elementTypes; + for (Type type : llvm::drop_begin(launchOp.getOperandTypes(), + gpu::LaunchOp::kNumConfigOperands)) { + // The below cast always succeeds as it has already been verified in + // 'declareVulkanLaunchFunc' that these are MemRefs with compatible element + // types. + elementTypes.push_back(type.cast().getElementType()); + } + vulkanLaunchCallOp->setAttr(kSPIRVElementTypesAttrName, + builder.getTypeArrayAttr(elementTypes)); + launchOp.erase(); } diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -25,7 +25,7 @@ #include "llvm/Support/FormatVariadic.h" namespace mlir { -#define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLS +#define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir @@ -42,6 +42,7 @@ static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; +static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types"; static constexpr const char *kVulkanLaunch = "vulkanLaunch"; namespace { @@ -58,14 +59,17 @@ /// * deinitVulkan -- deinitializes vulkan runtime /// class VulkanLaunchFuncToVulkanCallsPass - : public impl::ConvertVulkanLaunchFuncToVulkanCallsBase< + : public impl::ConvertVulkanLaunchFuncToVulkanCallsPassBase< VulkanLaunchFuncToVulkanCallsPass> { private: void initializeCachedTypes() { llvmFloatType = Float32Type::get(&getContext()); llvmVoidType = LLVM::LLVMVoidType::get(&getContext()); - llvmPointerType = - LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8)); + if (useOpaquePointers) + llvmPointerType = LLVM::LLVMPointerType::get(&getContext()); + else + llvmPointerType = + LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8)); llvmInt32Type = IntegerType::get(&getContext(), 32); llvmInt64Type = IntegerType::get(&getContext(), 64); } @@ -81,7 +85,9 @@ // int64_t sizes[Rank]; // omitted when rank == 0 // int64_t strides[Rank]; // omitted when rank == 0 // }; - auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType); + auto llvmPtrToElementType = useOpaquePointers + ? llvmPointerType + : LLVM::LLVMPointerType::get(elemenType); auto llvmArrayRankElementSizeType = LLVM::LLVMArrayType::get(getInt64Type(), rank); @@ -131,9 +137,8 @@ /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); - /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`. - LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor, - uint32_t &rank, Type &type); + /// Deduces a rank from the given 'launchCallArg`. + LogicalResult deduceMemRefRank(Value launchCallArg, uint32_t &rank); /// Returns a string representation from the given `type`. StringRef stringifyType(Type type) { @@ -154,6 +159,8 @@ } public: + using Base::Base; + void runOnOperation() override; private: @@ -163,8 +170,14 @@ Type llvmInt32Type; Type llvmInt64Type; + struct SPIRVAttributes { + StringAttr blob; + StringAttr entryPoint; + SmallVector elementTypes; + }; + // TODO: Use an associative array to support multiple vulkan launch calls. - std::pair spirvAttributes; + SPIRVAttributes spirvAttributes; /// The number of vulkan launch configuration operands, placed at the leading /// positions of the operand list. static constexpr unsigned kVulkanLaunchNumConfigOperands = 3; @@ -209,7 +222,24 @@ return signalPassFailure(); } - spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); + auto spirvElementTypesAttr = + vulkanLaunchCallOp->getAttrOfType(kSPIRVElementTypesAttrName); + if (!spirvElementTypesAttr) { + vulkanLaunchCallOp.emitError() + << "missing " << kSPIRVElementTypesAttrName << " attribute"; + return signalPassFailure(); + } + if (llvm::any_of(spirvElementTypesAttr, + [](Attribute attr) { return !isa(attr); })) { + vulkanLaunchCallOp.emitError() + << "expected " << spirvElementTypesAttr << " to be an array of types"; + return signalPassFailure(); + } + + spirvAttributes.blob = spirvBlobAttr; + spirvAttributes.entryPoint = spirvEntryPointNameAttr; + spirvAttributes.elementTypes = + llvm::to_vector(spirvElementTypesAttr.getAsValueRange()); } void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( @@ -226,17 +256,23 @@ Value descriptorSet = builder.create(loc, getInt32Type(), 0); - for (const auto &en : + for (auto [index, ptrToMemRefDescriptor] : llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( kVulkanLaunchNumConfigOperands))) { // Create LLVM constant for the descriptor binding index. Value descriptorBinding = - builder.create(loc, getInt32Type(), en.index()); + builder.create(loc, getInt32Type(), index); + + if (index >= spirvAttributes.elementTypes.size()) { + cInterfaceVulkanLaunchCallOp.emitError() + << kSPIRVElementTypesAttrName << " missing element type for " + << ptrToMemRefDescriptor; + return signalPassFailure(); + } - auto ptrToMemRefDescriptor = en.value(); uint32_t rank = 0; - Type type; - if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) { + Type type = spirvAttributes.elementTypes[index]; + if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) { cInterfaceVulkanLaunchCallOp.emitError() << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); return signalPassFailure(); @@ -246,7 +282,7 @@ llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); // Special case for fp16 type. Since it is not a supported type in C we use // int16_t and bitcast the descriptor. - if (type.isa()) { + if (!useOpaquePointers && type.isa()) { auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16)); ptrToMemRefDescriptor = builder.create( loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor); @@ -259,15 +295,24 @@ } } -LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType( - Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) { - auto llvmPtrDescriptorTy = - ptrToMemRefDescriptor.getType().dyn_cast(); - if (!llvmPtrDescriptorTy) +LogicalResult +VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg, + uint32_t &rank) { + // Deduce the rank from the type used to allocate the lowered MemRef. + auto alloca = launchCallArg.getDefiningOp(); + if (!alloca) return failure(); - auto llvmDescriptorTy = - llvmPtrDescriptorTy.getElementType().dyn_cast(); + LLVM::LLVMStructType llvmDescriptorTy; + if (std::optional elementType = alloca.getElemType()) { + llvmDescriptorTy = dyn_cast(*elementType); + } else { + // This case is only possible if we are not using opaque pointers + // since opaque pointer producing allocas require an element type. + llvmDescriptorTy = dyn_cast( + alloca.getRes().getType().getElementType()); + } + // template // struct { // Elem *allocated; @@ -279,9 +324,6 @@ if (!llvmDescriptorTy) return failure(); - type = llvmDescriptorTy.getBody()[0] - .cast() - .getElementType(); if (llvmDescriptorTy.getBody().size() == 3) { rank = 0; return success(); @@ -339,7 +381,9 @@ auto fnType = LLVM::LLVMFunctionType::get( getVoidType(), {getPointerType(), getInt32Type(), getInt32Type(), - LLVM::LLVMPointerType::get(getMemRefType(i, type))}, + useOpaquePointers + ? llvmPointerType + : LLVM::LLVMPointerType::get(getMemRefType(i, type))}, /*isVarArg=*/false); builder.create(loc, fnName, fnType); } @@ -368,7 +412,7 @@ std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); return LLVM::createGlobalString(loc, builder, entryPointGlobalName, shaderName, LLVM::Linkage::Internal, - /*TODO:useOpaquePointers=*/false); + useOpaquePointers); } void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( @@ -385,12 +429,12 @@ // Create LLVM global with SPIR-V binary data, so we can pass a pointer with // that data to runtime call. Value ptrToSPIRVBinary = LLVM::createGlobalString( - loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), - LLVM::Linkage::Internal, /*TODO:useOpaquePointers=*/false); + loc, builder, kSPIRVBinary, spirvAttributes.blob.getValue(), + LLVM::Linkage::Internal, useOpaquePointers); // Create LLVM constant for the size of SPIR-V binary shader. Value binarySize = builder.create( - loc, getInt32Type(), spirvAttributes.first.getValue().size()); + loc, getInt32Type(), spirvAttributes.blob.getValue().size()); // Create call to `bindMemRef` for each memref operand. createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); @@ -402,7 +446,7 @@ ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize}); // Create LLVM global with entry point name. Value entryPointName = createEntryPointNameConstant( - spirvAttributes.second.getValue(), loc, builder); + spirvAttributes.entryPoint.getValue(), loc, builder); // Create call to `setEntryPoint` runtime function with the given pointer to // entry point name. builder.create(loc, TypeRange(), kSetEntryPoint, @@ -428,8 +472,3 @@ cInterfaceVulkanLaunchCallOp.erase(); } - -std::unique_ptr> -mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { - return std::make_unique(); -} diff --git a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir --- a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir +++ b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir @@ -1,63 +1,62 @@ -// RUN: mlir-opt %s -launch-func-to-vulkan | FileCheck %s +// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=1' | FileCheck %s // CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name // CHECK: llvm.mlir.global internal constant @SPIRV_BIN -// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr +// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr // CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN // CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]] // CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant -// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>) -> () -// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr, !llvm.ptr, i32) -> () +// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr) -> () +// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr, !llvm.ptr, i32) -> () // CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name // CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]] -// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr, !llvm.ptr) -> () -// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i64, i64, i64) -> () -// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> () -// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> () +// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr, !llvm.ptr) -> () +// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i64, i64, i64) -> () +// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> () +// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> () -// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr, i32, i32, !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>) +// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr, i32, i32, !llvm.ptr) module attributes {gpu.container_module} { - llvm.func @malloc(i64) -> !llvm.ptr + llvm.func @malloc(i64) -> !llvm.ptr llvm.func @foo() { %0 = llvm.mlir.constant(12 : index) : i64 - %1 = llvm.mlir.null : !llvm.ptr + %1 = llvm.mlir.null : !llvm.ptr %2 = llvm.mlir.constant(1 : index) : i64 - %3 = llvm.getelementptr %1[%2] : (!llvm.ptr, i64) -> !llvm.ptr - %4 = llvm.ptrtoint %3 : !llvm.ptr to i64 + %3 = llvm.getelementptr %1[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + %4 = llvm.ptrtoint %3 : !llvm.ptr to i64 %5 = llvm.mul %0, %4 : i64 - %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr - %7 = llvm.bitcast %6 : !llvm.ptr to !llvm.ptr - %8 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr + %8 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %9 = llvm.insertvalue %6, %8[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %10 = llvm.insertvalue %6, %9[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %11 = llvm.mlir.constant(0 : index) : i64 - %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %13 = llvm.mlir.constant(1 : index) : i64 - %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %16 = llvm.mlir.constant(1 : index) : i64 - %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_entry_point = "kernel"} - : (i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> () + %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_element_types = [f32], spirv_entry_point = "kernel"} + : (i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> () llvm.return } - llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr, %arg7: !llvm.ptr, %arg8: i64, %arg9: i64, %arg10: i64) { - %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr, %arg7: !llvm.ptr, %arg8: i64, %arg9: i64, %arg10: i64) { + %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %6 = llvm.mlir.constant(1 : index) : i64 - %7 = llvm.alloca %6 x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> - llvm.store %5, %7 : !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> - llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>) -> () + %7 = llvm.alloca %6 x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr + llvm.store %5, %7 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr + llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr) -> () llvm.return } - llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>) + llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr) } diff --git a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir --- a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir +++ b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir @@ -2,7 +2,7 @@ // CHECK: %[[resource:.*]] = memref.alloc() : memref<12xf32> // CHECK: %[[index:.*]] = arith.constant 1 : index -// CHECK: call @vulkanLaunch(%[[index]], %[[index]], %[[index]], %[[resource]]) {spirv_blob = "{{.*}}", spirv_entry_point = "kernel"} +// CHECK: call @vulkanLaunch(%[[index]], %[[index]], %[[index]], %[[resource]]) {spirv_blob = "{{.*}}", spirv_element_types = [f32], spirv_entry_point = "kernel"} module attributes {gpu.container_module} { spirv.module Logical GLSL450 requires #spirv.vce { diff --git a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir b/mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir copy from mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir copy to mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir --- a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir +++ b/mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -launch-func-to-vulkan | FileCheck %s +// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=0' | FileCheck %s // CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name // CHECK: llvm.mlir.global internal constant @SPIRV_BIN @@ -42,7 +42,7 @@ %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_entry_point = "kernel"} + llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_element_types = [f32], spirv_entry_point = "kernel"} : (i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> () llvm.return } diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -68,16 +68,25 @@ if (options.spirvWebGPUPrepare) modulePM.addPass(spirv::createSPIRVWebGPUPreparePass()); + auto enableOpaquePointers = [](auto passOption) { + passOption.useOpaquePointers = true; + return passOption; + }; + passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass()); - passManager.addPass(createFinalizeMemRefToLLVMConversionPass()); - passManager.addPass(createConvertVectorToLLVMPass()); + passManager.addPass(createFinalizeMemRefToLLVMConversionPass( + enableOpaquePointers(FinalizeMemRefToLLVMConversionPassOptions{}))); + passManager.addPass(createConvertVectorToLLVMPass( + enableOpaquePointers(ConvertVectorToLLVMPassOptions{}))); passManager.nest().addPass(LLVM::createRequestCWrappersPass()); ConvertFuncToLLVMPassOptions funcToLLVMOptions{}; funcToLLVMOptions.indexBitwidth = DataLayout(module).getTypeSizeInBits(IndexType::get(module.getContext())); - passManager.addPass(createConvertFuncToLLVMPass(funcToLLVMOptions)); + passManager.addPass( + createConvertFuncToLLVMPass(enableOpaquePointers(funcToLLVMOptions))); passManager.addPass(createReconcileUnrealizedCastsPass()); - passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass()); + passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass( + enableOpaquePointers(ConvertVulkanLaunchFuncToVulkanCallsPassOptions{}))); return passManager.run(module); }