diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -27,6 +27,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" @@ -505,6 +506,19 @@ // know of kinds of multiexit loops that would benefit from unrolling. } +static void UpdateLatchBranchWeights(Loop *L, uint64_t BackEdgeWeight, + uint64_t ExitWeight) { + BasicBlock *Header = L->getHeader(); + BasicBlock *Latch = L->getLoopLatch(); + auto *LatchBR = cast(Latch->getTerminator()); + unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); + MDBuilder MDB(LatchBR->getContext()); + MDNode *WeightNode = + HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight) + : MDB.createBranchWeights(BackEdgeWeight, ExitWeight); + LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); +} + /// Insert code in the prolog/epilog code when unrolling a loop with a /// run-time trip-count. /// @@ -788,6 +802,18 @@ InsertTop, InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI); + // Assign the maximum possible trip count as the back edge weight for the + // remainder loop if the original loop comes with a branch weight. + if (remainderLoop && !UnrollRemainder) { + uint64_t TrueWeight, FalseWeight; + if (LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) { + uint64_t ExitWeight = + LatchBR->getSuccessor(0) == L->getHeader() ? FalseWeight : TrueWeight; + UpdateLatchBranchWeights(remainderLoop, (Count - 1) * ExitWeight, + ExitWeight); + } + } + // Insert the cloned blocks into the function. F->getBasicBlockList().splice(InsertBot->getIterator(), F->getBasicBlockList(), diff --git a/llvm/test/Transforms/LoopUnroll/runtime-loop-branchweight.ll b/llvm/test/Transforms/LoopUnroll/runtime-loop-branchweight.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopUnroll/runtime-loop-branchweight.ll @@ -0,0 +1,34 @@ +; RUN: opt < %s -S -loop-unroll -unroll-runtime=true -unroll-count=4 | FileCheck %s +; RUN: opt < %s -S -passes=loop-unroll -unroll-runtime=true -unroll-count=4 | FileCheck %s + +;; Check that the remainder loop is properly assigned a branch weight for its latch branch. +; CHECK-LABEL: @test( +; CHECK-LABEL: for.body: +; CHECK: br i1 [[COND1:%.*]], label %for.end.loopexit.unr-lcssa.loopexit, label %for.body, !prof ![[#PROF:]], !llvm.loop ![[#LOOP:]] +; CHECK-LABEL: for.body.epil: +; CHECK: br i1 [[COND2:%.*]], label %for.body.epil, label %for.end.loopexit.epilog-lcssa, !prof ![[#PROF2:]], !llvm.loop ![[#LOOP2:]] +; CHECK: ![[#PROF]] = !{!"branch_weights", i32 1, i32 9999} +; CHECK: ![[#PROF2]] = !{!"branch_weights", i32 3, i32 1} + +define i3 @test(i3* %a, i3 %n) { +entry: + %cmp1 = icmp eq i3 %n, 0 + br i1 %cmp1, label %for.end, label %for.body + +for.body: + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %entry ] + %sum.02 = phi i3 [ %add, %for.body ], [ 0, %entry ] + %arrayidx = getelementptr inbounds i3, i3* %a, i64 %indvars.iv + %0 = load i3, i3* %arrayidx + %add = add nsw i3 %0, %sum.02 + %indvars.iv.next = add i64 %indvars.iv, 1 + %lftr.wideiv = trunc i64 %indvars.iv.next to i3 + %exitcond = icmp eq i3 %lftr.wideiv, %n + br i1 %exitcond, label %for.end, label %for.body, !prof !0 + +for.end: + %sum.0.lcssa = phi i3 [ 0, %entry ], [ %add, %for.body ] + ret i3 %sum.0.lcssa +} + +!0 = !{!"branch_weights", i32 1, i32 9999}