Index: lib/Transforms/Scalar/LoopPredication.cpp =================================================================== --- lib/Transforms/Scalar/LoopPredication.cpp +++ lib/Transforms/Scalar/LoopPredication.cpp @@ -825,56 +825,114 @@ Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; } +static bool isValidProfileData(MDNode *ProfileData, const Instruction *Term) { + if (!ProfileData || !ProfileData->getOperand(0)) + return false; + if (MDString *MDS = dyn_cast(ProfileData->getOperand(0))) + if (!MDS->getString().equals("branch_weights")) + return false; + if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors()) + return false; + return true; +} + +static Optional +ComputeBranchProbability(BranchProbabilityInfo *BPI, + const BasicBlock *ExitingBlock, + const BasicBlock *ExitBlock) { + if (BPI) + return BPI->getEdgeProbability(ExitingBlock, ExitBlock); + + auto *Term = ExitingBlock->getTerminator(); + MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof); + unsigned succ_num = + std::distance(succ_begin(ExitingBlock), succ_end(ExitingBlock)); + if (!isValidProfileData(ProfileData, Term)) + return None; + uint64_t numerator = 0, denom = 0, prof_val = 0; + for (unsigned i = 0; i < succ_num; i++) { + ConstantInt *CI = + mdconst::extract(ProfileData->getOperand(i + 1)); + prof_val = CI->getValue().getZExtValue(); + if (Term->getSuccessor(i) == ExitBlock) + numerator += prof_val; + denom += prof_val; + } + return BranchProbability::getBranchProbability(numerator, denom); +} + bool LoopPredication::isLoopProfitableToPredicate() { - if (SkipProfitabilityChecks || !BPI) + if (SkipProfitabilityChecks) return true; + // Check to see that no other exit is taken more than a constant factor more + // often that the latch. We are going to use the latch limit as an + // approximation of the loops backedge taken count, so we need to make sure + // this is actually true. (i.e. We want to be confident in our prediction + // that the loop exits through the latch.) We ignore implicit exits under + // the assumption they're rare. Note that we still end up with only an + // *estimation* of the true backedge taken count. + // + // Heuristic: If any of the exiting blocks' probability of exiting the loop + // is larger than exiting through the latch block, it's not profitable to + // predicate the loop. If we can't tell how frequently an exit is taken, be + // conservative and don't perform loop predication. + // + // TODO: We may wish to be even more conservative here. + SmallVector, 8> ExitEdges; L->getExitEdges(ExitEdges); // If there is only one exiting edge in the loop, it is always profitable to - // predicate the loop. + // predicate the loop. Regardless of whether we have profiling information. if (ExitEdges.size() == 1) return true; - - // Calculate the exiting probabilities of all exiting edges from the loop, - // starting with the LatchExitProbability. - // Heuristic for profitability: If any of the exiting blocks' probability of - // exiting the loop is larger than exiting through the latch block, it's not - // profitable to predicate the loop. + auto *LatchBlock = L->getLoopLatch(); assert(LatchBlock && "Should have a single latch at this point!"); auto *LatchTerm = LatchBlock->getTerminator(); assert(LatchTerm->getNumSuccessors() == 2 && "expected to be an exiting block with 2 succs!"); unsigned LatchBrExitIdx = - LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0; - BranchProbability LatchExitProbability = - BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx); - + LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0; + + auto LatchExitProbability = + ComputeBranchProbability(BPI, LatchBlock, + LatchTerm->getSuccessor(LatchBrExitIdx)); + if (!LatchExitProbability) + return false; + // Protect against degenerate inputs provided by the user. Providing a value // less than one, can invert the definition of profitable loop predication. float ScaleFactor = LatchExitProbabilityScale; if (ScaleFactor < 1) { - LLVM_DEBUG( - dbgs() - << "Ignored user setting for loop-predication-latch-probability-scale: " - << LatchExitProbabilityScale << "\n"); + LLVM_DEBUG(dbgs() << "Ignored user setting for " + "loop-predication-latch-probability-scale: " + << LatchExitProbabilityScale << "\n"); LLVM_DEBUG(dbgs() << "The value is set to 1.0\n"); ScaleFactor = 1.0; } - const auto LatchProbabilityThreshold = - LatchExitProbability * ScaleFactor; - + const auto LatchProbabilityThreshold = *LatchExitProbability * ScaleFactor; + for (const auto &ExitEdge : ExitEdges) { - BranchProbability ExitingBlockProbability = - BPI->getEdgeProbability(ExitEdge.first, ExitEdge.second); + auto ExitingBlockProbability = + ComputeBranchProbability(BPI, ExitEdge.first, ExitEdge.second); + if (!ExitingBlockProbability) { + assert(ExitEdge.first != LatchBlock && + "Latch term should always have profile data!"); + unsigned succ_num = std::distance(succ_begin(ExitEdge.first), + succ_end(ExitEdge.first)); + // No profile data, so we choose the weight as 1/num_of_succ(Src) + // TODO: what about multiple edges between the same pair of blocks? + ExitingBlockProbability = + BranchProbability::getBranchProbability(1, succ_num); + } // Some exiting edge has higher probability than the latch exiting edge. // No longer profitable to predicate. - if (ExitingBlockProbability > LatchProbabilityThreshold) + if (*ExitingBlockProbability > LatchProbabilityThreshold) return false; } - // Using BPI, we have concluded that the most probable way to exit from the - // loop is through the latch (or there's no profile information and all + // We have concluded that the most probable way to exit from the + // loop is through the latch. (or there's no profile information and all // exits are equally likely). return true; } Index: test/Transforms/LoopPredication/profitability.ll =================================================================== --- test/Transforms/LoopPredication/profitability.ll +++ test/Transforms/LoopPredication/profitability.ll @@ -1,6 +1,10 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt -S -loop-predication -loop-predication-skip-profitability-checks=false < %s 2>&1 | FileCheck %s +; Note: This test for the new pass manager is *fragile*. It relies on us not +; needing to make any changes in LCSSA or LoopSimplify (which would invalidate +; BPI currently). ; RUN: opt -S -loop-predication-skip-profitability-checks=false -passes='require,require,loop(loop-predication)' < %s 2>&1 | FileCheck %s +; RUN: opt -S -loop-predication-skip-profitability-checks=false -passes='require,loop(loop-predication)' < %s 2>&1 | FileCheck %s ; latch block exits to a speculation block. BPI already knows (without prof ; data) that deopt is very rarely @@ -59,7 +63,6 @@ %result.le = load i64, i64* %result.in3.lcssa, align 8 ret i64 %result.le } -!0 = !{!"branch_weights", i32 18, i32 104200} ; predicate loop since there's no profile information and BPI concluded all ; exiting blocks have same probability of exiting from loop. @@ -78,10 +81,10 @@ ; CHECK-NEXT: call void (i1, ...) @llvm.experimental.guard(i1 [[TMP2]], i32 9) [ "deopt"() ] ; CHECK-NEXT: [[INNERCMP:%.*]] = icmp eq i64 [[J2]], [[N_PRE]] ; CHECK-NEXT: [[J_NEXT]] = add nuw nsw i64 [[J2]], 1 -; CHECK-NEXT: br i1 [[INNERCMP]], label [[LATCH]], label [[EXIT:%.*]] +; CHECK-NEXT: br i1 [[INNERCMP]], label [[LATCH]], label [[EXIT:%.*]], !prof !1 ; CHECK: Latch: ; CHECK-NEXT: [[SPECULATE_TRIP_COUNT:%.*]] = icmp ult i64 [[J_NEXT]], 1048576 -; CHECK-NEXT: br i1 [[SPECULATE_TRIP_COUNT]], label [[HEADER]], label [[EXITLATCH:%.*]] +; CHECK-NEXT: br i1 [[SPECULATE_TRIP_COUNT]], label [[HEADER]], label [[EXITLATCH:%.*]], !prof !2 ; CHECK: exitLatch: ; CHECK-NEXT: ret i64 1 ; CHECK: exit: @@ -101,11 +104,11 @@ call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] %innercmp = icmp eq i64 %j2, %n.pre %j.next = add nuw nsw i64 %j2, 1 - br i1 %innercmp, label %Latch, label %exit + br i1 %innercmp, label %Latch, label %exit, !prof !3 Latch: ; preds = %Header %speculate_trip_count = icmp ult i64 %j.next, 1048576 - br i1 %speculate_trip_count, label %Header, label %exitLatch + br i1 %speculate_trip_count, label %Header, label %exitLatch, !prof !4 exitLatch: ; preds = %Latch ret i64 1 @@ -133,10 +136,10 @@ ; CHECK-NEXT: call void (i1, ...) @llvm.experimental.guard(i1 [[WITHIN_BOUNDS]], i32 9) [ "deopt"() ] ; CHECK-NEXT: [[INNERCMP:%.*]] = icmp eq i64 [[J2]], [[N_PRE]] ; CHECK-NEXT: [[J_NEXT]] = add nuw nsw i64 [[J2]], 1 -; CHECK-NEXT: br i1 [[INNERCMP]], label [[LATCH]], label [[EXIT:%.*]], !prof !1 +; CHECK-NEXT: br i1 [[INNERCMP]], label [[LATCH]], label [[EXIT:%.*]], !prof !3 ; CHECK: Latch: ; CHECK-NEXT: [[SPECULATE_TRIP_COUNT:%.*]] = icmp ult i64 [[J_NEXT]], 1048576 -; CHECK-NEXT: br i1 [[SPECULATE_TRIP_COUNT]], label [[HEADER]], label [[EXITLATCH:%.*]], !prof !2 +; CHECK-NEXT: br i1 [[SPECULATE_TRIP_COUNT]], label [[HEADER]], label [[EXITLATCH:%.*]], !prof !4 ; CHECK: exitLatch: ; CHECK-NEXT: ret i64 1 ; CHECK: exit: @@ -173,5 +176,8 @@ declare i64 @llvm.experimental.deoptimize.i64(...) declare void @llvm.experimental.guard(i1, ...) +!0 = !{!"branch_weights", i32 18, i32 104200} !1 = !{!"branch_weights", i32 104, i32 1042861} !2 = !{!"branch_weights", i32 255129, i32 1} +!3 = !{!"branch_weights", i64 1024000, i64 1} +!4 = !{!"branch_weights", i64 1024, i64 1}