diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -10075,23 +10075,48 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { - // Match Result to (X + Y) where Y is a constant integer. - // Return Y via OutY. - auto MatchBinaryAddToConst = - [this](const SCEV *Result, const SCEV *X, APInt &OutY, - SCEV::NoWrapFlags ExpectedFlags) { - const SCEV *NonConstOp, *ConstOp; - SCEV::NoWrapFlags FlagsPresent; - - if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) || - !isa(ConstOp) || NonConstOp != X) + // Match X to (A + C1) and Y to (A + C2), where + // C1 and C2 are constant integers. If either X or Y are not add expressions, + // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via + // OutC1 and OutC2. + auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y, + APInt &OutC1, APInt &OutC2, + SCEV::NoWrapFlags ExpectedFlags) { + const SCEV *XNonConstOp, *XConstOp; + const SCEV *YNonConstOp, *YConstOp; + SCEV::NoWrapFlags XFlagsPresent; + SCEV::NoWrapFlags YFlagsPresent; + + if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) { + XConstOp = getZero(X->getType()); + XNonConstOp = X; + XFlagsPresent = ExpectedFlags; + } + if (!isa(XConstOp) || + (XFlagsPresent & ExpectedFlags) != ExpectedFlags) return false; - OutY = cast(ConstOp)->getAPInt(); - return (FlagsPresent & ExpectedFlags) == ExpectedFlags; + if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) { + YConstOp = getZero(Y->getType()); + YNonConstOp = Y; + YFlagsPresent = ExpectedFlags; + } + + if (!isa(YConstOp) || + (YFlagsPresent & ExpectedFlags) != ExpectedFlags) + return false; + + if (YNonConstOp != XNonConstOp) + return false; + + OutC1 = cast(XConstOp)->getAPInt(); + OutC2 = cast(YConstOp)->getAPInt(); + + return true; }; - APInt C; + APInt C1; + APInt C2; switch (Pred) { default: @@ -10101,45 +10126,38 @@ std::swap(LHS, RHS); LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLE: - // X s<= (X + C) if C >= 0 - if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative()) + // (X + C1) s<= (X + C2) if C1 s<= C2. + if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2)) return true; - // (X + C) s<= X if C <= 0 - if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && - !C.isStrictlyPositive()) - return true; break; case ICmpInst::ICMP_SGT: std::swap(LHS, RHS); LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLT: - // X s< (X + C) if C > 0 - if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && - C.isStrictlyPositive()) + // (X + C1) s< (X + C2) if C1 s< C2. + if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2)) return true; - // (X + C) s< X if C < 0 - if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative()) - return true; break; case ICmpInst::ICMP_UGE: std::swap(LHS, RHS); LLVM_FALLTHROUGH; case ICmpInst::ICMP_ULE: - // X u<= (X + C) for any C - if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW)) + // (X + C1) u<= (X + C2) for C1 u<= C2. + if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2)) return true; + break; case ICmpInst::ICMP_UGT: std::swap(LHS, RHS); LLVM_FALLTHROUGH; case ICmpInst::ICMP_ULT: - // X u< (X + C) if C != 0 - if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW) && !C.isNullValue()) + // (X + C1) u< (X + C2) if C1 u< C2. + if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2)) return true; break; }