diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp --- a/llvm/lib/Transforms/Utils/LoopPeel.cpp +++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp @@ -72,12 +72,44 @@ "unroll-force-peel-count", cl::init(0), cl::Hidden, cl::desc("Force a peel count regardless of profiling information.")); +static cl::opt DisableAdvancedPeeling( + "disable-advanced-peeling", cl::init(false), cl::Hidden, + cl::desc( + "Disable advance peeling. Issues for convergent targets (D134803).")); + static const char *PeeledCountMetaData = "llvm.loop.peeled.count"; // Check whether we are capable of peeling this loop. bool llvm::canPeel(Loop *L) { // Make sure the loop is in simplified form - return L->isLoopSimplifyForm(); + if (!L->isLoopSimplifyForm()) + return false; + if (!DisableAdvancedPeeling) + return true; + + // Don't try to peel loops where the latch is not the exiting block. + // This can be an indication of two different things: + // 1) The loop is not rotated. + // 2) The loop contains irreducible control flow that involves the latch. + const BasicBlock *Latch = L->getLoopLatch(); + if (!L->isLoopExiting(Latch)) + return false; + + // Peeling is only supported if the latch is a branch. + if (!isa(Latch->getTerminator())) + return false; + + SmallVector Exits; + L->getUniqueNonLatchExitBlocks(Exits); + // The latch must either be the only exiting block or all non-latch exit + // blocks have either a deopt or unreachable terminator or compose a chain of + // blocks where the last one is either deopt or unreachable terminated. Both + // deopt and unreachable terminators are a strong indication they are not + // taken. Note that this is a profitability check, not a legality check. Also + // note that LoopPeeling currently can only update the branch weights of latch + // blocks and branch weights to blocks with deopt or unreachable do not need + // updating. + return llvm::all_of(Exits, IsBlockFollowedByDeoptOrUnreachable); } // This function calculates the number of iterations after which the given Phi @@ -493,6 +525,24 @@ : 1; } +static void updateBranchWeightsLimited(BasicBlock *Header, BranchInst *LatchBR, + uint64_t ExitWeight, + uint64_t &FallThroughWeight) { + // FallThroughWeight is 0 means that there is no branch weights on original + // latch block or estimated trip count is zero. + if (!FallThroughWeight) + return; + + unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); + MDBuilder MDB(LatchBR->getContext()); + MDNode *WeightNode = + HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight) + : MDB.createBranchWeights(FallThroughWeight, ExitWeight); + LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); + FallThroughWeight = + FallThroughWeight > ExitWeight ? FallThroughWeight - ExitWeight : 1; +} + /// Initialize the weights for all exiting blocks. static void initBranchWeights(DenseMap &WeightInfos, Loop *L) { @@ -537,6 +587,17 @@ } } +static void initBranchWeightsLimited(BasicBlock *Header, BranchInst *LatchBR, + uint64_t &ExitWeight, + uint64_t &FallThroughWeight) { + uint64_t TrueWeight, FalseWeight; + if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight)) + return; + unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1; + ExitWeight = HeaderIdx ? TrueWeight : FalseWeight; + FallThroughWeight = HeaderIdx ? FalseWeight : TrueWeight; +} + /// Update the weights of original exiting block after peeling off all /// iterations. static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) { @@ -545,6 +606,23 @@ MDB.createBranchWeights(Info.Weights)); } +static void fixupBranchWeightsLimited(BasicBlock *Header, BranchInst *LatchBR, + uint64_t ExitWeight, + uint64_t FallThroughWeight) { + // FallThroughWeight is 0 means that there is no branch weights on original + // latch block or estimated trip count is zero. + if (!FallThroughWeight) + return; + + // Sets the branch weights on the loop exit. + MDBuilder MDB(LatchBR->getContext()); + unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1; + MDNode *WeightNode = + HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight) + : MDB.createBranchWeights(FallThroughWeight, ExitWeight); + LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); +} + /// Clones the body of the loop L, putting it between \p InsertTop and \p /// InsertBot. /// \param IterNumber The serial number of the iteration currently being @@ -821,7 +899,12 @@ // If we have branch weight information, we'll want to update it for the // newly created branches. DenseMap Weights; - initBranchWeights(Weights, L); + uint64_t ExitWeight = 0, FallThroughWeight = 0; + if (DisableAdvancedPeeling) + initBranchWeightsLimited(Header, cast(LatchTerm), ExitWeight, + FallThroughWeight); + else + initBranchWeights(Weights, L); // Identify what noalias metadata is inside the loop: if it is inside the // loop, the associated metadata must be cloned for each iteration. @@ -850,15 +933,25 @@ assert(DT.verify(DominatorTree::VerificationLevel::Fast)); #endif - for (auto &[Term, Info] : Weights) { - auto *TermCopy = cast(VMap[Term]); - updateBranchWeights(TermCopy, Info); - } + if (DisableAdvancedPeeling) { + auto *LatchBRCopy = cast(VMap[cast(LatchTerm)]); + updateBranchWeightsLimited(InsertBot, LatchBRCopy, ExitWeight, + FallThroughWeight); - // Remove Loop metadata from the latch branch instruction - // because it is not the Loop's latch branch anymore. - auto *LatchTermCopy = cast(VMap[LatchTerm]); - LatchTermCopy->setMetadata(LLVMContext::MD_loop, nullptr); + // Remove Loop metadata from the latch branch instruction + // because it is not the Loop's latch branch anymore. + LatchBRCopy->setMetadata(LLVMContext::MD_loop, nullptr); + } else { + for (auto &[Term, Info] : Weights) { + auto *TermCopy = cast(VMap[Term]); + updateBranchWeights(TermCopy, Info); + } + + // Remove Loop metadata from the latch branch instruction + // because it is not the Loop's latch branch anymore. + auto *LatchTermCopy = cast(VMap[LatchTerm]); + LatchTermCopy->setMetadata(LLVMContext::MD_loop, nullptr); + } InsertTop = InsertBot; InsertBot = SplitBlock(InsertBot, InsertBot->getTerminator(), &DT, LI); @@ -881,8 +974,13 @@ PHI->setIncomingValueForBlock(NewPreHeader, NewVal); } - for (const auto &[Term, Info] : Weights) - fixupBranchWeights(Term, Info); + if (DisableAdvancedPeeling) { + fixupBranchWeightsLimited(Header, cast(LatchTerm), ExitWeight, + FallThroughWeight); + } else { + for (const auto &[Term, Info] : Weights) + fixupBranchWeights(Term, Info); + } // Update Metadata for count of peeled off iterations. unsigned AlreadyPeeled = 0; diff --git a/llvm/test/Transforms/LoopUnroll/peel-branch-weights.ll b/llvm/test/Transforms/LoopUnroll/peel-branch-weights.ll --- a/llvm/test/Transforms/LoopUnroll/peel-branch-weights.ll +++ b/llvm/test/Transforms/LoopUnroll/peel-branch-weights.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals ; RUN: opt < %s -S -loop-unroll -unroll-force-peel-count=2 2>&1 | FileCheck %s +; RUN: opt < %s -S -loop-unroll -unroll-force-peel-count=2 -disable-advanced-peeling 2>&1 | FileCheck %s --check-prefix=DISABLEADV declare i32 @get.x() @@ -50,7 +51,22 @@ ; CHECK-NEXT: br label [[LOOP_EXIT]] ; CHECK: loop.exit: ; CHECK-NEXT: ret void -; + +; DISABLEADV-LABEL: @test() +; DISABLEADV-NEXT: entry: +; DISABLEADV-NEXT: br label %loop +; DISABLEADV: loop +; DISABLEADV-NEXT: %x = call i32 @get.x() +; DISABLEADV-NEXT: switch i32 %x, label %loop.latch [ +; DISABLEADV-NEXT: i32 0, label %loop.latch +; DISABLEADV-NEXT: i32 1, label %loop.exit +; DISABLEADV-NEXT: i32 2, label %loop.exit +; DISABLEADV-NEXT: ], !prof !0 +; DISABLEADV: loop.latch: +; DISABLEADV-NEXT: br label %loop +; DISABLEADV: loop.exit: +; DISABLEADV-NEXT: ret void + entry: br label %loop