Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -587,7 +587,9 @@ const SCEV *getUMaxExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getUMaxExpr(SmallVectorImpl &Operands); const SCEV *getSMinExpr(const SCEV *LHS, const SCEV *RHS); + const SCEV *getSMinExpr(SmallVectorImpl &Operands); const SCEV *getUMinExpr(const SCEV *LHS, const SCEV *RHS); + const SCEV *getUMinExpr(SmallVectorImpl &Operands); const SCEV *getUnknown(Value *V); const SCEV *getCouldNotCompute(); @@ -650,6 +652,10 @@ /// then perform a umin operation with them. const SCEV *getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS); + /// Promote the operands to the wider of the types using zero-extension, and + /// then perform a umin operation with them. N-ary function. + const SCEV *getUMinFromMismatchedTypes(SmallVectorImpl &Ops); + /// Transitively follow the chain of pointer-type operands until reaching a /// SCEV that does not have a single pointer operand. This returns a /// SCEVUnknown pointer for well-formed pointer-type expressions, but corner Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -3548,14 +3548,30 @@ const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS, const SCEV *RHS) { - // ~smax(~x, ~y) == smin(x, y). - return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); + SmallVector Ops = { LHS, RHS }; + return getSMinExpr(Ops); +} + +const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl &Ops) { + // ~smax(~x, ~y, ~z) == smin(x, y, z). + SmallVector NotOps; + for (auto *S : Ops) + NotOps.push_back(getNotSCEV(S)); + return getNotSCEV(getSMaxExpr(NotOps)); } const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS) { - // ~umax(~x, ~y) == umin(x, y) - return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); + SmallVector Ops = { LHS, RHS }; + return getUMinExpr(Ops); +} + +const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl &Ops) { + // ~umax(~x, ~y, ~z) == umin(x, y, z). + SmallVector NotOps; + for (auto *S : Ops) + NotOps.push_back(getNotSCEV(S)); + return getNotSCEV(getUMaxExpr(NotOps)); } const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { @@ -3943,15 +3959,27 @@ const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS) { - const SCEV *PromotedLHS = LHS; - const SCEV *PromotedRHS = RHS; + SmallVector Ops = { LHS, RHS }; + return getUMinFromMismatchedTypes(Ops); +} - if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) - PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); - else - PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); +const SCEV *ScalarEvolution::getUMinFromMismatchedTypes( + SmallVectorImpl &Ops) { + // Find the max type first. + Type *MaxType = nullptr; + for (auto *S : Ops) + if (MaxType) + MaxType = getWiderType(MaxType, S->getType()); + else + MaxType = S->getType(); + + // Extend all ops to max type. + SmallVector PromotedOps; + for (auto *S : Ops) + PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType)); - return getUMinExpr(PromotedLHS, PromotedRHS); + // Generate umin. + return getUMinExpr(PromotedOps); } const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { @@ -6666,7 +6694,6 @@ if (!isComplete() || ExitNotTaken.empty()) return SE->getCouldNotCompute(); - const SCEV *BECount = nullptr; const BasicBlock *Latch = L->getLoopLatch(); // All exiting blocks we have collected must dominate the only backedge. if (!Latch) @@ -6674,16 +6701,15 @@ // All exiting blocks we have gathered dominate loop's latch, so exact trip // count is simply a minimum out of all these calculated exit counts. + SmallVector Ops; for (auto &ENT : ExitNotTaken) { - assert(ENT.ExactNotTaken != SE->getCouldNotCompute() && "Bad exit SCEV!"); + const SCEV *BECount = ENT.ExactNotTaken; + assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!"); assert(SE->DT.dominates(ENT.ExitingBlock, Latch) && "We should only have known counts for exiting blocks that dominate " "latch!"); - if (!BECount) - BECount = ENT.ExactNotTaken; - else if (BECount != ENT.ExactNotTaken) - BECount = SE->getUMinFromMismatchedTypes(BECount, ENT.ExactNotTaken); + Ops.push_back(BECount); if (Preds && !ENT.hasAlwaysTruePredicate()) Preds->add(ENT.Predicate.get()); @@ -6692,8 +6718,8 @@ "Predicate should be always true!"); } - assert(BECount && "Invalid not taken count for loop exit"); - return BECount; + assert(!Ops.empty() && "Loop without exits"); + return Ops.size() == 1 ? Ops[0] : SE->getUMinFromMismatchedTypes(Ops); } /// Get the exact not taken count for this loop exit.