diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h --- a/llvm/include/llvm/ADT/GenericUniformityImpl.h +++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h @@ -357,7 +357,7 @@ /// \brief Mark \p DivVal as a value that is always divergent. /// \returns Whether the tracked divergence state of \p DivVal changed. - bool markDivergent(const InstructionT &I); + void markDivergent(const InstructionT &I); bool markDivergent(ConstValueRefT DivVal); bool markDefsDivergent(const InstructionT &Instr); @@ -774,21 +774,23 @@ } template -bool GenericUniformityAnalysisImpl::markDivergent( +void GenericUniformityAnalysisImpl::markDivergent( const InstructionT &I) { + bool Marked = false; if (I.isTerminator()) { if (DivergentTermBlocks.insert(I.getParent()).second) { LLVM_DEBUG(dbgs() << "marked divergent term block: " << Context.print(I.getParent()) << "\n"); - return true; + Marked = true; } - return false; + } else if (isAlwaysUniform(I)) { + return; + } else { + Marked = markDefsDivergent(I); } - if (isAlwaysUniform(I)) - return false; - - return markDefsDivergent(I); + if (Marked) + Worklist.push_back(&I); } template @@ -828,8 +830,7 @@ for (auto *Exit : Exits) { for (auto &Phi : Exit->phis()) { if (usesValueFromCycle(Phi, DefCycle)) { - if (markDivergent(Phi)) - Worklist.push_back(&Phi); + markDivergent(Phi); } } } @@ -889,8 +890,7 @@ if (I.isTerminator()) break; - if (markDivergent(I)) - Worklist.push_back(&I); + markDivergent(I); } } @@ -910,8 +910,7 @@ // https://reviews.llvm.org/D19013 if (ContextT::isConstantOrUndefValuePhi(Phi)) continue; - if (markDivergent(Phi)) - Worklist.push_back(&Phi); + markDivergent(Phi); } } diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp --- a/llvm/lib/Analysis/UniformityAnalysis.cpp +++ b/llvm/lib/Analysis/UniformityAnalysis.cpp @@ -27,7 +27,7 @@ template <> bool llvm::GenericUniformityAnalysisImpl::markDefsDivergent( const Instruction &Instr) { - return markDivergent(&Instr); + return markDivergent(cast(&Instr)); } template <> void llvm::GenericUniformityAnalysisImpl::initialize() { @@ -49,9 +49,7 @@ const Value *V) { for (const auto *User : V->users()) { if (const auto *UserInstr = dyn_cast(User)) { - if (markDivergent(*UserInstr)) { - Worklist.push_back(UserInstr); - } + markDivergent(*UserInstr); } } } @@ -88,8 +86,7 @@ auto *UserInstr = cast(User); if (DefCycle.contains(UserInstr->getParent())) continue; - if (markDivergent(*UserInstr)) - Worklist.push_back(UserInstr); + markDivergent(*UserInstr); } } diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp --- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp +++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp @@ -62,8 +62,7 @@ } if (uniformity == InstructionUniformity::NeverUniform) { - if (markDivergent(instr)) - Worklist.push_back(&instr); + markDivergent(instr); } } } @@ -74,8 +73,7 @@ Register Reg) { const auto &RegInfo = F.getRegInfo(); for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { - if (markDivergent(UserInstr)) - Worklist.push_back(&UserInstr); + markDivergent(UserInstr); } } @@ -128,8 +126,7 @@ for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { if (DefCycle.contains(UserInstr.getParent())) continue; - if (markDivergent(UserInstr)) - Worklist.push_back(&UserInstr); + markDivergent(UserInstr); } } }