diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -331,6 +331,18 @@ } } +// Return the top-most loop containing ExitBB and having ExitBB as exiting block +// or the loop containing ExitBB, if there is no parent loop containing ExitBB +// as exiting block. +static Loop *getTopMostExitingLoop(BasicBlock *ExitBB, LoopInfo &LI) { + Loop *TopMost = LI.getLoopFor(ExitBB); + Loop *Current = TopMost; + while ((Current = Current->getParentLoop())) + if (Current->isLoopExiting(ExitBB)) + TopMost = Current; + return TopMost; +} + /// Unswitch a trivial branch if the condition is loop invariant. /// /// This routine should only be called when loop code leading to the branch has @@ -415,9 +427,10 @@ }); // If we have scalar evolutions, we need to invalidate them including this - // loop and the loop containing the exit block. + // loop, the loop containing the exit block and the topmost parent loop + // exiting via LoopExitBB. if (SE) { - if (Loop *ExitL = LI.getLoopFor(LoopExitBB)) + if (Loop *ExitL = getTopMostExitingLoop(LoopExitBB, LI)) SE->forgetLoop(ExitL); else // Forget the entire nest as this exits the entire nest. diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/preserve-scev-exiting-multiple-loops.ll b/llvm/test/Transforms/SimpleLoopUnswitch/preserve-scev-exiting-multiple-loops.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/SimpleLoopUnswitch/preserve-scev-exiting-multiple-loops.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py + +; We run -indvars before -simple-loop-unswitch to compute SCEV exit counts before +; running -simple-loop-unswitch. +; RUN: opt -indvars -simple-loop-unswitch -S %s -verify-scev | FileCheck %s + +; Test for PR43972. + +; We have a 3 nested loops (l1 <- l2 <- l3). %for.cond.5 is the exit block of +; l3 and the loop for it is l2. But it is also the exiting block of l1. That +; means we have to invalidate l1 to preserve SCEV. + +define void @f() { +; CHECK-LABEL: @f( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[LNOT:%.*]] = xor i1 undef, true +; CHECK-NEXT: br label [[FOR_COND:%.*]] +; CHECK: for.cond.loopexit: +; CHECK-NEXT: br label [[FOR_COND]] +; CHECK: for.cond: +; CHECK-NEXT: br label [[FOR_BODY:%.*]] +; CHECK: for.cond1: +; CHECK-NEXT: br i1 true, label [[FOR_BODY]], label [[FOR_COND_LOOPEXIT:%.*]] +; CHECK: for.body: +; CHECK-NEXT: br i1 [[LNOT]], label [[FOR_BODY_SPLIT:%.*]], label [[FOR_COND5_SPLIT:%.*]] +; CHECK: for.body.split: +; CHECK-NEXT: br label [[LAND_RHS:%.*]] +; CHECK: for.cond2: +; CHECK-NEXT: br i1 true, label [[LAND_RHS]], label [[FOR_COND5:%.*]] +; CHECK: land.rhs: +; CHECK-NEXT: br label [[FOR_COND2:%.*]] +; CHECK: for.cond5: +; CHECK-NEXT: br label [[FOR_COND5_SPLIT]] +; CHECK: for.cond5.split: +; CHECK-NEXT: br i1 true, label [[FOR_BODY7:%.*]], label [[FOR_COND1:%.*]] +; CHECK: for.body7: +; CHECK-NEXT: ret void +; +entry: + %lnot = xor i1 undef, true + br label %for.cond + +for.cond: ; preds = %for.cond1, %entry + br label %for.body + +for.cond1: ; preds = %for.cond5 + br i1 true, label %for.body, label %for.cond + +for.body: ; preds = %for.cond1, %for.cond + br label %land.rhs + +for.cond2: ; preds = %land.rhs + br i1 true, label %land.rhs, label %for.cond5 + +land.rhs: ; preds = %for.cond2, %for.body + br i1 %lnot, label %for.cond2, label %for.cond5 + +for.cond5: ; preds = %land.rhs, %for.cond2 + br i1 true, label %for.body7, label %for.cond1 + +for.body7: ; preds = %for.cond5 + ret void +}