Index: llvm/include/llvm/Analysis/DivergenceAnalysis.h =================================================================== --- llvm/include/llvm/Analysis/DivergenceAnalysis.h +++ llvm/include/llvm/Analysis/DivergenceAnalysis.h @@ -21,6 +21,7 @@ #include namespace llvm { +class AssumptionCache; class Function; class Instruction; class Loop; @@ -44,6 +45,7 @@ /// region in LCSSA form. DivergenceAnalysisImpl(const Function &F, const Loop *RegionLoop, const DominatorTree &DT, const LoopInfo &LI, + const TargetTransformInfo &TTI, AssumptionCache &AC, SyncDependenceAnalysis &SDA, bool IsLCSSAForm); /// \brief The loop that defines the analyzed region (if any). @@ -104,6 +106,9 @@ void analyzeTemporalDivergence(const Instruction &I, const Loop &OuterDivLoop); + /// Check if \p V can be assumed uniform at \p User. + bool isUseAssumedAllUniform(const Instruction &User, const Value &V) const; + /// \brief Push all users of \p Val (in the region) to the worklist. void pushUsers(const Value &I); @@ -119,6 +124,8 @@ const DominatorTree &DT; const LoopInfo &LI; + const TargetTransformInfo &TTI; + AssumptionCache ∾ // Recognized divergent loops DenseSet DivergentLoops; @@ -153,7 +160,8 @@ public: DivergenceInfo(Function &F, const DominatorTree &DT, const PostDominatorTree &PDT, const LoopInfo &LI, - const TargetTransformInfo &TTI, bool KnownReducible); + const TargetTransformInfo &TTI, AssumptionCache &AC, + bool KnownReducible); /// Whether any divergence was detected. bool hasDivergence() const { Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -339,6 +339,19 @@ // even taking non-uniform arguments bool isAlwaysUniform(const Value *V) const; + enum class BallotKind { + NotBallot, + All, + Any + }; + + /// For targets with execution units that progress in lock step. + /// + /// Check if \p I is a call that performs a ballot or vote / operation + /// (e.g. OpenCL's sub_group_all or sub_group_any). Returns the asserted + /// value, and the ballot type. + std::pair isBallot(const Instruction *I) const; + /// Returns the address space ID for a target's 'flat' address space. Note /// this is not necessarily the same as addrspace(0), which LLVM sometimes /// refers to as the generic address space. The flat address space is a @@ -1580,6 +1593,8 @@ virtual bool useGPUDivergenceAnalysis() = 0; virtual bool isSourceOfDivergence(const Value *V) = 0; virtual bool isAlwaysUniform(const Value *V) = 0; + virtual std::pair + isBallot(const Instruction *I) = 0; virtual unsigned getFlatAddressSpace() = 0; virtual bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, Intrinsic::ID IID) const = 0; @@ -1948,6 +1963,11 @@ return Impl.isAlwaysUniform(V); } + std::pair + isBallot(const Instruction *I) override { + return Impl.isBallot(I); + } + unsigned getFlatAddressSpace() override { return Impl.getFlatAddressSpace(); } bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -94,6 +94,11 @@ bool isAlwaysUniform(const Value *V) const { return false; } + std::pair + isBallot(const Instruction *) const { + return {nullptr, TTI::BallotKind::NotBallot}; + } + unsigned getFlatAddressSpace() const { return -1; } bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, Index: llvm/include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -270,6 +270,10 @@ bool isAlwaysUniform(const Value *V) { return false; } + std::pair isBallot(const Instruction *I) { + return {nullptr, TTI::BallotKind::NotBallot}; + } + unsigned getFlatAddressSpace() { // Return an invalid address space. return -1; Index: llvm/lib/Analysis/AssumptionCache.cpp =================================================================== --- llvm/lib/Analysis/AssumptionCache.cpp +++ llvm/lib/Analysis/AssumptionCache.cpp @@ -130,6 +130,11 @@ std::tie(Ptr, AS) = TTI->getPredicatedAddrSpace(Cond); if (Ptr) AddAffected(const_cast(Ptr->stripInBoundsOffsets())); + + if (const Instruction *CondInst = dyn_cast(Cond)) { + if (const Value *BallotVal = TTI->isBallot(CondInst).first) + AddAffected(const_cast(BallotVal)); + } } } Index: llvm/lib/Analysis/DivergenceAnalysis.cpp =================================================================== --- llvm/lib/Analysis/DivergenceAnalysis.cpp +++ llvm/lib/Analysis/DivergenceAnalysis.cpp @@ -74,13 +74,17 @@ #include "llvm/Analysis/DivergenceAnalysis.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/Value.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -91,8 +95,9 @@ DivergenceAnalysisImpl::DivergenceAnalysisImpl( const Function &F, const Loop *RegionLoop, const DominatorTree &DT, - const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm) - : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), + const LoopInfo &LI, const TargetTransformInfo &TTI, AssumptionCache &AC, + SyncDependenceAnalysis &SDA, bool IsLCSSAForm) + : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), TTI(TTI), AC(AC), SDA(SDA), IsLCSSAForm(IsLCSSAForm) {} bool DivergenceAnalysisImpl::markDivergent(const Value &DivVal) { @@ -132,6 +137,28 @@ return RegionLoop ? RegionLoop->contains(&BB) : (BB.getParent() == &F); } +bool DivergenceAnalysisImpl::isUseAssumedAllUniform(const Instruction &UserInst, + const Value &V) const { + for (auto &AssumeVH : AC.assumptionsFor(&V)) { + assert(AssumeVH && "IR changed during analysis?"); + + CallInst *CI = cast(AssumeVH); + if (!isValidAssumeForContext(CI, &UserInst, &DT)) + continue; + + const Value *BallotVal; + TargetTransformInfo::BallotKind BK; + std::tie(BallotVal, BK) = + TTI.isBallot(cast(CI->getArgOperand(0))); + if (BK == TargetTransformInfo::BallotKind::All) { + assert(BallotVal == &V); + return true; + } + } + + return false; +} + void DivergenceAnalysisImpl::pushUsers(const Value &V) { const auto *I = dyn_cast(&V); @@ -149,6 +176,13 @@ if (!inRegion(*UserInst)) continue; + // Ignore any uses that are assumed uniform at the use point. Currently this + // is assumed to only apply for boolean ballots. If there are any other + // divergent operands, the instruction should be found through its other + // divergent operands. + if (V.getType()->isIntegerTy(1) && isUseAssumedAllUniform(*UserInst, V)) + continue; + // All users of divergent values are immediate divergent if (markDivergent(*UserInst)) Worklist.push_back(UserInst); @@ -346,7 +380,7 @@ DivergenceInfo::DivergenceInfo(Function &F, const DominatorTree &DT, const PostDominatorTree &PDT, const LoopInfo &LI, const TargetTransformInfo &TTI, - bool KnownReducible) + AssumptionCache &AC, bool KnownReducible) : F(F) { if (!KnownReducible) { using RPOTraversal = ReversePostOrderTraversal; @@ -358,7 +392,8 @@ } } SDA = std::make_unique(DT, PDT, LI); - DA = std::make_unique(F, nullptr, DT, LI, *SDA, + DA = std::make_unique(F, nullptr, DT, LI, TTI, AC, + *SDA, /* LCSSA */ false); for (auto &I : instructions(F)) { if (TTI.isSourceOfDivergence(&I)) { @@ -384,8 +419,10 @@ auto &PDT = AM.getResult(F); auto &LI = AM.getResult(F); auto &TTI = AM.getResult(F); + auto &AC = AM.getResult(F); - return DivergenceInfo(F, DT, PDT, LI, TTI, /* KnownReducible = */ false); + return DivergenceInfo(F, DT, PDT, LI, TTI, AC, + /* KnownReducible = */ false); } PreservedAnalyses Index: llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp =================================================================== --- llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp +++ llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp @@ -66,6 +66,7 @@ #include "llvm/Analysis/LegacyDivergenceAnalysis.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/DivergenceAnalysis.h" #include "llvm/Analysis/LoopInfo.h" @@ -292,6 +293,7 @@ INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_END(LegacyDivergenceAnalysis, "divergence", "Legacy Divergence Analysis", false, true) @@ -303,6 +305,7 @@ AU.addRequiredTransitive(); AU.addRequiredTransitive(); AU.addRequiredTransitive(); + AU.addRequiredTransitive(); AU.setPreservesAll(); } @@ -336,11 +339,12 @@ auto &DT = getAnalysis().getDomTree(); auto &PDT = getAnalysis().getPostDomTree(); + auto &AC = getAnalysis().getAssumptionCache(F); if (shouldUseGPUDivergenceAnalysis(F, TTI)) { // run the new GPU divergence analysis auto &LI = getAnalysis().getLoopInfo(); - gpuDA = std::make_unique(F, DT, PDT, LI, TTI, + gpuDA = std::make_unique(F, DT, PDT, LI, TTI, AC, /* KnownReducible = */ true); } else { Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -250,6 +250,11 @@ return TTIImpl->isAlwaysUniform(V); } +std::pair +llvm::TargetTransformInfo::isBallot(const Instruction *I) const { + return TTIImpl->isBallot(I); +} + unsigned TargetTransformInfo::getFlatAddressSpace() const { return TTIImpl->getFlatAddressSpace(); } Index: llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h +++ llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h @@ -166,6 +166,8 @@ bool isReadRegisterSourceOfDivergence(const IntrinsicInst *ReadReg) const; bool isSourceOfDivergence(const Value *V) const; bool isAlwaysUniform(const Value *V) const; + std::pair + isBallot(const Instruction *I) const; unsigned getFlatAddressSpace() const { // Don't bother running InferAddressSpaces pass on graphics shaders which Index: llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp +++ llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp @@ -1005,6 +1005,40 @@ return false; } +std::pair +GCNTTIImpl::isBallot(const Instruction *I) const { + using namespace PatternMatch; + + ICmpInst::Predicate Pred; + Value *BallotVal = nullptr; + Value *ReadReg = nullptr; + if (!match(I, m_c_ICmp( + Pred, + m_Intrinsic(m_Value(BallotVal)), + m_Intrinsic(m_Value(ReadReg)))) || + !ICmpInst::isEquality(Pred)) + return {nullptr, TTI::BallotKind::NotBallot}; + + // Make sure the read exec is done in the same block as the ballot + // + // FIXME: Should really check they have the same convergence + // token. Alternatively, could have a dedicated vote all intrinsic. + if (cast(I->getOperand(0))->getParent() != + cast(I->getOperand(1))->getParent()) + return {nullptr, TTI::BallotKind::NotBallot}; + + auto *Node = cast(cast(ReadReg)->getMetadata()); + auto *ReadRegName = cast(Node->getOperand(0)); + const StringRef ExecName = ST->isWave64() ? "exec" : "exec_lo"; + if (ReadRegName->getString() != ExecName) + return {nullptr, TTI::BallotKind::NotBallot}; + + if (Pred == ICmpInst::ICMP_EQ) + return {BallotVal, TTI::BallotKind::All}; + + return {nullptr, TTI::BallotKind::NotBallot}; +} + bool GCNTTIImpl::collectFlatAddressOperands(SmallVectorImpl &OpIndexes, Intrinsic::ID IID) const { switch (IID) { Index: llvm/test/Analysis/DivergenceAnalysis/AMDGPU/assume.ll =================================================================== --- /dev/null +++ llvm/test/Analysis/DivergenceAnalysis/AMDGPU/assume.ll @@ -0,0 +1,289 @@ +; RUN: opt -mtriple amdgcn-unknown-amdhsa -passes='print' -disable-output %s 2>&1 | FileCheck -strict-whitespace %s + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_eq_neg1' +; CHECK: {{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}}DIVERGENT: br i1 %cmp, label %foo, label %bar +define void @assume_ballot_eq_neg1(i32 %x) { + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %all = icmp eq i64 %ballot, -1 + call void @llvm.assume(i1 %all) + br i1 %cmp, label %foo, label %bar + +foo: + ret void + +bar: + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_eq_0' +; CHECK: {{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}}DIVERGENT: br i1 %cmp, label %foo, label %bar +define void @assume_ballot_eq_0(i32 %x) { + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %all = icmp eq i64 %ballot, 0 + call void @llvm.assume(i1 %all) + br i1 %cmp, label %foo, label %bar + +foo: + ret void + +bar: + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_eq_popcnt64' +; CHECK: {{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}}DIVERGENT: br i1 %cmp, label %foo, label %bar +define void @assume_ballot_eq_popcnt64(i32 %x) { + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %ctpop = call i64 @llvm.ctpop.i64(i64 %ballot) + %all = icmp eq i64 %ctpop, 64 + call void @llvm.assume(i1 %all) + br i1 %cmp, label %foo, label %bar + +foo: + ret void + +bar: + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_ne_popcnt64' +; CHECK: {{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}}DIVERGENT: br i1 %cmp, label %foo, label %bar +define void @assume_ballot_ne_popcnt64(i32 %x) { + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %ctpop = call i64 @llvm.ctpop.i64(i64 %ballot) + %all = icmp ne i64 %ctpop, 64 + call void @llvm.assume(i1 %all) + br i1 %cmp, label %foo, label %bar + +foo: + ret void + +bar: + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_eq_read_exec' +; CHECK: {{^}}{{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}}{{^}} br i1 %cmp, label %foo, label %bar +define void @assume_ballot_eq_read_exec(i32 %x) { + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %exec = call i64 @llvm.read_register(metadata !0) + %all = icmp eq i64 %ballot, %exec + call void @llvm.assume(i1 %all) + br i1 %cmp, label %foo, label %bar + +foo: + ret void + +bar: + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_eq_read_wrong_reg' +; CHECK: {{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}}DIVERGENT: br i1 %cmp, label %foo, label %bar +define void @assume_ballot_eq_read_wrong_reg(i32 %x) { + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %exec = call i64 @llvm.read_register(metadata !1) + %all = icmp eq i64 %ballot, %exec + call void @llvm.assume(i1 %all) + br i1 %cmp, label %foo, label %bar + +foo: + ret void + +bar: + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_ne_read_exec' +; CHECK: {{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}}DIVERGENT: br i1 %cmp, label %foo, label %bar +define void @assume_ballot_ne_read_exec(i32 %x) { + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %exec = call i64 @llvm.read_register(metadata !0) + %all = icmp ne i64 %ballot, %exec + call void @llvm.assume(i1 %all) + br i1 %cmp, label %foo, label %bar + +foo: + ret void + +bar: + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_select_user' +; CHECK: {{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}} %select = select i1 %cmp, i32 123, i32 456 +define void @assume_ballot_select_user(i32 %x, ptr addrspace(1) %ptr) { + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %exec = call i64 @llvm.read_register(metadata !0) + %all = icmp eq i64 %ballot, %exec + %load = load i32, ptr addrspace(1) %ptr + call void @llvm.assume(i1 %all) + %select = select i1 %cmp, i32 123, i32 456 + store i32 %select, ptr addrspace(1) %ptr + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_argument_select_user' +; CHECK: {{^}} %select = select i1 %arg.bool, i32 123, i32 456 +define void @assume_ballot_argument_select_user(i1 %arg.bool, i32 %x, ptr addrspace(1) %ptr) { + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %arg.bool) + %exec = call i64 @llvm.read_register(metadata !0) + %all = icmp eq i64 %ballot, %exec + %load = load i32, ptr addrspace(1) %ptr + call void @llvm.assume(i1 %all) + %select = select i1 %arg.bool, i32 123, i32 456 + store i32 %select, ptr addrspace(1) %ptr + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_wrong_bool_uniform' +; CHECK: {{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}}DIVERGENT: %select = select i1 %cmp, i32 123, i32 456 +define void @assume_wrong_bool_uniform(i1 %arg.bool, i32 %x, ptr addrspace(1) %ptr) { + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %exec = call i64 @llvm.read_register(metadata !0) + %all = icmp eq i64 %ballot, %exec + call void @llvm.assume(i1 %arg.bool) + %load = load i32, ptr addrspace(1) %ptr + %select = select i1 %cmp, i32 123, i32 456 + store i32 %select, ptr addrspace(1) %ptr + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_select_user_other_divergent_input' +; CHECK: {{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}}DIVERGENT: %select = select i1 %cmp, i32 123, i32 %load +define void @assume_ballot_select_user_other_divergent_input(i32 %x, ptr addrspace(1) %ptr) { + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %exec = call i64 @llvm.read_register(metadata !0) + %all = icmp eq i64 %ballot, %exec + %load = load i32, ptr addrspace(1) %ptr + call void @llvm.assume(i1 %all) + %select = select i1 %cmp, i32 123, i32 %load + store i32 %select, ptr addrspace(1) %ptr + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_eq_read_exec_out_of_block_user' +; CHECK: {{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}} %select = select i1 %cmp, i32 123, i32 456 + define void @assume_ballot_eq_read_exec_out_of_block_user(i32 %x, ptr addrspace(1) %ptr) { + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %exec = call i64 @llvm.read_register(metadata !0) + %all = icmp eq i64 %ballot, %exec + %load = load i32, ptr addrspace(1) %ptr + call void @llvm.assume(i1 %all) + br i1 %cmp, label %foo, label %bar + +foo: + %select = select i1 %cmp, i32 123, i32 456 + %add = add i32 %load, %load + store i32 %add, ptr addrspace(1) %ptr + ret void + +bar: + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_assume_wrong_place' +; CHECK: {{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}}DIVERGENT: br i1 %cmp, label %foo, label %bar +define void @assume_ballot_assume_wrong_place(i32 %x, ptr addrspace(1) %ptr) { + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + %exec = call i64 @llvm.read_register(metadata !0) + %all = icmp eq i64 %ballot, %exec + %load = load i32, ptr addrspace(1) %ptr + br i1 %cmp, label %foo, label %bar + +foo: + call void @llvm.assume(i1 %all) + %select = select i1 %cmp, i32 123, i32 456 + %add = add i32 %load, %load + store i32 %add, ptr addrspace(1) %ptr + ret void + +bar: + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_select_i1_i1_i1' +; CHECK: {{^}}DIVERGENT: %cmp0 = icmp eq i32 %x, 0 +; CHECK: {{^}} %all0 = icmp eq i64 %ballot0, %exec0 +; CHECK: {{^}}DIVERGENT: %cmp1 = icmp eq i32 %y, 0 +; CHECK: {{^}} %all1 = icmp eq i64 %ballot1, %exec1 +; CHECK: {{^}}DIVERGENT: %cmp2 = icmp eq i32 %z, 0 +; CHECK: {{^}} %all2 = icmp eq i64 %ballot2, %exec2 +; CHECK: {{^}} %select = select i1 %cmp0, i1 %cmp1, i1 %cmp2 +define void @assume_ballot_select_i1_i1_i1(i32 %x, i32 %y, i32 %z, ptr addrspace(1) %ptr) { + %cmp0 = icmp eq i32 %x, 0 + %ballot0 = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp0) + %exec0 = call i64 @llvm.read_register(metadata !0) + %all0 = icmp eq i64 %ballot0, %exec0 + call void @llvm.assume(i1 %all0) + + %cmp1 = icmp eq i32 %y, 0 + %ballot1 = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp1) + %exec1 = call i64 @llvm.read_register(metadata !0) + %all1 = icmp eq i64 %ballot1, %exec1 + call void @llvm.assume(i1 %all1) + + %cmp2 = icmp eq i32 %z, 0 + %ballot2 = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp2) + %exec2 = call i64 @llvm.read_register(metadata !0) + %all2 = icmp eq i64 %ballot2, %exec2 + call void @llvm.assume(i1 %all2) + + %select = select i1 %cmp0, i1 %cmp1, i1 %cmp2 + store i1 %select, ptr addrspace(1) %ptr + ret void +} + +; CHECK-LABEL: Divergence Analysis' for function 'assume_ballot_read_register_wrong_block' +; CHECK: {{^}}DIVERGENT: %cmp = icmp eq i32 %x, 0 +; CHECK: {{^}}DIVERGENT: br i1 %br.cond, label %bb1, label %bb2 +; CHECK: {{^}}DIVERGENT: %select = select i1 %cmp, i32 123, i32 456 +define void @assume_ballot_read_register_wrong_block(i1 %br.cond, i32 %x, ptr addrspace(1) %ptr) { +bb0: + %cmp = icmp eq i32 %x, 0 + %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp) + br i1 %br.cond, label %bb1, label %bb2 + +bb1: + %exec = call i64 @llvm.read_register(metadata !0) + %all = icmp eq i64 %ballot, %exec + %load = load i32, ptr addrspace(1) %ptr + call void @llvm.assume(i1 %all) + %select = select i1 %cmp, i32 123, i32 456 + store i32 %select, ptr addrspace(1) %ptr + br label %bb2 + +bb2: + ret void +} + +declare i64 @llvm.amdgcn.ballot.i64(i1) +declare i64 @llvm.ctpop.i64(i64) +declare void @llvm.assume(i1) +declare i64 @llvm.read_register(metadata) +!0 = !{!"exec"} +!1 = !{!"s[0:3]"} Index: llvm/unittests/Analysis/DivergenceAnalysisTest.cpp =================================================================== --- llvm/unittests/Analysis/DivergenceAnalysisTest.cpp +++ llvm/unittests/Analysis/DivergenceAnalysisTest.cpp @@ -6,13 +6,14 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Analysis/DivergenceAnalysis.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/DivergenceAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/SyncDependenceAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" @@ -51,6 +52,8 @@ std::unique_ptr PDT; std::unique_ptr LI; std::unique_ptr SDA; + std::unique_ptr AC; + std::unique_ptr TTI; DivergenceAnalysisTest() : M("", Context), TLII(), TLI(TLII) {} @@ -59,7 +62,9 @@ PDT.reset(new PostDominatorTree(F)); LI.reset(new LoopInfo(*DT)); SDA.reset(new SyncDependenceAnalysis(*DT, *PDT, *LI)); - return DivergenceAnalysisImpl(F, nullptr, *DT, *LI, *SDA, IsLCSSA); + AC.reset(new AssumptionCache(F, &*TTI)); + return DivergenceAnalysisImpl(F, nullptr, *DT, *LI, *TTI, *AC, *SDA, + IsLCSSA); } void runWithDA(