Index: include/llvm/Analysis/Passes.h =================================================================== --- include/llvm/Analysis/Passes.h +++ include/llvm/Analysis/Passes.h @@ -138,6 +138,13 @@ //===--------------------------------------------------------------------===// // + // createDivergenceAnalysisPass - This pass determines which branches in a GPU + // program are divergent. + // + FunctionPass *createDivergenceAnalysisPass(); + + //===--------------------------------------------------------------------===// + // // Minor pass prototypes, allowing us to expose them through bugpoint and // analyze. FunctionPass *createInstCountPass(); Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -190,12 +190,21 @@ /// comments for a detailed explanation of the cost values. unsigned getUserCost(const User *U) const; - /// \brief hasBranchDivergence - Return true if branch divergence exists. + /// \brief Return true if branch divergence exists. + /// /// Branch divergence has a significantly negative impact on GPU performance /// when threads in the same wavefront take different paths due to conditional /// branches. bool hasBranchDivergence() const; + /// \brief Returns whether V is a source of divergence. + /// + /// This function provides the target-dependent information for + /// the target-independent DivergenceAnalysis. DivergenceAnalysis first + /// builds the dependency graph, and then runs the reachability algorithm + /// starting with the sources of divergence. + bool isSourceOfDivergence(const Value *V) const; + /// \brief Test whether calls to a function lower to actual program function /// calls. /// @@ -520,6 +529,7 @@ ArrayRef Arguments) = 0; virtual unsigned getUserCost(const User *U) = 0; virtual bool hasBranchDivergence() = 0; + virtual bool isSourceOfDivergence(const Value *V) = 0; virtual bool isLoweredToCall(const Function *F) = 0; virtual void getUnrollingPreferences(Loop *L, UnrollingPreferences &UP) = 0; virtual bool isLegalAddImmediate(int64_t Imm) = 0; @@ -619,6 +629,9 @@ } unsigned getUserCost(const User *U) override { return Impl.getUserCost(U); } bool hasBranchDivergence() override { return Impl.hasBranchDivergence(); } + bool isSourceOfDivergence(const Value *V) override { + return Impl.isSourceOfDivergence(V); + } bool isLoweredToCall(const Function *F) override { return Impl.isLoweredToCall(F); } Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -164,6 +164,8 @@ bool hasBranchDivergence() { return false; } + bool isSourceOfDivergence(const Value *V) { return false; } + bool isLoweredToCall(const Function *F) { // FIXME: These should almost certainly not be handled here, and instead // handled with the help of TLI or the target itself. This was largely Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -114,6 +114,8 @@ bool hasBranchDivergence() { return false; } + bool isSourceOfDivergence(const Value *V) { return false; } + bool isLegalAddImmediate(int64_t imm) { return getTLI()->isLegalAddImmediate(imm); } Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -110,6 +110,7 @@ void initializeDeadMachineInstructionElimPass(PassRegistry&); void initializeDelinearizationPass(PassRegistry &); void initializeDependenceAnalysisPass(PassRegistry&); +void initializeDivergenceAnalysisPass(PassRegistry&); void initializeDomOnlyPrinterPass(PassRegistry&); void initializeDomOnlyViewerPass(PassRegistry&); void initializeDomPrinterPass(PassRegistry&); Index: include/llvm/LinkAllPasses.h =================================================================== --- include/llvm/LinkAllPasses.h +++ include/llvm/LinkAllPasses.h @@ -74,6 +74,7 @@ (void) llvm::createDeadInstEliminationPass(); (void) llvm::createDeadStoreEliminationPass(); (void) llvm::createDependenceAnalysisPass(); + (void) llvm::createDivergenceAnalysisPass(); (void) llvm::createDomOnlyPrinterPass(); (void) llvm::createDomPrinterPass(); (void) llvm::createDomOnlyViewerPass(); Index: lib/Analysis/Analysis.cpp =================================================================== --- lib/Analysis/Analysis.cpp +++ lib/Analysis/Analysis.cpp @@ -37,6 +37,7 @@ initializeCFLAliasAnalysisPass(Registry); initializeDependenceAnalysisPass(Registry); initializeDelinearizationPass(Registry); + initializeDivergenceAnalysisPass(Registry); initializeDominanceFrontierPass(Registry); initializeDomViewerPass(Registry); initializeDomPrinterPass(Registry); Index: lib/Analysis/CMakeLists.txt =================================================================== --- lib/Analysis/CMakeLists.txt +++ lib/Analysis/CMakeLists.txt @@ -20,6 +20,7 @@ ConstantFolding.cpp Delinearization.cpp DependenceAnalysis.cpp + DivergenceAnalysis.cpp DomPrinter.cpp DominanceFrontier.cpp IVUsers.cpp Index: lib/Analysis/DivergenceAnalysis.cpp =================================================================== --- /dev/null +++ lib/Analysis/DivergenceAnalysis.cpp @@ -0,0 +1,253 @@ +//===- DivergenceAnalysis.cpp ------ Divergence Analysis ------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines divergence analysis which determines whether a branch in a +// GPU program is divergent. It can help branch optimizations such as jump +// threading and loop unswitching to make better decisions. +// +// GPU programs typically use the SIMD execution model, where multiple threads +// in the same exeuction group have to execute in lock-step. Therefore, if the +// code contains divergent branches (i.e., threads in a group do not agree on +// which path of the branch to take), the group of threads has to execute all +// the paths from that branch with different subsets of threads enabled until +// they converge at the immediately post-dominating BB of the paths. +// +// Due to this execution model, some optimizations such as jump +// threading and loop unswitching can be unfortunately harmful when performed on +// divergent branches. Therefore, an analysis that computes which branches in a +// GPU program are divergent can help the compiler to selectively run these +// optimizations. +// +// This file defines divergence analysis which computes a conservative but +// non-trivial approximation of all divergent branches in a GPU pogram. It +// partially implements the approach described in +// +// Divergence Analysis and Optimizations +// Coutinho, Sampaio, Pereira, Meira +// PACT '11 +// +// The divergence analysis identifies the sources of divergence (e.g., special +// variables that hold the thread ID), and recursively marks variables that are +// data or sync dependent on a source of divergence as divergent. +// +// While data dependency is a well-known concept, the notion of sync dependency +// is worth more explanation. Sync dependence characterizes the control flow +// aspect of the propagation of branch divergence. For example, +// +// %cond = icmp slt i32 %tid, 10 +// br i1 %cond, label %then, label %else +// then: +// br label %merge +// else: +// br label %merge +// merge: +// %a = phi i32 [ 0, %then ], [ 1, %else ] +// +// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid +// because %tid is not on its use-def chains, %a is sync dependent on %tid +// because the branch "br i1 %cond" depends on %tid and affects which value %a +// is assigned to. +// +// Given LLVM's SSA form, computing sync dependency is easy. For any predicate +// P of a branch, all the PHINodes in the immediate post dominator of the branch +// are sync dependent on P. +// +// The current implementation has the following limitations: +// 1. intra-procedural. It conservatively considers the arguments of a +// non-kernel-entry function as divergent. +// 2. memory as black box. It conservatively considers the values loaded from +// memory as divergent. This can be improved by leveraging pointer analysis. +//===----------------------------------------------------------------------===// + +#include +#include "llvm/IR/Dominators.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +#define DEBUG_TYPE "divergence" + +namespace { +class DivergenceAnalysis : public FunctionPass { +public: + static char ID; + + DivergenceAnalysis() : FunctionPass(ID) { + initializeDivergenceAnalysisPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesAll(); + } + + bool runOnFunction(Function &F) override; + + void print(raw_ostream &OS, const Module *) const override; + + bool isBranchDivergent(const TerminatorInst *TI) const { + return DivergentBranches.count(TI); + } + +private: + DenseSet DivergentBranches; +}; +} // End of anonymous namespace + +// Register this pass. +char DivergenceAnalysis::ID = 0; +INITIALIZE_PASS_BEGIN(DivergenceAnalysis, "divergence", "Divergence Analysis", + false, true) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTree) +INITIALIZE_PASS_END(DivergenceAnalysis, "divergence", "Divergence Analysis", + false, true) + +namespace { + +class DivergencePropagator { +public: + DivergencePropagator(Function &F, TargetTransformInfo &TTI, + PostDominatorTree &PDT, + DenseSet &DB) + : F(F), TTI(TTI), PDT(PDT), DB(DB) {} + void populateWithSourcesOfDivergence(); + void propagate(); + +private: + // A helper function that explores data dependents of V. + void exploreDataDependency(Value *V); + // A helper function that explores sync dependents of V. + void exploreSyncDependency(TerminatorInst *TI); + + Function &F; + TargetTransformInfo &TTI; + PostDominatorTree &PDT; + // Place to store the analysis results. + DenseSet &DB; + // Auxiliary data structures for DFS. + std::vector Worklist; + DenseSet Visited; +}; + +void DivergencePropagator::populateWithSourcesOfDivergence() { + Worklist.clear(); + Visited.clear(); + for (auto &I : inst_range(F)) { + if (TTI.isSourceOfDivergence(&I)) { + Worklist.push_back(&I); + Visited.insert(&I); + } + } + for (auto &Arg : F.args()) { + if (TTI.isSourceOfDivergence(&Arg)) { + Worklist.push_back(&Arg); + Visited.insert(&Arg); + } + } +} + +void DivergencePropagator::exploreSyncDependency(TerminatorInst *TI) { + Value *Cond = nullptr; + if (BranchInst *BI = dyn_cast(TI)) { + if (BI->isConditional()) + Cond = BI->getCondition(); + } else if (IndirectBrInst *IBI = dyn_cast(TI)) { + Cond = IBI->getAddress(); + } else if (SwitchInst *SI = dyn_cast(TI)) { + Cond = SI->getCondition(); + } + if (Cond == nullptr) + return; + + // Since TI is divergent, Cond is also divergent. Per the definition of sync + // dependency, we mark all PHINodes in TI's immediate post dominator block as + // divergent. + BasicBlock *IPostDom = PDT.getNode(TI->getParent())->getIDom()->getBlock(); + if (IPostDom == nullptr) + return; + for (auto I = IPostDom->begin(); IPostDom->getFirstNonPHI() != I; ++I) { + if (Visited.insert(I).second) + Worklist.push_back(I); + } +} + +void DivergencePropagator::exploreDataDependency(Value *V) { + // Follow def-use chains of V. + for (User *U : V->users()) { + Instruction *UserInst = cast(U); + if (Visited.insert(UserInst).second) + Worklist.push_back(UserInst); + } +} + +void DivergencePropagator::propagate() { + // Traverse the dependency graph using DFS. + while (!Worklist.empty()) { + Value *V = Worklist.back(); + Worklist.pop_back(); + if (TerminatorInst *TI = dyn_cast(V)) { + // Terminator instructions with less than two successors are not really + // branches, and can be ignored for lookup. + if (TI->getNumSuccessors() > 1) + DB.insert(TI); + exploreSyncDependency(TI); + } + exploreDataDependency(V); + } +} + +} /// end namespace anonymous + +FunctionPass *llvm::createDivergenceAnalysisPass() { + return new DivergenceAnalysis(); +} + +bool DivergenceAnalysis::runOnFunction(Function &F) { + auto *TTIWP = getAnalysisIfAvailable(); + if (TTIWP == nullptr) + return false; + + TargetTransformInfo &TTI = TTIWP->getTTI(F); + // Fast path: if the target does not have branch divergence, we do not mark + // any branch as divergent. + if (!TTI.hasBranchDivergence()) + return false; + + DivergentBranches.clear(); + DivergencePropagator DP(F, TTI, getAnalysis(), + DivergentBranches); + DP.populateWithSourcesOfDivergence(); + DP.propagate(); + return false; +} + +void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const { + if (DivergentBranches.empty()) + return; + const Function *F = (*DivergentBranches.begin())->getParent()->getParent(); + // Dumps all divergent branches in F. Iterate instructions using inst_range + // to ensure a deterministic order. + for (auto &I : inst_range(F)) { + if (const TerminatorInst *TI = dyn_cast(&I)) { + if (DivergentBranches.count(TI)) + OS << "DIVERGENT:" << I << "\n"; + } + } +} Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -76,6 +76,10 @@ return TTIImpl->hasBranchDivergence(); } +bool TargetTransformInfo::isSourceOfDivergence(const Value *V) const { + return TTIImpl->isSourceOfDivergence(V); +} + bool TargetTransformInfo::isLoweredToCall(const Function *F) const { return TTIImpl->isLoweredToCall(F); } Index: lib/Target/NVPTX/NVPTXTargetTransformInfo.h =================================================================== --- lib/Target/NVPTX/NVPTXTargetTransformInfo.h +++ lib/Target/NVPTX/NVPTXTargetTransformInfo.h @@ -61,6 +61,8 @@ bool hasBranchDivergence() { return true; } + bool isSourceOfDivergence(const Value *V); + unsigned getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info = TTI::OK_AnyValue, Index: lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp =================================================================== --- lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp +++ lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "NVPTXTargetTransformInfo.h" +#include "NVPTXUtilities.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -19,6 +20,64 @@ #define DEBUG_TYPE "NVPTXtti" +// Whether the given intrinsic reads threadIdx.x/y/z. +static bool readsThreadIndex(const IntrinsicInst *II) { + switch (II->getIntrinsicID()) { + default: return false; + case Intrinsic::nvvm_read_ptx_sreg_tid_x: + case Intrinsic::nvvm_read_ptx_sreg_tid_y: + case Intrinsic::nvvm_read_ptx_sreg_tid_z: + return true; + } +} + +// Whether the given intrinsic is an atomic instruction in PTX. +static bool isNVVMAtomic(const IntrinsicInst *II) { + switch (II->getIntrinsicID()) { + default: return false; + case Intrinsic::nvvm_atomic_load_add_f32: + case Intrinsic::nvvm_atomic_load_inc_32: + case Intrinsic::nvvm_atomic_load_dec_32: + return true; + } +} + +bool NVPTXTTIImpl::isSourceOfDivergence(const Value *V) { + // Without inter-procedural analysis, we conservatively assume that arguments + // to __device__ functions are divergent. + if (const Argument *Arg = dyn_cast(V)) + return !isKernelFunction(*Arg->getParent()); + + if (const Instruction *I = dyn_cast(V)) { + // Without pointer analysis, we conservatively assume values loaded are + // divergent. + if (isa(I)) + return true; + // Atomic instructions may cause divergence. Atomic instructions are + // executed sequentially across all threads in a warp. Therefore, an earlier + // executed thread may see different memory inputs than an later executed + // thread. For example, suppose *a = 0 initially. + // + // atom.global.add.s32 d, [a], 1 + // + // returns 0 for the first thread that enters the critical region, and 1 for + // the second thread. + if (I->isAtomic()) + return true; + if (const IntrinsicInst *II = dyn_cast(I)) { + // Instructions that read threadIdx are abviously divergent. + if (readsThreadIndex(II)) + return true; + // Handle the NVPTX atomic instrinsics which cannot be represented as an + // atomic IR instruction. + if (isNVVMAtomic(II)) + return true; + } + } + + return false; +} + unsigned NVPTXTTIImpl::getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info, TTI::OperandValueKind Opd2Info, TTI::OperandValueProperties Opd1PropInfo, Index: test/Analysis/DivergenceAnalysis/NVPTX/diverge.ll =================================================================== --- /dev/null +++ test/Analysis/DivergenceAnalysis/NVPTX/diverge.ll @@ -0,0 +1,106 @@ +target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64" +target triple = "nvptx64-nvidia-cuda" + +; return (n < 0 ? a + threadIdx.x : b + threadIdx.x) +define i32 @no_diverge(i32 %n, i32 %a, i32 %b) { +; CHECK-LABEL: Printing analysis 'Divergence Analysis' for function 'no_diverge' +; CHECK-NOT: DIVERGENT +entry: + %tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %cond = icmp slt i32 %n, 0 + br i1 %cond, label %then, label %else ; not divergent +then: + %a1 = add i32 %a, %tid + br label %merge +else: + %b2 = add i32 %b, %tid + br label %merge +merge: + %c = phi i32 [ %a1, %then ], [ %b2, %else ] + ret i32 %c +} + +; c = a; +; if (threadIdx.x < 5) // divergent: data dependent +; c = b; +; return (c == 0 ? a : b) // divergent: sync dependent +define i32 @sync(i32 %a, i32 %b) { +; CHECK-LABEL: Printing analysis 'Divergence Analysis' for function 'sync' +bb1: + %tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %cond = icmp slt i32 %tid, 5 + br i1 %cond, label %bb2, label %bb3 +; CHECK: DIVERGENT: br i1 %cond, +bb2: + br label %bb3 +bb3: + %c = phi i32 [ %a, %bb1 ], [ %b, %bb2 ] ; sync dependent on tid + %cond2 = icmp eq i32 %c, 0 + br i1 %cond2, label %bb4, label %bb5 ; therefore divergent +; CHECK: DIVERGENT: br i1 %cond2, +bb4: + ret i32 %a +bb5: + ret i32 %b +} + +; d = a; +; if (threadIdx.x >= 5) { // divergent +; c = (n >= 0 ? a : b); // non-divergent because n is not divergent +; d = (c == 0 ? c : b); // non-divergent because c is not divergent +; } +; return (d == 0 ? a : b); // divergent because d is sync dependent on +; // threadIdx.x >= 5 +define i32 @mixed(i32 %n, i32 %a, i32 %b) { +; CHECK-LABEL: Printing analysis 'Divergence Analysis' for function 'mixed' +bb1: + %tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %cond = icmp slt i32 %tid, 5 + br i1 %cond, label %bb6, label %bb2 +; CHECK: DIVERGENT: br i1 %cond, +bb2: + %cond2 = icmp slt i32 %n, 0 + br i1 %cond2, label %bb4, label %bb3 +; CHECK-NOT: DIVERGENT: br i1 %cond2, +bb3: + br label %bb4 +bb4: + %c = phi i32 [ %a, %bb2 ], [ %b, %bb3 ] + %cond3 = icmp eq i32 %c, 0 + br i1 %cond3, label %bb6, label %bb5 +; CHECK-NOT: DIVERGENT: br i1 %cond3, +bb5: + br label %bb6 +bb6: + %d = phi i32 [ %a, %bb1], [ %b, %bb4 ], [ %c, %bb5 ] + %cond4 = icmp eq i32 %d, 0 + br i1 %cond4, label %bb7, label %bb8 +; CHECK: DIVERGENT: br i1 %cond4, +bb7: + ret i32 %a +bb8: + ret i32 %b +} + +; We conservatively treats all parameters of a __device__ function as divergent. +define i32 @device(i32 %n, i32 %a, i32 %b) { +; CHECK-LABEL: Printing analysis 'Divergence Analysis' for function 'device' +entry: + %cond = icmp slt i32 %n, 0 + br i1 %cond, label %then, label %else +; CHECK: DIVERGENT: br i1 %cond, +then: + br label %merge +else: + br label %merge +merge: + %c = phi i32 [ %a, %then ], [ %b, %else ] + ret i32 %c +} + +declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() + +!nvvm.annotations = !{!0, !1, !2} +!0 = !{i32 (i32, i32, i32)* @no_diverge, !"kernel", i32 1} +!1 = !{i32 (i32, i32)* @sync, !"kernel", i32 1} +!2 = !{i32 (i32, i32, i32)* @mixed, !"kernel", i32 1} Index: test/Analysis/DivergenceAnalysis/NVPTX/lit.local.cfg =================================================================== --- /dev/null +++ test/Analysis/DivergenceAnalysis/NVPTX/lit.local.cfg @@ -0,0 +1,2 @@ +if not 'NVPTX' in config.root.targets: + config.unsupported = True