diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -328,7 +328,7 @@ std::string fnName = "bindMemRef" + std::to_string(i) + "D" + std::string(stringifyType(type)); if (type.isHalfTy()) - type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(&getContext())); + type = LLVM::LLVMType::getInt16Ty(&getContext()); if (!module.lookupSymbol(fnName)) { auto fnType = LLVM::LLVMType::getFunctionTy( getVoidType(), diff --git a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir --- a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir +++ b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir @@ -15,6 +15,8 @@ // CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> !llvm.void // CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> !llvm.void +// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr, !llvm.i32, !llvm.i32, !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>) + module attributes {gpu.container_module} { llvm.func @malloc(!llvm.i64) -> !llvm.ptr llvm.func @foo() {