diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -1395,7 +1395,7 @@ // Add the nested pass manager with the appropriate adaptor. bool UseMemorySSA = (Name == "loop-mssa"); bool UseBFI = llvm::any_of( - InnerPipeline, [](auto Pipeline) { return Pipeline.Name == "licm"; }); + InnerPipeline, [](auto Pipeline) { return Pipeline.Name.contains("licm") || Pipeline.Name.contains("simple-loop-unswitch"); }); bool UseBPI = llvm::any_of(InnerPipeline, [](auto Pipeline) { return Pipeline.Name == "loop-predication"; }); diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/GuardUtils.h" @@ -26,6 +27,7 @@ #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -3044,6 +3046,8 @@ bool NonTrivial, function_ref)> UnswitchCB, ScalarEvolution *SE, MemorySSAUpdater *MSSAU, + ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI, function_ref DestroyLoopCB) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); @@ -3080,6 +3084,13 @@ if (L.getHeader()->getParent()->hasOptSize()) return false; + // Skip cold loops, as unswitching them brings little benefit + // but increases the code size + if (PSI && PSI->hasProfileSummary() && BFI && PSI->isColdBlock(L.getHeader(), BFI)) { + LLVM_DEBUG(dbgs() << " Skip cold loop: " << L << "\n"); + return false; + } + // Skip non-trivial unswitching for loops that cannot be cloned. if (!L.isSafeToClone()) return false; @@ -3105,7 +3116,10 @@ LPMUpdater &U) { Function &F = *L.getHeader()->getParent(); (void)F; - + ProfileSummaryInfo *PSI = + AM.getResult(L, AR) + .getCachedResult(F) + ->getCachedResult(*F.getParent()); LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L << "\n"); @@ -3152,6 +3166,7 @@ } if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial, UnswitchCB, &AR.SE, MSSAU ? MSSAU.getPointer() : nullptr, + PSI, AR.BFI, DestroyLoopCB)) return PreservedAnalyses::all(); @@ -3214,12 +3229,13 @@ LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << *L << "\n"); - auto &DT = getAnalysis().getDomTree(); auto &LI = getAnalysis().getLoopInfo(); auto &AC = getAnalysis().getAssumptionCache(F); auto &AA = getAnalysis().getAAResults(); auto &TTI = getAnalysis().getTTI(F); + auto BFI = nullptr; + auto PSI = nullptr; MemorySSA *MSSA = &getAnalysis().getMSSA(); MemorySSAUpdater MSSAU(MSSA); @@ -3251,9 +3267,8 @@ if (VerifyMemorySSA) MSSA->verifyMemorySSA(); - bool Changed = unswitchLoop(*L, DT, LI, AC, AA, TTI, true, NonTrivial, - UnswitchCB, SE, &MSSAU, DestroyLoopCB); + UnswitchCB, SE, &MSSAU, PSI, BFI, DestroyLoopCB); if (VerifyMemorySSA) MSSA->verifyMemorySSA(); @@ -3269,10 +3284,12 @@ INITIALIZE_PASS_BEGIN(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch", "Simple unswitch loops", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch", "Simple unswitch loops", false, false) diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/PGO-nontrivial-unswitch.ll b/llvm/test/Transforms/SimpleLoopUnswitch/PGO-nontrivial-unswitch.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/SimpleLoopUnswitch/PGO-nontrivial-unswitch.ll @@ -0,0 +1,75 @@ + +; RUN: opt < %s -passes='require,function(loop-mssa(simple-loop-unswitch)),print' \ +; RUN: --disable-output 2>&1 | sort -b -k 1 | FileCheck %s + +declare i32 @a() +declare i32 @b() + +; CHECK: Loop at depth 1 containing: %cold_loop_begin +; CHECK: Loop at depth 1 containing: %hot_loop_begin.us +; CHECK: Loop at depth 1 containing: %hot_loop_begin +define void @f1(i32 %i, i1 %cond, i1 %hot_cond, i1 %cold_cond, i1* %ptr) !prof !0 { +entry: + br label %entry_hot_loop + +entry_hot_loop: + br i1 %hot_cond, label %hot_loop_begin, label %hot_loop_exit, !prof !15 + +hot_loop_begin: + br i1 %cond, label %hot_loop_a, label %hot_loop_b + +hot_loop_a: + call i32 @a() + br label %hot_loop_latch + +hot_loop_b: + call i32 @b() + br label %hot_loop_latch + +hot_loop_latch: + %v1 = load i1, i1* %ptr + br i1 %v1, label %hot_loop_begin, label %hot_loop_exit + +hot_loop_exit: + br label %entry_cold_loop + +entry_cold_loop: + br i1 %cold_cond, label %cold_loop_begin, label %cold_loop_exit, !prof !16 + +cold_loop_begin: + br i1 %cond, label %cold_loop_a, label %cold_loop_b + +cold_loop_a: + call i32 @a() + br label %cold_loop_latch + +cold_loop_b: + call i32 @b() + br label %cold_loop_latch + +cold_loop_latch: + %v2 = load i1, i1* %ptr + br i1 %v2, label %cold_loop_begin, label %cold_loop_exit + +cold_loop_exit: + ret void +} + +!llvm.module.flags = !{!1} +!0 = !{!"function_entry_count", i64 400} +!1 = !{i32 1, !"ProfileSummary", !2} +!2 = !{!3, !4, !5, !6, !7, !8, !9, !10} +!3 = !{!"ProfileFormat", !"InstrProf"} +!4 = !{!"TotalCount", i64 10000} +!5 = !{!"MaxCount", i64 10} +!6 = !{!"MaxInternalCount", i64 1} +!7 = !{!"MaxFunctionCount", i64 1000} +!8 = !{!"NumCounts", i64 3} +!9 = !{!"NumFunctions", i64 3} +!10 = !{!"DetailedSummary", !11} +!11 = !{!12, !13, !14} +!12 = !{i32 10000, i64 100, i32 1} +!13 = !{i32 999000, i64 100, i32 1} +!14 = !{i32 999999, i64 1, i32 2} +!15 = !{!"branch_weights", i32 100, i32 0} +!16 = !{!"branch_weights", i32 0, i32 100} \ No newline at end of file