diff --git a/llvm/lib/Target/AMDGPU/AMDGPUSetWavePriority.cpp b/llvm/lib/Target/AMDGPU/AMDGPUSetWavePriority.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUSetWavePriority.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUSetWavePriority.cpp @@ -26,11 +26,18 @@ #define DEBUG_TYPE "amdgpu-set-wave-priority" +static cl::opt VALUInstsThreshold( + "amdgpu-set-wave-priority-valu-insts-threshold", + cl::desc("VALU instruction count threshold for adjusting wave priority"), + cl::init(100), cl::Hidden); + namespace { struct MBBInfo { MBBInfo() = default; bool MayReachVMEMLoad = false; + unsigned NumVALUInsts = 0; + MachineInstr *LastVMEMLoad = nullptr; }; using MBBInfoSet = DenseMap; @@ -82,10 +89,6 @@ return true; } -static bool isVMEMLoad(const MachineInstr &MI) { - return SIInstrInfo::isVMEM(MI) && MI.mayLoad(); -} - bool AMDGPUSetWavePriority::runOnMachineFunction(MachineFunction &MF) { const unsigned HighPriority = 3; const unsigned LowPriority = 0; @@ -97,14 +100,38 @@ const GCNSubtarget &ST = MF.getSubtarget(); TII = ST.getInstrInfo(); - MBBInfoSet MBBInfos; + // Find VMEM loads that may be executed before the number of executed + // VALU instructions hits the specified threshold. We currently assume + // that backedges/loops, branch probabilities and other details can be + // ignored. SmallVector Worklist; - for (MachineBasicBlock &MBB : MF) { - if (any_of(MBB, isVMEMLoad)) - Worklist.push_back(&MBB); + MBBInfoSet MBBInfos; + ReversePostOrderTraversal RPOT(&MF); + for (MachineBasicBlock *MBB : RPOT) { + unsigned NumVALUInsts = 0; + for (const MachineBasicBlock *Pred : MBB->predecessors()) + NumVALUInsts = std::max(NumVALUInsts, MBBInfos[Pred].NumVALUInsts); + + MachineInstr *LastVMEMLoad = nullptr; + for (MachineInstr &MI : *MBB) { + if (NumVALUInsts >= VALUInstsThreshold) + break; + if (SIInstrInfo::isVMEM(MI) && MI.mayLoad()) + LastVMEMLoad = &MI; + if (SIInstrInfo::isVALU(MI)) + ++NumVALUInsts; + } + + MBBInfo &Info = MBBInfos[MBB]; + Info.NumVALUInsts = NumVALUInsts; + + if (LastVMEMLoad) { + Info.LastVMEMLoad = LastVMEMLoad; + Worklist.push_back(MBB); + } } - // Mark blocks from which control may reach VMEM loads. + // Mark blocks from which control may reach the VMEM loads. while (!Worklist.empty()) { const MachineBasicBlock *MBB = Worklist.pop_back_val(); MBBInfo &Info = MBBInfos[MBB]; @@ -125,7 +152,7 @@ Entry.insert(I, BuildSetprioMI(MF, HighPriority)); // Lower the priority on edges where control leaves blocks from which - // VMEM loads are reachable. + // the VMEM loads are reachable. SmallSet PriorityLoweringBlocks; for (MachineBasicBlock &MBB : MF) { if (MBBInfos[&MBB].MayReachVMEMLoad) { @@ -152,14 +179,13 @@ } for (MachineBasicBlock *MBB : PriorityLoweringBlocks) { - MachineBasicBlock::iterator I = MBB->end(), B = MBB->begin(); - while (I != B) { - if (isVMEMLoad(*--I)) { - ++I; - break; - } + MachineInstr *Setprio = BuildSetprioMI(MF, LowPriority); + if (MachineInstr *LastVMEMLoad = MBBInfos[MBB].LastVMEMLoad) { + MBB->insertAfter(MachineBasicBlock::instr_iterator(LastVMEMLoad), + Setprio); + continue; } - MBB->insert(I, BuildSetprioMI(MF, LowPriority)); + MBB->insert(MBB->begin(), Setprio); } return true; diff --git a/llvm/test/CodeGen/AMDGPU/set-wave-priority.ll b/llvm/test/CodeGen/AMDGPU/set-wave-priority.ll --- a/llvm/test/CodeGen/AMDGPU/set-wave-priority.ll +++ b/llvm/test/CodeGen/AMDGPU/set-wave-priority.ll @@ -1,5 +1,5 @@ -; RUN: llc -mtriple=amdgcn -amdgpu-set-wave-priority=true -o - %s | \ -; RUN: FileCheck %s +; RUN: llc -mtriple=amdgcn -amdgpu-set-wave-priority=true \ +; RUN: -amdgpu-set-wave-priority-valu-insts-threshold=4 -o - %s | FileCheck %s ; CHECK-LABEL: no_setprio: ; CHECK-NOT: s_setprio @@ -150,4 +150,33 @@ ret <2 x float> %sum } +; CHECK-LABEL: valu_insts_threshold: +; CHECK: s_setprio 3 +; CHECK: buffer_load_dwordx2 +; CHECK-NEXT: s_setprio 0 +; CHECK-COUNT-4: v_add_f32_e32 +; CHECK: s_cbranch_scc0 [[A:.*]] +; CHECK: {{.*}}: ; %b +; CHECK-NEXT: buffer_load_dwordx2 +; CHECK: s_branch [[END:.*]] +; CHECK: [[A]]: ; %a +; CHECK: s_branch [[END]] +; CHECK: [[END]]: +define amdgpu_ps <2 x float> @valu_insts_threshold(<4 x i32> inreg %p, i32 inreg %i) { + %v = call <2 x float> @llvm.amdgcn.struct.buffer.load.v2f32(<4 x i32> %p, i32 0, i32 0, i32 0, i32 0) + %add = fadd <2 x float> %v, %v + %add2 = fadd <2 x float> %add, %add + + %cond = icmp eq i32 %i, 0 + br i1 %cond, label %a, label %b + +a: + ret <2 x float> %add2 + +b: + %v2 = call <2 x float> @llvm.amdgcn.struct.buffer.load.v2f32(<4 x i32> %p, i32 0, i32 1, i32 0, i32 0) + %sub = fsub <2 x float> %add2, %v2 + ret <2 x float> %sub +} + declare <2 x float> @llvm.amdgcn.struct.buffer.load.v2f32(<4 x i32>, i32, i32, i32, i32) nounwind