diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -27,12 +27,6 @@ using namespace mlir; -static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; -static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat"; -static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat"; -static constexpr const char *kBindMemRef1DInt = "bindMemRef1DInt"; -static constexpr const char *kBindMemRef2DInt = "bindMemRef2DInt"; -static constexpr const char *kBindMemRef3DInt = "bindMemRef3DInt"; static constexpr const char *kCInterfaceVulkanLaunch = "_mlir_ciface_vulkanLaunch"; static constexpr const char *kDeinitVulkan = "deinitVulkan"; @@ -76,12 +70,6 @@ llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); - llvmMemRef1DFloat = getMemRefType(1, llvmFloatType); - llvmMemRef2DFloat = getMemRefType(2, llvmFloatType); - llvmMemRef3DFloat = getMemRefType(3, llvmFloatType); - llvmMemRef1DInt = getMemRefType(1, llvmInt32Type); - llvmMemRef2DInt = getMemRefType(2, llvmInt32Type); - llvmMemRef3DInt = getMemRefType(3, llvmInt32Type); } LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) { @@ -108,17 +96,10 @@ llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); } - LLVM::LLVMType getFloatType() { return llvmFloatType; } LLVM::LLVMType getVoidType() { return llvmVoidType; } LLVM::LLVMType getPointerType() { return llvmPointerType; } LLVM::LLVMType getInt32Type() { return llvmInt32Type; } LLVM::LLVMType getInt64Type() { return llvmInt64Type; } - LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } - LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; } - LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; } - LLVM::LLVMType getMemRef1DInt() { return llvmMemRef1DInt; } - LLVM::LLVMType getMemRef2DInt() { return llvmMemRef2DInt; } - LLVM::LLVMType getMemRef3DInt() { return llvmMemRef3DInt; } /// Creates a LLVM global for the given `name`. Value createEntryPointNameConstant(StringRef name, Location loc, @@ -160,8 +141,14 @@ StringRef stringifyType(LLVM::LLVMType type) { if (type.isFloatTy()) return "Float"; - if (type.isIntegerTy()) - return "Int"; + if (type.isHalfTy()) + return "Half"; + if (type.isIntegerTy(32)) + return "Int32"; + if (type.isIntegerTy(16)) + return "Int16"; + if (type.isIntegerTy(8)) + return "Int8"; llvm_unreachable("unsupported type"); } @@ -176,12 +163,6 @@ LLVM::LLVMType llvmPointerType; LLVM::LLVMType llvmInt32Type; LLVM::LLVMType llvmInt64Type; - LLVM::LLVMType llvmMemRef1DFloat; - LLVM::LLVMType llvmMemRef2DFloat; - LLVM::LLVMType llvmMemRef3DFloat; - LLVM::LLVMType llvmMemRef1DInt; - LLVM::LLVMType llvmMemRef2DInt; - LLVM::LLVMType llvmMemRef3DInt; // TODO: Use an associative array to support multiple vulkan launch calls. std::pair spirvAttributes; @@ -264,6 +245,14 @@ auto symbolName = llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); + // Special case for fp16 type. Since it is not a supported type in C we use + // int16_t and bitcast the descriptor. + if (type.isHalfTy()) { + auto memRefTy = + getMemRefType(rank, LLVM::LLVMType::getInt16Ty(llvmDialect)); + ptrToMemRefDescriptor = builder.create( + loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor); + } // Create call to `bindMemRef`. builder.create( loc, ArrayRef{getVoidType()}, @@ -338,24 +327,27 @@ /*isVarArg=*/false)); } -#define CREATE_VULKAN_BIND_FUNC(MemRefType) \ - if (!module.lookupSymbol(kBind##MemRefType)) { \ - builder.create( \ - loc, kBind##MemRefType, \ - LLVM::LLVMType::getFunctionTy(getVoidType(), \ - {getPointerType(), getInt32Type(), \ - getInt32Type(), \ - get##MemRefType().getPointerTo()}, \ - /*isVarArg=*/false)); \ + for (unsigned i = 1; i <= 3; i++) { + for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(llvmDialect), + LLVM::LLVMType::getInt32Ty(llvmDialect), + LLVM::LLVMType::getInt16Ty(llvmDialect), + LLVM::LLVMType::getInt8Ty(llvmDialect), + LLVM::LLVMType::getHalfTy(llvmDialect)}) { + std::string fnName = "bindMemRef" + std::to_string(i) + "D" + + std::string(stringifyType(type)); + if (type.isHalfTy()) + type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(llvmDialect)); + if (!module.lookupSymbol(fnName)) { + auto fnType = LLVM::LLVMType::getFunctionTy( + getVoidType(), + {getPointerType(), getInt32Type(), getInt32Type(), + getMemRefType(i, type).getPointerTo()}, + /*isVarArg=*/false); + builder.create(loc, fnName, fnType); + } + } } - CREATE_VULKAN_BIND_FUNC(MemRef1DFloat); - CREATE_VULKAN_BIND_FUNC(MemRef2DFloat); - CREATE_VULKAN_BIND_FUNC(MemRef3DFloat); - CREATE_VULKAN_BIND_FUNC(MemRef1DInt); - CREATE_VULKAN_BIND_FUNC(MemRef2DInt); - CREATE_VULKAN_BIND_FUNC(MemRef3DInt); - if (!module.lookupSymbol(kInitVulkan)) { builder.create( loc, kInitVulkan, diff --git a/mlir/test/mlir-vulkan-runner/addi8.mlir b/mlir/test/mlir-vulkan-runner/addi8.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-vulkan-runner/addi8.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-vulkan-runner %s --shared-libs=%vulkan_wrapper_library_dir/libvulkan-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s + +// CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3] +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + gpu.module @kernels { + gpu.func @kernel_addi(%arg0 : memref<8xi8>, %arg1 : memref<8x8xi8>, %arg2 : memref<8x8x8xi32>) + kernel attributes { spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} { + %x = "gpu.block_id"() {dimension = "x"} : () -> index + %y = "gpu.block_id"() {dimension = "y"} : () -> index + %z = "gpu.block_id"() {dimension = "z"} : () -> index + %0 = load %arg0[%x] : memref<8xi8> + %1 = load %arg1[%y, %x] : memref<8x8xi8> + %2 = addi %0, %1 : i8 + %3 = zexti %2 : i8 to i32 + store %3, %arg2[%z, %y, %x] : memref<8x8x8xi32> + gpu.return + } + } + + func @main() { + %arg0 = alloc() : memref<8xi8> + %arg1 = alloc() : memref<8x8xi8> + %arg2 = alloc() : memref<8x8x8xi32> + %value0 = constant 0 : i32 + %value1 = constant 1 : i8 + %value2 = constant 2 : i8 + %arg3 = memref_cast %arg0 : memref<8xi8> to memref + %arg4 = memref_cast %arg1 : memref<8x8xi8> to memref + %arg5 = memref_cast %arg2 : memref<8x8x8xi32> to memref + call @fillResource1DInt8(%arg3, %value1) : (memref, i8) -> () + call @fillResource2DInt8(%arg4, %value2) : (memref, i8) -> () + call @fillResource3DInt(%arg5, %value0) : (memref, i32) -> () + + %cst1 = constant 1 : index + %cst8 = constant 8 : index + "gpu.launch_func"(%cst8, %cst8, %cst8, %cst1, %cst1, %cst1, %arg0, %arg1, %arg2) { kernel = @kernels::@kernel_addi } + : (index, index, index, index, index, index, memref<8xi8>, memref<8x8xi8>, memref<8x8x8xi32>) -> () + %arg6 = memref_cast %arg5 : memref to memref<*xi32> + call @print_memref_i32(%arg6) : (memref<*xi32>) -> () + return + } + func @fillResource1DInt8(%0 : memref, %1 : i8) + func @fillResource2DInt8(%0 : memref, %1 : i8) + func @fillResource3DInt(%0 : memref, %1 : i32) + func @print_memref_i32(%ptr : memref<*xi32>) +} diff --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp --- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp +++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp @@ -71,6 +71,17 @@ int64_t strides[N]; }; +template +void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex, + BindingIndex bindIndex, MemRefDescriptor *ptr) { + uint32_t size = sizeof(T); + for (unsigned i = 0; i < S; i++) + size *= ptr->sizes[i]; + VulkanHostMemoryBuffer memBuffer{ptr->allocated, size}; + reinterpret_cast(vkRuntimeManager) + ->setResourceData(setIndex, bindIndex, memBuffer); +} + extern "C" { /// Initializes `VulkanRuntimeManager` and returns a pointer to it. void *initVulkan() { return new VulkanRuntimeManager(); } @@ -100,75 +111,30 @@ ->setShaderModule(shader, size); } -/// Binds the given 1D float memref to the given descriptor set and descriptor -/// index. -void bindMemRef1DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex, - BindingIndex bindIndex, - MemRefDescriptor *ptr) { - VulkanHostMemoryBuffer memBuffer{ - ptr->allocated, static_cast(ptr->sizes[0] * sizeof(float))}; - reinterpret_cast(vkRuntimeManager) - ->setResourceData(setIndex, bindIndex, memBuffer); -} - -/// Binds the given 2D float memref to the given descriptor set and descriptor +/// Binds the given memref to the given descriptor set and descriptor /// index. -void bindMemRef2DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex, - BindingIndex bindIndex, - MemRefDescriptor *ptr) { - VulkanHostMemoryBuffer memBuffer{ - ptr->allocated, - static_cast(ptr->sizes[0] * ptr->sizes[1] * sizeof(float))}; - reinterpret_cast(vkRuntimeManager) - ->setResourceData(setIndex, bindIndex, memBuffer); -} - -/// Binds the given 3D float memref to the given descriptor set and descriptor -/// index. -void bindMemRef3DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex, - BindingIndex bindIndex, - MemRefDescriptor *ptr) { - VulkanHostMemoryBuffer memBuffer{ - ptr->allocated, static_cast(ptr->sizes[0] * ptr->sizes[1] * - ptr->sizes[2] * sizeof(float))}; - reinterpret_cast(vkRuntimeManager) - ->setResourceData(setIndex, bindIndex, memBuffer); -} +#define DECLARE_BIND_MEMREF(size, type, typeName) \ + void bindMemRef##size##D##typeName( \ + void *vkRuntimeManager, DescriptorSetIndex setIndex, \ + BindingIndex bindIndex, MemRefDescriptor *ptr) { \ + bindMemRef(vkRuntimeManager, setIndex, bindIndex, ptr); \ + } -/// Binds the given 1D int memref to the given descriptor set and descriptor -/// index. -void bindMemRef1DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex, - BindingIndex bindIndex, - MemRefDescriptor *ptr) { - VulkanHostMemoryBuffer memBuffer{ - ptr->allocated, static_cast(ptr->sizes[0] * sizeof(int32_t))}; - reinterpret_cast(vkRuntimeManager) - ->setResourceData(setIndex, bindIndex, memBuffer); -} - -/// Binds the given 2D int memref to the given descriptor set and descriptor -/// index. -void bindMemRef2DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex, - BindingIndex bindIndex, - MemRefDescriptor *ptr) { - VulkanHostMemoryBuffer memBuffer{ - ptr->allocated, - static_cast(ptr->sizes[0] * ptr->sizes[1] * sizeof(int32_t))}; - reinterpret_cast(vkRuntimeManager) - ->setResourceData(setIndex, bindIndex, memBuffer); -} - -/// Binds the given 3D int memref to the given descriptor set and descriptor -/// index. -void bindMemRef3DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex, - BindingIndex bindIndex, - MemRefDescriptor *ptr) { - VulkanHostMemoryBuffer memBuffer{ - ptr->allocated, static_cast(ptr->sizes[0] * ptr->sizes[1] * - ptr->sizes[2] * sizeof(int32_t))}; - reinterpret_cast(vkRuntimeManager) - ->setResourceData(setIndex, bindIndex, memBuffer); -} +DECLARE_BIND_MEMREF(1, float, Float) +DECLARE_BIND_MEMREF(2, float, Float) +DECLARE_BIND_MEMREF(3, float, Float) +DECLARE_BIND_MEMREF(1, int32_t, Int32) +DECLARE_BIND_MEMREF(2, int32_t, Int32) +DECLARE_BIND_MEMREF(3, int32_t, Int32) +DECLARE_BIND_MEMREF(1, int16_t, Int16) +DECLARE_BIND_MEMREF(2, int16_t, Int16) +DECLARE_BIND_MEMREF(3, int16_t, Int16) +DECLARE_BIND_MEMREF(1, int8_t, Int8) +DECLARE_BIND_MEMREF(2, int8_t, Int8) +DECLARE_BIND_MEMREF(3, int8_t, Int8) +DECLARE_BIND_MEMREF(1, int16_t, Half) +DECLARE_BIND_MEMREF(2, int16_t, Half) +DECLARE_BIND_MEMREF(3, int16_t, Half) /// Fills the given 1D float memref with the given float value. void _mlir_ciface_fillResource1DFloat(MemRefDescriptor *ptr, // NOLINT @@ -207,4 +173,23 @@ std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], value); } + +/// Fills the given 1D int memref with the given int8 value. +void _mlir_ciface_fillResource1DInt8(MemRefDescriptor *ptr, // NOLINT + int8_t value) { + std::fill_n(ptr->allocated, ptr->sizes[0], value); } + +/// Fills the given 2D int memref with the given int8 value. +void _mlir_ciface_fillResource2DInt8(MemRefDescriptor *ptr, // NOLINT + int8_t value) { + std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); +} + +/// Fills the given 3D int memref with the given int8 value. +void _mlir_ciface_fillResource3DInt8(MemRefDescriptor *ptr, // NOLINT + int8_t value) { + std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], + value); +} +}