Index: lib/Analysis/ScalarEvolutionExpander.cpp =================================================================== --- lib/Analysis/ScalarEvolutionExpander.cpp +++ lib/Analysis/ScalarEvolutionExpander.cpp @@ -1853,6 +1853,31 @@ return NumElim; } +// Recursive search of the I's operands to see if any of them are congruent to +// S. This function is called by findExistingExpansion. +static Instruction *FindCongruentInst(ScalarEvolution &SE, const SCEV *S, + Instruction *I, unsigned Depth) { + // Limit our recursion Depth. + if (Depth > 3) + return nullptr; + + // Pre-Check on using isSCEVable and comparing type is for avoiding too many + // SCEVUnknown creation. + if (I->getType() == S->getType() && SE.isSCEVable(I->getType()) && + SE.getSCEV(I) == S) + return I; + + for (Value *Val : I->operands()) { + if (Instruction *Inst = dyn_cast(Val)) { + Instruction *Op = FindCongruentInst(SE, S, Inst, Depth + 1); + if (Op != nullptr) + return Op; + } + } + + return nullptr; +} + Value *SCEVExpander::findExistingExpansion(const SCEV *S, const Instruction *At, Loop *L) { using namespace llvm::PatternMatch; @@ -1863,19 +1888,34 @@ // Look for suitable value in simple conditions at the loop exits. for (BasicBlock *BB : ExitingBlocks) { ICmpInst::Predicate Pred; - Instruction *LHS, *RHS; + Value *LHSV, *RHSV; BasicBlock *TrueBB, *FalseBB; - if (!match(BB->getTerminator(), - m_Br(m_ICmp(Pred, m_Instruction(LHS), m_Instruction(RHS)), - TrueBB, FalseBB))) + if (!match( + BB->getTerminator(), + m_Br(m_ICmp(Pred, m_Value(LHSV), m_Value(RHSV)), TrueBB, FalseBB))) continue; - if (SE.getSCEV(LHS) == S && SE.DT.dominates(LHS, At)) - return LHS; + auto FindExistingInst = [&](Value *Val) -> Instruction * { + if (Instruction *Inst = dyn_cast(Val)) { + const SCEV *IExpr = SE.getSCEV(Inst); + if (IExpr == S && SE.DT.dominates(Inst, At)) + return Inst; + if (SE.hasOperand(IExpr, S)) { + auto *ExistingInst = FindCongruentInst(SE, S, Inst, 0); + if (ExistingInst != nullptr && + SE.getSCEVValues(SE.getSCEV(ExistingInst)) && + SE.DT.dominates(ExistingInst, At)) + return ExistingInst; + } + } + return nullptr; + }; - if (SE.getSCEV(RHS) == S && SE.DT.dominates(RHS, At)) - return RHS; + if (Value *FoundVal = FindExistingInst(LHSV)) + return FoundVal; + else if (Value *FoundVal = FindExistingInst(RHSV)) + return FoundVal; } // There is potential to make this significantly smarter, but this simple Index: lib/Transforms/Utils/LoopUnrollRuntime.cpp =================================================================== --- lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -311,9 +311,12 @@ return false; BasicBlock *Header = L->getHeader(); + BasicBlock *PH = L->getLoopPreheader(); + BranchInst *PreHeaderBR = cast(PH->getTerminator()); const DataLayout &DL = Header->getModule()->getDataLayout(); SCEVExpander Expander(*SE, DL, "loop-unroll"); - if (!AllowExpensiveTripCount && Expander.isHighCostExpansion(TripCountSC, L)) + if (!AllowExpensiveTripCount && + Expander.isHighCostExpansion(TripCountSC, L, PreHeaderBR)) return false; // We only handle cases when the unroll factor is a power of 2. @@ -331,13 +334,12 @@ if (Loop *ParentLoop = L->getParentLoop()) SE->forgetLoop(ParentLoop); - BasicBlock *PH = L->getLoopPreheader(); BasicBlock *Latch = L->getLoopLatch(); // It helps to splits the original preheader twice, one for the end of the // prolog code and one for a new loop preheader BasicBlock *PEnd = SplitEdge(PH, Header, DT, LI); BasicBlock *NewPH = SplitBlock(PEnd, PEnd->getTerminator(), DT, LI); - BranchInst *PreHeaderBR = cast(PH->getTerminator()); + PreHeaderBR = cast(PH->getTerminator()); // Compute the number of extra iterations required, which is: // extra iterations = run-time trip count % (loop unroll factor + 1) Index: test/Transforms/LoopUnroll/high-cost-trip-count-computation.ll =================================================================== --- test/Transforms/LoopUnroll/high-cost-trip-count-computation.ll +++ test/Transforms/LoopUnroll/high-cost-trip-count-computation.ll @@ -24,4 +24,32 @@ ret i32 0 } +;; Though SCEV for loop tripcount contains division, +;; it shouldn't be considered expensive, since the division already +;; exists in the code and we don't need to expand it once more. +;; Thus, it shouldn't prevent us from unrolling the loop. + +define i32 @test2(i64* %loc, i64 %conv7) { +; CHECK-LABEL: @test2( +; CHECK-LABEL: for.body.prol +entry: + %rem0 = load i64, i64* %loc, align 8 + %div11 = udiv i64 %rem0, %conv7 + %cmp.i38 = icmp ugt i64 %div11, 1 + %div12 = select i1 %cmp.i38, i64 %div11, i64 1 + br label %for.body +for.body: + %rem1 = phi i64 [ %rem0, %entry ], [ %rem2, %for.body ] + %k1 = phi i64 [ %div12, %entry ], [ %dec, %for.body ] + %mul1 = mul i64 %rem1, 48271 + %rem2 = urem i64 %mul1, 2147483647 + %dec = add i64 %k1, -1 + %cmp = icmp eq i64 %dec, 0 + br i1 %cmp, label %exit, label %for.body +exit: + %rem3 = phi i64 [ %rem2, %for.body ] + store i64 %rem3, i64* %loc, align 8 + ret i32 0 +} + !0 = !{i64 1, i64 100}