Index: llvm/include/llvm/Transforms/Scalar/SimpleLoopUnswitch.h =================================================================== --- llvm/include/llvm/Transforms/Scalar/SimpleLoopUnswitch.h +++ llvm/include/llvm/Transforms/Scalar/SimpleLoopUnswitch.h @@ -71,6 +71,19 @@ LoopStandardAnalysisResults &AR, LPMUpdater &U); }; +class SimpleLoopNestUnswitchPass + : public PassInfoMixin { + bool NonTrivial; + bool Trivial; + +public: + SimpleLoopNestUnswitchPass(bool NonTrivial = false, bool Trivial = true) + : NonTrivial(NonTrivial), Trivial(Trivial) {} + + PreservedAnalyses run(LoopNest &LN, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, LPMUpdater &U); +}; + /// Create the legacy pass object for the simple loop unswitcher. /// /// See the documentaion for `SimpleLoopUnswitchPass` for details. Index: llvm/lib/Passes/PassRegistry.def =================================================================== --- llvm/lib/Passes/PassRegistry.def +++ llvm/lib/Passes/PassRegistry.def @@ -488,4 +488,11 @@ }, parseLoopUnswitchOptions, "nontrivial;no-nontrivial;trivial;no-trivial") +LOOP_PASS_WITH_PARAMS("simple-loop-nest-unswitch", + "SimpleLoopNestUnswitchPass", + [](std::pair Params) { + return SimpleLoopNestUnswitchPass(Params.first, Params.second); + }, + parseLoopUnswitchOptions, + "nontrivial;no-nontrivial;trivial;no-trivial") #undef LOOP_PASS_WITH_PARAMS Index: llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -2666,11 +2666,13 @@ return CostMultiplier; } -static bool unswitchBestCondition( +static Instruction *findBestUnswitchTI( Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - AAResults &AA, TargetTransformInfo &TTI, - function_ref)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { + AAResults &AA, TargetTransformInfo &TTI, ScalarEvolution *SE, + MemorySSAUpdater *MSSAU, IVConditionInfo &PartialIVInfo, + SmallVector &ExitBlocks, + TinyPtrVector &BestUnswitchInvariants, bool LoopNestMode = false) { + // Collect all invariant conditions within this loop (as opposed to an inner // loop which would be handled when visiting that inner loop). SmallVector>, 4> @@ -2685,9 +2687,8 @@ CollectGuards = true; } - IVConditionInfo PartialIVInfo; for (auto *BB : L.blocks()) { - if (LI.getLoopFor(BB) != &L) + if (!LoopNestMode && LI.getLoopFor(BB) != &L) continue; if (CollectGuards) @@ -2759,7 +2760,7 @@ // If we didn't find any candidates, we're done. if (UnswitchCandidates.empty()) - return false; + return nullptr; // Check if there are irreducible CFG cycles in this loop. If so, we cannot // easily unswitch non-trivial edges out of the loop. Doing so might turn the @@ -2770,9 +2771,9 @@ LoopBlocksRPO RPOT(&L); RPOT.perform(&LI); if (containsIrreducibleCFG(RPOT, LI)) - return false; + return nullptr; - SmallVector ExitBlocks; + SmallVector ExitBlocks_; L.getUniqueExitBlocks(ExitBlocks); // We cannot unswitch if exit blocks contain a cleanuppad/catchswitch @@ -2784,7 +2785,7 @@ if (isa(I) || isa(I)) { LLVM_DEBUG(dbgs() << "Cannot unswitch because of cleanuppad/catchswitch " "in exit block\n"); - return false; + return nullptr; } } @@ -2818,10 +2819,10 @@ continue; if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB)) - return false; + return nullptr; if (auto *CB = dyn_cast(&I)) if (CB->isConvergent() || CB->cannotDuplicate()) - return false; + return nullptr; Cost += TTI.getUserCost(&I, CostKind); } @@ -2906,10 +2907,9 @@ }; Instruction *BestUnswitchTI = nullptr; InstructionCost BestUnswitchCost = 0; - ArrayRef BestUnswitchInvariants; for (auto &TerminatorAndInvariants : UnswitchCandidates) { Instruction &TI = *TerminatorAndInvariants.first; - ArrayRef Invariants = TerminatorAndInvariants.second; + TinyPtrVector Invariants = TerminatorAndInvariants.second; BranchInst *BI = dyn_cast(&TI); InstructionCost CandidateCost = ComputeUnswitchedCost( TI, /*FullUnswitch*/ !BI || (Invariants.size() == 1 && @@ -2942,7 +2942,7 @@ if (BestUnswitchCost >= UnswitchThreshold) { LLVM_DEBUG(dbgs() << "Cannot unswitch, lowest cost found: " << BestUnswitchCost << "\n"); - return false; + return nullptr; } if (BestUnswitchTI != PartialIVCondBranch) @@ -2956,9 +2956,38 @@ LLVM_DEBUG(dbgs() << " Unswitching non-trivial (cost = " << BestUnswitchCost << ") terminator: " << *BestUnswitchTI << "\n"); + + return BestUnswitchTI; +} + +static bool unswitchBestCondition( + Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, + AAResults &AA, TargetTransformInfo &TTI, + function_ref)> UnswitchCB, + ScalarEvolution *SE, MemorySSAUpdater *MSSAU, + Instruction *BestUnswitchTIForLoopNest = nullptr) { + IVConditionInfo PartialIVInfo; + SmallVector ExitBlocks; + TinyPtrVector BestUnswitchInvariants; + + Instruction *BestUnswitchTI = + findBestUnswitchTI(L, DT, LI, AC, AA, TTI, SE, MSSAU, PartialIVInfo, + ExitBlocks, BestUnswitchInvariants); + + if (BestUnswitchTI == nullptr) + return false; + + // If BestUnswitchTIForLoopNest is specified and it is the same instruction as + // BestUnswitchTI, we can continue unswitching. Otherwise, we stop + // unswitching. + if (BestUnswitchTIForLoopNest != nullptr && + BestUnswitchTI != BestUnswitchTIForLoopNest) + return false; + unswitchNontrivialInvariants(L, *BestUnswitchTI, BestUnswitchInvariants, ExitBlocks, PartialIVInfo, DT, LI, AC, UnswitchCB, SE, MSSAU); + return true; } @@ -2988,7 +3017,8 @@ AAResults &AA, TargetTransformInfo &TTI, bool Trivial, bool NonTrivial, function_ref)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { + ScalarEvolution *SE, MemorySSAUpdater *MSSAU, + Instruction *BestUnswitchTIForLoopNest = nullptr) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); @@ -3036,7 +3066,8 @@ // Try to unswitch the best invariant condition. We prefer this full unswitch to // a partial unswitch when possible below the threshold. - if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, UnswitchCB, SE, MSSAU)) + if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, UnswitchCB, SE, MSSAU, + BestUnswitchTIForLoopNest)) return true; // No other opportunities to unswitch. @@ -3107,6 +3138,97 @@ return PA; } +PreservedAnalyses +SimpleLoopNestUnswitchPass::run(LoopNest &LN, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + bool DidSomething = false; + ArrayRef Loops = LN.getLoops(); + Loop *OutermostLoop = &LN.getOutermostLoop(); + + SmallPriorityWorklist Worklist; + appendLoopsToWorklist(Loops, Worklist); + + IVConditionInfo PartialIVInfo; + SmallVector ExitBlocks; + TinyPtrVector BestUnswitchInvariants; + + Optional MSSAU; + if (AR.MSSA) { + MSSAU = MemorySSAUpdater(AR.MSSA); + if (VerifyMemorySSA) + AR.MSSA->verifyMemorySSA(); + } + + Instruction *BestUnswitchTIForLoopNest = findBestUnswitchTI( + *OutermostLoop, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, &AR.SE, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, PartialIVInfo, + ExitBlocks, BestUnswitchInvariants, true); + + if (BestUnswitchTIForLoopNest == nullptr) + return PreservedAnalyses::all(); + + while (!Worklist.empty()) { + Loop *L = Worklist.pop_back_val(); + + Function &F = *L->getHeader()->getParent(); + (void)F; + + LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << *L + << "\n"); + + // Save the current loop name in a variable so that we can report it even + // after it has been deleted. + std::string LoopName = std::string(L->getName()); + + auto UnswitchCB = [&L, &U, &LoopName, &OutermostLoop]( + bool CurrentLoopValid, bool PartiallyInvariant, + ArrayRef NewLoops) { + // If we did a non-trivial unswitch, we have added new (cloned) loops. + if (OutermostLoop == L && !NewLoops.empty()) + U.addSiblingLoops(NewLoops); + + // If the current loop remains valid, we should revisit it to catch any + // other unswitch opportunities. Otherwise, we need to mark it as deleted. + if (CurrentLoopValid) { + if (PartiallyInvariant) { + // Mark the new loop as partially unswitched, to avoid unswitching on + // the same condition again. + auto &Context = L->getHeader()->getContext(); + MDNode *DisableUnswitchMD = MDNode::get( + Context, + MDString::get(Context, "llvm.loop.unswitch.partial.disable")); + MDNode *NewLoopID = makePostTransformationMetadata( + Context, L->getLoopID(), {"llvm.loop.unswitch.partial"}, + {DisableUnswitchMD}); + L->setLoopID(NewLoopID); + } else + U.revisitCurrentLoop(); + } else + U.markLoopAsDeleted(*L, LoopName); + }; + + DidSomething |= unswitchLoop( + *L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial, UnswitchCB, + &AR.SE, MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, + BestUnswitchTIForLoopNest); + + assert(DidSomething && "L must be unswitched"); + + if (AR.MSSA && VerifyMemorySSA) + AR.MSSA->verifyMemorySSA(); + + // Historically this pass has had issues with the dominator tree so verify + // it in asserts builds. + assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); + } + + auto PA = getLoopPassPreservedAnalyses(); + if (AR.MSSA) + PA.preserve(); + return PA; +} + namespace { class SimpleLoopUnswitchLegacyPass : public LoopPass { Index: llvm/test/Transforms/SimpleLoopUnswitch/loop-nest-unswitch.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimpleLoopUnswitch/loop-nest-unswitch.ll @@ -0,0 +1,174 @@ +; RUN: opt -passes='simple-loop-nest-unswitch' -S < %s | FileCheck %s --check-prefixes CHECK,SLNU +; RUN: opt -passes='simple-loop-unswitch' -S < %s | FileCheck %s --check-prefixes CHECK,SLU + +; This test represents the following code: +; +; int test(bool flag, int n) { +; int sum = 0; +; for (int i = 0; i < n; i++) { +; for (int k = 0; k < n; k++) { +; sum++; +; if (flag) +; sum++; +; } +; } +; return sum; +; } +; +; Both SimpleLoopUnswitch and SimpleLoopNestUnswitch do unswitching. + +define dso_local i32 @test(i1 zeroext %flag, i32 %n) { +; CHECK-LABEL: test +entry: + %frombool = zext i1 %flag to i8 + %tobool = trunc i8 %frombool to i1 + br label %for.cond + +; Check if SimpleLoopNestUnswitch creates new blocks for unswitching +; CHECK: entry.split.us +; CHECK: for.cond.us +; CHECK: for.body.us +; CHECK: for.inc6.us +; CHECK: for.end.us +; CHECK: for.body.split.us.us +; CHECK: for.cond1.us.us +; CHECK: for.body3.us.us +; CHECK: if.then.us.us +; CHECK: if.end.us.us +; CHECK: for.inc.us.us +; CHECK: for.end.split.us.us +; CHECK: for.end8.split.us + +for.cond: ; preds = %for.inc6, %entry + %sum.0 = phi i32 [ 0, %entry ], [ %sum.1.lcssa, %for.inc6 ] + %i.0 = phi i32 [ 0, %entry ], [ %inc7, %for.inc6 ] + %cmp = icmp slt i32 %i.0, %n + br i1 %cmp, label %for.body, label %for.end8 + +for.body: ; preds = %for.cond + br label %for.cond1 + +for.cond1: ; preds = %for.inc, %for.body + %sum.1 = phi i32 [ %sum.0, %for.body ], [ %sum.2, %for.inc ] + %k.0 = phi i32 [ 0, %for.body ], [ %inc5, %for.inc ] + %cmp2 = icmp slt i32 %k.0, %n + br i1 %cmp2, label %for.body3, label %for.end + +for.body3: ; preds = %for.cond1 + %inc = add nsw i32 %sum.1, 1 + br i1 %tobool, label %if.then, label %if.end + +if.then: ; preds = %for.body3 + %inc4 = add nsw i32 %inc, 1 + br label %if.end + +if.end: ; preds = %if.then, %for.body3 + %sum.2 = phi i32 [ %inc4, %if.then ], [ %inc, %for.body3 ] + br label %for.inc + +for.inc: ; preds = %if.end + %inc5 = add nsw i32 %k.0, 1 + br label %for.cond1, !llvm.loop !0 + +for.end: ; preds = %for.cond1 + %sum.1.lcssa = phi i32 [ %sum.1, %for.cond1 ] + br label %for.inc6 + +for.inc6: ; preds = %for.end + %inc7 = add nsw i32 %i.0, 1 + br label %for.cond, !llvm.loop !2 + +for.end8: ; preds = %for.cond + %sum.0.lcssa = phi i32 [ %sum.0, %for.cond ] + ret i32 %sum.0.lcssa +} + +; This test represents the following code: +; +; int test2(int n) { +; int sum = 0; +; for (int i = 0; i < n; i++) { +; for (int k = 0; k < n; k++) { +; sum++; +; if (i == 10) +; sum++; +; } +; } +; return sum; +; } +; +; We don't want to unswitch the branch `if (i == 10)` out of for-k because +; it prevents the loop nest from being a perfect loop nest. This test shows +; that SimpleLoopUnswitch unswitches the branch, but SimpleLoopNestUnswitch doesn't. + +define dso_local i32 @test2(i32 %n) { +; CHECK-LABEL: test2 +entry: + br label %for.cond + +for.cond: ; preds = %for.inc7, %entry + %i.0 = phi i32 [ 0, %entry ], [ %inc8, %for.inc7 ] + %sum.0 = phi i32 [ 0, %entry ], [ %sum.1.lcssa, %for.inc7 ] + %cmp = icmp slt i32 %i.0, %n + br i1 %cmp, label %for.body, label %for.end9 + +for.body: ; preds = %for.cond + %cmp4 = icmp eq i32 %i.0, 10 + br label %for.cond1 + +; SLU: for.body.split.us: +; SLU: for.cond1.us: +; SLU: for.body3.us: +; SLU: if.then.us: +; SLU: if.end.us: +; SLU: for.inc.us: +; SLU: for.end.split.us: +; SLU: for.body.split: +; SLNU-NOT: for.body.split.us: +; SLNU-NOT: for.cond1.us: +; SLNU-NOT: for.body3.us: +; SLNU-NOT: if.then.us: +; SLNU-NOT: if.end.us: +; SLNU-NOT: for.inc.us: +; SLNU-NOT: for.end.split.us: +; SLNU-NOT: for.body.split: + +for.cond1: ; preds = %for.inc, %for.body + %sum.1 = phi i32 [ %sum.0, %for.body ], [ %sum.2, %for.inc ] + %k.0 = phi i32 [ 0, %for.body ], [ %inc6, %for.inc ] + %cmp2 = icmp slt i32 %k.0, %n + br i1 %cmp2, label %for.body3, label %for.end + +for.body3: ; preds = %for.cond1 + %inc = add nsw i32 %sum.1, 1 + br i1 %cmp4, label %if.then, label %if.end + +if.then: ; preds = %for.body3 + %inc5 = add nsw i32 %inc, 1 + br label %if.end + +if.end: ; preds = %if.then, %for.body3 + %sum.2 = phi i32 [ %inc5, %if.then ], [ %inc, %for.body3 ] + br label %for.inc + +for.inc: ; preds = %if.end + %inc6 = add nsw i32 %k.0, 1 + br label %for.cond1, !llvm.loop !0 + +for.end: ; preds = %for.cond1 + %sum.1.lcssa = phi i32 [ %sum.1, %for.cond1 ] + br label %for.inc7 + +for.inc7: ; preds = %for.end + %inc8 = add nsw i32 %i.0, 1 + br label %for.cond, !llvm.loop !2 + +for.end9: ; preds = %for.cond + %sum.0.lcssa = phi i32 [ %sum.0, %for.cond ] + ret i32 %sum.0.lcssa +} + +!0 = distinct !{!0, !1} +!1 = !{!"llvm.loop.mustprogress"} +!2 = distinct !{!2, !1} +