diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -54,10 +54,12 @@ /// Checks where the given type is supported by Vulkan runtime. bool isSupportedType(Type type) { - // TODO(denis0x0D): Handle other types. - if (auto memRefType = type.dyn_cast_or_null()) + if (auto memRefType = type.dyn_cast_or_null()) { + auto elementType = memRefType.getElementType(); return memRefType.hasRank() && - (memRefType.getRank() >= 1 && memRefType.getRank() <= 3); + (memRefType.getRank() >= 1 && memRefType.getRank() <= 3) && + (elementType.isIntOrFloat()); + } return false; } diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -30,6 +30,9 @@ static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat"; static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat"; +static constexpr const char *kBindMemRef1DInt = "bindMemRef1DInt"; +static constexpr const char *kBindMemRef2DInt = "bindMemRef2DInt"; +static constexpr const char *kBindMemRef3DInt = "bindMemRef3DInt"; static constexpr const char *kCInterfaceVulkanLaunch = "_mlir_ciface_vulkanLaunch"; static constexpr const char *kDeinitVulkan = "deinitVulkan"; @@ -73,12 +76,15 @@ llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); - llvmMemRef1DFloat = getMemRefType(1); - llvmMemRef2DFloat = getMemRefType(2); - llvmMemRef3DFloat = getMemRefType(3); + llvmMemRef1DFloat = getMemRefType(1, llvmFloatType); + llvmMemRef2DFloat = getMemRefType(2, llvmFloatType); + llvmMemRef3DFloat = getMemRefType(3, llvmFloatType); + llvmMemRef1DInt = getMemRefType(1, llvmInt32Type); + llvmMemRef2DInt = getMemRefType(2, llvmInt32Type); + llvmMemRef3DInt = getMemRefType(3, llvmInt32Type); } - LLVM::LLVMType getMemRefType(uint32_t rank) { + LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) { // According to the MLIR doc memref argument is converted into a // pointer-to-struct argument of type: // template @@ -89,15 +95,16 @@ // int64_t sizes[Rank]; // omitted when rank == 0 // int64_t strides[Rank]; // omitted when rank == 0 // }; - auto llvmPtrToFloatType = getFloatType().getPointerTo(); + auto llvmPtrToElementType = elemenType.getPointerTo(); auto llvmArrayRankElementSizeType = LLVM::LLVMType::getArrayTy(getInt64Type(), rank); // Create a type - // `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`. + // `!llvm<"{ `element-type`*, `element-type`*, i64, + // [`rank` x i64], [`rank` x i64]}">`. return LLVM::LLVMType::getStructTy( llvmDialect, - {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), + {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(), llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); } @@ -109,6 +116,9 @@ LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; } LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; } + LLVM::LLVMType getMemRef1DInt() { return llvmMemRef1DInt; } + LLVM::LLVMType getMemRef2DInt() { return llvmMemRef2DInt; } + LLVM::LLVMType getMemRef3DInt() { return llvmMemRef3DInt; } /// Creates a LLVM global for the given `name`. Value createEntryPointNameConstant(StringRef name, Location loc, @@ -142,8 +152,19 @@ /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); - /// Deduces a rank from the given 'ptrToMemRefDescriptor`. - LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank); + /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`. + LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor, + uint32_t &rank, LLVM::LLVMType &type); + + /// Returns a string representation from the given `type`. + StringRef stringifyType(LLVM::LLVMType type) { + if (type.isFloatTy()) + return "Float"; + if (type.isIntegerTy()) + return "Int"; + + llvm_unreachable("unsupported type"); + } public: void runOnOperation() override; @@ -158,6 +179,9 @@ LLVM::LLVMType llvmMemRef1DFloat; LLVM::LLVMType llvmMemRef2DFloat; LLVM::LLVMType llvmMemRef3DFloat; + LLVM::LLVMType llvmMemRef1DInt; + LLVM::LLVMType llvmMemRef2DInt; + LLVM::LLVMType llvmMemRef3DInt; // TODO: Use an associative array to support multiple vulkan launch calls. std::pair spirvAttributes; @@ -231,13 +255,15 @@ auto ptrToMemRefDescriptor = en.value(); uint32_t rank = 0; - if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) { + LLVM::LLVMType type; + if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) { cInterfaceVulkanLaunchCallOp.emitError() << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); return signalPassFailure(); } - auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str(); + auto symbolName = + llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); // Create call to `bindMemRef`. builder.create( loc, ArrayRef{getVoidType()}, @@ -248,9 +274,8 @@ } } -LogicalResult -VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor, - uint32_t &rank) { +LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType( + Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) { auto llvmPtrDescriptorTy = ptrToMemRefDescriptor.getType().dyn_cast(); if (!llvmPtrDescriptorTy) @@ -267,11 +292,12 @@ // }; if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) return failure(); + + type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy(); if (llvmDescriptorTy.getStructNumElements() == 3) { rank = 0; return success(); } - rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); return success(); } @@ -312,35 +338,23 @@ /*isVarArg=*/false)); } - if (!module.lookupSymbol(kBindMemRef1DFloat)) { - builder.create( - loc, kBindMemRef1DFloat, - LLVM::LLVMType::getFunctionTy(getVoidType(), - {getPointerType(), getInt32Type(), - getInt32Type(), - getMemRef1DFloat().getPointerTo()}, - /*isVarArg=*/false)); +#define CREATE_VULKAN_BIND_FUNC(MemRefType) \ + if (!module.lookupSymbol(kBind##MemRefType)) { \ + builder.create( \ + loc, kBind##MemRefType, \ + LLVM::LLVMType::getFunctionTy(getVoidType(), \ + {getPointerType(), getInt32Type(), \ + getInt32Type(), \ + get##MemRefType().getPointerTo()}, \ + /*isVarArg=*/false)); \ } - if (!module.lookupSymbol(kBindMemRef2DFloat)) { - builder.create( - loc, kBindMemRef2DFloat, - LLVM::LLVMType::getFunctionTy(getVoidType(), - {getPointerType(), getInt32Type(), - getInt32Type(), - getMemRef2DFloat().getPointerTo()}, - /*isVarArg=*/false)); - } - - if (!module.lookupSymbol(kBindMemRef3DFloat)) { - builder.create( - loc, kBindMemRef3DFloat, - LLVM::LLVMType::getFunctionTy(getVoidType(), - {getPointerType(), getInt32Type(), - getInt32Type(), - getMemRef3DFloat().getPointerTo()}, - /*isVarArg=*/false)); - } + CREATE_VULKAN_BIND_FUNC(MemRef1DFloat); + CREATE_VULKAN_BIND_FUNC(MemRef2DFloat); + CREATE_VULKAN_BIND_FUNC(MemRef3DFloat); + CREATE_VULKAN_BIND_FUNC(MemRef1DInt); + CREATE_VULKAN_BIND_FUNC(MemRef2DInt); + CREATE_VULKAN_BIND_FUNC(MemRef3DInt); if (!module.lookupSymbol(kInitVulkan)) { builder.create( diff --git a/mlir/test/mlir-vulkan-runner/addi.mlir b/mlir/test/mlir-vulkan-runner/addi.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-vulkan-runner/addi.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-vulkan-runner %s --shared-libs=%vulkan_wrapper_library_dir/libvulkan-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s + +// CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3] +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + gpu.module @kernels { + gpu.func @kernel_addi(%arg0 : memref<8xi32>, %arg1 : memref<8x8xi32>, %arg2 : memref<8x8x8xi32>) + kernel attributes { spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} { + %x = "gpu.block_id"() {dimension = "x"} : () -> index + %y = "gpu.block_id"() {dimension = "y"} : () -> index + %z = "gpu.block_id"() {dimension = "z"} : () -> index + %0 = load %arg0[%x] : memref<8xi32> + %1 = load %arg1[%y, %x] : memref<8x8xi32> + %2 = addi %0, %1 : i32 + store %2, %arg2[%z, %y, %x] : memref<8x8x8xi32> + gpu.return + } + } + + func @main() { + %arg0 = alloc() : memref<8xi32> + %arg1 = alloc() : memref<8x8xi32> + %arg2 = alloc() : memref<8x8x8xi32> + %value0 = constant 0 : i32 + %value1 = constant 1 : i32 + %value2 = constant 2 : i32 + %arg3 = memref_cast %arg0 : memref<8xi32> to memref + %arg4 = memref_cast %arg1 : memref<8x8xi32> to memref + %arg5 = memref_cast %arg2 : memref<8x8x8xi32> to memref + call @fillResource1DInt(%arg3, %value1) : (memref, i32) -> () + call @fillResource2DInt(%arg4, %value2) : (memref, i32) -> () + call @fillResource3DInt(%arg5, %value0) : (memref, i32) -> () + + %cst1 = constant 1 : index + %cst8 = constant 8 : index + "gpu.launch_func"(%cst8, %cst8, %cst8, %cst1, %cst1, %cst1, %arg0, %arg1, %arg2) { kernel = @kernels::@kernel_addi } + : (index, index, index, index, index, index, memref<8xi32>, memref<8x8xi32>, memref<8x8x8xi32>) -> () + %arg6 = memref_cast %arg5 : memref to memref<*xi32> + call @print_memref_i32(%arg6) : (memref<*xi32>) -> () + return + } + func @fillResource1DInt(%0 : memref, %1 : i32) + func @fillResource2DInt(%0 : memref, %1 : i32) + func @fillResource3DInt(%0 : memref, %1 : i32) + func @print_memref_i32(%ptr : memref<*xi32>) +} + diff --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp --- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp +++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp @@ -135,6 +135,41 @@ ->setResourceData(setIndex, bindIndex, memBuffer); } +/// Binds the given 1D int memref to the given descriptor set and descriptor +/// index. +void bindMemRef1DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex, + BindingIndex bindIndex, + MemRefDescriptor *ptr) { + VulkanHostMemoryBuffer memBuffer{ + ptr->allocated, static_cast(ptr->sizes[0] * sizeof(int32_t))}; + reinterpret_cast(vkRuntimeManager) + ->setResourceData(setIndex, bindIndex, memBuffer); +} + +/// Binds the given 2D int memref to the given descriptor set and descriptor +/// index. +void bindMemRef2DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex, + BindingIndex bindIndex, + MemRefDescriptor *ptr) { + VulkanHostMemoryBuffer memBuffer{ + ptr->allocated, + static_cast(ptr->sizes[0] * ptr->sizes[1] * sizeof(int32_t))}; + reinterpret_cast(vkRuntimeManager) + ->setResourceData(setIndex, bindIndex, memBuffer); +} + +/// Binds the given 3D int memref to the given descriptor set and descriptor +/// index. +void bindMemRef3DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex, + BindingIndex bindIndex, + MemRefDescriptor *ptr) { + VulkanHostMemoryBuffer memBuffer{ + ptr->allocated, static_cast(ptr->sizes[0] * ptr->sizes[1] * + ptr->sizes[2] * sizeof(int32_t))}; + reinterpret_cast(vkRuntimeManager) + ->setResourceData(setIndex, bindIndex, memBuffer); +} + /// Fills the given 1D float memref with the given float value. void _mlir_ciface_fillResource1DFloat(MemRefDescriptor *ptr, // NOLINT float value) { @@ -153,4 +188,23 @@ std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], value); } + +/// Fills the given 1D int memref with the given int value. +void _mlir_ciface_fillResource1DInt(MemRefDescriptor *ptr, // NOLINT + int32_t value) { + std::fill_n(ptr->allocated, ptr->sizes[0], value); +} + +/// Fills the given 2D int memref with the given int value. +void _mlir_ciface_fillResource2DInt(MemRefDescriptor *ptr, // NOLINT + int32_t value) { + std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); +} + +/// Fills the given 3D int memref with the given int value. +void _mlir_ciface_fillResource3DInt(MemRefDescriptor *ptr, // NOLINT + int32_t value) { + std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], + value); +} }