Index: lib/Transforms/Scalar/LoopPredication.cpp =================================================================== --- lib/Transforms/Scalar/LoopPredication.cpp +++ lib/Transforms/Scalar/LoopPredication.cpp @@ -269,24 +269,29 @@ /// trivial result would be the at the User itself, but we try to return a /// loop invariant location if possible. Instruction *findInsertPt(Instruction *User, ArrayRef Ops); + /// Same as above, *except* that this uses the SCEV definition of invariant + /// which is that an expression *can be made* invariant via SCEVExpander. + /// Thus, this version is only suitable for finding an insert point to be be + /// passed to SCEVExpander! + Instruction *findInsertPt(Instruction *User, ArrayRef Ops); bool CanExpand(const SCEV* S); - Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder, + Value *expandCheck(SCEVExpander &Expander, Instruction *Guard, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); Optional widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, - IRBuilder<> &Builder); + Instruction *Guard); Optional widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, - IRBuilder<> &Builder); + Instruction *Guard); Optional widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, - IRBuilder<> &Builder); + Instruction *Guard); unsigned collectChecks(SmallVectorImpl &Checks, Value *Condition, - SCEVExpander &Expander, IRBuilder<> &Builder); + SCEVExpander &Expander, Instruction *Guard); bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander); // If the loop always exits through another block in the loop, we should not @@ -394,21 +399,24 @@ } Value *LoopPredication::expandCheck(SCEVExpander &Expander, - IRBuilder<> &Builder, + Instruction *Guard, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { Type *Ty = LHS->getType(); assert(Ty == RHS->getType() && "expandCheck operands have different types?"); - if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS)) - return Builder.getTrue(); - if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred), - LHS, RHS)) - return Builder.getFalse(); - - Instruction *InsertAt = &*Builder.GetInsertPoint(); - Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt); - Value *RHSV = Expander.expandCodeFor(RHS, Ty, InsertAt); + if (SE->isLoopInvariant(LHS, L) && SE->isLoopInvariant(RHS, L)) { + IRBuilder<> Builder(Guard); + if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS)) + return Builder.getTrue(); + if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred), + LHS, RHS)) + return Builder.getFalse(); + } + + Value *LHSV = Expander.expandCodeFor(LHS, Ty, findInsertPt(Guard, {LHS})); + Value *RHSV = Expander.expandCodeFor(RHS, Ty, findInsertPt(Guard, {RHS})); + IRBuilder<> Builder(findInsertPt(Guard, {LHSV, RHSV})); return Builder.CreateICmp(Pred, LHSV, RHSV); } @@ -452,13 +460,22 @@ return Preheader->getTerminator(); } +Instruction *LoopPredication::findInsertPt(Instruction *Use, + ArrayRef Ops) { + for (const SCEV *Op : Ops) + if (!SE->isLoopInvariant(Op, L)) + return Use; + return Preheader->getTerminator(); +} + + bool LoopPredication::CanExpand(const SCEV* S) { return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE); } Optional LoopPredication::widenICmpRangeCheckIncrementingLoop( LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck, - SCEVExpander &Expander, IRBuilder<> &Builder) { + SCEVExpander &Expander, Instruction *Guard) { auto *Ty = RangeCheck.IV->getType(); // Generate the widened condition for the forward loop: // guardStart u< guardLimit && @@ -488,15 +505,16 @@ LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); auto *LimitCheck = - expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, RHS); - auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck.Pred, + expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS); + auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred, GuardStart, GuardLimit); + IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck})); return Builder.CreateAnd(FirstIterationCheck, LimitCheck); } Optional LoopPredication::widenICmpRangeCheckDecrementingLoop( LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck, - SCEVExpander &Expander, IRBuilder<> &Builder) { + SCEVExpander &Expander, Instruction *Guard) { auto *Ty = RangeCheck.IV->getType(); const SCEV *GuardStart = RangeCheck.IV->getStart(); const SCEV *GuardLimit = RangeCheck.Limit; @@ -522,10 +540,12 @@ // See the header comment for reasoning of the checks. auto LimitCheckPred = ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred); - auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT, + auto *FirstIterationCheck = expandCheck(Expander, Guard, + ICmpInst::ICMP_ULT, GuardStart, GuardLimit); - auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, + auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, SE->getOne(Ty)); + IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck})); return Builder.CreateAnd(FirstIterationCheck, LimitCheck); } @@ -534,7 +554,7 @@ /// returns None. Optional LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, - IRBuilder<> &Builder) { + Instruction *Guard) { LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); LLVM_DEBUG(ICI->dump()); @@ -588,18 +608,18 @@ if (Step->isOne()) return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck, - Expander, Builder); + Expander, Guard); else { assert(Step->isAllOnesValue() && "Step should be -1!"); return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck, - Expander, Builder); + Expander, Guard); } } unsigned LoopPredication::collectChecks(SmallVectorImpl &Checks, Value *Condition, SCEVExpander &Expander, - IRBuilder<> &Builder) { + Instruction *Guard) { unsigned NumWidened = 0; // The guard condition is expected to be in form of: // cond1 && cond2 && cond3 ... @@ -631,7 +651,7 @@ if (ICmpInst *ICI = dyn_cast(Condition)) { if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, - Builder)) { + Guard)) { Checks.push_back(NewRangeCheck.getValue()); NumWidened++; continue; @@ -657,16 +677,15 @@ TotalConsidered++; SmallVector Checks; - IRBuilder<> Builder(cast(Preheader->getTerminator())); unsigned NumWidened = collectChecks(Checks, Guard->getOperand(0), Expander, - Builder); + Guard); if (NumWidened == 0) return false; TotalWidened += NumWidened; // Emit the new guard condition - Builder.SetInsertPoint(findInsertPt(Guard, Checks)); + IRBuilder<> Builder(findInsertPt(Guard, Checks)); Value *LastCheck = nullptr; for (auto *Check : Checks) if (!LastCheck) @@ -689,16 +708,15 @@ TotalConsidered++; SmallVector Checks; - IRBuilder<> Builder(cast(Preheader->getTerminator())); unsigned NumWidened = collectChecks(Checks, BI->getCondition(), - Expander, Builder); + Expander, BI); if (NumWidened == 0) return false; TotalWidened += NumWidened; // Emit the new guard condition - Builder.SetInsertPoint(findInsertPt(BI, Checks)); + IRBuilder<> Builder(findInsertPt(BI, Checks)); Value *LastCheck = nullptr; for (auto *Check : Checks) if (!LastCheck)