Index: llvm/lib/Transforms/Scalar/IndVarSimplify.cpp =================================================================== --- llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -36,6 +36,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -2804,6 +2805,64 @@ return Changed; } +// TODO: This is identical to the version in ScalarEvolution and needs commoned. +static std::pair +getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB, LoopInfo &LI) { + // If the block has a unique predecessor, then there is no path from the + // predecessor to the block that does not go through the direct edge + // from the predecessor to the block. + if (BasicBlock *Pred = BB->getSinglePredecessor()) + return {Pred, BB}; + + // A loop's header is defined to be a block that dominates the loop. + // If the header has a unique predecessor outside the loop, it must be + // a block that has exactly one successor that can reach the loop. + if (Loop *L = LI.getLoopFor(BB)) + return {L->getLoopPredecessor(), L->getHeader()}; + + return {nullptr, nullptr}; +} + + +static std::pair +getLikelyControlEqPredecessorForBB(BasicBlock *BB, LoopInfo &LI) { + // If the block has a unique predecessor, then there is no path from the + // predecessor to the block that does not go through the direct edge + // from the predecessor to the block. + if (BasicBlock *Pred = BB->getSinglePredecessor()) { + if (BB == Pred->getSingleSuccessor()) + return {Pred, BB}; + return {Pred, BB}; + } + + return {nullptr, nullptr}; +} + + +/// If we can (cheaply) find a widenable branch which controls entry into the +/// loop, return it. +static BranchInst *FindWideableTerminatorAboveLoop(Loop *L, LoopInfo &LI) { + // FIXME: need to walk through CE, and then accept *one* non-CE at top + for (std::pair + Pair(L->getLoopPredecessor(), L->getHeader()); + Pair.first; + Pair = getLikelyControlEqPredecessorForBB(Pair.first, LI)) { + + auto *Term = Pair.first->getTerminator(); + dbgs() << "Considering" << *Term << "\n"; + + Value *Cond, *WC; + BasicBlock *IfTrueBB, *IfFalseBB; + if (!parseWidenableBranch(Term, Cond, WC, + IfTrueBB, IfFalseBB)) + continue; + dbgs() << "parsed as widenable\n"; + if (IfTrueBB == Pair.second) + return cast(Term); + } + return nullptr; +} + bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { SmallVector ExitingBlocks; L->getExitingBlocks(ExitingBlocks); @@ -2834,6 +2893,29 @@ !isSafeToExpand(ExactBTC, *SE)) return Changed; + auto *WidenableBR = FindWideableTerminatorAboveLoop(L, *LI); + dbgs() << WidenableBR << "\n"; + if (WidenableBR && + !isSafeToExpandAt(ExactBTC, WidenableBR, *SE)) + WidenableBR = nullptr; + dbgs() << WidenableBR << "\n"; + + + auto mayHaveSideEffectOrEarlyExitImpl = [](Loop *L) { + // TODO:isGuaranteedToTransfer + for (BasicBlock *BB : L->blocks()) + for (auto &I : *BB) + if (I.mayHaveSideEffects() || I.mayThrow()) + return true; + return false; + }; + Optional mayHaveSideEffectOrEarlyExitCache; + auto mayHaveSideEffectOrEarlyExit = [&]() { + if (!mayHaveSideEffectOrEarlyExitCache) + mayHaveSideEffectOrEarlyExitCache = mayHaveSideEffectOrEarlyExitImpl(L); + return *mayHaveSideEffectOrEarlyExitCache; + }; + auto BadExit = [&](BasicBlock *ExitingBB) { // If our exiting block exits multiple loops, we can only rewrite the // innermost one. Otherwise, we're changing how many times the innermost @@ -2850,6 +2932,20 @@ if (isa(BI->getCondition())) return true; + const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); + assert(!isa(ExactBTC) && + "implied by having exact trip count"); + + // If we can combine into a widenable condition above the loop, we don't + // need to worry about either a) side effects in the loop, or b) any need + // to compute values down the loop exit. + if (WidenableBR && + isSafeToExpandAt(ExitCount, WidenableBR, *SE) && + ExitingBB != L->getLoopLatch()) { + dbgs() << "safe to expand ec\n"; + return false; + } + // If the exit block has phis, we need to be able to compute the values // within the loop which contains them. This assumes trivially lcssa phis // have already been removed; TODO: generalize @@ -2858,12 +2954,17 @@ if (!ExitBlock->phis().empty()) return true; - const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); - assert(!isa(ExactBTC) && "implied by having exact trip count"); if (!SE->isLoopInvariant(ExitCount, L) || !isSafeToExpand(ExitCount, *SE)) return true; + // Check to see that the loop is doesn't have side effects and doesn't have + // any implicit exits (because then our exact BTC isn't actually exact). + // This is carefully cached internally so that we compute it at most once + // per loop. + if (mayHaveSideEffectOrEarlyExit()) + return true; + return false; }; @@ -2909,19 +3010,6 @@ return DT->dominates(ExitingBB, L->getLoopLatch()); })); - // At this point, ExitingBlocks consists of only those blocks which are - // predicatable. Given that, we know we have at least one exit we can - // predicate if the loop is doesn't have side effects and doesn't have any - // implicit exits (because then our exact BTC isn't actually exact). - // @Reviewers - As structured, this is O(I^2) for loop nests. Any - // suggestions on how to improve this? I can obviously bail out for outer - // loops, but that seems less than ideal. MemorySSA can find memory writes, - // is that enough for *all* side effects? - for (BasicBlock *BB : L->blocks()) - for (auto &I : *BB) - // TODO:isGuaranteedToTransfer - if (I.mayHaveSideEffects() || I.mayThrow()) - return Changed; // Finally, do the actual predication for all predicatable blocks. A couple // of notes here: @@ -2933,12 +3021,17 @@ // predicate even if we can't insert a loop invariant expression as // peeling or unrolling will likely reduce the cost of the otherwise loop // varying check. - Rewriter.setInsertPoint(L->getLoopPreheader()->getTerminator()); - IRBuilder<> B(L->getLoopPreheader()->getTerminator()); Value *ExactBTCV = nullptr; //lazy generated if needed for (BasicBlock *ExitingBB : ExitingBlocks) { const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); + Instruction *IP = L->getLoopPreheader()->getTerminator(); + if (WidenableBR && + isSafeToExpandAt(ExitCount, WidenableBR, *SE)) + IP = WidenableBR; + Rewriter.setInsertPoint(IP); + IRBuilder<> B(IP); + auto *BI = cast(ExitingBB->getTerminator()); Value *NewCond; if (ExitCount == ExactBTC) { @@ -2946,6 +3039,7 @@ B.getFalse() : B.getTrue(); } else { Value *ECV = Rewriter.expandCodeFor(ExitCount); + // FIXME: caching now wrong if (!ExactBTCV) ExactBTCV = Rewriter.expandCodeFor(ExactBTC); Value *RHS = ExactBTCV; @@ -2959,12 +3053,26 @@ NewCond = B.CreateICmp(Pred, ECV, RHS); } Value *OldCond = BI->getCondition(); - BI->setCondition(NewCond); + if (WidenableBR && + isSafeToExpandAt(ExitCount, WidenableBR, *SE)) { + // There's a really subtle profitability problem here: rather than + // selecting between exits, we're rewriting exits to deopt. We only want + // to do that for *rare* exits. + NewCond = B.CreateNot(NewCond); + NewCond = B.CreateAnd(NewCond, WidenableBR->getCondition()); + WidenableBR->setCondition(NewCond); + bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); + BI->setCondition(ConstantInt::get(OldCond->getType(), !ExitIfTrue)); + } else { + BI->setCondition(NewCond); + } if (OldCond->use_empty()) DeadInsts.push_back(OldCond); Changed = true; } + if (WidenableBR) + WidenableBR->getParent()->getParent()->dump(); return Changed; } Index: llvm/test/Transforms/IndVarSimplify/loop-predication.ll =================================================================== --- llvm/test/Transforms/IndVarSimplify/loop-predication.ll +++ llvm/test/Transforms/IndVarSimplify/loop-predication.ll @@ -850,8 +850,87 @@ ret i32 %result2 } +define i32 @wc(i32* %array, i32 %length, i32 %n, i1 %cond_0) { +; CHECK-LABEL: @wc( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[WIDENABLE_COND:%.*]] = call i1 @llvm.experimental.widenable.condition() +; CHECK-NEXT: [[EXIPLICIT_GUARD_COND:%.*]] = and i1 [[COND_0:%.*]], [[WIDENABLE_COND]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp ugt i32 [[N:%.*]], 1 +; CHECK-NEXT: [[UMAX:%.*]] = select i1 [[TMP0]], i32 [[N]], i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[UMAX]], -1 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i32 [[LENGTH:%.*]], [[TMP1]] +; CHECK-NEXT: [[UMIN:%.*]] = select i1 [[TMP2]], i32 [[LENGTH]], i32 [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i32 [[LENGTH]], [[UMIN]] +; CHECK-NEXT: [[TMP4:%.*]] = xor i1 [[TMP3]], true +; CHECK-NEXT: [[TMP5:%.*]] = and i1 [[TMP4]], [[EXIPLICIT_GUARD_COND]] +; CHECK-NEXT: br i1 [[TMP5]], label [[LOOP_PREHEADER:%.*]], label [[DEOPT:%.*]], !prof !0 +; CHECK: deopt: +; CHECK-NEXT: [[DEOPTRET:%.*]] = call i32 (...) @llvm.experimental.deoptimize.i32() [ "deopt"() ] +; CHECK-NEXT: ret i32 [[DEOPTRET]] +; CHECK: loop.preheader: +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[LOOP_ACC:%.*]] = phi i32 [ [[LOOP_ACC_NEXT:%.*]], [[GUARDED:%.*]] ], [ 0, [[LOOP_PREHEADER]] ] +; CHECK-NEXT: [[I:%.*]] = phi i32 [ [[I_NEXT:%.*]], [[GUARDED]] ], [ 0, [[LOOP_PREHEADER]] ] +; CHECK-NEXT: call void @unknown() +; CHECK-NEXT: br i1 true, label [[GUARDED]], label [[DEOPT2:%.*]], !prof !0 +; CHECK: deopt2: +; CHECK-NEXT: call void @prevent_merging() +; CHECK-NEXT: ret i32 -1 +; CHECK: guarded: +; CHECK-NEXT: [[I_I64:%.*]] = zext i32 [[I]] to i64 +; CHECK-NEXT: [[ARRAY_I_PTR:%.*]] = getelementptr inbounds i32, i32* [[ARRAY:%.*]], i64 [[I_I64]] +; CHECK-NEXT: [[ARRAY_I:%.*]] = load i32, i32* [[ARRAY_I_PTR]], align 4 +; CHECK-NEXT: [[LOOP_ACC_NEXT]] = add i32 [[LOOP_ACC]], [[ARRAY_I]] +; CHECK-NEXT: [[I_NEXT]] = add nuw i32 [[I]], 1 +; CHECK-NEXT: [[CONTINUE:%.*]] = icmp ult i32 [[I_NEXT]], [[N]] +; CHECK-NEXT: br i1 [[CONTINUE]], label [[LOOP]], label [[EXIT:%.*]] +; CHECK: exit: +; CHECK-NEXT: [[RESULT:%.*]] = phi i32 [ [[LOOP_ACC_NEXT]], [[GUARDED]] ] +; CHECK-NEXT: ret i32 [[RESULT]] +; +entry: + %widenable_cond = call i1 @llvm.experimental.widenable.condition() + %exiplicit_guard_cond = and i1 %cond_0, %widenable_cond + br i1 %exiplicit_guard_cond, label %loop.preheader, label %deopt, !prof !0 + +deopt: + %deoptret = call i32 (...) @llvm.experimental.deoptimize.i32() [ "deopt"() ] + ret i32 %deoptret + +loop.preheader: + br label %loop + +loop: + %loop.acc = phi i32 [ %loop.acc.next, %guarded ], [ 0, %loop.preheader ] + %i = phi i32 [ %i.next, %guarded ], [ 0, %loop.preheader ] + call void @unknown() + %within.bounds = icmp ult i32 %i, %length + br i1 %within.bounds, label %guarded, label %deopt2, !prof !0 + +deopt2: + call void @prevent_merging() + ret i32 -1 + +guarded: + %i.i64 = zext i32 %i to i64 + %array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64 + %array.i = load i32, i32* %array.i.ptr, align 4 + %loop.acc.next = add i32 %loop.acc, %array.i + %i.next = add nuw i32 %i, 1 + %continue = icmp ult i32 %i.next, %n + br i1 %continue, label %loop, label %exit + +exit: + %result = phi i32 [ %loop.acc.next, %guarded ] + ret i32 %result +} + +declare void @unknown() +declare i1 @llvm.experimental.widenable.condition() declare i32 @llvm.experimental.deoptimize.i32(...) +declare void @llvm.experimental.deoptimize.isVoid(...) !0 = !{!"branch_weights", i32 1048576, i32 1} !1 = !{i32 1, i32 -2147483648}