diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp --- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp @@ -159,6 +159,25 @@ assert(PType && "Expecting pointer type in handleByValParam"); Type *StructType = PType->getElementType(); + + if (Arg->onlyReadsMemory()) { + // Once there's no store to that byval argument, there's no need to + // generate an `alloca`. Cast it into the parameter space and cast it back + // to the generic space so that the address space inference could infer the + // correct address space. + Value *ArgInParam = new AddrSpaceCastInst( + Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), + FirstInst); + Value *ArgInGeneric = new AddrSpaceCastInst( + ArgInParam, PType, Arg->getName() + ".addrspacecast", FirstInst); + for (auto &U : Arg->uses()) { + if (U.getUser() == ArgInParam) + continue; + U.getUser()->setOperand(U.getOperandNo(), ArgInGeneric); + } + return; + } + const DataLayout &DL = Func->getParent()->getDataLayout(); unsigned AS = DL.getAllocaAddrSpace(); AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst); diff --git a/llvm/test/CodeGen/NVPTX/lower-args.ll b/llvm/test/CodeGen/NVPTX/lower-args.ll --- a/llvm/test/CodeGen/NVPTX/lower-args.ll +++ b/llvm/test/CodeGen/NVPTX/lower-args.ll @@ -10,6 +10,27 @@ ; Check that nvptx-lower-args preserves arg alignment define void @load_alignment(%class.outer* nocapture readonly byval(%class.outer) align 8 %arg) { entry: +; IR-LABEL: @load_alignment +; IR: addrspacecast %class.outer* %arg to %class.outer addrspace(101)* +; IR-NEXT: addrspacecast %class.outer addrspace(101)* %arg1 to %class.outer* +; PTX: ld.param.u64 +; PTX-NOT: ld.param.u8 + %arg.idx = getelementptr %class.outer, %class.outer* %arg, i64 0, i32 0, i32 0 + %arg.idx.val = load i32*, i32** %arg.idx, align 8 + %arg.idx1 = getelementptr %class.outer, %class.outer* %arg, i64 0, i32 0, i32 1 + %arg.idx1.val = load i32*, i32** %arg.idx1, align 8 + %arg.idx2 = getelementptr %class.outer, %class.outer* %arg, i64 0, i32 1 + %arg.idx2.val = load i32, i32* %arg.idx2, align 8 + %arg.idx.val.val = load i32, i32* %arg.idx.val, align 4 + %add.i = add nsw i32 %arg.idx.val.val, %arg.idx2.val + store i32 %add.i, i32* %arg.idx1.val, align 4 + ret void +} + +; Check that nvptx-lower-args preserves arg alignment +define void @load_alignment_without_readonly(%class.outer* nocapture byval(%class.outer) align 8 %arg) { +entry: +; IR-LABEL: @load_alignment_without_readonly ; IR: load %class.outer, %class.outer addrspace(101)* ; IR-SAME: align 8 ; PTX: ld.param.u64