Index: llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -181,8 +181,11 @@ /// (splitting the exit block as necessary). It simplifies the branch within /// the loop to an unconditional branch but doesn't remove it entirely. Further /// cleanup can be done with some simplify-cfg like pass. +/// +/// If `SE` is not null, it will be updated based on the potential loop SCEVs +/// invalidated by this. static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, ScalarEvolution *SE) { assert(BI.isConditional() && "Can only unswitch a conditional branch!"); LLVM_DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n"); @@ -215,6 +218,16 @@ LLVM_DEBUG(dbgs() << " unswitching trivial branch when: " << CondVal << " == " << LoopCond << "\n"); + // If we have scalar evolutions, we need to invalidate them including this + // loop and the loop containing the exit block. + if (SE) { + if (Loop *ExitL = LI.getLoopFor(LoopExitBB)) + SE->forgetLoop(ExitL); + else + // Forget the entire nest as this exits the entire nest. + SE->forgetTopmostLoop(&L); + } + // Split the preheader, so that we know that there is a safe place to insert // the conditional branch. We will change the preheader to have a conditional // branch on LoopCond. @@ -290,8 +303,11 @@ /// switch will not be revisited. If after unswitching there is only a single /// in-loop successor, the switch is further simplified to an unconditional /// branch. Still more cleanup can be done with some simplify-cfg like pass. +/// +/// If `SE` is not null, it will be updated based on the potential loop SCEVs +/// invalidated by this. static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, ScalarEvolution *SE) { LLVM_DEBUG(dbgs() << " Trying to unswitch switch: " << SI << "\n"); Value *LoopCond = SI.getCondition(); @@ -318,18 +334,33 @@ LLVM_DEBUG(dbgs() << " unswitching trivial cases...\n"); + // We may need to invalidate SCEVs for the outermost loop reached by any of + // the exits. + Loop *OuterL = &L; + SmallVector, 4> ExitCases; ExitCases.reserve(ExitCaseIndices.size()); // We walk the case indices backwards so that we remove the last case first // and don't disrupt the earlier indices. for (unsigned Index : reverse(ExitCaseIndices)) { auto CaseI = SI.case_begin() + Index; + // Compute the outer loop from this exit. + Loop *ExitL = LI.getLoopFor(CaseI->getCaseSuccessor()); + if (!ExitL || ExitL->contains(OuterL)) + OuterL = ExitL; // Save the value of this case. ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()}); // Delete the unswitched cases. SI.removeCase(CaseI); } + if (SE) { + if (OuterL) + SE->forgetLoop(OuterL); + else + SE->forgetTopmostLoop(&L); + } + // Check if after this all of the remaining cases point at the same // successor. BasicBlock *CommonSuccBB = nullptr; @@ -487,8 +518,11 @@ /// /// The return value indicates whether anything was unswitched (and therefore /// changed). +/// +/// If `SE` is not null, it will be updated based on the potential loop SCEVs +/// invalidated by this. static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, ScalarEvolution *SE) { bool Changed = false; // If loop header has only one reachable successor we should keep looking for @@ -522,7 +556,7 @@ if (isa(SI->getCondition())) return Changed; - if (!unswitchTrivialSwitch(L, *SI, DT, LI)) + if (!unswitchTrivialSwitch(L, *SI, DT, LI, SE)) // Coludn't unswitch this one so we're done. return Changed; @@ -554,7 +588,7 @@ // Found a trivial condition candidate: non-foldable conditional branch. If // we fail to unswitch this, we can't do anything else that is trivial. - if (!unswitchTrivialBranch(L, *BI, DT, LI)) + if (!unswitchTrivialBranch(L, *BI, DT, LI, SE)) return Changed; // Mark that we managed to unswitch something. @@ -1461,10 +1495,13 @@ /// /// Once unswitching has been performed it runs the provided callback to report /// the new loops and no-longer valid loops to the caller. +/// +/// If `SE` is non-null, we will update the analysis if unswitching occurs. static bool unswitchInvariantBranch( Loop &L, BranchInst &BI, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - function_ref)> UnswitchCB) { + function_ref)> UnswitchCB, + ScalarEvolution *SE) { assert(BI.isConditional() && "Can only unswitch a conditional branch!"); assert(L.isLoopInvariant(BI.getCondition()) && "Can only unswitch an invariant branch condition!"); @@ -1498,9 +1535,6 @@ SmallPtrSet ExitBlockSet(ExitBlocks.begin(), ExitBlocks.end()); - // Compute the parent loop now before we start hacking on things. - Loop *ParentL = L.getParentLoop(); - // Compute the outer-most loop containing one of our exit blocks. This is the // furthest up our loopnest which can be mutated, which we will use below to // update things. @@ -1516,6 +1550,19 @@ OuterExitL = NewOuterExitL; } + // At this point, we're definitely going to unswitch something so invalidate + // any cached information in ScalarEvolution for the outer most loop + // containing an exit block and all nested loops. + if (SE) { + if (OuterExitL) + SE->forgetLoop(OuterExitL); + else + SE->forgetTopmostLoop(&L); + } + + // Compute the parent loop now before we start hacking on things. + Loop *ParentL = L.getParentLoop(); + // If the edge we *aren't* cloning in the unswitch (the continuing edge) // dominates its target, we can skip cloning the dominated region of the loop // and its exits. We compute this as a set of nodes to be skipped. @@ -1738,10 +1785,25 @@ /// require duplicating any part of the loop) out of the loop body. It then /// looks at other loop invariant control flows and tries to unswitch those as /// well by cloning the loop if the result is small enough. -static bool -unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - TargetTransformInfo &TTI, bool NonTrivial, - function_ref)> UnswitchCB) { +/// +/// The `DT`, `LI`, `AC`, `TTI` parameters are required analyses that are also +/// updated based on the unswitch. +/// +/// If either `NonTrivial` is true or the flag `EnableNonTrivialUnswitch` is +/// true, we will attempt to do non-trivial unswitching as well as trivial +/// unswitching. +/// +/// The `UnswitchCB` callback provided will be run after unswitching is +/// complete, with the first parameter set to `true` if the provided loop +/// remains a loop, and a list of new sibling loops created. +/// +/// If `SE` is non-null, we will update that analysis based on the unswitching +/// done. +static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, TargetTransformInfo &TTI, + bool NonTrivial, + function_ref)> UnswitchCB, + ScalarEvolution *SE) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); @@ -1750,7 +1812,7 @@ return false; // Try trivial unswitch first before loop over other basic blocks in the loop. - if (unswitchAllTrivialConditions(L, DT, LI)) { + if (unswitchAllTrivialConditions(L, DT, LI, SE)) { // If we unswitched successfully we will want to clean up the loop before // processing it further so just mark it as unswitched and return. UnswitchCB(/*CurrentLoopValid*/ true, {}); @@ -1899,7 +1961,7 @@ << BestUnswitchCost << ") branch: " << *BestUnswitchTI << "\n"); return unswitchInvariantBranch(L, cast(*BestUnswitchTI), DT, LI, - AC, UnswitchCB); + AC, UnswitchCB, SE); } PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, @@ -1930,7 +1992,7 @@ }; if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, - UnswitchCB)) + UnswitchCB, &AR.SE)) return PreservedAnalyses::all(); // Historically this pass has had issues with the dominator tree so verify it @@ -1978,6 +2040,9 @@ auto &AC = getAnalysis().getAssumptionCache(F); auto &TTI = getAnalysis().getTTI(F); + auto *SEWP = getAnalysisIfAvailable(); + auto *SE = SEWP ? &SEWP->getSE() : nullptr; + auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid, ArrayRef NewLoops) { // If we did a non-trivial unswitch, we have added new (cloned) loops. @@ -1994,7 +2059,7 @@ }; bool Changed = - unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB); + unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB, SE); // If anything was unswitched, also clear any cached information about this // loop. Index: llvm/test/Transforms/SimpleLoopUnswitch/update-scev.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimpleLoopUnswitch/update-scev.ll @@ -0,0 +1,188 @@ +; RUN: opt -passes='print,loop(unswitch,loop-instsimplify),print' -enable-nontrivial-unswitch -S < %s 2>%t.scev | FileCheck %s +; RUN: FileCheck %s --check-prefix=SCEV < %t.scev + +target triple = "x86_64-unknown-linux-gnu" + +declare void @f() + +; Check that trivially unswitching an inner loop resets both the inner and outer +; loop trip count. +define void @test1(i32 %n, i32 %m, i1 %cond) { +; Check that SCEV has no trip count before unswitching. +; SCEV-LABEL: Determining loop execution counts for: @test1 +; SCEV: Loop %inner_loop_begin: Unpredictable backedge-taken count. +; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count. +; +; Now check that after unswitching and simplifying instructions we get clean +; backedge-taken counts. +; SCEV-LABEL: Determining loop execution counts for: @test1 +; SCEV: Loop %inner_loop_begin: backedge-taken count is (-1 + (1 smax %m)) +; SCEV: Loop %outer_loop_begin: backedge-taken count is (-1 + (1 smax %n)) +; +; And verify the code matches what we expect. +; CHECK-LABEL: define void @test1( +entry: + br label %outer_loop_begin +; Ensure the outer loop didn't get unswitched. +; CHECK: entry: +; CHECK-NEXT: br label %outer_loop_begin + +outer_loop_begin: + %i = phi i32 [ %i.next, %outer_loop_latch ], [ 0, %entry ] + ; Block unswitching of the outer loop with a noduplicate call. + call void @f() noduplicate + br label %inner_loop_begin +; Ensure the inner loop got unswitched into the outer loop. +; CHECK: outer_loop_begin: +; CHECK-NEXT: %{{.*}} = phi i32 +; CHECK-NEXT: call void @f() +; CHECK-NEXT: br i1 %cond, + +inner_loop_begin: + %j = phi i32 [ %j.next, %inner_loop_latch ], [ 0, %outer_loop_begin ] + br i1 %cond, label %inner_loop_latch, label %inner_loop_early_exit + +inner_loop_latch: + %j.next = add nsw i32 %j, 1 + %j.cmp = icmp slt i32 %j.next, %m + br i1 %j.cmp, label %inner_loop_begin, label %inner_loop_late_exit + +inner_loop_early_exit: + %j.lcssa = phi i32 [ %i, %inner_loop_begin ] + br label %outer_loop_latch + +inner_loop_late_exit: + br label %outer_loop_latch + +outer_loop_latch: + %i.phi = phi i32 [ %j.lcssa, %inner_loop_early_exit ], [ %i, %inner_loop_late_exit ] + %i.next = add nsw i32 %i.phi, 1 + %i.cmp = icmp slt i32 %i.next, %n + br i1 %i.cmp, label %outer_loop_begin, label %exit + +exit: + ret void +} + +; Check that trivially unswitching an inner loop resets both the inner and outer +; loop trip count. +define void @test2(i32 %n, i32 %m, i32 %cond) { +; Check that SCEV has no trip count before unswitching. +; SCEV-LABEL: Determining loop execution counts for: @test2 +; SCEV: Loop %inner_loop_begin: Unpredictable backedge-taken count. +; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count. +; +; Now check that after unswitching and simplifying instructions we get clean +; backedge-taken counts. +; SCEV-LABEL: Determining loop execution counts for: @test2 +; SCEV: Loop %inner_loop_begin: backedge-taken count is (-1 + (1 smax %m)) +; FIXME: The following backedge taken count should be known but isn't apparently +; just because of a switch in the outer loop. +; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count. +; +; CHECK-LABEL: define void @test2( +entry: + br label %outer_loop_begin +; Ensure the outer loop didn't get unswitched. +; CHECK: entry: +; CHECK-NEXT: br label %outer_loop_begin + +outer_loop_begin: + %i = phi i32 [ %i.next, %outer_loop_latch ], [ 0, %entry ] + ; Block unswitching of the outer loop with a noduplicate call. + call void @f() noduplicate + br label %inner_loop_begin +; Ensure the inner loop got unswitched into the outer loop. +; CHECK: outer_loop_begin: +; CHECK-NEXT: %{{.*}} = phi i32 +; CHECK-NEXT: call void @f() +; CHECK-NEXT: switch i32 %cond, + +inner_loop_begin: + %j = phi i32 [ %j.next, %inner_loop_latch ], [ 0, %outer_loop_begin ] + switch i32 %cond, label %inner_loop_early_exit [ + i32 1, label %inner_loop_latch + i32 2, label %inner_loop_latch + ] + +inner_loop_latch: + %j.next = add nsw i32 %j, 1 + %j.cmp = icmp slt i32 %j.next, %m + br i1 %j.cmp, label %inner_loop_begin, label %inner_loop_late_exit + +inner_loop_early_exit: + %j.lcssa = phi i32 [ %i, %inner_loop_begin ] + br label %outer_loop_latch + +inner_loop_late_exit: + br label %outer_loop_latch + +outer_loop_latch: + %i.phi = phi i32 [ %j.lcssa, %inner_loop_early_exit ], [ %i, %inner_loop_late_exit ] + %i.next = add nsw i32 %i.phi, 1 + %i.cmp = icmp slt i32 %i.next, %n + br i1 %i.cmp, label %outer_loop_begin, label %exit + +exit: + ret void +} + +; Check that non-trivial unswitching of a branch in an inner loop into the outer +; loop invalidates both inner and outer. +define void @test3(i32 %n, i32 %m, i1 %cond) { +; Check that SCEV has no trip count before unswitching. +; SCEV-LABEL: Determining loop execution counts for: @test3 +; SCEV: Loop %inner_loop_begin: Unpredictable backedge-taken count. +; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count. +; +; Now check that after unswitching and simplifying instructions we get clean +; backedge-taken counts. +; SCEV-LABEL: Determining loop execution counts for: @test3 +; SCEV: Loop %inner_loop_begin{{.*}}: backedge-taken count is (-1 + (1 smax %m)) +; SCEV: Loop %outer_loop_begin: backedge-taken count is (-1 + (1 smax %n)) +; +; And verify the code matches what we expect. +; CHECK-LABEL: define void @test3( +entry: + br label %outer_loop_begin +; Ensure the outer loop didn't get unswitched. +; CHECK: entry: +; CHECK-NEXT: br label %outer_loop_begin + +outer_loop_begin: + %i = phi i32 [ %i.next, %outer_loop_latch ], [ 0, %entry ] + ; Block unswitching of the outer loop with a noduplicate call. + call void @f() noduplicate + br label %inner_loop_begin +; Ensure the inner loop got unswitched into the outer loop. +; CHECK: outer_loop_begin: +; CHECK-NEXT: %{{.*}} = phi i32 +; CHECK-NEXT: call void @f() +; CHECK-NEXT: br i1 %cond, + +inner_loop_begin: + %j = phi i32 [ %j.next, %inner_loop_latch ], [ 0, %outer_loop_begin ] + %j.tmp = add nsw i32 %j, 1 + br i1 %cond, label %inner_loop_latch, label %inner_loop_early_exit + +inner_loop_latch: + %j.next = add nsw i32 %j, 1 + %j.cmp = icmp slt i32 %j.next, %m + br i1 %j.cmp, label %inner_loop_begin, label %inner_loop_late_exit + +inner_loop_early_exit: + %j.lcssa = phi i32 [ %j.tmp, %inner_loop_begin ] + br label %outer_loop_latch + +inner_loop_late_exit: + br label %outer_loop_latch + +outer_loop_latch: + %inc.phi = phi i32 [ %j.lcssa, %inner_loop_early_exit ], [ 1, %inner_loop_late_exit ] + %i.next = add nsw i32 %i, %inc.phi + %i.cmp = icmp slt i32 %i.next, %n + br i1 %i.cmp, label %outer_loop_begin, label %exit + +exit: + ret void +} \ No newline at end of file