diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp --- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp @@ -99,6 +99,8 @@ #include "llvm/IR/Type.h" #include "llvm/Pass.h" +#define DEBUG_TYPE "nvptx-lower-args" + using namespace llvm; namespace llvm { @@ -166,40 +168,60 @@ Value *NewParam; }; SmallVector ItemsToConvert = {{I, Param}}; - SmallVector GEPsToDelete; - while (!ItemsToConvert.empty()) { - IP I = ItemsToConvert.pop_back_val(); - if (auto *LI = dyn_cast(I.OldInstruction)) + SmallVector InstructionsToDelete; + + auto CloneInstInParamAS = [](const IP &I) -> Value * { + if (auto *LI = dyn_cast(I.OldInstruction)) { LI->setOperand(0, I.NewParam); - else if (auto *GEP = dyn_cast(I.OldInstruction)) { + return LI; + } + if (auto *GEP = dyn_cast(I.OldInstruction)) { SmallVector Indices(GEP->indices()); auto *NewGEP = GetElementPtrInst::Create(nullptr, I.NewParam, Indices, GEP->getName(), GEP); NewGEP->setIsInBounds(GEP->isInBounds()); - llvm::for_each(GEP->users(), [NewGEP, &ItemsToConvert](Value *V) { - ItemsToConvert.push_back({cast(V), NewGEP}); - }); - GEPsToDelete.push_back(GEP); - } else - llvm_unreachable("Only Load and GEP can be converted to param AS."); - } - llvm::for_each(GEPsToDelete, - [](GetElementPtrInst *GEP) { GEP->eraseFromParent(); }); -} + return NewGEP; + } + if (auto *BC = dyn_cast(I.OldInstruction)) { + auto NewBCType = BC->getType()->getPointerElementType()->getPointerTo( + ADDRESS_SPACE_PARAM); + return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType, + BC->getName(), BC); + } + if (auto *ASC = dyn_cast(I.OldInstruction)) { + assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM); + // Just pass through the argument, the old ASC is no longer needed. + return I.NewParam; + } + llvm_unreachable("Unsupported instruction"); + }; + + while (!ItemsToConvert.empty()) { + IP I = ItemsToConvert.pop_back_val(); + Value *NewInst = CloneInstInParamAS(I); -static bool isALoadChain(Value *Start) { - SmallVector ValuesToCheck = {Start}; - while (!ValuesToCheck.empty()) { - Value *V = ValuesToCheck.pop_back_val(); - Instruction *I = dyn_cast(V); - if (!I) - return false; - if (isa(I)) - ValuesToCheck.append(I->user_begin(), I->user_end()); - else if (!isa(I)) - return false; + if (NewInst && NewInst != I.OldInstruction) { + // We've created a new instruction. Queue users of the old instruction to + // be converted and the instruction itself to be deleted. We can't delete + // the old instruction yet, because it's still in use by a load somewhere. + llvm::for_each( + I.OldInstruction->users(), [NewInst, &ItemsToConvert](Value *V) { + ItemsToConvert.push_back({cast(V), NewInst}); + }); + + InstructionsToDelete.push_back(I.OldInstruction); + } } - return true; + + // Now we know that all argument loads are using addresses in parameter space + // and we can finally remove the old instructions in generic AS. Instructions + // scheduled for removal should be processed in reverse order so the ones + // closest to the load are deleted first. Otherwise they may still be in use. + // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will + // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by + // the BitCast. + llvm::for_each(reverse(InstructionsToDelete), + [](Instruction *I) { I->eraseFromParent(); }); } void NVPTXLowerArgs::handleByValParam(Argument *Arg) { @@ -211,9 +233,35 @@ Type *StructType = PType->getElementType(); - if (llvm::all_of(Arg->users(), isALoadChain)) { - // Replace all loads with the loads in param AS. This allows loading the Arg - // directly from parameter AS, without making a temporary copy. + auto IsALoadChain = [Arg](Value *Start) { + SmallVector ValuesToCheck = {Start}; + auto IsALoadChainInstr = [](Value *V) -> bool { + if (isa(V) || isa(V) || isa(V)) + return true; + // ASC to param space are OK, too -- we'll just strip them. + if (auto *ASC = dyn_cast(V)) { + if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM) + return true; + } + return false; + }; + + while (!ValuesToCheck.empty()) { + Value *V = ValuesToCheck.pop_back_val(); + if (!IsALoadChainInstr(V)) { + LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V + << "\n"); + return false; + } + if (!isa(V)) + llvm::append_range(ValuesToCheck, V->users()); + } + return true; + }; + + if (llvm::all_of(Arg->users(), IsALoadChain)) { + // Convert all loads and intermediate operations to use parameter AS and + // skip creation of a local copy of the argument. SmallVector UsersToUpdate(Arg->users()); Value *ArgInParamAS = new AddrSpaceCastInst( Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), @@ -221,6 +269,7 @@ llvm::for_each(UsersToUpdate, [ArgInParamAS](Value *V) { convertToParamAS(V, ArgInParamAS); }); + LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n"); return; } @@ -297,6 +346,7 @@ } } + LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n"); for (Argument &Arg : F.args()) { if (Arg.getType()->isPointerTy()) { if (Arg.hasByValAttr()) @@ -310,6 +360,7 @@ // Device functions only need to copy byval args into local memory. bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) { + LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n"); for (Argument &Arg : F.args()) if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) handleByValParam(&Arg); diff --git a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll --- a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll +++ b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll @@ -6,6 +6,7 @@ ; // Verify that load with static offset into parameter is done directly. ; CHECK-LABEL: .visible .entry static_offset +; CHECK-NOT: .local ; CHECK: ld.param.u64 [[result_addr:%rd[0-9]+]], [{{.*}}_param_0] ; CHECK: mov.b64 %[[param_addr:rd[0-9]+]], {{.*}}_param_1 ; CHECK: mov.u64 %[[param_addr1:rd[0-9]+]], %[[param_addr]] @@ -30,6 +31,7 @@ ; // Verify that load with dynamic offset into parameter is also done directly. ; CHECK-LABEL: .visible .entry dynamic_offset +; CHECK-NOT: .local ; CHECK: ld.param.u64 [[result_addr:%rd[0-9]+]], [{{.*}}_param_0] ; CHECK: mov.b64 %[[param_addr:rd[0-9]+]], {{.*}}_param_1 ; CHECK: mov.u64 %[[param_addr1:rd[0-9]+]], %[[param_addr]] @@ -48,6 +50,48 @@ ret void } +; Same as above, but with a bitcast present in the chain +; CHECK-LABEL:.visible .entry gep_bitcast +; CHECK-NOT: .local +; CHECK-DAG: ld.param.u64 [[out:%rd[0-9]+]], [gep_bitcast_param_0] +; CHECK-DAG: mov.b64 {{%rd[0-9]+}}, gep_bitcast_param_1 +; CHECK-DAG: ld.param.u32 {{%r[0-9]+}}, [gep_bitcast_param_2] +; CHECK: ld.param.u8 [[value:%rs[0-9]+]], [{{%rd[0-9]+}}] +; CHECK: st.global.u8 [{{%rd[0-9]+}}], [[value]]; +; +; Function Attrs: nofree norecurse nounwind willreturn mustprogress +define dso_local void @gep_bitcast(i8* nocapture %out, %struct.ham* nocapture readonly byval(%struct.ham) align 4 %in, i32 %n) local_unnamed_addr #0 { +bb: + %n64 = sext i32 %n to i64 + %gep = getelementptr inbounds %struct.ham, %struct.ham* %in, i64 0, i32 0, i64 %n64 + %bc = bitcast i32* %gep to i8* + %load = load i8, i8* %bc, align 4 + store i8 %load, i8* %out, align 4 + ret void +} + +; Same as above, but with an ASC(101) present in the chain +; CHECK-LABEL:.visible .entry gep_bitcast_asc +; CHECK-NOT: .local +; CHECK-DAG: ld.param.u64 [[out:%rd[0-9]+]], [gep_bitcast_asc_param_0] +; CHECK-DAG: mov.b64 {{%rd[0-9]+}}, gep_bitcast_asc_param_1 +; CHECK-DAG: ld.param.u32 {{%r[0-9]+}}, [gep_bitcast_asc_param_2] +; CHECK: ld.param.u8 [[value:%rs[0-9]+]], [{{%rd[0-9]+}}] +; CHECK: st.global.u8 [{{%rd[0-9]+}}], [[value]]; +; +; Function Attrs: nofree norecurse nounwind willreturn mustprogress +define dso_local void @gep_bitcast_asc(i8* nocapture %out, %struct.ham* nocapture readonly byval(%struct.ham) align 4 %in, i32 %n) local_unnamed_addr #0 { +bb: + %n64 = sext i32 %n to i64 + %gep = getelementptr inbounds %struct.ham, %struct.ham* %in, i64 0, i32 0, i64 %n64 + %bc = bitcast i32* %gep to i8* + %asc = addrspacecast i8* %bc to i8 addrspace(101)* + %load = load i8, i8 addrspace(101)* %asc, align 4 + store i8 %load, i8* %out, align 4 + ret void +} + + ; Verify that if the pointer escapes, then we do fall back onto using a temp copy. ; CHECK-LABEL: .visible .entry pointer_escapes ; CHECK: .local .align 8 .b8 __local_depot{{.*}} @@ -82,7 +126,7 @@ !llvm.module.flags = !{!0, !1, !2} -!nvvm.annotations = !{!3, !4, !5} +!nvvm.annotations = !{!3, !4, !5, !6, !7} !0 = !{i32 2, !"SDK Version", [2 x i32] [i32 9, i32 1]} !1 = !{i32 1, !"wchar_size", i32 4} @@ -90,3 +134,5 @@ !3 = !{void (i32*, %struct.ham*, i32)* @static_offset, !"kernel", i32 1} !4 = !{void (i32*, %struct.ham*, i32)* @dynamic_offset, !"kernel", i32 1} !5 = !{void (i32*, %struct.ham*, i32)* @pointer_escapes, !"kernel", i32 1} +!6 = !{void (i8*, %struct.ham*, i32)* @gep_bitcast, !"kernel", i32 1} +!7 = !{void (i8*, %struct.ham*, i32)* @gep_bitcast_asc, !"kernel", i32 1}