Index: llvm/trunk/lib/Transforms/Scalar/LoopPredication.cpp =================================================================== --- llvm/trunk/lib/Transforms/Scalar/LoopPredication.cpp +++ llvm/trunk/lib/Transforms/Scalar/LoopPredication.cpp @@ -178,6 +178,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopPredication.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -202,6 +203,20 @@ static cl::opt EnableCountDownLoop("loop-predication-enable-count-down-loop", cl::Hidden, cl::init(true)); + +static cl::opt + SkipProfitabilityChecks("loop-predication-skip-profitability-checks", + cl::Hidden, cl::init(false)); + +// This is the scale factor for the latch probability. We use this during +// profitability analysis to find other exiting blocks that have a much higher +// probability of exiting the loop instead of loop exiting via latch. +// This value should be greater than 1 for a sane profitability check. +static cl::opt LatchExitProbabilityScale( + "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0), + cl::desc("scale factor for the latch probability. Value should be greater " + "than 1. Lower values are ignored")); + namespace { class LoopPredication { /// Represents an induction variable check: @@ -221,6 +236,7 @@ }; ScalarEvolution *SE; + BranchProbabilityInfo *BPI; Loop *L; const DataLayout *DL; @@ -254,6 +270,12 @@ IRBuilder<> &Builder); bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); + // If the loop always exits through another block in the loop, we should not + // predicate based on the latch check. For example, the latch check can be a + // very coarse grained check and there can be more fine grained exit checks + // within the loop. We identify such unprofitable loops through BPI. + bool isLoopProfitableToPredicate(); + // When the IV type is wider than the range operand type, we can still do loop // predication, by generating SCEVs for the range and latch that are of the // same type. We achieve this by generating a SCEV truncate expression for the @@ -272,7 +294,8 @@ Optional generateLoopLatchCheck(Type *RangeCheckType); public: - LoopPredication(ScalarEvolution *SE) : SE(SE){}; + LoopPredication(ScalarEvolution *SE, BranchProbabilityInfo *BPI) + : SE(SE), BPI(BPI){}; bool runOnLoop(Loop *L); }; @@ -284,6 +307,7 @@ } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); getLoopAnalysisUsage(AU); } @@ -291,7 +315,9 @@ if (skipLoop(L)) return false; auto *SE = &getAnalysis().getSE(); - LoopPredication LP(SE); + BranchProbabilityInfo &BPI = + getAnalysis().getBPI(); + LoopPredication LP(SE, &BPI); return LP.runOnLoop(L); } }; @@ -301,6 +327,7 @@ INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", "Loop predication", false, false) +INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication", "Loop predication", false, false) @@ -312,7 +339,11 @@ PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U) { - LoopPredication LP(&AR.SE); + const auto &FAM = + AM.getResult(L, AR).getManager(); + Function *F = L.getHeader()->getParent(); + auto *BPI = FAM.getCachedResult(*F); + LoopPredication LP(&AR.SE, BPI); if (!LP.runOnLoop(&L)) return PreservedAnalyses::all(); @@ -690,6 +721,60 @@ Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; } +bool LoopPredication::isLoopProfitableToPredicate() { + if (SkipProfitabilityChecks || !BPI) + return true; + + SmallVector, 8> ExitEdges; + L->getExitEdges(ExitEdges); + // If there is only one exiting edge in the loop, it is always profitable to + // predicate the loop. + if (ExitEdges.size() == 1) + return true; + + // Calculate the exiting probabilities of all exiting edges from the loop, + // starting with the LatchExitProbability. + // Heuristic for profitability: If any of the exiting blocks' probability of + // exiting the loop is larger than exiting through the latch block, it's not + // profitable to predicate the loop. + auto *LatchBlock = L->getLoopLatch(); + assert(LatchBlock && "Should have a single latch at this point!"); + auto *LatchTerm = LatchBlock->getTerminator(); + assert(LatchTerm->getNumSuccessors() == 2 && + "expected to be an exiting block with 2 succs!"); + unsigned LatchBrExitIdx = + LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0; + BranchProbability LatchExitProbability = + BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx); + + // Protect against degenerate inputs provided by the user. Providing a value + // less than one, can invert the definition of profitable loop predication. + float ScaleFactor = LatchExitProbabilityScale; + if (ScaleFactor < 1) { + DEBUG( + dbgs() + << "Ignored user setting for loop-predication-latch-probability-scale: " + << LatchExitProbabilityScale << "\n"); + DEBUG(dbgs() << "The value is set to 1.0\n"); + ScaleFactor = 1.0; + } + const auto LatchProbabilityThreshold = + LatchExitProbability * ScaleFactor; + + for (const auto &ExitEdge : ExitEdges) { + BranchProbability ExitingBlockProbability = + BPI->getEdgeProbability(ExitEdge.first, ExitEdge.second); + // Some exiting edge has higher probability than the latch exiting edge. + // No longer profitable to predicate. + if (ExitingBlockProbability > LatchProbabilityThreshold) + return false; + } + // Using BPI, we have concluded that the most probable way to exit from the + // loop is through the latch (or there's no profile information and all + // exits are equally likely). + return true; +} + bool LoopPredication::runOnLoop(Loop *Loop) { L = Loop; @@ -718,6 +803,10 @@ DEBUG(dbgs() << "Latch check:\n"); DEBUG(LatchCheck.dump()); + if (!isLoopProfitableToPredicate()) { + DEBUG(dbgs()<< "Loop not profitable to predicate!\n"); + return false; + } // Collect all the guards into a vector and process later, so as not // to invalidate the instruction iterator. SmallVector Guards; Index: llvm/trunk/test/Transforms/LoopPredication/profitability.ll =================================================================== --- llvm/trunk/test/Transforms/LoopPredication/profitability.ll +++ llvm/trunk/test/Transforms/LoopPredication/profitability.ll @@ -0,0 +1,120 @@ +; RUN: opt -S -loop-predication -loop-predication-skip-profitability-checks=false < %s 2>&1 | FileCheck %s +; RUN: opt -S -loop-predication-skip-profitability-checks=false -passes='require,require,loop(loop-predication)' < %s 2>&1 | FileCheck %s + +; latch block exits to a speculation block. BPI already knows (without prof +; data) that deopt is very rarely +; taken. So we do not predicate this loop using that coarse latch check. +; LatchExitProbability: 0x04000000 / 0x80000000 = 3.12% +; ExitingBlockProbability: 0x7ffa572a / 0x80000000 = 99.98% +define i64 @donot_predicate(i64* nocapture readonly %arg, i32 %length, i64* nocapture readonly %arg2, i64* nocapture readonly %n_addr, i64 %i) { +; CHECK-LABEL: donot_predicate( +entry: + %length.ext = zext i32 %length to i64 + %n.pre = load i64, i64* %n_addr, align 4 + br label %Header + +; CHECK-LABEL: Header: +; CHECK: %within.bounds = icmp ult i64 %j2, %length.ext +; CHECK-NEXT: call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) +Header: ; preds = %entry, %Latch + %result.in3 = phi i64* [ %arg2, %entry ], [ %arg, %Latch ] + %j2 = phi i64 [ 0, %entry ], [ %j.next, %Latch ] + %within.bounds = icmp ult i64 %j2, %length.ext + call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] + %innercmp = icmp eq i64 %j2, %n.pre + %j.next = add nuw nsw i64 %j2, 1 + br i1 %innercmp, label %Latch, label %exit, !prof !0 + +Latch: ; preds = %Header + %speculate_trip_count = icmp ult i64 %j.next, 1048576 + br i1 %speculate_trip_count, label %Header, label %deopt + +deopt: ; preds = %Latch + %counted_speculation_failed = call i64 (...) @llvm.experimental.deoptimize.i64(i64 30) [ "deopt"(i32 0) ] + ret i64 %counted_speculation_failed + +exit: ; preds = %Header + %result.in3.lcssa = phi i64* [ %result.in3, %Header ] + %result.le = load i64, i64* %result.in3.lcssa, align 8 + ret i64 %result.le +} +!0 = !{!"branch_weights", i32 18, i32 104200} + +; predicate loop since there's no profile information and BPI concluded all +; exiting blocks have same probability of exiting from loop. +define i64 @predicate(i64* nocapture readonly %arg, i32 %length, i64* nocapture readonly %arg2, i64* nocapture readonly %n_addr, i64 %i) { +; CHECK-LABEL: predicate( +; CHECK-LABEL: entry: +; CHECK: [[limit_check:[^ ]+]] = icmp ule i64 1048576, %length.ext +; CHECK-NEXT: [[first_iteration_check:[^ ]+]] = icmp ult i64 0, %length.ext +; CHECK-NEXT: [[wide_cond:[^ ]+]] = and i1 [[first_iteration_check]], [[limit_check]] +entry: + %length.ext = zext i32 %length to i64 + %n.pre = load i64, i64* %n_addr, align 4 + br label %Header + +; CHECK-LABEL: Header: +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ] +Header: ; preds = %entry, %Latch + %result.in3 = phi i64* [ %arg2, %entry ], [ %arg, %Latch ] + %j2 = phi i64 [ 0, %entry ], [ %j.next, %Latch ] + %within.bounds = icmp ult i64 %j2, %length.ext + call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] + %innercmp = icmp eq i64 %j2, %n.pre + %j.next = add nuw nsw i64 %j2, 1 + br i1 %innercmp, label %Latch, label %exit + +Latch: ; preds = %Header + %speculate_trip_count = icmp ult i64 %j.next, 1048576 + br i1 %speculate_trip_count, label %Header, label %exitLatch + +exitLatch: ; preds = %Latch + ret i64 1 + +exit: ; preds = %Header + %result.in3.lcssa = phi i64* [ %result.in3, %Header ] + %result.le = load i64, i64* %result.in3.lcssa, align 8 + ret i64 %result.le +} + +; Same as test above but with profiling data that the most probable exit from +; the loop is the header exiting block (not the latch block). So do not predicate. +; LatchExitProbability: 0x000020e1 / 0x80000000 = 0.00% +; ExitingBlockProbability: 0x7ffcbb86 / 0x80000000 = 99.99% +define i64 @donot_predicate_prof(i64* nocapture readonly %arg, i32 %length, i64* nocapture readonly %arg2, i64* nocapture readonly %n_addr, i64 %i) { +; CHECK-LABEL: donot_predicate_prof( +; CHECK-LABEL: entry: +entry: + %length.ext = zext i32 %length to i64 + %n.pre = load i64, i64* %n_addr, align 4 + br label %Header + +; CHECK-LABEL: Header: +; CHECK: %within.bounds = icmp ult i64 %j2, %length.ext +; CHECK-NEXT: call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) +Header: ; preds = %entry, %Latch + %result.in3 = phi i64* [ %arg2, %entry ], [ %arg, %Latch ] + %j2 = phi i64 [ 0, %entry ], [ %j.next, %Latch ] + %within.bounds = icmp ult i64 %j2, %length.ext + call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] + %innercmp = icmp eq i64 %j2, %n.pre + %j.next = add nuw nsw i64 %j2, 1 + br i1 %innercmp, label %Latch, label %exit, !prof !1 + +Latch: ; preds = %Header + %speculate_trip_count = icmp ult i64 %j.next, 1048576 + br i1 %speculate_trip_count, label %Header, label %exitLatch, !prof !2 + +exitLatch: ; preds = %Latch + ret i64 1 + +exit: ; preds = %Header + %result.in3.lcssa = phi i64* [ %result.in3, %Header ] + %result.le = load i64, i64* %result.in3.lcssa, align 8 + ret i64 %result.le +} +declare i64 @llvm.experimental.deoptimize.i64(...) +declare void @llvm.experimental.guard(i1, ...) + +!1 = !{!"branch_weights", i32 104, i32 1042861} +!2 = !{!"branch_weights", i32 255129, i32 1}