diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -680,8 +680,7 @@ for (const auto &en : llvm::enumerate(arguments)) { Value fieldPtr = builder.create( loc, getTypeConverter()->getPointerType(argumentTypes[en.index()]), - argumentTypes[en.index()], structPtr, - ArrayRef{0, en.index()}); + structType, structPtr, ArrayRef{0, en.index()}); builder.create(loc, en.value(), fieldPtr); auto elementPtr = builder.create( loc, llvmPointerPointerType, llvmPointerType, arrayPtr, diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir --- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir @@ -40,9 +40,18 @@ // CHECK: [[STREAM:%.*]] = llvm.call @mgpuStreamCreate + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) + // CHECK: %[[MEMREF:.*]] = llvm.alloca %[[ONE]] x !llvm.struct[[STRUCT_BODY:<.*>]] // CHECK: [[NUM_PARAMS:%.*]] = llvm.mlir.constant(6 : i32) : i32 // CHECK-NEXT: [[PARAMS:%.*]] = llvm.alloca [[NUM_PARAMS]] x !llvm.ptr + // CHECK: llvm.getelementptr %[[MEMREF]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]] + // CHECK: llvm.getelementptr %[[MEMREF]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]] + // CHECK: llvm.getelementptr %[[MEMREF]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]] + // CHECK: llvm.getelementptr %[[MEMREF]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]] + // CHECK: llvm.getelementptr %[[MEMREF]][0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]] + // CHECK: llvm.getelementptr %[[MEMREF]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]] + // CHECK: [[EXTRA_PARAMS:%.*]] = llvm.mlir.null : !llvm.ptr // CHECK: llvm.call @mgpuLaunchKernel([[FUNC]], [[C8]], [[C8]], [[C8]],