Index: llvm/trunk/include/llvm/ADT/PostOrderIterator.h =================================================================== --- llvm/trunk/include/llvm/ADT/PostOrderIterator.h +++ llvm/trunk/include/llvm/ADT/PostOrderIterator.h @@ -296,12 +296,15 @@ public: using rpo_iterator = typename std::vector::reverse_iterator; + using const_rpo_iterator = typename std::vector::const_reverse_iterator; ReversePostOrderTraversal(GraphT G) { Initialize(GT::getEntryNode(G)); } // Because we want a reverse post order, use reverse iterators from the vector rpo_iterator begin() { return Blocks.rbegin(); } + const_rpo_iterator begin() const { return Blocks.crbegin(); } rpo_iterator end() { return Blocks.rend(); } + const_rpo_iterator end() const { return Blocks.crend(); } }; } // end namespace llvm Index: llvm/trunk/include/llvm/Analysis/DivergenceAnalysis.h =================================================================== --- llvm/trunk/include/llvm/Analysis/DivergenceAnalysis.h +++ llvm/trunk/include/llvm/Analysis/DivergenceAnalysis.h @@ -0,0 +1,178 @@ +//===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// \file +// The divergence analysis determines which instructions and branches are +// divergent given a set of divergent source instructions. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H +#define LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H + +#include "llvm/ADT/DenseSet.h" +#include "llvm/Analysis/SyncDependenceAnalysis.h" +#include "llvm/IR/Function.h" +#include "llvm/Pass.h" +#include + +namespace llvm { +class Module; +class Value; +class Instruction; +class Loop; +class raw_ostream; +class TargetTransformInfo; + +/// \brief Generic divergence analysis for reducible CFGs. +/// +/// This analysis propagates divergence in a data-parallel context from sources +/// of divergence to all users. It requires reducible CFGs. All assignments +/// should be in SSA form. +class DivergenceAnalysis { +public: + /// \brief This instance will analyze the whole function \p F or the loop \p + /// RegionLoop. + /// + /// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop. + /// Otherwise the whole function is analyzed. + /// \param IsLCSSAForm whether the analysis may assume that the IR in the + /// region in in LCSSA form. + DivergenceAnalysis(const Function &F, const Loop *RegionLoop, + const DominatorTree &DT, const LoopInfo &LI, + SyncDependenceAnalysis &SDA, bool IsLCSSAForm); + + /// \brief The loop that defines the analyzed region (if any). + const Loop *getRegionLoop() const { return RegionLoop; } + const Function &getFunction() const { return F; } + + /// \brief Whether \p BB is part of the region. + bool inRegion(const BasicBlock &BB) const; + /// \brief Whether \p I is part of the region. + bool inRegion(const Instruction &I) const; + + /// \brief Mark \p UniVal as a value that is always uniform. + void addUniformOverride(const Value &UniVal); + + /// \brief Mark \p DivVal as a value that is always divergent. + void markDivergent(const Value &DivVal); + + /// \brief Propagate divergence to all instructions in the region. + /// Divergence is seeded by calls to \p markDivergent. + void compute(); + + /// \brief Whether any value was marked or analyzed to be divergent. + bool hasDetectedDivergence() const { return !DivergentValues.empty(); } + + /// \brief Whether \p Val will always return a uniform value regardless of its + /// operands + bool isAlwaysUniform(const Value &Val) const; + + /// \brief Whether \p Val is a divergent value + bool isDivergent(const Value &Val) const; + + void print(raw_ostream &OS, const Module *) const; + +private: + bool updateTerminator(const TerminatorInst &Term) const; + bool updatePHINode(const PHINode &Phi) const; + + /// \brief Computes whether \p Inst is divergent based on the + /// divergence of its operands. + /// + /// \returns Whether \p Inst is divergent. + /// + /// This should only be called for non-phi, non-terminator instructions. + bool updateNormalInstruction(const Instruction &Inst) const; + + /// \brief Mark users of live-out users as divergent. + /// + /// \param LoopHeader the header of the divergent loop. + /// + /// Marks all users of live-out values of the loop headed by \p LoopHeader + /// as divergent and puts them on the worklist. + void taintLoopLiveOuts(const BasicBlock &LoopHeader); + + /// \brief Push all users of \p Val (in the region) to the worklist + void pushUsers(const Value &I); + + /// \brief Push all phi nodes in @block to the worklist + void pushPHINodes(const BasicBlock &Block); + + /// \brief Mark \p Block as join divergent + /// + /// A block is join divergent if two threads may reach it from different + /// incoming blocks at the same time. + void markBlockJoinDivergent(const BasicBlock &Block) { + DivergentJoinBlocks.insert(&Block); + } + + /// \brief Whether \p Val is divergent when read in \p ObservingBlock. + bool isTemporalDivergent(const BasicBlock &ObservingBlock, + const Value &Val) const; + + /// \brief Whether \p Block is join divergent + /// + /// (see markBlockJoinDivergent). + bool isJoinDivergent(const BasicBlock &Block) const { + return DivergentJoinBlocks.find(&Block) != DivergentJoinBlocks.end(); + } + + /// \brief Propagate control-induced divergence to users (phi nodes and + /// instructions). + // + // \param JoinBlock is a divergent loop exit or join point of two disjoint + // paths. + // \returns Whether \p JoinBlock is a divergent loop exit of \p TermLoop. + bool propagateJoinDivergence(const BasicBlock &JoinBlock, + const Loop *TermLoop); + + /// \brief Propagate induced value divergence due to control divergence in \p + /// Term. + void propagateBranchDivergence(const TerminatorInst &Term); + + /// \brief Propagate divergent caused by a divergent loop exit. + /// + /// \param ExitingLoop is a divergent loop. + void propagateLoopDivergence(const Loop &ExitingLoop); + +private: + const Function &F; + // If regionLoop != nullptr, analysis is only performed within \p RegionLoop. + // Otw, analyze the whole function + const Loop *RegionLoop; + + const DominatorTree &DT; + const LoopInfo &LI; + + // Recognized divergent loops + DenseSet DivergentLoops; + + // The SDA links divergent branches to divergent control-flow joins. + SyncDependenceAnalysis &SDA; + + // Use simplified code path for LCSSA form. + bool IsLCSSAForm; + + // Set of known-uniform values. + DenseSet UniformOverrides; + + // Blocks with joining divergent control from different predecessors. + DenseSet DivergentJoinBlocks; + + // Detected/marked divergent values. + DenseSet DivergentValues; + + // Internal worklist for divergence propagation. + std::vector Worklist; +}; + +} // namespace llvm + +#endif // LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H Index: llvm/trunk/include/llvm/Analysis/SyncDependenceAnalysis.h =================================================================== --- llvm/trunk/include/llvm/Analysis/SyncDependenceAnalysis.h +++ llvm/trunk/include/llvm/Analysis/SyncDependenceAnalysis.h @@ -0,0 +1,88 @@ +//===- SyncDependenceAnalysis.h - Divergent Branch Dependence -*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// \file +// This file defines the SyncDependenceAnalysis class, which computes for +// every divergent branch the set of phi nodes that the branch will make +// divergent. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_SYNC_DEPENDENCE_ANALYSIS_H +#define LLVM_ANALYSIS_SYNC_DEPENDENCE_ANALYSIS_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/LoopInfo.h" +#include + +namespace llvm { + +class BasicBlock; +class DominatorTree; +class Loop; +class PostDominatorTree; +class TerminatorInst; +class TerminatorInst; + +using ConstBlockSet = SmallPtrSet; + +/// \brief Relates points of divergent control to join points in +/// reducible CFGs. +/// +/// This analysis relates points of divergent control to points of converging +/// divergent control. The analysis requires all loops to be reducible. +class SyncDependenceAnalysis { + void visitSuccessor(const BasicBlock &succBlock, const Loop *termLoop, + const BasicBlock *defBlock); + +public: + bool inRegion(const BasicBlock &BB) const; + + ~SyncDependenceAnalysis(); + SyncDependenceAnalysis(const DominatorTree &DT, const PostDominatorTree &PDT, + const LoopInfo &LI); + + /// \brief Computes divergent join points and loop exits caused by branch + /// divergence in \p Term. + /// + /// The set of blocks which are reachable by disjoint paths from \p Term. + /// The set also contains loop exits if there two disjoint paths: + /// one from \p Term to the loop exit and another from \p Term to the loop + /// header. Those exit blocks are added to the returned set. + /// If L is the parent loop of \p Term and an exit of L is in the returned + /// set then L is a divergent loop. + const ConstBlockSet &join_blocks(const TerminatorInst &Term); + + /// \brief Computes divergent join points and loop exits (in the surrounding + /// loop) caused by the divergent loop exits of\p Loop. + /// + /// The set of blocks which are reachable by disjoint paths from the + /// loop exits of \p Loop. + /// This treats the loop as a single node in \p Loop's parent loop. + /// The returned set has the same properties as for join_blocks(TermInst&). + const ConstBlockSet &join_blocks(const Loop &Loop); + +private: + static ConstBlockSet EmptyBlockSet; + + ReversePostOrderTraversal FuncRPOT; + const DominatorTree &DT; + const PostDominatorTree &PDT; + const LoopInfo &LI; + + std::map> CachedLoopExitJoins; + std::map> + CachedBranchJoins; +}; + +} // namespace llvm + +#endif // LLVM_ANALYSIS_SYNC_DEPENDENCE_ANALYSIS_H Index: llvm/trunk/lib/Analysis/CMakeLists.txt =================================================================== --- llvm/trunk/lib/Analysis/CMakeLists.txt +++ llvm/trunk/lib/Analysis/CMakeLists.txt @@ -25,6 +25,7 @@ Delinearization.cpp DemandedBits.cpp DependenceAnalysis.cpp + DivergenceAnalysis.cpp DomPrinter.cpp DominanceFrontier.cpp EHPersonalities.cpp @@ -80,6 +81,7 @@ ScalarEvolutionAliasAnalysis.cpp ScalarEvolutionExpander.cpp ScalarEvolutionNormalization.cpp + SyncDependenceAnalysis.cpp SyntheticCountsUtils.cpp TargetLibraryInfo.cpp TargetTransformInfo.cpp Index: llvm/trunk/lib/Analysis/DivergenceAnalysis.cpp =================================================================== --- llvm/trunk/lib/Analysis/DivergenceAnalysis.cpp +++ llvm/trunk/lib/Analysis/DivergenceAnalysis.cpp @@ -0,0 +1,425 @@ +//===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a general divergence analysis for loop vectorization +// and GPU programs. It determines which branches and values in a loop or GPU +// program are divergent. It can help branch optimizations such as jump +// threading and loop unswitching to make better decisions. +// +// GPU programs typically use the SIMD execution model, where multiple threads +// in the same execution group have to execute in lock-step. Therefore, if the +// code contains divergent branches (i.e., threads in a group do not agree on +// which path of the branch to take), the group of threads has to execute all +// the paths from that branch with different subsets of threads enabled until +// they re-converge. +// +// Due to this execution model, some optimizations such as jump +// threading and loop unswitching can interfere with thread re-convergence. +// Therefore, an analysis that computes which branches in a GPU program are +// divergent can help the compiler to selectively run these optimizations. +// +// This implementation is derived from the Vectorization Analysis of the +// Region Vectorizer (RV). That implementation in turn is based on the approach +// described in +// +// Improving Performance of OpenCL on CPUs +// Ralf Karrenberg and Sebastian Hack +// CC '12 +// +// This DivergenceAnalysis implementation is generic in the sense that it does +// not itself identify original sources of divergence. +// Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and +// (GPUDivergenceAnalysis) for GPU programs, identify the sources of divergence +// (e.g., special variables that hold the thread ID or the iteration variable). +// +// The generic implementation propagates divergence to variables that are data +// or sync dependent on a source of divergence. +// +// While data dependency is a well-known concept, the notion of sync dependency +// is worth more explanation. Sync dependence characterizes the control flow +// aspect of the propagation of branch divergence. For example, +// +// %cond = icmp slt i32 %tid, 10 +// br i1 %cond, label %then, label %else +// then: +// br label %merge +// else: +// br label %merge +// merge: +// %a = phi i32 [ 0, %then ], [ 1, %else ] +// +// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid +// because %tid is not on its use-def chains, %a is sync dependent on %tid +// because the branch "br i1 %cond" depends on %tid and affects which value %a +// is assigned to. +// +// The sync dependence detection (which branch induces divergence in which join +// points) is implemented in the SyncDependenceAnalysis. +// +// The current DivergenceAnalysis implementation has the following limitations: +// 1. intra-procedural. It conservatively considers the arguments of a +// non-kernel-entry function and the return value of a function call as +// divergent. +// 2. memory as black box. It conservatively considers values loaded from +// generic or local address as divergent. This can be improved by leveraging +// pointer analysis and/or by modelling non-escaping memory objects in SSA +// as done in RV. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace llvm; + +#define DEBUG_TYPE "divergence-analysis" + +// class DivergenceAnalysis +DivergenceAnalysis::DivergenceAnalysis( + const Function &F, const Loop *RegionLoop, const DominatorTree &DT, + const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm) + : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), + IsLCSSAForm(IsLCSSAForm) {} + +void DivergenceAnalysis::markDivergent(const Value &DivVal) { + assert(isa(DivVal) || isa(DivVal)); + assert(!isAlwaysUniform(DivVal) && "cannot be a divergent"); + DivergentValues.insert(&DivVal); +} + +void DivergenceAnalysis::addUniformOverride(const Value &UniVal) { + UniformOverrides.insert(&UniVal); +} + +bool DivergenceAnalysis::updateTerminator(const TerminatorInst &Term) const { + if (Term.getNumSuccessors() <= 1) + return false; + if (auto *BranchTerm = dyn_cast(&Term)) { + assert(BranchTerm->isConditional()); + return isDivergent(*BranchTerm->getCondition()); + } + if (auto *SwitchTerm = dyn_cast(&Term)) { + return isDivergent(*SwitchTerm->getCondition()); + } + if (isa(Term)) { + return false; // ignore abnormal executions through landingpad + } + + llvm_unreachable("unexpected terminator"); +} + +bool DivergenceAnalysis::updateNormalInstruction(const Instruction &I) const { + // TODO function calls with side effects, etc + for (const auto &Op : I.operands()) { + if (isDivergent(*Op)) + return true; + } + return false; +} + +bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock, + const Value &Val) const { + const auto *Inst = dyn_cast(&Val); + if (!Inst) + return false; + // check whether any divergent loop carrying Val terminates before control + // proceeds to ObservingBlock + for (const auto *Loop = LI.getLoopFor(Inst->getParent()); + Loop != RegionLoop && !Loop->contains(&ObservingBlock); + Loop = Loop->getParentLoop()) { + if (DivergentLoops.find(Loop) != DivergentLoops.end()) + return true; + } + + return false; +} + +bool DivergenceAnalysis::updatePHINode(const PHINode &Phi) const { + // joining divergent disjoint path in Phi parent block + if (!Phi.hasConstantOrUndefValue() && isJoinDivergent(*Phi.getParent())) { + return true; + } + + // An incoming value could be divergent by itself. + // Otherwise, an incoming value could be uniform within the loop + // that carries its definition but it may appear divergent + // from outside the loop. This happens when divergent loop exits + // drop definitions of that uniform value in different iterations. + // + // for (int i = 0; i < n; ++i) { // 'i' is uniform inside the loop + // if (i % thread_id == 0) break; // divergent loop exit + // } + // int divI = i; // divI is divergent + for (size_t i = 0; i < Phi.getNumIncomingValues(); ++i) { + const auto *InVal = Phi.getIncomingValue(i); + if (isDivergent(*Phi.getIncomingValue(i)) || + isTemporalDivergent(*Phi.getParent(), *InVal)) { + return true; + } + } + return false; +} + +bool DivergenceAnalysis::inRegion(const Instruction &I) const { + return I.getParent() && inRegion(*I.getParent()); +} + +bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const { + return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB); +} + +// marks all users of loop-carried values of the loop headed by LoopHeader as +// divergent +void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) { + auto *DivLoop = LI.getLoopFor(&LoopHeader); + assert(DivLoop && "loopHeader is not actually part of a loop"); + + SmallVector TaintStack; + DivLoop->getExitBlocks(TaintStack); + + // Otherwise potential users of loop-carried values could be anywhere in the + // dominance region of DivLoop (including its fringes for phi nodes) + DenseSet Visited; + for (auto *Block : TaintStack) { + Visited.insert(Block); + } + Visited.insert(&LoopHeader); + + while (!TaintStack.empty()) { + auto *UserBlock = TaintStack.back(); + TaintStack.pop_back(); + + // don't spread divergence beyond the region + if (!inRegion(*UserBlock)) + continue; + + assert(!DivLoop->contains(UserBlock) && + "irreducible control flow detected"); + + // phi nodes at the fringes of the dominance region + if (!DT.dominates(&LoopHeader, UserBlock)) { + // all PHI nodes of UserBlock become divergent + for (auto &Phi : UserBlock->phis()) { + Worklist.push_back(&Phi); + } + continue; + } + + // taint outside users of values carried by DivLoop + for (auto &I : *UserBlock) { + if (isAlwaysUniform(I)) + continue; + if (isDivergent(I)) + continue; + + for (auto &Op : I.operands()) { + auto *OpInst = dyn_cast(&Op); + if (!OpInst) + continue; + if (DivLoop->contains(OpInst->getParent())) { + markDivergent(I); + pushUsers(I); + break; + } + } + } + + // visit all blocks in the dominance region + for (auto *SuccBlock : successors(UserBlock)) { + if (!Visited.insert(SuccBlock).second) { + continue; + } + TaintStack.push_back(SuccBlock); + } + } +} + +void DivergenceAnalysis::pushPHINodes(const BasicBlock &Block) { + for (const auto &Phi : Block.phis()) { + if (isDivergent(Phi)) + continue; + Worklist.push_back(&Phi); + } +} + +void DivergenceAnalysis::pushUsers(const Value &V) { + for (const auto *User : V.users()) { + const auto *UserInst = dyn_cast(User); + if (!UserInst) + continue; + + if (isDivergent(*UserInst)) + continue; + + // only compute divergent inside loop + if (!inRegion(*UserInst)) + continue; + Worklist.push_back(UserInst); + } +} + +bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock &JoinBlock, + const Loop *BranchLoop) { + LLVM_DEBUG(dbgs() << "\tpropJoinDiv " << JoinBlock.getName() << "\n"); + + // ignore divergence outside the region + if (!inRegion(JoinBlock)) { + return false; + } + + // push non-divergent phi nodes in JoinBlock to the worklist + pushPHINodes(JoinBlock); + + // JoinBlock is a divergent loop exit + if (BranchLoop && !BranchLoop->contains(&JoinBlock)) { + return true; + } + + // disjoint-paths divergent at JoinBlock + markBlockJoinDivergent(JoinBlock); + return false; +} + +void DivergenceAnalysis::propagateBranchDivergence(const TerminatorInst &Term) { + LLVM_DEBUG(dbgs() << "propBranchDiv " << Term.getParent()->getName() << "\n"); + + markDivergent(Term); + + const auto *BranchLoop = LI.getLoopFor(Term.getParent()); + + // whether there is a divergent loop exit from BranchLoop (if any) + bool IsBranchLoopDivergent = false; + + // iterate over all blocks reachable by disjoint from Term within the loop + // also iterates over loop exits that become divergent due to Term. + for (const auto *JoinBlock : SDA.join_blocks(Term)) { + IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop); + } + + // Branch loop is a divergent loop due to the divergent branch in Term + if (IsBranchLoopDivergent) { + assert(BranchLoop); + if (!DivergentLoops.insert(BranchLoop).second) { + return; + } + propagateLoopDivergence(*BranchLoop); + } +} + +void DivergenceAnalysis::propagateLoopDivergence(const Loop &ExitingLoop) { + LLVM_DEBUG(dbgs() << "propLoopDiv " << ExitingLoop.getName() << "\n"); + + // don't propagate beyond region + if (!inRegion(*ExitingLoop.getHeader())) + return; + + const auto *BranchLoop = ExitingLoop.getParentLoop(); + + // Uses of loop-carried values could occur anywhere + // within the dominance region of the definition. All loop-carried + // definitions are dominated by the loop header (reducible control). + // Thus all users have to be in the dominance region of the loop header, + // except PHI nodes that can also live at the fringes of the dom region + // (incoming defining value). + if (!IsLCSSAForm) + taintLoopLiveOuts(*ExitingLoop.getHeader()); + + // whether there is a divergent loop exit from BranchLoop (if any) + bool IsBranchLoopDivergent = false; + + // iterate over all blocks reachable by disjoint paths from exits of + // ExitingLoop also iterates over loop exits (of BranchLoop) that in turn + // become divergent. + for (const auto *JoinBlock : SDA.join_blocks(ExitingLoop)) { + IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop); + } + + // Branch loop is a divergent due to divergent loop exit in ExitingLoop + if (IsBranchLoopDivergent) { + assert(BranchLoop); + if (!DivergentLoops.insert(BranchLoop).second) { + return; + } + propagateLoopDivergence(*BranchLoop); + } +} + +void DivergenceAnalysis::compute() { + for (auto *DivVal : DivergentValues) { + pushUsers(*DivVal); + } + + // propagate divergence + while (!Worklist.empty()) { + const Instruction &I = *Worklist.back(); + Worklist.pop_back(); + + // maintain uniformity of overrides + if (isAlwaysUniform(I)) + continue; + + bool WasDivergent = isDivergent(I); + if (WasDivergent) + continue; + + // propagate divergence caused by terminator + if (isa(I)) { + auto &Term = cast(I); + if (updateTerminator(Term)) { + // propagate control divergence to affected instructions + propagateBranchDivergence(Term); + continue; + } + } + + // update divergence of I due to divergent operands + bool DivergentUpd = false; + const auto *Phi = dyn_cast(&I); + if (Phi) { + DivergentUpd = updatePHINode(*Phi); + } else { + DivergentUpd = updateNormalInstruction(I); + } + + // propagate value divergence to users + if (DivergentUpd) { + markDivergent(I); + pushUsers(I); + } + } +} + +bool DivergenceAnalysis::isAlwaysUniform(const Value &V) const { + return UniformOverrides.find(&V) != UniformOverrides.end(); +} + +bool DivergenceAnalysis::isDivergent(const Value &V) const { + return DivergentValues.find(&V) != DivergentValues.end(); +} + +void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const { + if (DivergentValues.empty()) + return; + // iterate instructions using instructions() to ensure a deterministic order. + for (auto &I : instructions(F)) { + if (isDivergent(I)) + OS << "DIVERGENT:" << I << '\n'; + } +} Index: llvm/trunk/lib/Analysis/SyncDependenceAnalysis.cpp =================================================================== --- llvm/trunk/lib/Analysis/SyncDependenceAnalysis.cpp +++ llvm/trunk/lib/Analysis/SyncDependenceAnalysis.cpp @@ -0,0 +1,380 @@ +//===- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation +//--===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements an algorithm that returns for a divergent branch +// the set of basic blocks whose phi nodes become divergent due to divergent +// control. These are the blocks that are reachable by two disjoint paths from +// the branch or loop exits that have a reaching path that is disjoint from a +// path to the loop latch. +// +// The SyncDependenceAnalysis is used in the DivergenceAnalysis to model +// control-induced divergence in phi nodes. +// +// -- Summary -- +// The SyncDependenceAnalysis lazily computes sync dependences [3]. +// The analysis evaluates the disjoint path criterion [2] by a reduction +// to SSA construction. The SSA construction algorithm is implemented as +// a simple data-flow analysis [1]. +// +// [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy +// [2] "Efficiently Computing Static Single Assignment Form +// and the Control Dependence Graph", TOPLAS '91, +// Cytron, Ferrante, Rosen, Wegman and Zadeck +// [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack +// [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira +// +// -- Sync dependence -- +// Sync dependence [4] characterizes the control flow aspect of the +// propagation of branch divergence. For example, +// +// %cond = icmp slt i32 %tid, 10 +// br i1 %cond, label %then, label %else +// then: +// br label %merge +// else: +// br label %merge +// merge: +// %a = phi i32 [ 0, %then ], [ 1, %else ] +// +// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid +// because %tid is not on its use-def chains, %a is sync dependent on %tid +// because the branch "br i1 %cond" depends on %tid and affects which value %a +// is assigned to. +// +// -- Reduction to SSA construction -- +// There are two disjoint paths from A to X, if a certain variant of SSA +// construction places a phi node in X under the following set-up scheme [2]. +// +// This variant of SSA construction ignores incoming undef values. +// That is paths from the entry without a definition do not result in +// phi nodes. +// +// entry +// / \ +// A \ +// / \ Y +// B C / +// \ / \ / +// D E +// \ / +// F +// Assume that A contains a divergent branch. We are interested +// in the set of all blocks where each block is reachable from A +// via two disjoint paths. This would be the set {D, F} in this +// case. +// To generally reduce this query to SSA construction we introduce +// a virtual variable x and assign to x different values in each +// successor block of A. +// entry +// / \ +// A \ +// / \ Y +// x = 0 x = 1 / +// \ / \ / +// D E +// \ / +// F +// Our flavor of SSA construction for x will construct the following +// entry +// / \ +// A \ +// / \ Y +// x0 = 0 x1 = 1 / +// \ / \ / +// x2=phi E +// \ / +// x3=phi +// The blocks D and F contain phi nodes and are thus each reachable +// by two disjoins paths from A. +// +// -- Remarks -- +// In case of loop exits we need to check the disjoint path criterion for loops +// [2]. To this end, we check whether the definition of x differs between the +// loop exit and the loop header (_after_ SSA construction). +// +//===----------------------------------------------------------------------===// +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/SyncDependenceAnalysis.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" + +#include +#include + +#define DEBUG_TYPE "sync-dependence" + +namespace llvm { + +ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet; + +SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, + const PostDominatorTree &PDT, + const LoopInfo &LI) + : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {} + +SyncDependenceAnalysis::~SyncDependenceAnalysis() {} + +using FunctionRPOT = ReversePostOrderTraversal; + +// divergence propagator for reducible CFGs +struct DivergencePropagator { + const FunctionRPOT &FuncRPOT; + const DominatorTree &DT; + const PostDominatorTree &PDT; + const LoopInfo &LI; + + // identified join points + std::unique_ptr JoinBlocks; + + // reached loop exits (by a path disjoint to a path to the loop header) + SmallPtrSet ReachedLoopExits; + + // if DefMap[B] == C then C is the dominating definition at block B + // if DefMap[B] ~ undef then we haven't seen B yet + // if DefMap[B] == B then B is a join point of disjoint paths from X or B is + // an immediate successor of X (initial value). + using DefiningBlockMap = std::map; + DefiningBlockMap DefMap; + + // all blocks with pending visits + std::unordered_set PendingUpdates; + + DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT, + const PostDominatorTree &PDT, const LoopInfo &LI) + : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI), + JoinBlocks(new ConstBlockSet) {} + + // set the definition at @block and mark @block as pending for a visit + void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) { + bool WasAdded = DefMap.emplace(&Block, &DefBlock).second; + if (WasAdded) + PendingUpdates.insert(&Block); + } + + void printDefs(raw_ostream &Out) { + Out << "Propagator::DefMap {\n"; + for (const auto *Block : FuncRPOT) { + auto It = DefMap.find(Block); + Out << Block->getName() << " : "; + if (It == DefMap.end()) { + Out << "\n"; + } else { + const auto *DefBlock = It->second; + Out << (DefBlock ? DefBlock->getName() : "") << "\n"; + } + } + Out << "}\n"; + } + + // process @succBlock with reaching definition @defBlock + // the original divergent branch was in @parentLoop (if any) + void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop, + const BasicBlock &DefBlock) { + + // @succBlock is a loop exit + if (ParentLoop && !ParentLoop->contains(&SuccBlock)) { + DefMap.emplace(&SuccBlock, &DefBlock); + ReachedLoopExits.insert(&SuccBlock); + return; + } + + // first reaching def? + auto ItLastDef = DefMap.find(&SuccBlock); + if (ItLastDef == DefMap.end()) { + addPending(SuccBlock, DefBlock); + return; + } + + // a join of at least two definitions + if (ItLastDef->second != &DefBlock) { + // do we know this join already? + if (!JoinBlocks->insert(&SuccBlock).second) + return; + + // update the definition + addPending(SuccBlock, SuccBlock); + } + } + + // find all blocks reachable by two disjoint paths from @rootTerm. + // This method works for both divergent TerminatorInsts and loops with + // divergent exits. + // @rootBlock is either the block containing the branch or the header of the + // divergent loop. + // @nodeSuccessors is the set of successors of the node (Loop or Terminator) + // headed by @rootBlock. + // @parentLoop is the parent loop of the Loop or the loop that contains the + // Terminator. + template + std::unique_ptr + computeJoinPoints(const BasicBlock &RootBlock, + SuccessorIterable NodeSuccessors, const Loop *ParentLoop) { + assert(JoinBlocks); + + // immediate post dominator (no join block beyond that block) + const auto *PdNode = PDT.getNode(const_cast(&RootBlock)); + const auto *IpdNode = PdNode->getIDom(); + const auto *PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr; + + // bootstrap with branch targets + for (const auto *SuccBlock : NodeSuccessors) { + DefMap.emplace(SuccBlock, SuccBlock); + + if (ParentLoop && !ParentLoop->contains(SuccBlock)) { + // immediate loop exit from node. + ReachedLoopExits.insert(SuccBlock); + continue; + } else { + // regular successor + PendingUpdates.insert(SuccBlock); + } + } + + auto ItBeginRPO = FuncRPOT.begin(); + + // skip until term (TODO RPOT won't let us start at @term directly) + for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) {} + + auto ItEndRPO = FuncRPOT.end(); + assert(ItBeginRPO != ItEndRPO); + + // propagate definitions at the immediate successors of the node in RPO + auto ItBlockRPO = ItBeginRPO; + while (++ItBlockRPO != ItEndRPO && *ItBlockRPO != PdBoundBlock) { + const auto *Block = *ItBlockRPO; + + // skip @block if not pending update + auto ItPending = PendingUpdates.find(Block); + if (ItPending == PendingUpdates.end()) + continue; + PendingUpdates.erase(ItPending); + + // propagate definition at @block to its successors + auto ItDef = DefMap.find(Block); + const auto *DefBlock = ItDef->second; + assert(DefBlock); + + auto *BlockLoop = LI.getLoopFor(Block); + if (ParentLoop && + (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) { + // if the successor is the header of a nested loop pretend its a + // single node with the loop's exits as successors + SmallVector BlockLoopExits; + BlockLoop->getExitBlocks(BlockLoopExits); + for (const auto *BlockLoopExit : BlockLoopExits) { + visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock); + } + + } else { + // the successors are either on the same loop level or loop exits + for (const auto *SuccBlock : successors(Block)) { + visitSuccessor(*SuccBlock, ParentLoop, *DefBlock); + } + } + } + + // We need to know the definition at the parent loop header to decide + // whether the definition at the header is different from the definition at + // the loop exits, which would indicate a divergent loop exits. + // + // A // loop header + // | + // B // nested loop header + // | + // C -> X (exit from B loop) -..-> (A latch) + // | + // D -> back to B (B latch) + // | + // proper exit from both loops + // + // D post-dominates B as it is the only proper exit from the "A loop". + // If C has a divergent branch, propagation will therefore stop at D. + // That implies that B will never receive a definition. + // But that definition can only be the same as at D (D itself in thise case) + // because all paths to anywhere have to pass through D. + // + const BasicBlock *ParentLoopHeader = + ParentLoop ? ParentLoop->getHeader() : nullptr; + if (ParentLoop && ParentLoop->contains(PdBoundBlock)) { + DefMap[ParentLoopHeader] = DefMap[PdBoundBlock]; + } + + // analyze reached loop exits + if (!ReachedLoopExits.empty()) { + assert(ParentLoop); + const auto *HeaderDefBlock = DefMap[ParentLoopHeader]; + LLVM_DEBUG(printDefs(dbgs())); + assert(HeaderDefBlock && "no definition in header of carrying loop"); + + for (const auto *ExitBlock : ReachedLoopExits) { + auto ItExitDef = DefMap.find(ExitBlock); + assert((ItExitDef != DefMap.end()) && + "no reaching def at reachable loop exit"); + if (ItExitDef->second != HeaderDefBlock) { + JoinBlocks->insert(ExitBlock); + } + } + } + + return std::move(JoinBlocks); + } +}; + +const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) { + using LoopExitVec = SmallVector; + LoopExitVec LoopExits; + Loop.getExitBlocks(LoopExits); + if (LoopExits.size() < 1) { + return EmptyBlockSet; + } + + // already available in cache? + auto ItCached = CachedLoopExitJoins.find(&Loop); + if (ItCached != CachedLoopExitJoins.end()) + return *ItCached->second; + + // compute all join points + DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; + auto JoinBlocks = Propagator.computeJoinPoints( + *Loop.getHeader(), LoopExits, Loop.getParentLoop()); + + auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks)); + assert(ItInserted.second); + return *ItInserted.first->second; +} + +const ConstBlockSet & +SyncDependenceAnalysis::join_blocks(const TerminatorInst &Term) { + // trivial case + if (Term.getNumSuccessors() < 1) { + return EmptyBlockSet; + } + + // already available in cache? + auto ItCached = CachedBranchJoins.find(&Term); + if (ItCached != CachedBranchJoins.end()) + return *ItCached->second; + + // compute all join points + DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; + const auto &TermBlock = *Term.getParent(); + auto JoinBlocks = Propagator.computeJoinPoints( + TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock)); + + auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks)); + assert(ItInserted.second); + return *ItInserted.first->second; +} + +} // namespace llvm Index: llvm/trunk/unittests/Analysis/CMakeLists.txt =================================================================== --- llvm/trunk/unittests/Analysis/CMakeLists.txt +++ llvm/trunk/unittests/Analysis/CMakeLists.txt @@ -14,6 +14,7 @@ CallGraphTest.cpp CFGTest.cpp CGSCCPassManagerTest.cpp + DivergenceAnalysisTest.cpp GlobalsModRefTest.cpp ValueLatticeTest.cpp LazyCallGraphTest.cpp Index: llvm/trunk/unittests/Analysis/DivergenceAnalysisTest.cpp =================================================================== --- llvm/trunk/unittests/Analysis/DivergenceAnalysisTest.cpp +++ llvm/trunk/unittests/Analysis/DivergenceAnalysisTest.cpp @@ -0,0 +1,431 @@ +//===- DivergenceAnalysisTest.cpp - DivergenceAnalysis unit tests ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/SyncDependenceAnalysis.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +namespace llvm { +namespace { + +BasicBlock *GetBlockByName(StringRef BlockName, Function &F) { + for (auto &BB : F) { + if (BB.getName() != BlockName) + continue; + return &BB; + } + return nullptr; +} + +// We use this fixture to ensure that we clean up DivergenceAnalysis before +// deleting the PassManager. +class DivergenceAnalysisTest : public testing::Test { +protected: + LLVMContext Context; + Module M; + TargetLibraryInfoImpl TLII; + TargetLibraryInfo TLI; + + std::unique_ptr DT; + std::unique_ptr PDT; + std::unique_ptr LI; + std::unique_ptr SDA; + + DivergenceAnalysisTest() : M("", Context), TLII(), TLI(TLII) {} + + DivergenceAnalysis buildDA(Function &F, bool IsLCSSA) { + DT.reset(new DominatorTree(F)); + PDT.reset(new PostDominatorTree(F)); + LI.reset(new LoopInfo(*DT)); + SDA.reset(new SyncDependenceAnalysis(*DT, *PDT, *LI)); + return DivergenceAnalysis(F, nullptr, *DT, *LI, *SDA, IsLCSSA); + } + + void runWithDA( + Module &M, StringRef FuncName, bool IsLCSSA, + function_ref + Test) { + auto *F = M.getFunction(FuncName); + ASSERT_NE(F, nullptr) << "Could not find " << FuncName; + DivergenceAnalysis DA = buildDA(*F, IsLCSSA); + Test(*F, *LI, DA); + } +}; + +// Simple initial state test +TEST_F(DivergenceAnalysisTest, DAInitialState) { + IntegerType *IntTy = IntegerType::getInt32Ty(Context); + FunctionType *FTy = + FunctionType::get(Type::getVoidTy(Context), {IntTy}, false); + Function *F = cast(M.getOrInsertFunction("f", FTy)); + BasicBlock *BB = BasicBlock::Create(Context, "entry", F); + ReturnInst::Create(Context, nullptr, BB); + + DivergenceAnalysis DA = buildDA(*F, false); + + // Whole function region + EXPECT_EQ(DA.getRegionLoop(), nullptr); + + // No divergence in initial state + EXPECT_FALSE(DA.hasDetectedDivergence()); + + // No spurious divergence + DA.compute(); + EXPECT_FALSE(DA.hasDetectedDivergence()); + + // Detected divergence after marking + Argument &arg = *F->arg_begin(); + DA.markDivergent(arg); + + EXPECT_TRUE(DA.hasDetectedDivergence()); + EXPECT_TRUE(DA.isDivergent(arg)); + + DA.compute(); + EXPECT_TRUE(DA.hasDetectedDivergence()); + EXPECT_TRUE(DA.isDivergent(arg)); +} + +TEST_F(DivergenceAnalysisTest, DANoLCSSA) { + LLVMContext C; + SMDiagnostic Err; + + std::unique_ptr M = parseAssemblyString( + "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" " + " " + "define i32 @f_1(i8* nocapture %arr, i32 %n, i32* %A, i32* %B) " + " local_unnamed_addr { " + "entry: " + " br label %loop.ph " + " " + "loop.ph: " + " br label %loop " + " " + "loop: " + " %iv0 = phi i32 [ %iv0.inc, %loop ], [ 0, %loop.ph ] " + " %iv1 = phi i32 [ %iv1.inc, %loop ], [ -2147483648, %loop.ph ] " + " %iv0.inc = add i32 %iv0, 1 " + " %iv1.inc = add i32 %iv1, 3 " + " %cond.cont = icmp slt i32 %iv0, %n " + " br i1 %cond.cont, label %loop, label %for.end.loopexit " + " " + "for.end.loopexit: " + " ret i32 %iv0 " + "} ", + Err, C); + + Function *F = M->getFunction("f_1"); + DivergenceAnalysis DA = buildDA(*F, false); + EXPECT_FALSE(DA.hasDetectedDivergence()); + + auto ItArg = F->arg_begin(); + ItArg++; + auto &NArg = *ItArg; + + // Seed divergence in argument %n + DA.markDivergent(NArg); + + DA.compute(); + EXPECT_TRUE(DA.hasDetectedDivergence()); + + // Verify that "ret %iv.0" is divergent + auto ItBlock = F->begin(); + std::advance(ItBlock, 3); + auto &ExitBlock = *GetBlockByName("for.end.loopexit", *F); + auto &RetInst = *cast(ExitBlock.begin()); + EXPECT_TRUE(DA.isDivergent(RetInst)); +} + +TEST_F(DivergenceAnalysisTest, DALCSSA) { + LLVMContext C; + SMDiagnostic Err; + + std::unique_ptr M = parseAssemblyString( + "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" " + " " + "define i32 @f_lcssa(i8* nocapture %arr, i32 %n, i32* %A, i32* %B) " + " local_unnamed_addr { " + "entry: " + " br label %loop.ph " + " " + "loop.ph: " + " br label %loop " + " " + "loop: " + " %iv0 = phi i32 [ %iv0.inc, %loop ], [ 0, %loop.ph ] " + " %iv1 = phi i32 [ %iv1.inc, %loop ], [ -2147483648, %loop.ph ] " + " %iv0.inc = add i32 %iv0, 1 " + " %iv1.inc = add i32 %iv1, 3 " + " %cond.cont = icmp slt i32 %iv0, %n " + " br i1 %cond.cont, label %loop, label %for.end.loopexit " + " " + "for.end.loopexit: " + " %val.ret = phi i32 [ %iv0, %loop ] " + " br label %detached.return " + " " + "detached.return: " + " ret i32 %val.ret " + "} ", + Err, C); + + Function *F = M->getFunction("f_lcssa"); + DivergenceAnalysis DA = buildDA(*F, true); + EXPECT_FALSE(DA.hasDetectedDivergence()); + + auto ItArg = F->arg_begin(); + ItArg++; + auto &NArg = *ItArg; + + // Seed divergence in argument %n + DA.markDivergent(NArg); + + DA.compute(); + EXPECT_TRUE(DA.hasDetectedDivergence()); + + // Verify that "ret %iv.0" is divergent + auto ItBlock = F->begin(); + std::advance(ItBlock, 4); + auto &ExitBlock = *GetBlockByName("detached.return", *F); + auto &RetInst = *cast(ExitBlock.begin()); + EXPECT_TRUE(DA.isDivergent(RetInst)); +} + +TEST_F(DivergenceAnalysisTest, DAJoinDivergence) { + LLVMContext C; + SMDiagnostic Err; + + std::unique_ptr M = parseAssemblyString( + "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" " + " " + "define void @f_1(i1 %a, i1 %b, i1 %c) " + " local_unnamed_addr { " + "A: " + " br i1 %a, label %B, label %C " + " " + "B: " + " br i1 %b, label %C, label %D " + " " + "C: " + " %c.join = phi i32 [ 0, %A ], [ 1, %B ] " + " br i1 %c, label %D, label %E " + " " + "D: " + " %d.join = phi i32 [ 0, %B ], [ 1, %C ] " + " br label %E " + " " + "E: " + " %e.join = phi i32 [ 0, %C ], [ 1, %D ] " + " ret void " + "} " + " " + "define void @f_2(i1 %a, i1 %b, i1 %c) " + " local_unnamed_addr { " + "A: " + " br i1 %a, label %B, label %E " + " " + "B: " + " br i1 %b, label %C, label %D " + " " + "C: " + " br label %D " + " " + "D: " + " %d.join = phi i32 [ 0, %B ], [ 1, %C ] " + " br label %E " + " " + "E: " + " %e.join = phi i32 [ 0, %A ], [ 1, %D ] " + " ret void " + "} " + " " + "define void @f_3(i1 %a, i1 %b, i1 %c)" + " local_unnamed_addr { " + "A: " + " br i1 %a, label %B, label %C " + " " + "B: " + " br label %C " + " " + "C: " + " %c.join = phi i32 [ 0, %A ], [ 1, %B ] " + " br i1 %c, label %D, label %E " + " " + "D: " + " br label %E " + " " + "E: " + " %e.join = phi i32 [ 0, %C ], [ 1, %D ] " + " ret void " + "} ", + Err, C); + + // Maps divergent conditions to the basic blocks whose Phi nodes become + // divergent. Blocks need to be listed in IR order. + using SmallBlockVec = SmallVector; + using InducedDivJoinMap = std::map; + + // Actual function performing the checks. + auto CheckDivergenceFunc = [this](Function &F, + InducedDivJoinMap &ExpectedDivJoins) { + for (auto &ItCase : ExpectedDivJoins) { + auto *DivVal = ItCase.first; + auto DA = buildDA(F, false); + DA.markDivergent(*DivVal); + DA.compute(); + + // List of basic blocks that shall host divergent Phi nodes. + auto ItDivJoins = ItCase.second.begin(); + + for (auto &BB : F) { + auto *Phi = dyn_cast(BB.begin()); + if (!Phi) + continue; + + if (&BB == *ItDivJoins) { + EXPECT_TRUE(DA.isDivergent(*Phi)); + // Advance to next block with expected divergent PHI node. + ++ItDivJoins; + } else { + EXPECT_FALSE(DA.isDivergent(*Phi)); + } + } + } + }; + + { + auto *F = M->getFunction("f_1"); + auto ItBlocks = F->begin(); + ItBlocks++; // Skip A + ItBlocks++; // Skip B + auto *C = &*ItBlocks++; + auto *D = &*ItBlocks++; + auto *E = &*ItBlocks; + + auto ItArg = F->arg_begin(); + auto *AArg = &*ItArg++; + auto *BArg = &*ItArg++; + auto *CArg = &*ItArg; + + InducedDivJoinMap DivJoins; + DivJoins.emplace(AArg, SmallBlockVec({C, D, E})); + DivJoins.emplace(BArg, SmallBlockVec({D, E})); + DivJoins.emplace(CArg, SmallBlockVec({E})); + + CheckDivergenceFunc(*F, DivJoins); + } + + { + auto *F = M->getFunction("f_2"); + auto ItBlocks = F->begin(); + ItBlocks++; // Skip A + ItBlocks++; // Skip B + ItBlocks++; // Skip C + auto *D = &*ItBlocks++; + auto *E = &*ItBlocks; + + auto ItArg = F->arg_begin(); + auto *AArg = &*ItArg++; + auto *BArg = &*ItArg++; + auto *CArg = &*ItArg; + + InducedDivJoinMap DivJoins; + DivJoins.emplace(AArg, SmallBlockVec({E})); + DivJoins.emplace(BArg, SmallBlockVec({D})); + DivJoins.emplace(CArg, SmallBlockVec({})); + + CheckDivergenceFunc(*F, DivJoins); + } + + { + auto *F = M->getFunction("f_3"); + auto ItBlocks = F->begin(); + ItBlocks++; // Skip A + ItBlocks++; // Skip B + auto *C = &*ItBlocks++; + ItBlocks++; // Skip D + auto *E = &*ItBlocks; + + auto ItArg = F->arg_begin(); + auto *AArg = &*ItArg++; + auto *BArg = &*ItArg++; + auto *CArg = &*ItArg; + + InducedDivJoinMap DivJoins; + DivJoins.emplace(AArg, SmallBlockVec({C})); + DivJoins.emplace(BArg, SmallBlockVec({})); + DivJoins.emplace(CArg, SmallBlockVec({E})); + + CheckDivergenceFunc(*F, DivJoins); + } +} + +TEST_F(DivergenceAnalysisTest, DASwitchUnreachableDefault) { + LLVMContext C; + SMDiagnostic Err; + + std::unique_ptr M = parseAssemblyString( + "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" " + " " + "define void @switch_unreachable_default(i32 %cond) local_unnamed_addr { " + "entry: " + " switch i32 %cond, label %sw.default [ " + " i32 0, label %sw.bb0 " + " i32 1, label %sw.bb1 " + " ] " + " " + "sw.bb0: " + " br label %sw.epilog " + " " + "sw.bb1: " + " br label %sw.epilog " + " " + "sw.default: " + " unreachable " + " " + "sw.epilog: " + " %div.dbl = phi double [ 0.0, %sw.bb0], [ -1.0, %sw.bb1 ] " + " ret void " + "}", + Err, C); + + auto *F = M->getFunction("switch_unreachable_default"); + auto &CondArg = *F->arg_begin(); + auto DA = buildDA(*F, false); + + EXPECT_FALSE(DA.hasDetectedDivergence()); + + DA.markDivergent(CondArg); + DA.compute(); + + // Still %CondArg is divergent. + EXPECT_TRUE(DA.hasDetectedDivergence()); + + // The join uni.dbl is not divergent (see D52221) + auto &ExitBlock = *GetBlockByName("sw.epilog", *F); + auto &DivDblPhi = *cast(ExitBlock.begin()); + EXPECT_TRUE(DA.isDivergent(DivDblPhi)); +} + +} // end anonymous namespace +} // end namespace llvm