diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h --- a/llvm/include/llvm/ADT/GenericUniformityImpl.h +++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h @@ -456,15 +456,13 @@ /// \brief Mark all instruction as divergent that use a value defined in \p /// OuterDivCycle. Push their users on the worklist. - void analyzeTemporalDivergence(const InstructionT &I, - const CycleT &OuterDivCycle); + void propagateTemporalDivergence(const InstructionT &I, + const CycleT &OuterDivCycle); /// \brief Push all users of \p Val (in the region) to the worklist. void pushUsers(const InstructionT &I); void pushUsers(ConstValueRefT V); - bool usesValueFromCycle(const InstructionT &I, const CycleT &DefCycle) const; - /// \brief Whether \p Def is divergent when read in \p ObservingBlock. bool isTemporalDivergent(const BlockT &ObservingBlock, const InstructionT &Def) const; @@ -809,24 +807,6 @@ UniformOverrides.insert(&Instr); } -template -void GenericUniformityAnalysisImpl::analyzeTemporalDivergence( - const InstructionT &I, const CycleT &OuterDivCycle) { - if (isDivergent(I)) - return; - - LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << Context.print(&I) - << "\n"); - if (isAlwaysUniform(I)) - return; - - if (!usesValueFromCycle(I, OuterDivCycle)) - return; - - if (markDivergent(I)) - Worklist.push_back(&I); -} - // Mark all external users of values defined inside \param // OuterDivCycle as divergent. // @@ -841,74 +821,9 @@ template void GenericUniformityAnalysisImpl::analyzeCycleExitDivergence( const CycleT &OuterDivCycle) { - // Set of blocks that are dominated by the cycle, i.e., each is only - // reachable from paths that pass through the cycle. - SmallPtrSet DomRegion; - - // The boundary of DomRegion, formed by blocks that are not - // dominated by the cycle. - SmallVector DomFrontier; - OuterDivCycle.getExitBlocks(DomFrontier); - - // Returns true if BB is dominated by the cycle. - auto isInDomRegion = [&](BlockT *BB) { - for (auto *P : predecessors(BB)) { - if (OuterDivCycle.contains(P)) - continue; - if (DomRegion.count(P)) - continue; - return false; - } - return true; - }; - - // Keep advancing the frontier along successor edges, while - // promoting blocks to DomRegion. - while (true) { - bool Promoted = false; - SmallVector Temp; - for (auto *W : DomFrontier) { - if (!isInDomRegion(W)) { - Temp.push_back(W); - continue; - } - DomRegion.insert(W); - Promoted = true; - for (auto *Succ : successors(W)) { - if (DomRegion.contains(Succ)) - continue; - Temp.push_back(Succ); - } - } - if (!Promoted) - break; - - // Restore the set property for the temporary vector - llvm::sort(Temp); - Temp.erase(std::unique(Temp.begin(), Temp.end()), Temp.end()); - - DomFrontier = Temp; - } - - // At DomFrontier, only the PHI nodes are affected by temporal - // divergence. - for (const auto *UserBlock : DomFrontier) { - LLVM_DEBUG(dbgs() << "Analyze phis after cycle exit: " - << Context.print(UserBlock) << "\n"); - for (const auto &Phi : UserBlock->phis()) { - LLVM_DEBUG(dbgs() << " " << Context.print(&Phi) << "\n"); - analyzeTemporalDivergence(Phi, OuterDivCycle); - } - } - - // All instructions inside the dominance region are affected by - // temporal divergence. - for (const auto *UserBlock : DomRegion) { - LLVM_DEBUG(dbgs() << "Analyze non-phi users after cycle exit: " - << Context.print(UserBlock) << "\n"); - for (const auto &I : *UserBlock) { - LLVM_DEBUG(dbgs() << " " << Context.print(&I) << "\n"); - analyzeTemporalDivergence(I, OuterDivCycle); + for (auto *BB : OuterDivCycle.blocks()) { + for (auto &II : *BB) { + propagateTemporalDivergence(II, OuterDivCycle); } } } diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp --- a/llvm/lib/Analysis/UniformityAnalysis.cpp +++ b/llvm/lib/Analysis/UniformityAnalysis.cpp @@ -66,16 +66,16 @@ } template <> -bool llvm::GenericUniformityAnalysisImpl::usesValueFromCycle( - const Instruction &I, const Cycle &DefCycle) const { - assert(!isAlwaysUniform(I)); - for (const Use &U : I.operands()) { - if (auto *I = dyn_cast(&U)) { - if (DefCycle.contains(I->getParent())) - return true; - } +void llvm::GenericUniformityAnalysisImpl< + SSAContext>::propagateTemporalDivergence(const Instruction &I, + const Cycle &DefCycle) { + for (const Use &U : I.uses()) { + auto *UserInstr = cast(U.getUser()); + if (DefCycle.contains(UserInstr->getParent())) + continue; + if (markDivergent(*UserInstr)) + Worklist.push_back(UserInstr); } - return false; } template <> diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp --- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp +++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp @@ -93,24 +93,23 @@ } template <> -bool llvm::GenericUniformityAnalysisImpl::usesValueFromCycle( - const MachineInstr &I, const MachineCycle &DefCycle) const { - assert(!isAlwaysUniform(I)); +void llvm::GenericUniformityAnalysisImpl:: + propagateTemporalDivergence(const MachineInstr &I, + const MachineCycle &DefCycle) { + const auto &RegInfo = F.getRegInfo(); for (auto &Op : I.operands()) { - if (!Op.isReg() || !Op.readsReg()) + if (!Op.isReg() || !Op.isDef()) + continue; + if (!Op.getReg().isVirtual()) continue; auto Reg = Op.getReg(); - - // FIXME: Physical registers need to be properly checked instead of always - // returning true - if (Reg.isPhysical()) - return true; - - auto *Def = F.getRegInfo().getVRegDef(Reg); - if (DefCycle.contains(Def->getParent())) - return true; + for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { + if (DefCycle.contains(UserInstr.getParent())) + continue; + if (markDivergent(UserInstr)) + Worklist.push_back(&UserInstr); + } } - return false; } template <> diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll --- a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll +++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll @@ -85,6 +85,7 @@ Y: %div.alsouser = add i32 %uni.inc, 5 ret void +; CHECK: DIVERGENT: %div.alsouser = } @@ -113,6 +114,7 @@ G: %div.user = add i32 %uni.inc, 5 br i1 %uni.cond, label %G, label %Y +; CHECK: DIVERGENT: %div.user = Y: ret void @@ -127,10 +129,13 @@ entry: %tid = call i32 @llvm.amdgcn.workitem.id.x() %uni.cond = icmp slt i32 %a, 0 + br label %G + +G: br label %H H: - %uni.merge.h = phi i32 [ 0, %entry ], [ %uni.inc, %H ] + %uni.merge.h = phi i32 [ 0, %G ], [ %uni.inc, %H ] %uni.inc = add i32 %uni.merge.h, 1 %div.exitx = icmp slt i32 %tid, 0 br i1 %div.exitx, label %X, label %H ; divergent branch @@ -138,11 +143,9 @@ ; CHECK: DIVERGENT: br i1 %div.exitx, X: - br label %G - -G: +; CHECK: DIVERGENT: %div.user = %div.user = add i32 %uni.inc, 5 - br i1 %uni.cond, label %G, label %Y + br i1 %uni.cond, label %X, label %G Y: ret void