diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp --- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp @@ -423,17 +423,27 @@ // ============================================================================= bool NVPTXLowerArgs::runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F) { + auto handleIntToPtr = [this](Value &V) { + if (llvm::all_of(V.users(), + [&](User *U) { return isa(U); })) { + SmallVector UsersToUpdate(V.users()); + llvm::for_each(UsersToUpdate, [&](User *U) { markPointerAsGlobal(U); }); + } + }; if (TM.getDrvInterface() == NVPTX::CUDA) { // Mark pointers in byval structs as global. for (auto &B : F) { for (auto &I : B) { if (LoadInst *LI = dyn_cast(&I)) { - if (LI->getType()->isPointerTy()) { + if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) { Value *UO = getUnderlyingObject(LI->getPointerOperand()); if (Argument *Arg = dyn_cast(UO)) { if (Arg->hasByValAttr()) { // LI is a load from a pointer within a byval kernel parameter. - markPointerAsGlobal(LI); + if (LI->getType()->isPointerTy()) + markPointerAsGlobal(LI); + else + handleIntToPtr(*LI); } } } @@ -449,6 +459,9 @@ handleByValParam(TM, &Arg); else if (TM.getDrvInterface() == NVPTX::CUDA) markPointerAsGlobal(&Arg); + } else if (Arg.getType()->isIntegerTy() && + TM.getDrvInterface() == NVPTX::CUDA) { + handleIntToPtr(Arg); } } return true; diff --git a/llvm/test/CodeGen/NVPTX/lower-args.ll b/llvm/test/CodeGen/NVPTX/lower-args.ll --- a/llvm/test/CodeGen/NVPTX/lower-args.ll +++ b/llvm/test/CodeGen/NVPTX/lower-args.ll @@ -67,9 +67,28 @@ ret void } +; COMMON-LABEL: ptr_as_int +define void @ptr_as_int(i64 noundef %0, i32 noundef %1) { + %3 = inttoptr i64 %0 to ptr + store i32 %1, ptr %3, align 4 + ret void +} + +%struct.S = type { i64 } + +; COMMON-LABEL: ptr_as_int_aggr +define void @ptr_as_int_aggr(ptr nocapture noundef readonly byval(%struct.S) align 8 %0, i32 noundef %1) { + %3 = load i64, ptr %0, align 8 + %4 = inttoptr i64 %3 to ptr + store i32 %1, ptr %4, align 4 + ret void +} + ; Function Attrs: convergent nounwind declare dso_local ptr @escape(ptr) local_unnamed_addr -!nvvm.annotations = !{!0, !1} +!nvvm.annotations = !{!0, !1, !2, !3} !0 = !{ptr @ptr_generic, !"kernel", i32 1} !1 = !{ptr @ptr_nongeneric, !"kernel", i32 1} +!2 = !{ptr @ptr_as_int, !"kernel", i32 1} +!3 = !{ptr @ptr_as_int_aggr, !"kernel", i32 1}