diff --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h --- a/llvm/include/llvm/ADT/GenericSSAContext.h +++ b/llvm/include/llvm/ADT/GenericSSAContext.h @@ -53,6 +53,11 @@ // indicated by the compiler. using FunctionT = typename _FunctionT::invalidTemplateInstanceError; + // A UseT represents a data-edge from the defining instruction to the using + // instruction. + // + // using UseT = ... + // Initialize the SSA context with information about the FunctionT being // processed. // diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h --- a/llvm/include/llvm/ADT/GenericUniformityImpl.h +++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h @@ -330,6 +330,7 @@ using FunctionT = typename ContextT::FunctionT; using ValueRefT = typename ContextT::ValueRefT; using ConstValueRefT = typename ContextT::ConstValueRefT; + using UseT = typename ContextT::UseT; using InstructionT = typename ContextT::InstructionT; using DominatorTreeT = typename ContextT::DominatorTreeT; @@ -384,6 +385,8 @@ /// \brief Whether \p Val is divergent at its definition. bool isDivergent(ConstValueRefT V) const { return DivergentValues.count(V); } + bool isDivergentUse(const UseT &U) const; + bool hasDivergentTerminator(const BlockT &B) const { return DivergentTermBlocks.contains(&B); } @@ -462,9 +465,9 @@ bool usesValueFromCycle(const InstructionT &I, const CycleT &DefCycle) const; - /// \brief Whether \p Val is divergent when read in \p ObservingBlock. + /// \brief Whether \p Def is divergent when read in \p ObservingBlock. bool isTemporalDivergent(const BlockT &ObservingBlock, - ConstValueRefT Val) const; + const InstructionT &Def) const; }; template @@ -1091,6 +1094,20 @@ return Ext; } +template +bool GenericUniformityAnalysisImpl::isTemporalDivergent( + const BlockT &ObservingBlock, const InstructionT &Def) const { + const BlockT *DefBlock = Def.getParent(); + for (const CycleT *Cycle = CI.getCycle(DefBlock); + Cycle && !Cycle->contains(&ObservingBlock); + Cycle = Cycle->getParentCycle()) { + if (DivergentExitCycles.contains(Cycle)) { + return true; + } + } + return false; +} + template void GenericUniformityAnalysisImpl::analyzeControlDivergence( const InstructionT &Term) { @@ -1273,6 +1290,11 @@ return DA->isDivergent(*I); } +template +bool GenericUniformityInfo::isDivergentUse(const UseT &U) const { + return DA->isDivergentUse(U); +} + template bool GenericUniformityInfo::hasDivergentTerminator(const BlockT &B) { return DA->hasDivergentTerminator(B); diff --git a/llvm/include/llvm/ADT/GenericUniformityInfo.h b/llvm/include/llvm/ADT/GenericUniformityInfo.h --- a/llvm/include/llvm/ADT/GenericUniformityInfo.h +++ b/llvm/include/llvm/ADT/GenericUniformityInfo.h @@ -36,6 +36,7 @@ using FunctionT = typename ContextT::FunctionT; using ValueRefT = typename ContextT::ValueRefT; using ConstValueRefT = typename ContextT::ConstValueRefT; + using UseT = typename ContextT::UseT; using InstructionT = typename ContextT::InstructionT; using DominatorTreeT = typename ContextT::DominatorTreeT; using ThisT = GenericUniformityInfo; @@ -69,6 +70,10 @@ bool isUniform(const InstructionT *I) const { return !isDivergent(I); }; bool isDivergent(const InstructionT *I) const; + /// \brief Whether \p U is divergent. Uses of a uniform value can be + /// divergent. + bool isDivergentUse(const UseT &U) const; + bool hasDivergentTerminator(const BlockT &B); void print(raw_ostream &Out) const; diff --git a/llvm/include/llvm/CodeGen/MachineSSAContext.h b/llvm/include/llvm/CodeGen/MachineSSAContext.h --- a/llvm/include/llvm/CodeGen/MachineSSAContext.h +++ b/llvm/include/llvm/CodeGen/MachineSSAContext.h @@ -45,6 +45,7 @@ using ValueRefT = Register; using ConstValueRefT = Register; static const Register ValueRefNull; + using UseT = MachineOperand; using DominatorTreeT = DominatorTreeBase; void setFunction(MachineFunction &Fn); diff --git a/llvm/include/llvm/IR/SSAContext.h b/llvm/include/llvm/IR/SSAContext.h --- a/llvm/include/llvm/IR/SSAContext.h +++ b/llvm/include/llvm/IR/SSAContext.h @@ -44,6 +44,7 @@ using ValueRefT = Value *; using ConstValueRefT = const Value *; static Value *ValueRefNull; + using UseT = Use; using DominatorTreeT = DominatorTreeBase; void setFunction(Function &Fn); diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp --- a/llvm/lib/Analysis/UniformityAnalysis.cpp +++ b/llvm/lib/Analysis/UniformityAnalysis.cpp @@ -78,6 +78,19 @@ return false; } +template <> +bool llvm::GenericUniformityAnalysisImpl::isDivergentUse( + const Use &U) const { + const auto *V = U.get(); + if (isDivergent(V)) + return true; + if (const auto *DefInstr = dyn_cast(V)) { + const auto *UseInstr = cast(U.getUser()); + return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); + } + return false; +} + // This ensures explicit instantiation of // GenericUniformityAnalysisImpl::ImplDeleter::operator() template class llvm::GenericUniformityInfo; @@ -122,6 +135,7 @@ INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity", "Uniformity Analysis", true, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity", "Uniformity Analysis", true, true) @@ -129,7 +143,7 @@ void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired(); - AU.addRequired(); + AU.addRequiredTransitive(); AU.addRequired(); } diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp --- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp +++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp @@ -113,6 +113,26 @@ return false; } +template <> +bool llvm::GenericUniformityAnalysisImpl::isDivergentUse( + const MachineOperand &U) const { + if (!U.isReg()) + return false; + + auto Reg = U.getReg(); + if (isDivergent(Reg)) + return true; + + const auto &RegInfo = F.getRegInfo(); + auto *Def = RegInfo.getOneDef(Reg); + if (!Def) + return true; + + auto *DefInstr = Def->getParent(); + auto *UseInstr = U.getParent(); + return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); +} + // This ensures explicit instantiation of // GenericUniformityAnalysisImpl::ImplDeleter::operator() template class llvm::GenericUniformityInfo; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUAtomicOptimizer.cpp b/llvm/lib/Target/AMDGPU/AMDGPUAtomicOptimizer.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUAtomicOptimizer.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUAtomicOptimizer.cpp @@ -15,7 +15,7 @@ #include "AMDGPU.h" #include "GCNSubtarget.h" -#include "llvm/Analysis/LegacyDivergenceAnalysis.h" +#include "llvm/Analysis/UniformityAnalysis.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" @@ -42,7 +42,7 @@ public InstVisitor { private: SmallVector ToReplace; - const LegacyDivergenceAnalysis *DA; + const UniformityInfo *UA; const DataLayout *DL; DominatorTree *DT; const GCNSubtarget *ST; @@ -65,7 +65,7 @@ void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addPreserved(); - AU.addRequired(); + AU.addRequired(); AU.addRequired(); } @@ -84,7 +84,7 @@ return false; } - DA = &getAnalysis(); + UA = &getAnalysis().getUniformityInfo(); DL = &F.getParent()->getDataLayout(); DominatorTreeWrapperPass *const DTW = getAnalysisIfAvailable(); @@ -139,11 +139,11 @@ // If the pointer operand is divergent, then each lane is doing an atomic // operation on a different address, and we cannot optimize that. - if (DA->isDivergentUse(&I.getOperandUse(PtrIdx))) { + if (UA->isDivergentUse(I.getOperandUse(PtrIdx))) { return; } - const bool ValDivergent = DA->isDivergentUse(&I.getOperandUse(ValIdx)); + const bool ValDivergent = UA->isDivergentUse(I.getOperandUse(ValIdx)); // If the value operand is divergent, each lane is contributing a different // value to the atomic calculation. We can only optimize divergent values if @@ -217,7 +217,7 @@ const unsigned ValIdx = 0; - const bool ValDivergent = DA->isDivergentUse(&I.getOperandUse(ValIdx)); + const bool ValDivergent = UA->isDivergentUse(I.getOperandUse(ValIdx)); // If the value operand is divergent, each lane is contributing a different // value to the atomic calculation. We can only optimize divergent values if @@ -231,7 +231,7 @@ // If any of the other arguments to the intrinsic are divergent, we can't // optimize the operation. for (unsigned Idx = 1; Idx < I.getNumOperands(); Idx++) { - if (DA->isDivergentUse(&I.getOperandUse(Idx))) { + if (UA->isDivergentUse(I.getOperandUse(Idx))) { return; } } @@ -705,7 +705,7 @@ INITIALIZE_PASS_BEGIN(AMDGPUAtomicOptimizer, DEBUG_TYPE, "AMDGPU atomic optimizations", false, false) -INITIALIZE_PASS_DEPENDENCY(LegacyDivergenceAnalysis) +INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) INITIALIZE_PASS_END(AMDGPUAtomicOptimizer, DEBUG_TYPE, "AMDGPU atomic optimizations", false, false) diff --git a/llvm/test/CodeGen/AMDGPU/divergence-at-use.ll b/llvm/test/CodeGen/AMDGPU/divergence-at-use.ll --- a/llvm/test/CodeGen/AMDGPU/divergence-at-use.ll +++ b/llvm/test/CodeGen/AMDGPU/divergence-at-use.ll @@ -1,5 +1,4 @@ ; RUN: llc -march=amdgcn -mcpu=gfx900 -amdgpu-atomic-optimizations=true < %s | FileCheck %s -; RUN: llc -march=amdgcn -mcpu=gfx900 -amdgpu-atomic-optimizations=true < %s -use-gpu-divergence-analysis | FileCheck %s @local = addrspace(3) global i32 undef @@ -20,4 +19,26 @@ ret void } +define amdgpu_kernel void @def_in_nested_cycle() { +; CHECK-LABEL: def_in_nested_cycle: +; CHECK-NOT: dpp +entry: + %x = call i32 @llvm.amdgcn.workitem.id.x() + br label %loop +loop: + %i = phi i32 [ 0, %entry ], [ 0, %innerloop ], [ %i1, %loop ] + %cond = icmp ult i32 %i, %x + %i1 = add i32 %i, 1 + br i1 %cond, label %innerloop, label %loop +innerloop: + %i.inner = phi i32 [ 0, %loop ], [ %i1.inner, %innerloop ] + %gep = getelementptr i32, ptr addrspace(3) @local, i32 %i + %i1.inner = add i32 %i, 1 + %cond.inner = icmp ult i32 %i, %x + br i1 %cond, label %innerloop, label %loop +exit: + %old = atomicrmw add ptr addrspace(3) %gep, i32 %x acq_rel + ret void +} + declare i32 @llvm.amdgcn.workitem.id.x()