diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h --- a/llvm/include/llvm/ADT/GenericUniformityImpl.h +++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h @@ -355,10 +355,15 @@ /// \brief Mark \p UniVal as a value that is always uniform. void addUniformOverride(const InstructionT &Instr); - /// \brief Mark \p DivVal as a value that is always divergent. + /// \brief Examine \p I for divergent outputs and add to the worklist. + void markDivergent(const InstructionT &I); + + /// \brief Mark \p DivVal as a divergent value. /// \returns Whether the tracked divergence state of \p DivVal changed. - bool markDivergent(const InstructionT &I); bool markDivergent(ConstValueRefT DivVal); + + /// \brief Mark outputs of \p Instr as divergent. + /// \returns Whether the tracked divergence state of any output has changed. bool markDefsDivergent(const InstructionT &Instr); /// \brief Propagate divergence to all instructions in the region. @@ -774,21 +779,23 @@ } template <typename ContextT> -bool GenericUniformityAnalysisImpl<ContextT>::markDivergent( +void GenericUniformityAnalysisImpl<ContextT>::markDivergent( const InstructionT &I) { + if (isAlwaysUniform(I)) + return; + bool Marked = false; if (I.isTerminator()) { - if (DivergentTermBlocks.insert(I.getParent()).second) { + Marked = DivergentTermBlocks.insert(I.getParent()).second; + if (Marked) { LLVM_DEBUG(dbgs() << "marked divergent term block: " << Context.print(I.getParent()) << "\n"); - return true; } - return false; + } else { + Marked = markDefsDivergent(I); } - if (isAlwaysUniform(I)) - return false; - - return markDefsDivergent(I); + if (Marked) + Worklist.push_back(&I); } template <typename ContextT> @@ -828,8 +835,7 @@ for (auto *Exit : Exits) { for (auto &Phi : Exit->phis()) { if (usesValueFromCycle(Phi, DefCycle)) { - if (markDivergent(Phi)) - Worklist.push_back(&Phi); + markDivergent(Phi); } } } @@ -889,8 +895,7 @@ if (I.isTerminator()) break; - if (markDivergent(I)) - Worklist.push_back(&I); + markDivergent(I); } } @@ -910,8 +915,7 @@ // https://reviews.llvm.org/D19013 if (ContextT::isConstantOrUndefValuePhi(Phi)) continue; - if (markDivergent(Phi)) - Worklist.push_back(&Phi); + markDivergent(Phi); } } diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp --- a/llvm/lib/Analysis/UniformityAnalysis.cpp +++ b/llvm/lib/Analysis/UniformityAnalysis.cpp @@ -27,7 +27,7 @@ template <> bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent( const Instruction &Instr) { - return markDivergent(&Instr); + return markDivergent(cast<Value>(&Instr)); } template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() { @@ -49,9 +49,7 @@ const Value *V) { for (const auto *User : V->users()) { if (const auto *UserInstr = dyn_cast<const Instruction>(User)) { - if (markDivergent(*UserInstr)) { - Worklist.push_back(UserInstr); - } + markDivergent(*UserInstr); } } } @@ -88,8 +86,7 @@ auto *UserInstr = cast<Instruction>(User); if (DefCycle.contains(UserInstr->getParent())) continue; - if (markDivergent(*UserInstr)) - Worklist.push_back(UserInstr); + markDivergent(*UserInstr); } } diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp --- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp +++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp @@ -62,8 +62,7 @@ } if (uniformity == InstructionUniformity::NeverUniform) { - if (markDivergent(instr)) - Worklist.push_back(&instr); + markDivergent(instr); } } } @@ -72,10 +71,10 @@ template <> void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( Register Reg) { + assert(isDivergent(Reg)); const auto &RegInfo = F.getRegInfo(); for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { - if (markDivergent(UserInstr)) - Worklist.push_back(&UserInstr); + markDivergent(UserInstr); } } @@ -86,8 +85,11 @@ if (Instr.isTerminator()) return; for (const MachineOperand &op : Instr.operands()) { - if (op.isReg() && op.isDef() && op.getReg().isVirtual()) - pushUsers(op.getReg()); + if (!op.isReg() || !op.isDef()) + continue; + auto Reg = op.getReg(); + if (isDivergent(Reg)) + pushUsers(Reg); } } @@ -128,8 +130,7 @@ for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { if (DefCycle.contains(UserInstr.getParent())) continue; - if (markDivergent(UserInstr)) - Worklist.push_back(&UserInstr); + markDivergent(UserInstr); } } } diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir --- a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir +++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir @@ -97,7 +97,6 @@ ... -# FIXME :: BELOW INLINE ASM SHOULD BE DIVERGENT --- name: asm_mixed_sgpr_vgpr registers: @@ -116,7 +115,9 @@ ; CHECK-LABEL: MachineUniformityInfo for function: asm_mixed_sgpr_vgpr ; CHECK: DIVERGENT: %0: ; CHECK: DIVERGENT: %3: + ; CHECK-NOT: DIVERGENT: %1: ; CHECK: DIVERGENT: %2: + ; CHECK-NOT: DIVERGENT: %4: ; CHECK: DIVERGENT: %5: %0:_(s32) = COPY $vgpr0 %6:_(p1) = G_IMPLICIT_DEF