Index: llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -18,6 +18,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DivergenceAnalysis.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopAnalysisManager.h" @@ -27,6 +28,7 @@ #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -2677,7 +2679,8 @@ AAResults &AA, TargetTransformInfo &TTI, function_ref)> UnswitchCB, ScalarEvolution *SE, MemorySSAUpdater *MSSAU, - function_ref DestroyLoopCB) { + function_ref DestroyLoopCB, + bool UseDivergenceInfo) { // Collect all invariant conditions within this loop (as opposed to an inner // loop which would be handled when visiting that inner loop). SmallVector>, 4> @@ -2795,6 +2798,18 @@ } } + if (UseDivergenceInfo) { + Function *F = L.getHeader()->getParent(); + PostDominatorTree PDT(*F); + DivergenceInfo DI(*F, DT, PDT, LI, TTI, /*KnownReducible*/ true); + llvm::erase_if(UnswitchCandidates, + [&](std::pair> Cand) { + return DI.isDivergent(*Cand.second[0]); + }); + if (UnswitchCandidates.empty()) + return false; + } + LLVM_DEBUG( dbgs() << "Considering " << UnswitchCandidates.size() << " non-trivial loop invariant conditions for unswitching.\n"); @@ -2996,7 +3011,8 @@ bool NonTrivial, function_ref)> UnswitchCB, ScalarEvolution *SE, MemorySSAUpdater *MSSAU, - function_ref DestroyLoopCB) { + function_ref DestroyLoopCB, + bool UseDivergenceInfo) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); @@ -3018,14 +3034,7 @@ // NonTrivial: Parameter that enables non-trivial unswitching for this // invocation of the transform. But this should be allowed only // for targets without branch divergence. - // - // FIXME: If divergence analysis becomes available to a loop - // transform, we should allow unswitching for non-trivial uniform - // branches even on targets that have divergence. - // https://bugs.llvm.org/show_bug.cgi?id=48819 - bool ContinueWithNonTrivial = - EnableNonTrivialUnswitch || (NonTrivial && !TTI.hasBranchDivergence()); - if (!ContinueWithNonTrivial) + if (!EnableNonTrivialUnswitch && !NonTrivial) return false; // Skip non-trivial unswitching for optsize functions. @@ -3045,7 +3054,7 @@ // Try to unswitch the best invariant condition. We prefer this full unswitch to // a partial unswitch when possible below the threshold. if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, UnswitchCB, SE, MSSAU, - DestroyLoopCB)) + DestroyLoopCB, UseDivergenceInfo)) return true; // No other opportunities to unswitch. @@ -3105,7 +3114,7 @@ if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial, UnswitchCB, &AR.SE, MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, - DestroyLoopCB)) + DestroyLoopCB, AR.TTI.useGPUDivergenceAnalysis())) return PreservedAnalyses::all(); if (AR.MSSA && VerifyMemorySSA) @@ -3195,7 +3204,8 @@ MSSA->verifyMemorySSA(); bool Changed = unswitchLoop(*L, DT, LI, AC, AA, TTI, true, NonTrivial, - UnswitchCB, SE, &MSSAU, DestroyLoopCB); + UnswitchCB, SE, &MSSAU, DestroyLoopCB, + /*UseDivergenceInfo*/ false); if (VerifyMemorySSA) MSSA->verifyMemorySSA(); Index: llvm/test/Transforms/SimpleLoopUnswitch/divergent-nontrivial-unswitch.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimpleLoopUnswitch/divergent-nontrivial-unswitch.ll @@ -0,0 +1,73 @@ +; RUN: opt -mtriple=amdgcn-unknown-amdhsa -passes='loop(simple-loop-unswitch),verify' -S < %s | FileCheck %s +; REQUIRES: amdgpu-registered-target + +; Check that non-trivial loop unswitch occurs on a target with divergence. + +; CHECK-LABEL: @nontrivial_unswitch( +; CHECK: for.body.us: +; CHECK: if.then.us: +; CHECK: for.inc.us: +; CHECK: for.body: +; CHECK: for.inc: +define amdgpu_kernel void @nontrivial_unswitch(i32 * nocapture %out, i32 %n, i1 %cond) { +entry: + br label %for.body + +for.body: + %i = phi i32 [ 0, %entry ], [ %inc, %for.inc ] + br i1 %cond, label %if.then, label %for.inc + +if.then: + %arrayidx = getelementptr inbounds i32, i32 * %out, i32 %i + store i32 %i, i32 * %arrayidx, align 4 + br label %for.inc + +for.inc: + %inc = add nuw nsw i32 %i, 1 + %exitcond = icmp eq i32 %inc, %n + br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body + +for.cond.cleanup.loopexit: + ret void +} + + +; Check that loop unswitch does not happen if the condition is divergent. + +; CHECK-LABEL: @divergent_unswitch( +; CHECK: entry: +; CHECK: [[IF_COND:%[a-z0-9]+]] = icmp {{.*}}, 1 +; CHECK: br label +; CHECK: for.body: +; CHECK: br i1 [[IF_COND]] + +define amdgpu_kernel void @divergent_unswitch(i32 * nocapture %out, i32 %n) { +entry: + br label %for.body.lr.ph + +for.body.lr.ph: + %call = tail call i32 @llvm.amdgcn.workitem.id.x() #0 + %cmp2 = icmp eq i32 %call, 1 + br label %for.body + +for.body: + %i = phi i32 [ 0, %for.body.lr.ph ], [ %inc, %for.inc ] + br i1 %cmp2, label %if.then, label %for.inc + +if.then: + %arrayidx = getelementptr inbounds i32, i32 * %out, i32 %i + store i32 %i, i32 * %arrayidx, align 4 + br label %for.inc + +for.inc: + %inc = add nuw nsw i32 %i, 1 + %exitcond = icmp eq i32 %inc, %n + br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body + +for.cond.cleanup.loopexit: + ret void +} + +declare i32 @llvm.amdgcn.workitem.id.x() #0 + +attributes #0 = { nounwind readnone }