Index: llvm/include/llvm/Transforms/Scalar/SimpleLoopUnswitch.h =================================================================== --- llvm/include/llvm/Transforms/Scalar/SimpleLoopUnswitch.h +++ llvm/include/llvm/Transforms/Scalar/SimpleLoopUnswitch.h @@ -71,6 +71,18 @@ LoopStandardAnalysisResults &AR, LPMUpdater &U); }; +class SimpleLoopNestUnswitchPass : public PassInfoMixin { + bool NonTrivial; + bool Trivial; + +public: + SimpleLoopNestUnswitchPass(bool NonTrivial = false, bool Trivial = true) + : NonTrivial(NonTrivial), Trivial(Trivial) {} + + PreservedAnalyses run(LoopNest &LN, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, LPMUpdater &U); +}; + /// Create the legacy pass object for the simple loop unswitcher. /// /// See the documentaion for `SimpleLoopUnswitchPass` for details. @@ -78,4 +90,5 @@ } // end namespace llvm + #endif // LLVM_TRANSFORMS_SCALAR_SIMPLELOOPUNSWITCH_H Index: llvm/lib/Passes/PassRegistry.def =================================================================== --- llvm/lib/Passes/PassRegistry.def +++ llvm/lib/Passes/PassRegistry.def @@ -452,4 +452,11 @@ }, parseLoopUnswitchOptions, "nontrivial;no-nontrivial;trivial;no-trivial") +LOOP_PASS_WITH_PARAMS("simple-loop-nest-unswitch", + "SimpleLoopNestUnswitchPass", + [](std::pair Params) { + return SimpleLoopNestUnswitchPass(Params.first, Params.second); + }, + parseLoopUnswitchOptions, + "nontrivial;no-nontrivial;trivial;no-trivial") #undef LOOP_PASS_WITH_PARAMS Index: llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -3107,6 +3107,86 @@ return PA; } +PreservedAnalyses SimpleLoopNestUnswitchPass::run(LoopNest &LN, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + bool DidSomething = false; + ArrayRef Loops = LN.getLoops(); + Loop *OutermostLoop = &LN.getOutermostLoop(); + + SmallPriorityWorklist Worklist; + appendLoopsToWorklist(Loops, Worklist); + + while(!Worklist.empty()) { + Loop *L = Worklist.pop_back_val(); + + Function &F = *L->getHeader()->getParent(); + (void)F; + + LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << *L + << "\n"); + + // Save the current loop name in a variable so that we can report it even + // after it has been deleted. + std::string LoopName = std::string(L->getName()); + + auto UnswitchCB = [&L, &U, &LoopName, &OutermostLoop](bool CurrentLoopValid, + bool PartiallyInvariant, + ArrayRef NewLoops) { + // If we did a non-trivial unswitch, we have added new (cloned) loops. + // + if (OutermostLoop == L && !NewLoops.empty()) + U.addSiblingLoops(NewLoops); + + // If the current loop remains valid, we should revisit it to catch any + // other unswitch opportunities. Otherwise, we need to mark it as deleted. + if (CurrentLoopValid) { + if (PartiallyInvariant) { + // Mark the new loop as partially unswitched, to avoid unswitching on + // the same condition again. + auto &Context = L->getHeader()->getContext(); + MDNode *DisableUnswitchMD = MDNode::get( + Context, + MDString::get(Context, "llvm.loop.unswitch.partial.disable")); + MDNode *NewLoopID = makePostTransformationMetadata( + Context, L->getLoopID(), {"llvm.loop.unswitch.partial"}, + {DisableUnswitchMD}); + L->setLoopID(NewLoopID); + } else + U.revisitCurrentLoop(); + } else + U.markLoopAsDeleted(*L, LoopName); + }; + + Optional MSSAU; + if (AR.MSSA) { + MSSAU = MemorySSAUpdater(AR.MSSA); + if (VerifyMemorySSA) + AR.MSSA->verifyMemorySSA(); + } + if (!unswitchLoop(*L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial, + UnswitchCB, &AR.SE, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr)) continue; + + DidSomething = true; + + if (AR.MSSA && VerifyMemorySSA) + AR.MSSA->verifyMemorySSA(); + + // Historically this pass has had issues with the dominator tree so verify it + // in asserts builds. + assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); + } + + if (!DidSomething) + return PreservedAnalyses::all(); + + auto PA = getLoopPassPreservedAnalyses(); + if (AR.MSSA) + PA.preserve(); + return PA; +} + namespace { class SimpleLoopUnswitchLegacyPass : public LoopPass { Index: llvm/test/Transforms/SimpleLoopUnswitch/loop-nest-unswitch.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimpleLoopUnswitch/loop-nest-unswitch.ll @@ -0,0 +1,84 @@ +; RUN: opt -passes='simple-loop-nest-unswitch' -S < %s | FileCheck %s + +; This test represents the following code: +; +; int test(bool flag, int n) { +; int sum = 0; +; for (int i = 0; i < n; i++) { +; for (int k = 0; k < n; k++) { +; sum++; +; if (flag) +; sum++; +; } +; } +; return sum; +; } + +define dso_local i32 @test(i1 zeroext %flag, i32 %n) { +entry: + %frombool = zext i1 %flag to i8 + %tobool = trunc i8 %frombool to i1 + br label %for.cond + +; Check if SimpleLoopNestUnswitch creates new blocks for unswitching +; CHECK-LABEL: entry.split.us +; CHECK-LABEL: for.cond.us +; CHECK-LABEL: for.body.us +; CHECK-LABEL: for.inc6.us +; CHECK-LABEL: for.end.us +; CHECK-LABEL: for.body.split.us.us +; CHECK-LABEL: for.cond1.us.us +; CHECK-LABEL: for.body3.us.us +; CHECK-LABEL: if.then.us.us +; CHECK-LABEL: if.end.us.us +; CHECK-LABEL: for.inc.us.us +; CHECK-LABEL: for.end.split.us.us +; CHECK-LABEL: for.end8.split.us + +for.cond: ; preds = %for.inc6, %entry + %sum.0 = phi i32 [ 0, %entry ], [ %sum.1.lcssa, %for.inc6 ] + %i.0 = phi i32 [ 0, %entry ], [ %inc7, %for.inc6 ] + %cmp = icmp slt i32 %i.0, %n + br i1 %cmp, label %for.body, label %for.end8 + +for.body: ; preds = %for.cond + br label %for.cond1 + +for.cond1: ; preds = %for.inc, %for.body + %sum.1 = phi i32 [ %sum.0, %for.body ], [ %sum.2, %for.inc ] + %k.0 = phi i32 [ 0, %for.body ], [ %inc5, %for.inc ] + %cmp2 = icmp slt i32 %k.0, %n + br i1 %cmp2, label %for.body3, label %for.end + +for.body3: ; preds = %for.cond1 + %inc = add nsw i32 %sum.1, 1 + br i1 %tobool, label %if.then, label %if.end + +if.then: ; preds = %for.body3 + %inc4 = add nsw i32 %inc, 1 + br label %if.end + +if.end: ; preds = %if.then, %for.body3 + %sum.2 = phi i32 [ %inc4, %if.then ], [ %inc, %for.body3 ] + br label %for.inc + +for.inc: ; preds = %if.end + %inc5 = add nsw i32 %k.0, 1 + br label %for.cond1, !llvm.loop !0 + +for.end: ; preds = %for.cond1 + %sum.1.lcssa = phi i32 [ %sum.1, %for.cond1 ] + br label %for.inc6 + +for.inc6: ; preds = %for.end + %inc7 = add nsw i32 %i.0, 1 + br label %for.cond, !llvm.loop !2 + +for.end8: ; preds = %for.cond + %sum.0.lcssa = phi i32 [ %sum.0, %for.cond ] + ret i32 %sum.0.lcssa +} + +!0 = distinct !{!0, !1} +!1 = !{!"llvm.loop.mustprogress"} +!2 = distinct !{!2, !1}