diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -279,9 +279,9 @@ for (auto en : llvm::enumerate(arguments)) { auto index = builder.create( loc, llvmInt32Type, builder.getI32IntegerAttr(en.index())); - auto fieldPtr = - builder.create(loc, structType.getPointerTo(), structPtr, - ArrayRef{zero, index.getResult()}); + auto fieldPtr = builder.create( + loc, argumentTypes[en.index()].getPointerTo(), structPtr, + ArrayRef{zero, index.getResult()}); builder.create(loc, en.value(), fieldPtr); auto elementPtr = builder.create(loc, llvmPointerPointerType, arrayPtr, index.getResult());