diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -2329,11 +2329,44 @@ return MadeAnyChanges; } +// Returns true if the condition of \p BI being checked is invariant and can be +// proved to be trivially true. +static bool isTrivialCond(const Loop *L, BranchInst *BI, ScalarEvolution *SE, + bool ProvingLoopExit) { + ICmpInst::Predicate Pred; + Value *LHS, *RHS; + using namespace PatternMatch; + BasicBlock *TrueSucc, *FalseSucc; + if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), + m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) + return false; + + assert((L->contains(TrueSucc) != L->contains(FalseSucc)) && + "Not a loop exit!"); + + // 'LHS pred RHS' should now mean that we stay in loop. + if (L->contains(FalseSucc)) + Pred = CmpInst::getInversePredicate(Pred); + + // If we are proving loop exit, invert the predicate. + if (ProvingLoopExit) + Pred = CmpInst::getInversePredicate(Pred); + + const SCEV *LHSS = SE->getSCEVAtScope(LHS, L); + const SCEV *RHSS = SE->getSCEVAtScope(RHS, L); + // Can we prove it to be trivially true? + if (SE->isKnownPredicate(Pred, LHSS, RHSS)) + return true; + + return false; +} + bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { SmallVector ExitingBlocks; L->getExitingBlocks(ExitingBlocks); - // Remove all exits which aren't both rewriteable and analyzeable. + // Remove all exits which aren't both rewriteable and execute on every + // iteration. auto NewEnd = llvm::remove_if(ExitingBlocks, [&](BasicBlock *ExitingBB) { // If our exitting block exits multiple loops, we can only rewrite the // innermost one. Otherwise, we're changing how many times the innermost @@ -2350,9 +2383,10 @@ if (isa(BI->getCondition())) return true; - const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); - if (isa(ExitCount)) + // Likewise, the loop latch must be dominated by the exiting BB. + if (!DT->dominates(ExitingBB, L->getLoopLatch())) return true; + return false; }); ExitingBlocks.erase(NewEnd, ExitingBlocks.end()); @@ -2365,10 +2399,9 @@ if (isa(MaxExitCount)) return false; - // Visit our exit blocks in order of dominance. We know from the fact that - // all exits (left) are analyzeable that the must be a total dominance order - // between them as each must dominate the latch. The visit order only - // matters for the provably equal case. + // Visit our exit blocks in order of dominance. We know from the fact that + // all exits must dominate the latch, so there is a total dominance order + // between them. llvm::sort(ExitingBlocks, [&](BasicBlock *A, BasicBlock *B) { // std::sort sorts in ascending order, so we want the inverse of @@ -2399,7 +2432,20 @@ SmallSet DominatingExitCounts; for (BasicBlock *ExitingBB : ExitingBlocks) { const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); - assert(!isa(ExitCount) && "checked above"); + if (isa(ExitCount)) { + // Okay, we do not know the exit count here. Can we at least prove that it + // will remain the same within iteration space? + auto *BI = cast(ExitingBB->getTerminator()); + if (isTrivialCond(L, BI, SE, false)) { + FoldExit(ExitingBB, false); + Changed = true; + } + if (isTrivialCond(L, BI, SE, true)) { + FoldExit(ExitingBB, true); + Changed = true; + } + continue; + } // If we know we'd exit on the first iteration, rewrite the exit to // reflect this. This does not imply the loop must exit through this diff --git a/llvm/test/Transforms/IndVarSimplify/eliminate-comparison.ll b/llvm/test/Transforms/IndVarSimplify/eliminate-comparison.ll --- a/llvm/test/Transforms/IndVarSimplify/eliminate-comparison.ll +++ b/llvm/test/Transforms/IndVarSimplify/eliminate-comparison.ll @@ -176,32 +176,18 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[FORCOND:%.*]] ; CHECK: forcond: -; CHECK-NEXT: [[__KEY6_0:%.*]] = phi i32 [ 2, [[ENTRY:%.*]] ], [ [[TMP37:%.*]], [[NOASSERT:%.*]] ] -; CHECK-NEXT: [[EXITCOND1:%.*]] = icmp ne i32 [[__KEY6_0]], 10 -; CHECK-NEXT: br i1 [[EXITCOND1]], label [[NOASSERT]], label [[FORCOND38_PREHEADER:%.*]] +; CHECK-NEXT: br i1 false, label [[NOASSERT:%.*]], label [[FORCOND38_PREHEADER:%.*]] ; CHECK: forcond38.preheader: ; CHECK-NEXT: br label [[FORCOND38:%.*]] ; CHECK: noassert: -; CHECK-NEXT: [[TMP13:%.*]] = sdiv i32 -32768, [[__KEY6_0]] -; CHECK-NEXT: [[TMP2936:%.*]] = shl i32 [[TMP13]], 24 -; CHECK-NEXT: [[SEXT23:%.*]] = shl i32 [[TMP13]], 24 -; CHECK-NEXT: [[TMP32:%.*]] = icmp eq i32 [[TMP2936]], [[SEXT23]] -; CHECK-NEXT: [[TMP37]] = add nuw nsw i32 [[__KEY6_0]], 1 -; CHECK-NEXT: br i1 [[TMP32]], label [[FORCOND]], label [[ASSERT33:%.*]] +; CHECK-NEXT: br i1 true, label [[FORCOND]], label [[ASSERT33:%.*]] ; CHECK: assert33: ; CHECK-NEXT: tail call void @llvm.trap() ; CHECK-NEXT: unreachable ; CHECK: forcond38: -; CHECK-NEXT: [[__KEY8_0:%.*]] = phi i32 [ [[TMP81:%.*]], [[NOASSERT68:%.*]] ], [ 2, [[FORCOND38_PREHEADER]] ] -; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i32 [[__KEY8_0]], 10 -; CHECK-NEXT: br i1 [[EXITCOND]], label [[NOASSERT68]], label [[UNROLLEDEND:%.*]] +; CHECK-NEXT: br i1 false, label [[NOASSERT68:%.*]], label [[UNROLLEDEND:%.*]] ; CHECK: noassert68: -; CHECK-NEXT: [[TMP57:%.*]] = sdiv i32 -32768, [[__KEY8_0]] -; CHECK-NEXT: [[SEXT34:%.*]] = shl i32 [[TMP57]], 16 -; CHECK-NEXT: [[SEXT21:%.*]] = shl i32 [[TMP57]], 16 -; CHECK-NEXT: [[TMP76:%.*]] = icmp eq i32 [[SEXT34]], [[SEXT21]] -; CHECK-NEXT: [[TMP81]] = add nuw nsw i32 [[__KEY8_0]], 1 -; CHECK-NEXT: br i1 [[TMP76]], label [[FORCOND38]], label [[ASSERT77:%.*]] +; CHECK-NEXT: br i1 true, label [[FORCOND38]], label [[ASSERT77:%.*]] ; CHECK: assert77: ; CHECK-NEXT: tail call void @llvm.trap() ; CHECK-NEXT: unreachable @@ -252,6 +238,73 @@ ret i32 0 } +define i32 @func_11_flipped() nounwind uwtable { +; CHECK-LABEL: @func_11_flipped( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[FORCOND:%.*]] +; CHECK: forcond: +; CHECK-NEXT: br i1 true, label [[FORCOND38_PREHEADER:%.*]], label [[NOASSERT:%.*]] +; CHECK: forcond38.preheader: +; CHECK-NEXT: br label [[FORCOND38:%.*]] +; CHECK: noassert: +; CHECK-NEXT: br i1 true, label [[FORCOND]], label [[ASSERT33:%.*]] +; CHECK: assert33: +; CHECK-NEXT: tail call void @llvm.trap() +; CHECK-NEXT: unreachable +; CHECK: forcond38: +; CHECK-NEXT: br i1 false, label [[NOASSERT68:%.*]], label [[UNROLLEDEND:%.*]] +; CHECK: noassert68: +; CHECK-NEXT: br i1 true, label [[FORCOND38]], label [[ASSERT77:%.*]] +; CHECK: assert77: +; CHECK-NEXT: tail call void @llvm.trap() +; CHECK-NEXT: unreachable +; CHECK: unrolledend: +; CHECK-NEXT: ret i32 0 +; +entry: + br label %forcond + +forcond: ; preds = %noassert, %entry + %__key6.0 = phi i32 [ 2, %entry ], [ %tmp37, %noassert ] + %tmp5 = icmp sge i32 %__key6.0, 10 + br i1 %tmp5, label %forcond38.preheader, label %noassert + +forcond38.preheader: ; preds = %forcond + br label %forcond38 + +noassert: ; preds = %forbody + %tmp13 = sdiv i32 -32768, %__key6.0 + %tmp2936 = shl i32 %tmp13, 24 + %sext23 = shl i32 %tmp13, 24 + %tmp32 = icmp eq i32 %tmp2936, %sext23 + %tmp37 = add i32 %__key6.0, 1 + br i1 %tmp32, label %forcond, label %assert33 + +assert33: ; preds = %noassert + tail call void @llvm.trap() + unreachable + +forcond38: ; preds = %noassert68, %forcond38.preheader + %__key8.0 = phi i32 [ %tmp81, %noassert68 ], [ 2, %forcond38.preheader ] + %tmp46 = icmp slt i32 %__key8.0, 10 + br i1 %tmp46, label %noassert68, label %unrolledend + +noassert68: ; preds = %forbody39 + %tmp57 = sdiv i32 -32768, %__key8.0 + %sext34 = shl i32 %tmp57, 16 + %sext21 = shl i32 %tmp57, 16 + %tmp76 = icmp eq i32 %sext34, %sext21 + %tmp81 = add i32 %__key8.0, 1 + br i1 %tmp76, label %forcond38, label %assert77 + +assert77: ; preds = %noassert68 + tail call void @llvm.trap() + unreachable + +unrolledend: ; preds = %forcond38 + ret i32 0 +} + declare void @llvm.trap() noreturn nounwind ; In this case the second loop only has a single iteration, fold the header away @@ -260,18 +313,11 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[FORCOND:%.*]] ; CHECK: forcond: -; CHECK-NEXT: [[__KEY6_0:%.*]] = phi i32 [ 2, [[ENTRY:%.*]] ], [ [[TMP37:%.*]], [[NOASSERT:%.*]] ] -; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i32 [[__KEY6_0]], 10 -; CHECK-NEXT: br i1 [[EXITCOND]], label [[NOASSERT]], label [[FORCOND38_PREHEADER:%.*]] +; CHECK-NEXT: br i1 false, label [[NOASSERT:%.*]], label [[FORCOND38_PREHEADER:%.*]] ; CHECK: forcond38.preheader: ; CHECK-NEXT: br label [[FORCOND38:%.*]] ; CHECK: noassert: -; CHECK-NEXT: [[TMP13:%.*]] = sdiv i32 -32768, [[__KEY6_0]] -; CHECK-NEXT: [[TMP2936:%.*]] = shl i32 [[TMP13]], 24 -; CHECK-NEXT: [[SEXT23:%.*]] = shl i32 [[TMP13]], 24 -; CHECK-NEXT: [[TMP32:%.*]] = icmp eq i32 [[TMP2936]], [[SEXT23]] -; CHECK-NEXT: [[TMP37]] = add nuw nsw i32 [[__KEY6_0]], 1 -; CHECK-NEXT: br i1 [[TMP32]], label [[FORCOND]], label [[ASSERT33:%.*]] +; CHECK-NEXT: br i1 true, label [[FORCOND]], label [[ASSERT33:%.*]] ; CHECK: assert33: ; CHECK-NEXT: tail call void @llvm.trap() ; CHECK-NEXT: unreachable