diff --git a/llvm/include/llvm/Analysis/DivergenceAnalysis.h b/llvm/include/llvm/Analysis/DivergenceAnalysis.h --- a/llvm/include/llvm/Analysis/DivergenceAnalysis.h +++ b/llvm/include/llvm/Analysis/DivergenceAnalysis.h @@ -59,8 +59,10 @@ /// \brief Mark \p UniVal as a value that is always uniform. void addUniformOverride(const Value &UniVal); - /// \brief Mark \p DivVal as a value that is always divergent. - void markDivergent(const Value &DivVal); + /// \brief Mark \p DivVal as a value that is always divergent. Will not do so + /// if `isAlwaysUniform(DivVal)`. + /// \returns Whether the tracked divergence state of \p DivVal changed. + bool markDivergent(const Value &DivVal); /// \brief Propagate divergence to all instructions in the region. /// Divergence is seeded by calls to \p markDivergent. @@ -76,45 +78,38 @@ /// \brief Whether \p Val is divergent at its definition. bool isDivergent(const Value &Val) const; - /// \brief Whether \p U is divergent. Uses of a uniform value can be divergent. + /// \brief Whether \p U is divergent. Uses of a uniform value can be + /// divergent. bool isDivergentUse(const Use &U) const; void print(raw_ostream &OS, const Module *) const; private: - bool updateTerminator(const Instruction &Term) const; - bool updatePHINode(const PHINode &Phi) const; - - /// \brief Computes whether \p Inst is divergent based on the - /// divergence of its operands. - /// - /// \returns Whether \p Inst is divergent. - /// - /// This should only be called for non-phi, non-terminator instructions. - bool updateNormalInstruction(const Instruction &Inst) const; - - /// \brief Mark users of live-out users as divergent. - /// - /// \param LoopHeader the header of the divergent loop. - /// - /// Marks all users of live-out values of the loop headed by \p LoopHeader - /// as divergent and puts them on the worklist. - void taintLoopLiveOuts(const BasicBlock &LoopHeader); - - /// \brief Push all users of \p Val (in the region) to the worklist + /// \brief Mark \p Term as divergent and push all Instructions that become + /// divergent as a result on the worklist. + void analyzeControlDivergence(const Instruction &Term); + /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on + /// the worklist. + void taintAndPushPhiNodes(const BasicBlock &JoinBlock); + + /// \brief Identify all Instructions that become divergent because \p DivExit + /// is a divergent loop exit of \p DivLoop. Mark those instructions as + /// divergent and push them on the worklist. + void propagateLoopExitDivergence(const BasicBlock &DivExit, + const Loop &DivLoop); + + /// \brief Internal implementation function for propagateLoopExitDivergence. + void analyzeLoopExitDivergence(const BasicBlock &DivExit, + const Loop &OuterDivLoop); + + /// \brief Mark all instruction as divergent that use a value defined in \p + /// OuterDivLoop. Push their users on the worklist. + void analyzeTemporalDivergence(const Instruction &I, + const Loop &OuterDivLoop); + + /// \brief Push all users of \p Val (in the region) to the worklist. void pushUsers(const Value &I); - /// \brief Push all phi nodes in @block to the worklist - void pushPHINodes(const BasicBlock &Block); - - /// \brief Mark \p Block as join divergent - /// - /// A block is join divergent if two threads may reach it from different - /// incoming blocks at the same time. - void markBlockJoinDivergent(const BasicBlock &Block) { - DivergentJoinBlocks.insert(&Block); - } - /// \brief Whether \p Val is divergent when read in \p ObservingBlock. bool isTemporalDivergent(const BasicBlock &ObservingBlock, const Value &Val) const; @@ -126,24 +121,6 @@ return DivergentJoinBlocks.find(&Block) != DivergentJoinBlocks.end(); } - /// \brief Propagate control-induced divergence to users (phi nodes and - /// instructions). - // - // \param JoinBlock is a divergent loop exit or join point of two disjoint - // paths. - // \returns Whether \p JoinBlock is a divergent loop exit of \p TermLoop. - bool propagateJoinDivergence(const BasicBlock &JoinBlock, - const Loop *TermLoop); - - /// \brief Propagate induced value divergence due to control divergence in \p - /// Term. - void propagateBranchDivergence(const Instruction &Term); - - /// \brief Propagate divergent caused by a divergent loop exit. - /// - /// \param ExitingLoop is a divergent loop. - void propagateLoopDivergence(const Loop &ExitingLoop); - private: const Function &F; // If regionLoop != nullptr, analysis is only performed within \p RegionLoop. @@ -166,7 +143,7 @@ DenseSet UniformOverrides; // Blocks with joining divergent control from different predecessors. - DenseSet DivergentJoinBlocks; + DenseSet DivergentJoinBlocks; // FIXME Deprecated // Detected/marked divergent values. DenseSet DivergentValues; diff --git a/llvm/include/llvm/Analysis/SyncDependenceAnalysis.h b/llvm/include/llvm/Analysis/SyncDependenceAnalysis.h --- a/llvm/include/llvm/Analysis/SyncDependenceAnalysis.h +++ b/llvm/include/llvm/Analysis/SyncDependenceAnalysis.h @@ -21,6 +21,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/LoopInfo.h" #include +#include namespace llvm { @@ -30,6 +31,26 @@ class PostDominatorTree; using ConstBlockSet = SmallPtrSet; +struct ControlDivergenceDesc { + // Join points of divergent disjoint paths. + ConstBlockSet JoinDivBlocks; + // Divergent loop exits + ConstBlockSet LoopDivBlocks; +}; + +struct ModifiedPO { + std::vector LoopPO; + std::unordered_map POIndex; + void appendBlock(const BasicBlock &BB) { + POIndex[&BB] = LoopPO.size(); + LoopPO.push_back(&BB); + } + unsigned getIndexOf(const BasicBlock &BB) const { + return POIndex.find(&BB)->second; + } + unsigned size() const { return LoopPO.size(); } + const BasicBlock *getBlockAt(unsigned Idx) const { return LoopPO[Idx]; } +}; /// \brief Relates points of divergent control to join points in /// reducible CFGs. @@ -51,28 +72,19 @@ /// header. Those exit blocks are added to the returned set. /// If L is the parent loop of \p Term and an exit of L is in the returned /// set then L is a divergent loop. - const ConstBlockSet &join_blocks(const Instruction &Term); - - /// \brief Computes divergent join points and loop exits (in the surrounding - /// loop) caused by the divergent loop exits of\p Loop. - /// - /// The set of blocks which are reachable by disjoint paths from the - /// loop exits of \p Loop. - /// This treats the loop as a single node in \p Loop's parent loop. - /// The returned set has the same properties as for join_blocks(TermInst&). - const ConstBlockSet &join_blocks(const Loop &Loop); + const ControlDivergenceDesc &getJoinBlocks(const Instruction &Term); private: - static ConstBlockSet EmptyBlockSet; + static ControlDivergenceDesc EmptyDivergenceDesc; + + ModifiedPO LoopPO; - ReversePostOrderTraversal FuncRPOT; const DominatorTree &DT; const PostDominatorTree &PDT; const LoopInfo &LI; - std::map> CachedLoopExitJoins; - std::map> - CachedBranchJoins; + std::map> + CachedControlDivDescs; }; } // namespace llvm diff --git a/llvm/lib/Analysis/DivergenceAnalysis.cpp b/llvm/lib/Analysis/DivergenceAnalysis.cpp --- a/llvm/lib/Analysis/DivergenceAnalysis.cpp +++ b/llvm/lib/Analysis/DivergenceAnalysis.cpp @@ -1,4 +1,4 @@ -//===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==// +//===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -97,42 +97,18 @@ : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), IsLCSSAForm(IsLCSSAForm) {} -void DivergenceAnalysis::markDivergent(const Value &DivVal) { +bool DivergenceAnalysis::markDivergent(const Value &DivVal) { + if (isAlwaysUniform(DivVal)) + return false; assert(isa(DivVal) || isa(DivVal)); assert(!isAlwaysUniform(DivVal) && "cannot be a divergent"); - DivergentValues.insert(&DivVal); + return DivergentValues.insert(&DivVal).second; } void DivergenceAnalysis::addUniformOverride(const Value &UniVal) { UniformOverrides.insert(&UniVal); } -bool DivergenceAnalysis::updateTerminator(const Instruction &Term) const { - if (Term.getNumSuccessors() <= 1) - return false; - if (auto *BranchTerm = dyn_cast(&Term)) { - assert(BranchTerm->isConditional()); - return isDivergent(*BranchTerm->getCondition()); - } - if (auto *SwitchTerm = dyn_cast(&Term)) { - return isDivergent(*SwitchTerm->getCondition()); - } - if (isa(Term)) { - return false; // ignore abnormal executions through landingpad - } - - llvm_unreachable("unexpected terminator"); -} - -bool DivergenceAnalysis::updateNormalInstruction(const Instruction &I) const { - // TODO function calls with side effects, etc - for (const auto &Op : I.operands()) { - if (isDivergent(*Op)) - return true; - } - return false; -} - bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock, const Value &Val) const { const auto *Inst = dyn_cast(&Val); @@ -150,32 +126,6 @@ return false; } -bool DivergenceAnalysis::updatePHINode(const PHINode &Phi) const { - // joining divergent disjoint path in Phi parent block - if (!Phi.hasConstantOrUndefValue() && isJoinDivergent(*Phi.getParent())) { - return true; - } - - // An incoming value could be divergent by itself. - // Otherwise, an incoming value could be uniform within the loop - // that carries its definition but it may appear divergent - // from outside the loop. This happens when divergent loop exits - // drop definitions of that uniform value in different iterations. - // - // for (int i = 0; i < n; ++i) { // 'i' is uniform inside the loop - // if (i % thread_id == 0) break; // divergent loop exit - // } - // int divI = i; // divI is divergent - for (size_t i = 0; i < Phi.getNumIncomingValues(); ++i) { - const auto *InVal = Phi.getIncomingValue(i); - if (isDivergent(*Phi.getIncomingValue(i)) || - isTemporalDivergent(*Phi.getParent(), *InVal)) { - return true; - } - } - return false; -} - bool DivergenceAnalysis::inRegion(const Instruction &I) const { return I.getParent() && inRegion(*I.getParent()); } @@ -184,35 +134,82 @@ return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB); } -static bool usesLiveOut(const Instruction &I, const Loop *DivLoop) { - for (auto &Op : I.operands()) { - auto *OpInst = dyn_cast(&Op); +void DivergenceAnalysis::pushUsers(const Value &V) { + const auto *I = dyn_cast(&V); + + if (I && I->isTerminator()) { + analyzeControlDivergence(*I); + return; + } + + for (const auto *User : V.users()) { + const auto *UserInst = dyn_cast(User); + if (!UserInst) + continue; + + // only compute divergent inside loop + if (!inRegion(*UserInst)) + continue; + + // All users of divergent values are immediate divergent + if (markDivergent(*UserInst)) + Worklist.push_back(UserInst); + } +} + +static const Instruction *getIfCarriedInstruction(const Use &U, + const Loop &DivLoop) { + const auto *I = dyn_cast(&U); + if (!I) + return nullptr; + if (!DivLoop.contains(I)) + return nullptr; + return I; +} + +void DivergenceAnalysis::analyzeTemporalDivergence(const Instruction &I, + const Loop &OuterDivLoop) { + if (isAlwaysUniform(I)) + return; + if (isDivergent(I)) + return; + + LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n"); + assert((isa(I) || !IsLCSSAForm) && + "In LCSSA form all users of loop-exiting defs are Phi nodes."); + for (const Use &Op : I.operands()) { + const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop); if (!OpInst) continue; - if (DivLoop->contains(OpInst->getParent())) - return true; + if (markDivergent(I)) + pushUsers(I); + return; } - return false; } // marks all users of loop-carried values of the loop headed by LoopHeader as // divergent -void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) { - auto *DivLoop = LI.getLoopFor(&LoopHeader); - assert(DivLoop && "loopHeader is not actually part of a loop"); +void DivergenceAnalysis::analyzeLoopExitDivergence(const BasicBlock &DivExit, + const Loop &OuterDivLoop) { + // All users are in immediate exit blocks + if (IsLCSSAForm) { + for (const auto &Phi : DivExit.phis()) { + analyzeTemporalDivergence(Phi, OuterDivLoop); + } + return; + } - SmallVector TaintStack; - DivLoop->getExitBlocks(TaintStack); + // For non-LCSSA we have to follow all live out edges wherever they may lead. + const BasicBlock &LoopHeader = *OuterDivLoop.getHeader(); + SmallVector TaintStack; + TaintStack.push_back(&DivExit); // Otherwise potential users of loop-carried values could be anywhere in the // dominance region of DivLoop (including its fringes for phi nodes) DenseSet Visited; - for (auto *Block : TaintStack) { - Visited.insert(Block); - } - Visited.insert(&LoopHeader); + Visited.insert(&DivExit); - while (!TaintStack.empty()) { + do { auto *UserBlock = TaintStack.back(); TaintStack.pop_back(); @@ -220,33 +217,21 @@ if (!inRegion(*UserBlock)) continue; - assert(!DivLoop->contains(UserBlock) && + assert(!OuterDivLoop.contains(UserBlock) && "irreducible control flow detected"); // phi nodes at the fringes of the dominance region if (!DT.dominates(&LoopHeader, UserBlock)) { // all PHI nodes of UserBlock become divergent for (auto &Phi : UserBlock->phis()) { - Worklist.push_back(&Phi); + analyzeTemporalDivergence(Phi, OuterDivLoop); } continue; } - // taint outside users of values carried by DivLoop + // Taint outside users of values carried by OuterDivLoop. for (auto &I : *UserBlock) { - if (isAlwaysUniform(I)) - continue; - if (isDivergent(I)) - continue; - if (!usesLiveOut(I, DivLoop)) - continue; - - markDivergent(I); - if (I.isTerminator()) { - propagateBranchDivergence(I); - } else { - pushUsers(I); - } + analyzeTemporalDivergence(I, OuterDivLoop); } // visit all blocks in the dominance region @@ -256,56 +241,57 @@ } TaintStack.push_back(SuccBlock); } - } + } while (!TaintStack.empty()); } -void DivergenceAnalysis::pushPHINodes(const BasicBlock &Block) { - for (const auto &Phi : Block.phis()) { - if (isDivergent(Phi)) - continue; - Worklist.push_back(&Phi); +void DivergenceAnalysis::propagateLoopExitDivergence(const BasicBlock &DivExit, + const Loop &InnerDivLoop) { + LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n"); + + // Find outer-most loop that does not contain \p DivExit + const Loop *DivLoop = &InnerDivLoop; + const Loop *OuterDivLoop = DivLoop; + const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit); + const unsigned LoopExitDepth = + ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0; + while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) { + DivergentLoops.insert(DivLoop); // all crossed loops are divergent + OuterDivLoop = DivLoop; + DivLoop = DivLoop->getParentLoop(); } -} - -void DivergenceAnalysis::pushUsers(const Value &V) { - for (const auto *User : V.users()) { - const auto *UserInst = dyn_cast(User); - if (!UserInst) - continue; - - if (isDivergent(*UserInst)) - continue; + LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName() + << "\n"); - // only compute divergent inside loop - if (!inRegion(*UserInst)) - continue; - Worklist.push_back(UserInst); - } + analyzeLoopExitDivergence(DivExit, *OuterDivLoop); } -bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock &JoinBlock, - const Loop *BranchLoop) { - LLVM_DEBUG(dbgs() << "\tpropJoinDiv " << JoinBlock.getName() << "\n"); +// this is a divergent join point - mark all phi nodes as divergent and push +// them onto the stack. +void DivergenceAnalysis::taintAndPushPhiNodes(const BasicBlock &JoinBlock) { + LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName() + << "\n"); // ignore divergence outside the region if (!inRegion(JoinBlock)) { - return false; + return; } // push non-divergent phi nodes in JoinBlock to the worklist - pushPHINodes(JoinBlock); - - // disjoint-paths divergent at JoinBlock - markBlockJoinDivergent(JoinBlock); - - // JoinBlock is a divergent loop exit - return BranchLoop && !BranchLoop->contains(&JoinBlock); + for (const auto &Phi : JoinBlock.phis()) { + if (isDivergent(Phi)) + continue; + // FIXME Theoretically ,the 'undef' value could be replaced by any other + // value causing spurious divergence. + if (Phi.hasConstantOrUndefValue()) + continue; + if (markDivergent(Phi)) + Worklist.push_back(&Phi); + } } -void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) { - LLVM_DEBUG(dbgs() << "propBranchDiv " << Term.getParent()->getName() << "\n"); - - markDivergent(Term); +void DivergenceAnalysis::analyzeControlDivergence(const Instruction &Term) { + LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName() + << "\n"); // Don't propagate divergence from unreachable blocks. if (!DT.isReachableFromEntry(Term.getParent())) @@ -313,104 +299,36 @@ const auto *BranchLoop = LI.getLoopFor(Term.getParent()); - // whether there is a divergent loop exit from BranchLoop (if any) - bool IsBranchLoopDivergent = false; + const auto &DivDesc = SDA.getJoinBlocks(Term); - // iterate over all blocks reachable by disjoint from Term within the loop - // also iterates over loop exits that become divergent due to Term. - for (const auto *JoinBlock : SDA.join_blocks(Term)) { - IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop); + // Iterate over all blocks now reachable by a disjoint path join + for (const auto *JoinBlock : DivDesc.JoinDivBlocks) { + taintAndPushPhiNodes(*JoinBlock); } - // Branch loop is a divergent loop due to the divergent branch in Term - if (IsBranchLoopDivergent) { - assert(BranchLoop); - if (!DivergentLoops.insert(BranchLoop).second) { - return; - } - propagateLoopDivergence(*BranchLoop); - } -} - -void DivergenceAnalysis::propagateLoopDivergence(const Loop &ExitingLoop) { - LLVM_DEBUG(dbgs() << "propLoopDiv " << ExitingLoop.getName() << "\n"); - - // don't propagate beyond region - if (!inRegion(*ExitingLoop.getHeader())) - return; - - const auto *BranchLoop = ExitingLoop.getParentLoop(); - - // Uses of loop-carried values could occur anywhere - // within the dominance region of the definition. All loop-carried - // definitions are dominated by the loop header (reducible control). - // Thus all users have to be in the dominance region of the loop header, - // except PHI nodes that can also live at the fringes of the dom region - // (incoming defining value). - if (!IsLCSSAForm) - taintLoopLiveOuts(*ExitingLoop.getHeader()); - - // whether there is a divergent loop exit from BranchLoop (if any) - bool IsBranchLoopDivergent = false; - - // iterate over all blocks reachable by disjoint paths from exits of - // ExitingLoop also iterates over loop exits (of BranchLoop) that in turn - // become divergent. - for (const auto *JoinBlock : SDA.join_blocks(ExitingLoop)) { - IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop); - } - - // Branch loop is a divergent due to divergent loop exit in ExitingLoop - if (IsBranchLoopDivergent) { - assert(BranchLoop); - if (!DivergentLoops.insert(BranchLoop).second) { - return; - } - propagateLoopDivergence(*BranchLoop); + assert(DivDesc.LoopDivBlocks.empty() || BranchLoop); + for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) { + propagateLoopExitDivergence(*DivExitBlock, *BranchLoop); } } void DivergenceAnalysis::compute() { - for (auto *DivVal : DivergentValues) { + // Initialize worklist. + auto DivValuesCopy = DivergentValues; + for (const auto *DivVal : DivValuesCopy) { + assert(isDivergent(*DivVal) && "Worklist invariant violated!"); pushUsers(*DivVal); } - // propagate divergence + // All values on the Worklist are divergent. + // Their users may not have been updated yed. while (!Worklist.empty()) { const Instruction &I = *Worklist.back(); Worklist.pop_back(); - // maintain uniformity of overrides - if (isAlwaysUniform(I)) - continue; - - bool WasDivergent = isDivergent(I); - if (WasDivergent) - continue; - - // propagate divergence caused by terminator - if (I.isTerminator()) { - if (updateTerminator(I)) { - // propagate control divergence to affected instructions - propagateBranchDivergence(I); - continue; - } - } - - // update divergence of I due to divergent operands - bool DivergentUpd = false; - const auto *Phi = dyn_cast(&I); - if (Phi) { - DivergentUpd = updatePHINode(*Phi); - } else { - DivergentUpd = updateNormalInstruction(I); - } - // propagate value divergence to users - if (DivergentUpd) { - markDivergent(I); - pushUsers(I); - } + assert(isDivergent(I) && "Worklist invariant violated!"); + pushUsers(I); } } @@ -444,7 +362,7 @@ const PostDominatorTree &PDT, const LoopInfo &LI, const TargetTransformInfo &TTI) - : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, false) { + : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, /* LCSSA */ false) { for (auto &I : instructions(F)) { if (TTI.isSourceOfDivergence(&I)) { DA.markDivergent(I); diff --git a/llvm/lib/Analysis/SyncDependenceAnalysis.cpp b/llvm/lib/Analysis/SyncDependenceAnalysis.cpp --- a/llvm/lib/Analysis/SyncDependenceAnalysis.cpp +++ b/llvm/lib/Analysis/SyncDependenceAnalysis.cpp @@ -1,4 +1,4 @@ -//==- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation -==// +//===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -107,271 +107,353 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include #include #include #define DEBUG_TYPE "sync-dependence" +// The SDA algorithm operates on a modified CFG - we modify the edges leaving +// loop headers as follows: +// +// * We remove all edges leaving all loop headers. +// * We add additional edges from the loop headers to their exit blocks. +// +// The modification is virtual, that is whenever we visit a loop header we +// pretend it had different successors. +namespace { +using namespace llvm; + +// Custom Post-Order Traveral +// +// We cannot use the vanilla (R)PO computation of LLVM because: +// * We (virtually) modify the CFG. +// * We want a loop-compact block enumeration, that is the numbers assigned by +// the traveral to the blocks of a loop are an interval. +using POCB = std::function; +using VisitedSet = std::set; +using BlockStack = std::vector; + +// forward +static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, + VisitedSet &Finalized); + +// for a nested region (top-level loop or nested loop) +static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop, + POCB CallBack, VisitedSet &Finalized) { + const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr; + while (!Stack.empty()) { + const auto *NextBB = Stack.back(); + + auto *NestedLoop = LI.getLoopFor(NextBB); + bool IsNestedLoop = NestedLoop != Loop; + + // Treat the loop as a node + if (IsNestedLoop) { + SmallVector NestedExits; + NestedLoop->getUniqueExitBlocks(NestedExits); + bool PushedNodes = false; + for (const auto *NestedExitBB : NestedExits) { + if (NestedExitBB == LoopHeader) + continue; + if (Loop && !Loop->contains(NestedExitBB)) + continue; + if (Finalized.count(NestedExitBB)) + continue; + PushedNodes = true; + Stack.push_back(NestedExitBB); + } + if (!PushedNodes) { + // All loop exits finalized -> finish this node + Stack.pop_back(); + computeLoopPO(LI, *NestedLoop, CallBack, Finalized); + } + continue; + } + + // DAG-style + bool PushedNodes = false; + for (const auto *SuccBB : successors(NextBB)) { + if (SuccBB == LoopHeader) + continue; + if (Loop && !Loop->contains(SuccBB)) + continue; + if (Finalized.count(SuccBB)) + continue; + PushedNodes = true; + Stack.push_back(SuccBB); + } + if (!PushedNodes) { + // Never push nodes twice + Stack.pop_back(); + if (!Finalized.insert(NextBB).second) + continue; + CallBack(*NextBB); + } + } +} + +static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) { + VisitedSet Finalized; + BlockStack Stack; + Stack.reserve(24); // FIXME made-up number + Stack.push_back(&F.getEntryBlock()); + computeStackPO(Stack, LI, nullptr, CallBack, Finalized); +} + +static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, + VisitedSet &Finalized) { + /// Call CallBack on all loop blocks. + std::vector Stack; + const auto *LoopHeader = Loop.getHeader(); + + // Visit the header last + Finalized.insert(LoopHeader); + CallBack(*LoopHeader); + + // Initialize with immediate successors + for (const auto *BB : successors(LoopHeader)) { + if (!Loop.contains(BB)) + continue; + if (BB == LoopHeader) + continue; + Stack.push_back(BB); + } + + // Compute PO inside region + computeStackPO(Stack, LI, &Loop, CallBack, Finalized); +} + +} // namespace + namespace llvm { -ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet; +ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc; SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, const PostDominatorTree &PDT, const LoopInfo &LI) - : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {} + : DT(DT), PDT(PDT), LI(LI) { + computeTopLevelPO(*DT.getRoot()->getParent(), LI, + [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); }); +} SyncDependenceAnalysis::~SyncDependenceAnalysis() {} -using FunctionRPOT = ReversePostOrderTraversal; - // divergence propagator for reducible CFGs struct DivergencePropagator { - const FunctionRPOT &FuncRPOT; + const ModifiedPO &LoopPOT; const DominatorTree &DT; const PostDominatorTree &PDT; const LoopInfo &LI; - - // identified join points - std::unique_ptr JoinBlocks; - - // reached loop exits (by a path disjoint to a path to the loop header) - SmallPtrSet ReachedLoopExits; - - // if DefMap[B] == C then C is the dominating definition at block B - // if DefMap[B] ~ undef then we haven't seen B yet - // if DefMap[B] == B then B is a join point of disjoint paths from X or B is - // an immediate successor of X (initial value). - using DefiningBlockMap = std::map; - DefiningBlockMap DefMap; - - // all blocks with pending visits - std::unordered_set PendingUpdates; - - DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT, - const PostDominatorTree &PDT, const LoopInfo &LI) - : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI), - JoinBlocks(new ConstBlockSet) {} - - // set the definition at @block and mark @block as pending for a visit - void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) { - bool WasAdded = DefMap.emplace(&Block, &DefBlock).second; - if (WasAdded) - PendingUpdates.insert(&Block); - } + const BasicBlock &DivTermBlock; + + // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at + // block B + // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet + // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths + // from X or B is an immediate successor of X (initial value). + using BlockLabelVec = std::vector; + BlockLabelVec BlockLabels; + // divergent join and loop exit descriptor. + std::unique_ptr DivDesc; + + DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT, + const PostDominatorTree &PDT, const LoopInfo &LI, + const BasicBlock &DivTermBlock) + : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock), + BlockLabels(LoopPOT.size(), nullptr), + DivDesc(new ControlDivergenceDesc) {} void printDefs(raw_ostream &Out) { - Out << "Propagator::DefMap {\n"; - for (const auto *Block : FuncRPOT) { - auto It = DefMap.find(Block); - Out << Block->getName() << " : "; - if (It == DefMap.end()) { - Out << "\n"; + Out << "Propagator::BlockLabels {\n"; + for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) { + const auto *Label = BlockLabels[BlockIdx]; + Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx + << ") : "; + if (!Label) { + Out << "\n"; } else { - const auto *DefBlock = It->second; - Out << (DefBlock ? DefBlock->getName() : "") << "\n"; + Out << Label->getName() << "\n"; } } Out << "}\n"; } - // process @succBlock with reaching definition @defBlock - // the original divergent branch was in @parentLoop (if any) - void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop, - const BasicBlock &DefBlock) { + // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this + // causes a divergent join. + bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) { + auto SuccIdx = LoopPOT.getIndexOf(SuccBlock); - // @succBlock is a loop exit - if (ParentLoop && !ParentLoop->contains(&SuccBlock)) { - DefMap.emplace(&SuccBlock, &DefBlock); - ReachedLoopExits.insert(&SuccBlock); - return; + // unset or same reaching label + const auto *OldLabel = BlockLabels[SuccIdx]; + if (!OldLabel || (OldLabel == &PushedLabel)) { + BlockLabels[SuccIdx] = &PushedLabel; + return false; } - // first reaching def? - auto ItLastDef = DefMap.find(&SuccBlock); - if (ItLastDef == DefMap.end()) { - addPending(SuccBlock, DefBlock); - return; - } + // Update the definition + BlockLabels[SuccIdx] = &SuccBlock; + return true; + } - // a join of at least two definitions - if (ItLastDef->second != &DefBlock) { - // do we know this join already? - if (!JoinBlocks->insert(&SuccBlock).second) - return; + // visiting a virtual loop exit edge from the loop header --> temporal + // divergence on join + bool visitLoopExitEdge(const BasicBlock &ExitBlock, + const BasicBlock &DefBlock, bool FromParentLoop) { + // Pushing from a non-parent loop cannot cause temporal divergence. + if (!FromParentLoop) + return visitEdge(ExitBlock, DefBlock); - // update the definition - addPending(SuccBlock, SuccBlock); - } + if (!computeJoin(ExitBlock, DefBlock)) + return false; + + // Identified a divergent loop exit + DivDesc->LoopDivBlocks.insert(&ExitBlock); + LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName() + << "\n"); + return true; } - // find all blocks reachable by two disjoint paths from @rootTerm. - // This method works for both divergent terminators and loops with - // divergent exits. - // @rootBlock is either the block containing the branch or the header of the - // divergent loop. - // @nodeSuccessors is the set of successors of the node (Loop or Terminator) - // headed by @rootBlock. - // @parentLoop is the parent loop of the Loop or the loop that contains the - // Terminator. - template - std::unique_ptr - computeJoinPoints(const BasicBlock &RootBlock, - SuccessorIterable NodeSuccessors, const Loop *ParentLoop) { - assert(JoinBlocks); - - LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints. Parent loop: " - << (ParentLoop ? ParentLoop->getName() : "") + // process \p SuccBlock with reaching definition \p DefBlock + bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) { + if (!computeJoin(SuccBlock, DefBlock)) + return false; + + // Divergent, disjoint paths join. + DivDesc->JoinDivBlocks.insert(&SuccBlock); + LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName()); + return true; + } + + std::unique_ptr computeJoinPoints() { + assert(DivDesc); + + LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName() << "\n"); + const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock); + + // Early stopping criterion + int FloorIdx = LoopPOT.size() - 1; + const BasicBlock *FloorLabel = nullptr; + // bootstrap with branch targets - for (const auto *SuccBlock : NodeSuccessors) { - DefMap.emplace(SuccBlock, SuccBlock); + int BlockIdx = 0; - if (ParentLoop && !ParentLoop->contains(SuccBlock)) { - // immediate loop exit from node. - ReachedLoopExits.insert(SuccBlock); - } else { - // regular successor - PendingUpdates.insert(SuccBlock); - } - } + for (const auto *SuccBlock : successors(&DivTermBlock)) { + auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock); + BlockLabels[SuccIdx] = SuccBlock; - LLVM_DEBUG(dbgs() << "SDA: rpo order:\n"; for (const auto *RpoBlock - : FuncRPOT) { - dbgs() << "- " << RpoBlock->getName() << "\n"; - }); + // Find the successor with the highest index to start with + BlockIdx = std::max(BlockIdx, SuccIdx); + FloorIdx = std::min(FloorIdx, SuccIdx); - auto ItBeginRPO = FuncRPOT.begin(); - auto ItEndRPO = FuncRPOT.end(); + // Identify immediate divergent loop exits + if (!DivBlockLoop) + continue; - // skip until term (TODO RPOT won't let us start at @term directly) - for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) { - assert(ItBeginRPO != ItEndRPO && "Unable to find RootBlock"); + const auto *BlockLoop = LI.getLoopFor(SuccBlock); + if (BlockLoop && DivBlockLoop->contains(BlockLoop)) + continue; + DivDesc->LoopDivBlocks.insert(SuccBlock); + LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: " + << SuccBlock->getName() << "\n"); } // propagate definitions at the immediate successors of the node in RPO - auto ItBlockRPO = ItBeginRPO; - while ((++ItBlockRPO != ItEndRPO) && !PendingUpdates.empty()) { - const auto *Block = *ItBlockRPO; - LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); + for (; BlockIdx >= FloorIdx; --BlockIdx) { + LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs())); - // skip Block if not pending update - auto ItPending = PendingUpdates.find(Block); - if (ItPending == PendingUpdates.end()) + // Any label available here + const auto *Label = BlockLabels[BlockIdx]; + if (!Label) continue; - PendingUpdates.erase(ItPending); - // propagate definition at Block to its successors - auto ItDef = DefMap.find(Block); - const auto *DefBlock = ItDef->second; - assert(DefBlock); + // Ok. Get the block + const auto *Block = LoopPOT.getBlockAt(BlockIdx); + LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); auto *BlockLoop = LI.getLoopFor(Block); - if (ParentLoop && - (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) { - // if the successor is the header of a nested loop pretend its a - // single node with the loop's exits as successors + bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block; + bool CausedJoin = false; + int LoweredFloorIdx = FloorIdx; + if (IsLoopHeader) { + // Disconnect from immediate successors and propagate directly to loop + // exits. SmallVector BlockLoopExits; BlockLoop->getExitBlocks(BlockLoopExits); + + bool IsParentLoop = BlockLoop->contains(&DivTermBlock); for (const auto *BlockLoopExit : BlockLoopExits) { - visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock); + CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop); + LoweredFloorIdx = std::min(LoweredFloorIdx, + LoopPOT.getIndexOf(*BlockLoopExit)); } - } else { - // the successors are either on the same loop level or loop exits + // Acyclic successor case for (const auto *SuccBlock : successors(Block)) { - visitSuccessor(*SuccBlock, ParentLoop, *DefBlock); + CausedJoin |= visitEdge(*SuccBlock, *Label); + LoweredFloorIdx = + std::min(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock)); } } - } - LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); - - // We need to know the definition at the parent loop header to decide - // whether the definition at the header is different from the definition at - // the loop exits, which would indicate a divergent loop exits. - // - // A // loop header - // | - // B // nested loop header - // | - // C -> X (exit from B loop) -..-> (A latch) - // | - // D -> back to B (B latch) - // | - // proper exit from both loops - // - // analyze reached loop exits - if (!ReachedLoopExits.empty()) { - const BasicBlock *ParentLoopHeader = - ParentLoop ? ParentLoop->getHeader() : nullptr; - - assert(ParentLoop); - auto ItHeaderDef = DefMap.find(ParentLoopHeader); - const auto *HeaderDefBlock = - (ItHeaderDef == DefMap.end()) ? nullptr : ItHeaderDef->second; - - LLVM_DEBUG(printDefs(dbgs())); - assert(HeaderDefBlock && "no definition at header of carrying loop"); - - for (const auto *ExitBlock : ReachedLoopExits) { - auto ItExitDef = DefMap.find(ExitBlock); - assert((ItExitDef != DefMap.end()) && - "no reaching def at reachable loop exit"); - if (ItExitDef->second != HeaderDefBlock) { - JoinBlocks->insert(ExitBlock); - } + // Floor update + if (CausedJoin) { + // 1. Different labels pushed to successors + FloorIdx = LoweredFloorIdx; + } else if (FloorLabel != Label) { + // 2. No join caused BUT we pushed a label that is different than the + // last pushed label + FloorIdx = LoweredFloorIdx; + FloorLabel = Label; } } - return std::move(JoinBlocks); + LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); + + return std::move(DivDesc); } }; -const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) { - using LoopExitVec = SmallVector; - LoopExitVec LoopExits; - Loop.getExitBlocks(LoopExits); - if (LoopExits.size() < 1) { - return EmptyBlockSet; +static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) { + Out << "["; + bool First = true; + for (const auto *BB : Blocks) { + if (!First) + Out << ", "; + First = false; + Out << BB->getName(); } - - // already available in cache? - auto ItCached = CachedLoopExitJoins.find(&Loop); - if (ItCached != CachedLoopExitJoins.end()) { - return *ItCached->second; - } - - // compute all join points - DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; - auto JoinBlocks = Propagator.computeJoinPoints( - *Loop.getHeader(), LoopExits, Loop.getParentLoop()); - - auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks)); - assert(ItInserted.second); - return *ItInserted.first->second; + Out << "]"; } -const ConstBlockSet & -SyncDependenceAnalysis::join_blocks(const Instruction &Term) { +const ControlDivergenceDesc & +SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) { // trivial case - if (Term.getNumSuccessors() < 1) { - return EmptyBlockSet; + if (Term.getNumSuccessors() <= 1) { + return EmptyDivergenceDesc; } // already available in cache? - auto ItCached = CachedBranchJoins.find(&Term); - if (ItCached != CachedBranchJoins.end()) + auto ItCached = CachedControlDivDescs.find(&Term); + if (ItCached != CachedControlDivDescs.end()) return *ItCached->second; // compute all join points - DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; + // Special handling of divergent loop exits is not needed for LCSSA const auto &TermBlock = *Term.getParent(); - auto JoinBlocks = Propagator.computeJoinPoints( - TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock)); + DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock); + auto DivDesc = Propagator.computeJoinPoints(); + + LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n"; + dbgs() << "JoinDivBlocks: "; + printBlockSet(DivDesc->JoinDivBlocks, dbgs()); + dbgs() << "\nLoopDivBlocks: "; + printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";); - auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks)); + auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc)); assert(ItInserted.second); return *ItInserted.first->second; } diff --git a/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll b/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll --- a/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll +++ b/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll @@ -119,9 +119,8 @@ br i1 %uni.cond, label %D, label %G X: - %div.merge.x = phi i32 [ %a, %entry ], [ %uni.merge.h, %B ] ; temporal divergent phi + %uni.merge.x = phi i32 [ %a, %entry ], [ %uni.merge.h, %B ] br i1 %uni.cond, label %Y, label %exit -; CHECK: DIVERGENT: %div.merge.x = Y: %div.merge.y = phi i32 [ 42, %X ], [ %b, %C ] diff --git a/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll b/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll --- a/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll +++ b/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll @@ -1,7 +1,4 @@ ; RUN: opt -mtriple amdgcn-unknown-amdhsa -analyze -divergence -use-gpu-divergence-analysis %s | FileCheck %s -; XFAIL: * - -; https://bugs.llvm.org/show_bug.cgi?id=46372 ; CHECK: bb2: ; CHECK-NOT: DIVERGENT: %Guard.bb2 = phi i1 [ true, %bb1 ], [ false, %bb0 ]