Index: llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -253,8 +253,11 @@ /// (splitting the exit block as necessary). It simplifies the branch within /// the loop to an unconditional branch but doesn't remove it entirely. Further /// cleanup can be done with some simplify-cfg like pass. +/// +/// If `SE` is not null, it will be updated based on the potential loop SCEVs +/// invalidated by this. static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, ScalarEvolution *SE) { assert(BI.isConditional() && "Can only unswitch a conditional branch!"); LLVM_DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n"); @@ -318,6 +321,16 @@ } }); + // If we have scalar evolutions, we need to invalidate them including this + // loop and the loop containing the exit block. + if (SE) { + if (Loop *ExitL = LI.getLoopFor(LoopExitBB)) + SE->forgetLoop(ExitL); + else + // Forget the entire nest as this exits the entire nest. + SE->forgetTopmostLoop(&L); + } + // Split the preheader, so that we know that there is a safe place to insert // the conditional branch. We will change the preheader to have a conditional // branch on LoopCond. @@ -420,8 +433,11 @@ /// switch will not be revisited. If after unswitching there is only a single /// in-loop successor, the switch is further simplified to an unconditional /// branch. Still more cleanup can be done with some simplify-cfg like pass. +/// +/// If `SE` is not null, it will be updated based on the potential loop SCEVs +/// invalidated by this. static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, ScalarEvolution *SE) { LLVM_DEBUG(dbgs() << " Trying to unswitch switch: " << SI << "\n"); Value *LoopCond = SI.getCondition(); @@ -448,18 +464,33 @@ LLVM_DEBUG(dbgs() << " unswitching trivial cases...\n"); + // We may need to invalidate SCEVs for the outermost loop reached by any of + // the exits. + Loop *OuterL = &L; + SmallVector, 4> ExitCases; ExitCases.reserve(ExitCaseIndices.size()); // We walk the case indices backwards so that we remove the last case first // and don't disrupt the earlier indices. for (unsigned Index : reverse(ExitCaseIndices)) { auto CaseI = SI.case_begin() + Index; + // Compute the outer loop from this exit. + Loop *ExitL = LI.getLoopFor(CaseI->getCaseSuccessor()); + if (!ExitL || ExitL->contains(OuterL)) + OuterL = ExitL; // Save the value of this case. ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()}); // Delete the unswitched cases. SI.removeCase(CaseI); } + if (SE) { + if (OuterL) + SE->forgetLoop(OuterL); + else + SE->forgetTopmostLoop(&L); + } + // Check if after this all of the remaining cases point at the same // successor. BasicBlock *CommonSuccBB = nullptr; @@ -617,8 +648,11 @@ /// /// The return value indicates whether anything was unswitched (and therefore /// changed). +/// +/// If `SE` is not null, it will be updated based on the potential loop SCEVs +/// invalidated by this. static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, ScalarEvolution *SE) { bool Changed = false; // If loop header has only one reachable successor we should keep looking for @@ -652,7 +686,7 @@ if (isa(SI->getCondition())) return Changed; - if (!unswitchTrivialSwitch(L, *SI, DT, LI)) + if (!unswitchTrivialSwitch(L, *SI, DT, LI, SE)) // Couldn't unswitch this one so we're done. return Changed; @@ -684,7 +718,7 @@ // Found a trivial condition candidate: non-foldable conditional branch. If // we fail to unswitch this, we can't do anything else that is trivial. - if (!unswitchTrivialBranch(L, *BI, DT, LI)) + if (!unswitchTrivialBranch(L, *BI, DT, LI, SE)) return Changed; // Mark that we managed to unswitch something. @@ -1622,7 +1656,8 @@ static bool unswitchNontrivialInvariants( Loop &L, TerminatorInst &TI, ArrayRef Invariants, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - function_ref)> UnswitchCB) { + function_ref)> UnswitchCB, + ScalarEvolution *SE) { auto *ParentBB = TI.getParent(); BranchInst *BI = dyn_cast(&TI); SwitchInst *SI = BI ? nullptr : cast(&TI); @@ -1705,6 +1740,16 @@ OuterExitL = NewOuterExitL; } + // At this point, we're definitely going to unswitch something so invalidate + // any cached information in ScalarEvolution for the outer most loop + // containing an exit block and all nested loops. + if (SE) { + if (OuterExitL) + SE->forgetLoop(OuterExitL); + else + SE->forgetTopmostLoop(&L); + } + // If the edge from this terminator to a successor dominates that successor, // store a map from each block in its dominator subtree to it. This lets us // tell when cloning for a particular successor if a block is dominated by @@ -1968,10 +2013,11 @@ return Cost; } -static bool unswitchBestCondition( - Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - TargetTransformInfo &TTI, - function_ref)> UnswitchCB) { +static bool +unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, TargetTransformInfo &TTI, + function_ref)> UnswitchCB, + ScalarEvolution *SE) { // Collect all invariant conditions within this loop (as opposed to an inner // loop which would be handled when visiting that inner loop). SmallVector>, 4> @@ -2164,7 +2210,7 @@ << BestUnswitchCost << ") terminator: " << *BestUnswitchTI << "\n"); return unswitchNontrivialInvariants( - L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB); + L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB, SE); } /// Unswitch control flow predicated on loop invariant conditions. @@ -2173,10 +2219,25 @@ /// require duplicating any part of the loop) out of the loop body. It then /// looks at other loop invariant control flows and tries to unswitch those as /// well by cloning the loop if the result is small enough. -static bool -unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - TargetTransformInfo &TTI, bool NonTrivial, - function_ref)> UnswitchCB) { +/// +/// The `DT`, `LI`, `AC`, `TTI` parameters are required analyses that are also +/// updated based on the unswitch. +/// +/// If either `NonTrivial` is true or the flag `EnableNonTrivialUnswitch` is +/// true, we will attempt to do non-trivial unswitching as well as trivial +/// unswitching. +/// +/// The `UnswitchCB` callback provided will be run after unswitching is +/// complete, with the first parameter set to `true` if the provided loop +/// remains a loop, and a list of new sibling loops created. +/// +/// If `SE` is non-null, we will update that analysis based on the unswitching +/// done. +static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, TargetTransformInfo &TTI, + bool NonTrivial, + function_ref)> UnswitchCB, + ScalarEvolution *SE) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); bool Changed = false; @@ -2186,7 +2247,7 @@ return false; // Try trivial unswitch first before loop over other basic blocks in the loop. - if (unswitchAllTrivialConditions(L, DT, LI)) { + if (unswitchAllTrivialConditions(L, DT, LI, SE)) { // If we unswitched successfully we will want to clean up the loop before // processing it further so just mark it as unswitched and return. UnswitchCB(/*CurrentLoopValid*/ true, {}); @@ -2207,7 +2268,7 @@ // Try to unswitch the best invariant condition. We prefer this full unswitch to // a partial unswitch when possible below the threshold. - if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB)) + if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB, SE)) return true; // No other opportunities to unswitch. @@ -2241,8 +2302,8 @@ U.markLoopAsDeleted(L, LoopName); }; - if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, - UnswitchCB)) + if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, UnswitchCB, + &AR.SE)) return PreservedAnalyses::all(); // Historically this pass has had issues with the dominator tree so verify it @@ -2290,6 +2351,9 @@ auto &AC = getAnalysis().getAssumptionCache(F); auto &TTI = getAnalysis().getTTI(F); + auto *SEWP = getAnalysisIfAvailable(); + auto *SE = SEWP ? &SEWP->getSE() : nullptr; + auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid, ArrayRef NewLoops) { // If we did a non-trivial unswitch, we have added new (cloned) loops. @@ -2305,8 +2369,7 @@ LPM.markLoopAsDeleted(*L); }; - bool Changed = - unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB); + bool Changed = unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB, SE); // If anything was unswitched, also clear any cached information about this // loop. Index: llvm/test/Transforms/SimpleLoopUnswitch/update-scev.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimpleLoopUnswitch/update-scev.ll @@ -0,0 +1,188 @@ +; RUN: opt -passes='print,loop(unswitch,loop-instsimplify),print' -enable-nontrivial-unswitch -S < %s 2>%t.scev | FileCheck %s +; RUN: FileCheck %s --check-prefix=SCEV < %t.scev + +target triple = "x86_64-unknown-linux-gnu" + +declare void @f() + +; Check that trivially unswitching an inner loop resets both the inner and outer +; loop trip count. +define void @test1(i32 %n, i32 %m, i1 %cond) { +; Check that SCEV has no trip count before unswitching. +; SCEV-LABEL: Determining loop execution counts for: @test1 +; SCEV: Loop %inner_loop_begin: Unpredictable backedge-taken count. +; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count. +; +; Now check that after unswitching and simplifying instructions we get clean +; backedge-taken counts. +; SCEV-LABEL: Determining loop execution counts for: @test1 +; SCEV: Loop %inner_loop_begin: backedge-taken count is (-1 + (1 smax %m)) +; SCEV: Loop %outer_loop_begin: backedge-taken count is (-1 + (1 smax %n)) +; +; And verify the code matches what we expect. +; CHECK-LABEL: define void @test1( +entry: + br label %outer_loop_begin +; Ensure the outer loop didn't get unswitched. +; CHECK: entry: +; CHECK-NEXT: br label %outer_loop_begin + +outer_loop_begin: + %i = phi i32 [ %i.next, %outer_loop_latch ], [ 0, %entry ] + ; Block unswitching of the outer loop with a noduplicate call. + call void @f() noduplicate + br label %inner_loop_begin +; Ensure the inner loop got unswitched into the outer loop. +; CHECK: outer_loop_begin: +; CHECK-NEXT: %{{.*}} = phi i32 +; CHECK-NEXT: call void @f() +; CHECK-NEXT: br i1 %cond, + +inner_loop_begin: + %j = phi i32 [ %j.next, %inner_loop_latch ], [ 0, %outer_loop_begin ] + br i1 %cond, label %inner_loop_latch, label %inner_loop_early_exit + +inner_loop_latch: + %j.next = add nsw i32 %j, 1 + %j.cmp = icmp slt i32 %j.next, %m + br i1 %j.cmp, label %inner_loop_begin, label %inner_loop_late_exit + +inner_loop_early_exit: + %j.lcssa = phi i32 [ %i, %inner_loop_begin ] + br label %outer_loop_latch + +inner_loop_late_exit: + br label %outer_loop_latch + +outer_loop_latch: + %i.phi = phi i32 [ %j.lcssa, %inner_loop_early_exit ], [ %i, %inner_loop_late_exit ] + %i.next = add nsw i32 %i.phi, 1 + %i.cmp = icmp slt i32 %i.next, %n + br i1 %i.cmp, label %outer_loop_begin, label %exit + +exit: + ret void +} + +; Check that trivially unswitching an inner loop resets both the inner and outer +; loop trip count. +define void @test2(i32 %n, i32 %m, i32 %cond) { +; Check that SCEV has no trip count before unswitching. +; SCEV-LABEL: Determining loop execution counts for: @test2 +; SCEV: Loop %inner_loop_begin: Unpredictable backedge-taken count. +; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count. +; +; Now check that after unswitching and simplifying instructions we get clean +; backedge-taken counts. +; SCEV-LABEL: Determining loop execution counts for: @test2 +; SCEV: Loop %inner_loop_begin: backedge-taken count is (-1 + (1 smax %m)) +; FIXME: The following backedge taken count should be known but isn't apparently +; just because of a switch in the outer loop. +; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count. +; +; CHECK-LABEL: define void @test2( +entry: + br label %outer_loop_begin +; Ensure the outer loop didn't get unswitched. +; CHECK: entry: +; CHECK-NEXT: br label %outer_loop_begin + +outer_loop_begin: + %i = phi i32 [ %i.next, %outer_loop_latch ], [ 0, %entry ] + ; Block unswitching of the outer loop with a noduplicate call. + call void @f() noduplicate + br label %inner_loop_begin +; Ensure the inner loop got unswitched into the outer loop. +; CHECK: outer_loop_begin: +; CHECK-NEXT: %{{.*}} = phi i32 +; CHECK-NEXT: call void @f() +; CHECK-NEXT: switch i32 %cond, + +inner_loop_begin: + %j = phi i32 [ %j.next, %inner_loop_latch ], [ 0, %outer_loop_begin ] + switch i32 %cond, label %inner_loop_early_exit [ + i32 1, label %inner_loop_latch + i32 2, label %inner_loop_latch + ] + +inner_loop_latch: + %j.next = add nsw i32 %j, 1 + %j.cmp = icmp slt i32 %j.next, %m + br i1 %j.cmp, label %inner_loop_begin, label %inner_loop_late_exit + +inner_loop_early_exit: + %j.lcssa = phi i32 [ %i, %inner_loop_begin ] + br label %outer_loop_latch + +inner_loop_late_exit: + br label %outer_loop_latch + +outer_loop_latch: + %i.phi = phi i32 [ %j.lcssa, %inner_loop_early_exit ], [ %i, %inner_loop_late_exit ] + %i.next = add nsw i32 %i.phi, 1 + %i.cmp = icmp slt i32 %i.next, %n + br i1 %i.cmp, label %outer_loop_begin, label %exit + +exit: + ret void +} + +; Check that non-trivial unswitching of a branch in an inner loop into the outer +; loop invalidates both inner and outer. +define void @test3(i32 %n, i32 %m, i1 %cond) { +; Check that SCEV has no trip count before unswitching. +; SCEV-LABEL: Determining loop execution counts for: @test3 +; SCEV: Loop %inner_loop_begin: Unpredictable backedge-taken count. +; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count. +; +; Now check that after unswitching and simplifying instructions we get clean +; backedge-taken counts. +; SCEV-LABEL: Determining loop execution counts for: @test3 +; SCEV: Loop %inner_loop_begin{{.*}}: backedge-taken count is (-1 + (1 smax %m)) +; SCEV: Loop %outer_loop_begin: backedge-taken count is (-1 + (1 smax %n)) +; +; And verify the code matches what we expect. +; CHECK-LABEL: define void @test3( +entry: + br label %outer_loop_begin +; Ensure the outer loop didn't get unswitched. +; CHECK: entry: +; CHECK-NEXT: br label %outer_loop_begin + +outer_loop_begin: + %i = phi i32 [ %i.next, %outer_loop_latch ], [ 0, %entry ] + ; Block unswitching of the outer loop with a noduplicate call. + call void @f() noduplicate + br label %inner_loop_begin +; Ensure the inner loop got unswitched into the outer loop. +; CHECK: outer_loop_begin: +; CHECK-NEXT: %{{.*}} = phi i32 +; CHECK-NEXT: call void @f() +; CHECK-NEXT: br i1 %cond, + +inner_loop_begin: + %j = phi i32 [ %j.next, %inner_loop_latch ], [ 0, %outer_loop_begin ] + %j.tmp = add nsw i32 %j, 1 + br i1 %cond, label %inner_loop_latch, label %inner_loop_early_exit + +inner_loop_latch: + %j.next = add nsw i32 %j, 1 + %j.cmp = icmp slt i32 %j.next, %m + br i1 %j.cmp, label %inner_loop_begin, label %inner_loop_late_exit + +inner_loop_early_exit: + %j.lcssa = phi i32 [ %j.tmp, %inner_loop_begin ] + br label %outer_loop_latch + +inner_loop_late_exit: + br label %outer_loop_latch + +outer_loop_latch: + %inc.phi = phi i32 [ %j.lcssa, %inner_loop_early_exit ], [ 1, %inner_loop_late_exit ] + %i.next = add nsw i32 %i, %inc.phi + %i.cmp = icmp slt i32 %i.next, %n + br i1 %i.cmp, label %outer_loop_begin, label %exit + +exit: + ret void +}