Index: lib/Target/AMDGPU/SIInsertWaitcnts.cpp =================================================================== --- lib/Target/AMDGPU/SIInsertWaitcnts.cpp +++ lib/Target/AMDGPU/SIInsertWaitcnts.cpp @@ -367,7 +367,7 @@ DenseMap> BlockWaitcntBracketsMap; - DenseSet BlockWaitcntProcessedSet; + std::vector BlockWaitcntProcessedSet; DenseMap> LoopWaitcntDataMap; @@ -403,7 +403,8 @@ void updateEventWaitCntAfter(MachineInstr &Inst, BlockWaitcntBrackets *ScoreBrackets); void mergeInputScoreBrackets(MachineBasicBlock &Block); - MachineBasicBlock *loopBottom(const MachineLoop *Loop); + bool isLoopBottom(const MachineLoop *Loop, const MachineBasicBlock *Block); + unsigned countNumBottomBlocks(const MachineLoop *Loop); void insertWaitcntInBlock(MachineFunction &MF, MachineBasicBlock &Block); void insertWaitcntBeforeCF(MachineBasicBlock &Block, MachineInstr *Inst); bool isWaitcntStronger(unsigned LHS, unsigned RHS); @@ -1552,15 +1553,31 @@ } } -/// Return the "bottom" block of a loop. This differs from -/// MachineLoop::getBottomBlock in that it works even if the loop is -/// discontiguous. -MachineBasicBlock *SIInsertWaitcnts::loopBottom(const MachineLoop *Loop) { +/// Return TRUE if the given basic block is a "bottom" block of a loop. This +/// differs from MachineLoop::getBottomBlock in that it works even if the loop +/// is discontiguous. This also handles multiple back-edges for the same +/// "header" block of a loop. +bool SIInsertWaitcnts::isLoopBottom(const MachineLoop *Loop, + const MachineBasicBlock *Block) { + MachineBasicBlock *Bottom = Loop->getHeader(); + for (MachineBasicBlock *MBB : Loop->blocks()) { + if (MBB->getNumber() > Bottom->getNumber() && + MBB->isSuccessor(Loop->getHeader())) { + if (MBB == Block) + return true; + } + } + return false; +} + +/// Count the number of "bottom" basic blocks of a loop. +unsigned SIInsertWaitcnts::countNumBottomBlocks(const MachineLoop *Loop) { + unsigned Count = 0; MachineBasicBlock *Bottom = Loop->getHeader(); for (MachineBasicBlock *MBB : Loop->blocks()) - if (MBB->getNumber() > Bottom->getNumber()) - Bottom = MBB; - return Bottom; + if (MBB->isSuccessor(Loop->getHeader())) + Count++; + return Count; } // Generate s_waitcnt instructions where needed. @@ -1669,7 +1686,7 @@ // Check if we need to force convergence at loop footer. MachineLoop *ContainingLoop = MLI->getLoopFor(&Block); - if (ContainingLoop && loopBottom(ContainingLoop) == &Block) { + if (ContainingLoop && isLoopBottom(ContainingLoop, &Block)) { LoopWaitcntData *WaitcntData = LoopWaitcntDataMap[ContainingLoop].get(); WaitcntData->print(); DEBUG(dbgs() << '\n';); @@ -1783,21 +1800,30 @@ // If we are walking into the block from before the loop, then guarantee // at least 1 re-walk over the loop to propagate the information, even if // no S_WAITCNT instructions were generated. - if (ContainingLoop && ContainingLoop->getHeader() == &MBB && J < I && - (!BlockWaitcntProcessedSet.count(&MBB))) { - BlockWaitcntBracketsMap[&MBB]->setRevisitLoop(true); - DEBUG(dbgs() << "set-revisit: Block" - << ContainingLoop->getHeader()->getNumber() << '\n';); + if (ContainingLoop && ContainingLoop->getHeader() == &MBB) { + unsigned Count = countNumBottomBlocks(ContainingLoop); + + // If the loop has multiple back-edges, and so more than one "bottom" + // basic block, we have to guarantee a re-walk over every blocks. + if ((std::count(BlockWaitcntProcessedSet.begin(), + BlockWaitcntProcessedSet.end(), &MBB) < Count)) { + BlockWaitcntBracketsMap[&MBB]->setRevisitLoop(true); + DEBUG(dbgs() << "set-revisit: Block" + << ContainingLoop->getHeader()->getNumber() << '\n';); + } } // Walk over the instructions. insertWaitcntInBlock(MF, MBB); // Flag that waitcnts have been processed at least once. - BlockWaitcntProcessedSet.insert(&MBB); + BlockWaitcntProcessedSet.push_back(&MBB); - // See if we want to revisit the loop. - if (ContainingLoop && loopBottom(ContainingLoop) == &MBB) { + // See if we want to revisit the loop. If a loop has multiple back-edges, + // we shouldn't revisit the same "bottom" basic block. + if (ContainingLoop && isLoopBottom(ContainingLoop, &MBB) && + std::count(BlockWaitcntProcessedSet.begin(), + BlockWaitcntProcessedSet.end(), &MBB) == 1) { MachineBasicBlock *EntryBB = ContainingLoop->getHeader(); BlockWaitcntBrackets *EntrySB = BlockWaitcntBracketsMap[EntryBB].get(); if (EntrySB && EntrySB->getRevisitLoop()) { Index: test/CodeGen/AMDGPU/waitcnt-looptest.ll =================================================================== --- test/CodeGen/AMDGPU/waitcnt-looptest.ll +++ test/CodeGen/AMDGPU/waitcnt-looptest.ll @@ -144,3 +144,74 @@ attributes #0 = { "target-cpu"="fiji" "target-features"="-flat-for-global" } attributes #1 = { nounwind readnone speculatable } + +; Check that the waitcnt insertion algorithm correctly propagates wait counts +; from bottom loop to the loop header. + +; GCN-LABEL: {{^}}testLoopBottom +; GCN: BB1_2: +; GCN: s_waitcnt vmcnt(0) + +define amdgpu_kernel void @testLoopBottom([0 x i8] addrspace(4)* inreg noalias dereferenceable(18446744073709551615) %arg, [0 x i8] addrspace(4)* inreg noalias dereferenceable(18446744073709551615) %arg1, i32 inreg %arg2) { +bb: + %tmp = getelementptr [0 x i8], [0 x i8] addrspace(4)* %arg1, i64 0, i64 48 + %tmp3 = bitcast i8 addrspace(4)* %tmp to <4 x i32> addrspace(4)*, !amdgpu.uniform !0 + %tmp4 = load <4 x i32>, <4 x i32> addrspace(4)* %tmp3, align 16, !invariant.load !0 + %tmp5 = call float @llvm.amdgcn.buffer.load.f32(<4 x i32> %tmp4, i32 0, i32 0, i1 false, i1 false) + %tmp6 = bitcast [0 x i8] addrspace(4)* %arg to <4 x i32> addrspace(4)*, !amdgpu.uniform !0 + %tmp7 = load <4 x i32>, <4 x i32> addrspace(4)* %tmp6, align 16, !invariant.load !0 + %tmp8 = call float @llvm.SI.load.const.v4i32(<4 x i32> %tmp7, i32 64) + %tmp9 = icmp eq i32 %arg2, 3 + %tmp10 = select i1 %tmp9, i32 1065353216, i32 -1082130432 + br label %bb11 + +bb11: ; preds = %bb26, %bb, %bb32 + %tmp12 = phi float [ %tmp8, %bb ], [ %tmp31, %bb26 ], [ %tmp34, %bb32 ] + %tmp14 = phi float [ %tmp5, %bb ], [ %tmp27, %bb26 ], [ %tmp34, %bb32 ] + %tmp15 = phi float [ %tmp5, %bb ], [ %tmp28, %bb26 ], [ %tmp34, %bb32 ] + %tmp16 = fptosi float %tmp12 to i32 + %tmp17 = icmp eq i32 %tmp16, 0 + br i1 %tmp17, label %bb18, label %bb22 + +bb18: ; preds = %bb11 + %tmp19 = bitcast i32 %tmp10 to float + %tmp20 = fmul float %tmp8, %tmp19 + %tmp21 = fadd float %tmp14, %tmp20 + call void @llvm.amdgcn.exp.f32(i32 12, i32 15, float %tmp21, float undef, float undef, float undef, i1 true, i1 false) + ret void + +bb22: ; preds = %bb11 + %tmp23 = icmp eq i32 %tmp16, 9 + %tmp24 = icmp eq i32 %tmp16, 8 + br i1 %tmp23, label %bb32, label %bb26 + +bb26: ; preds = %bb32, %bb22 + %tmp27 = phi float [ %tmp34, %bb32 ], [ %tmp14, %bb22 ] + %tmp28 = phi float [ %tmp34, %bb32 ], [ %tmp15, %bb22 ] + %tmp29 = zext i1 %tmp24 to i32 + %tmp30 = load <4 x i32>, <4 x i32> addrspace(4)* %tmp3, align 16, !invariant.load !0 + %tmp31 = call float @llvm.SI.load.const.v4i32(<4 x i32> %tmp30, i32 %tmp29) + br label %bb11 + +bb32: ; preds = %bb22 + %tmp33 = load <4 x i32>, <4 x i32> addrspace(4)* %tmp3, align 16, !invariant.load !0 + %tmp34 = call float @llvm.SI.load.const.v4i32(<4 x i32> %tmp33, i32 0) + %tmp35 = fptosi float %tmp34 to i32 + %tmp36 = icmp eq i32 %tmp35, 2 + br i1 %tmp36, label %bb26, label %bb11 +} + +; Function Attrs: nounwind readnone +declare float @llvm.SI.load.const.v4i32(<4 x i32>, i32) #0 + +; Function Attrs: nounwind +declare void @llvm.amdgcn.exp.f32(i32, i32, float, float, float, float, i1, i1) #1 + +; Function Attrs: nounwind readonly +declare float @llvm.amdgcn.buffer.load.f32(<4 x i32>, i32, i32, i1, i1) #2 + +attributes #0 = { nounwind readnone } +attributes #1 = { nounwind } +attributes #2 = { nounwind readonly } + +!0 = !{}