Index: lib/Transforms/Scalar/LoopPredication.cpp =================================================================== --- lib/Transforms/Scalar/LoopPredication.cpp +++ lib/Transforms/Scalar/LoopPredication.cpp @@ -178,6 +178,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopPredication.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -202,6 +203,15 @@ static cl::opt EnableCountDownLoop("loop-predication-enable-count-down-loop", cl::Hidden, cl::init(true)); + +static cl::opt + SkipProfitabilityChecks("loop-predication-skip-profitability-checks", + cl::Hidden, cl::init(false)); + +static cl::opt + ProbabilityThreshold("loop-predication-threshold-for-profitability", + cl::Hidden, cl::init(2.0)); + namespace { class LoopPredication { /// Represents an induction variable check: @@ -221,6 +231,7 @@ }; ScalarEvolution *SE; + BranchProbabilityInfo *BPI; Loop *L; const DataLayout *DL; @@ -254,6 +265,12 @@ IRBuilder<> &Builder); bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); + // If the loop always exits through another block in the loop, we should not + // predicate based on the latch check. For example, the latch check can be a + // very coarse grained check and there can be more fine grained exit checks + // within the loop. We identify such unprofitable loops through BPI. + bool isLoopProfitableToPredicate(); + // When the IV type is wider than the range operand type, we can still do loop // predication, by generating SCEVs for the range and latch that are of the // same type. We achieve this by generating a SCEV truncate expression for the @@ -272,7 +289,8 @@ Optional generateLoopLatchCheck(Type *RangeCheckType); public: - LoopPredication(ScalarEvolution *SE) : SE(SE){}; + LoopPredication(ScalarEvolution *SE, BranchProbabilityInfo *BPI) + : SE(SE), BPI(BPI){}; bool runOnLoop(Loop *L); }; @@ -284,6 +302,7 @@ } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); getLoopAnalysisUsage(AU); } @@ -291,7 +310,9 @@ if (skipLoop(L)) return false; auto *SE = &getAnalysis().getSE(); - LoopPredication LP(SE); + BranchProbabilityInfo &BPI = + getAnalysis().getBPI(); + LoopPredication LP(SE, &BPI); return LP.runOnLoop(L); } }; @@ -301,6 +322,7 @@ INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", "Loop predication", false, false) +INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication", "Loop predication", false, false) @@ -312,7 +334,11 @@ PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U) { - LoopPredication LP(&AR.SE); + const auto &FAM = + AM.getResult(L, AR).getManager(); + Function *F = L.getHeader()->getParent(); + auto *BPI = FAM.getCachedResult(*F); + LoopPredication LP(&AR.SE, BPI); if (!LP.runOnLoop(&L)) return PreservedAnalyses::all(); @@ -690,6 +716,42 @@ Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; } +bool LoopPredication::isLoopProfitableToPredicate() { + if (SkipProfitabilityChecks || !BPI) + return true; + SmallVector ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + if (ExitingBlocks.size() < 2) + return true; + auto *LatchBlock = L->getLoopLatch(); + assert(LatchBlock && "Should have a single latch at this point!"); + unsigned LatchBrExitIdx = + LatchBlock->getTerminator()->getSuccessor(0) == L->getHeader() ? 1 : 0; + BranchProbability LatchExitProbability = + BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx); + + for (auto *EB : ExitingBlocks) { + if (EB == LatchBlock) + continue; + auto *EBTerminator = dyn_cast(EB->getTerminator()); + if (!EBTerminator) + continue; + assert(EBTerminator->isConditional() && "definition of exiting block!"); + unsigned EBExitIdx = L->contains(EBTerminator->getSuccessor(0)) ? 1 : 0; + BranchProbability ExitingBlockProbability = + BPI->getEdgeProbability(EB, EBExitIdx); + // If any of the exiting blocks' probability of exiting the + // loop is larger than LatchExitProbability, it's not profitable to + // predicate. + if (ExitingBlockProbability > LatchExitProbability * ProbabilityThreshold) + return false; + } + // Using BPI, we have concluded that the most probable way to exit from the + // loop is through the latch (or there's no profile information and all + // exits are equally likely). + return true; +} + bool LoopPredication::runOnLoop(Loop *Loop) { L = Loop; @@ -718,6 +780,10 @@ DEBUG(dbgs() << "Latch check:\n"); DEBUG(LatchCheck.dump()); + if (!isLoopProfitableToPredicate()) { + DEBUG(dbgs()<< "Loop not profitable to predicate!\n"); + return false; + } // Collect all the guards into a vector and process later, so as not // to invalidate the instruction iterator. SmallVector Guards; Index: test/Transforms/LoopPredication/profitability.ll =================================================================== --- /dev/null +++ test/Transforms/LoopPredication/profitability.ll @@ -0,0 +1,120 @@ +; RUN: opt -S -loop-predication -loop-predication-skip-profitability-checks=false < %s 2>&1 | FileCheck %s +; RUN: opt -S -loop-predication-skip-profitability-checks=false -passes='require,require,loop(loop-predication)' < %s 2>&1 | FileCheck %s + +; latch block exits to a speculation block. BPI already knows (without prof +; data) that deopt is very rarely +; taken. So we do not predicate this loop using that coarse latch check. +; LatchExitProbability: 0x04000000 / 0x80000000 = 3.12% +; ExitingBlockProbability: 0x7ffa572a / 0x80000000 = 99.98% +define i64 @donot_predicate(i64* nocapture readonly %arg, i32 %length, i64* nocapture readonly %arg2, i64* nocapture readonly %n_addr, i64 %i) { +; CHECK-LABEL: donot_predicate( +entry: + %length.ext = zext i32 %length to i64 + %n.pre = load i64, i64* %n_addr, align 4 + br label %Header + +; CHECK-LABEL: Header: +; CHECK: %within.bounds = icmp ult i64 %j2, %length.ext +; CHECK-NEXT: call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) +Header: ; preds = %entry, %Latch + %result.in3 = phi i64* [ %arg2, %entry ], [ %arg, %Latch ] + %j2 = phi i64 [ 0, %entry ], [ %j.next, %Latch ] + %within.bounds = icmp ult i64 %j2, %length.ext + call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] + %innercmp = icmp eq i64 %j2, %n.pre + %j.next = add nuw nsw i64 %j2, 1 + br i1 %innercmp, label %Latch, label %exit, !prof !0 + +Latch: ; preds = %Header + %speculate_trip_count = icmp ult i64 %j.next, 1048576 + br i1 %speculate_trip_count, label %Header, label %deopt + +deopt: ; preds = %Latch + %counted_speculation_failed = call i64 (...) @llvm.experimental.deoptimize.i64(i64 30) [ "deopt"(i32 0) ] + ret i64 %counted_speculation_failed + +exit: ; preds = %Header + %result.in3.lcssa = phi i64* [ %result.in3, %Header ] + %result.le = load i64, i64* %result.in3.lcssa, align 8 + ret i64 %result.le +} +!0 = !{!"branch_weights", i32 18, i32 104200} + +; predicate loop since there's no profile information and BPI concluded all +; exiting blocks have same probability of exiting from loop. +define i64 @predicate(i64* nocapture readonly %arg, i32 %length, i64* nocapture readonly %arg2, i64* nocapture readonly %n_addr, i64 %i) { +; CHECK-LABEL: predicate( +; CHECK-LABEL: entry: +; CHECK: [[limit_check:[^ ]+]] = icmp ule i64 1048576, %length.ext +; CHECK-NEXT: [[first_iteration_check:[^ ]+]] = icmp ult i64 0, %length.ext +; CHECK-NEXT: [[wide_cond:[^ ]+]] = and i1 [[first_iteration_check]], [[limit_check]] +entry: + %length.ext = zext i32 %length to i64 + %n.pre = load i64, i64* %n_addr, align 4 + br label %Header + +; CHECK-LABEL: Header: +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ] +Header: ; preds = %entry, %Latch + %result.in3 = phi i64* [ %arg2, %entry ], [ %arg, %Latch ] + %j2 = phi i64 [ 0, %entry ], [ %j.next, %Latch ] + %within.bounds = icmp ult i64 %j2, %length.ext + call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] + %innercmp = icmp eq i64 %j2, %n.pre + %j.next = add nuw nsw i64 %j2, 1 + br i1 %innercmp, label %Latch, label %exit + +Latch: ; preds = %Header + %speculate_trip_count = icmp ult i64 %j.next, 1048576 + br i1 %speculate_trip_count, label %Header, label %exitLatch + +exitLatch: ; preds = %Latch + ret i64 1 + +exit: ; preds = %Header + %result.in3.lcssa = phi i64* [ %result.in3, %Header ] + %result.le = load i64, i64* %result.in3.lcssa, align 8 + ret i64 %result.le +} + +; Same as test above but with profiling data that the most probable exit from +; the loop is the header exiting block (not the latch block). So do not predicate. +; LatchExitProbability: 0x000020e1 / 0x80000000 = 0.00% +; ExitingBlockProbability: 0x7ffcbb86 / 0x80000000 = 99.99% +define i64 @donot_predicate_prof(i64* nocapture readonly %arg, i32 %length, i64* nocapture readonly %arg2, i64* nocapture readonly %n_addr, i64 %i) { +; CHECK-LABEL: donot_predicate_prof( +; CHECK-LABEL: entry: +entry: + %length.ext = zext i32 %length to i64 + %n.pre = load i64, i64* %n_addr, align 4 + br label %Header + +; CHECK-LABEL: Header: +; CHECK: %within.bounds = icmp ult i64 %j2, %length.ext +; CHECK-NEXT: call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) +Header: ; preds = %entry, %Latch + %result.in3 = phi i64* [ %arg2, %entry ], [ %arg, %Latch ] + %j2 = phi i64 [ 0, %entry ], [ %j.next, %Latch ] + %within.bounds = icmp ult i64 %j2, %length.ext + call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] + %innercmp = icmp eq i64 %j2, %n.pre + %j.next = add nuw nsw i64 %j2, 1 + br i1 %innercmp, label %Latch, label %exit, !prof !1 + +Latch: ; preds = %Header + %speculate_trip_count = icmp ult i64 %j.next, 1048576 + br i1 %speculate_trip_count, label %Header, label %exitLatch, !prof !2 + +exitLatch: ; preds = %Latch + ret i64 1 + +exit: ; preds = %Header + %result.in3.lcssa = phi i64* [ %result.in3, %Header ] + %result.le = load i64, i64* %result.in3.lcssa, align 8 + ret i64 %result.le +} +declare i64 @llvm.experimental.deoptimize.i64(...) +declare void @llvm.experimental.guard(i1, ...) + +!1 = !{!"branch_weights", i32 104, i32 1042861} +!2 = !{!"branch_weights", i32 255129, i32 1}