Index: llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolution.h +++ llvm/include/llvm/Analysis/ScalarEvolution.h @@ -219,7 +219,7 @@ FoldingSetNodeIDRef FastID; public: - enum SCEVPredicateKind { P_Union, P_Compare, P_Wrap }; + enum SCEVPredicateKind { P_Union, P_Expr, P_Wrap }; protected: SCEVPredicateKind Kind; @@ -241,7 +241,7 @@ virtual bool isAlwaysTrue() const = 0; /// Returns true if this predicate implies \p N. - virtual bool implies(const SCEVPredicate *N) const = 0; + virtual bool implies(ScalarEvolution &SE, const SCEVPredicate *N) const = 0; /// Prints a textual representation of this predicate with an indentation of /// \p Depth. @@ -272,35 +272,24 @@ } }; -/// This class represents an assumption that the expression LHS Pred RHS -/// evaluates to true, and this can be checked at run-time. -class SCEVComparePredicate final : public SCEVPredicate { - /// We assume that LHS Pred RHS is true. - const ICmpInst::Predicate Pred; - const SCEV *LHS; - const SCEV *RHS; +/// This class represents an assumption that the expression Expr evaluates +/// to true, and this can be checked at run-time. +class SCEVExprPredicate final : public SCEVPredicate { + const SCEV *Expr; public: - SCEVComparePredicate(const FoldingSetNodeIDRef ID, - const ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + SCEVExprPredicate(const FoldingSetNodeIDRef ID, const SCEV *Expr); /// Implementation of the SCEVPredicate interface - bool implies(const SCEVPredicate *N) const override; + bool implies(ScalarEvolution &SE, const SCEVPredicate *N) const override; void print(raw_ostream &OS, unsigned Depth = 0) const override; bool isAlwaysTrue() const override; - ICmpInst::Predicate getPredicate() const { return Pred; } - - /// Returns the left hand side of the predicate. - const SCEV *getLHS() const { return LHS; } - - /// Returns the right hand side of the predicate. - const SCEV *getRHS() const { return RHS; } + const SCEV *getExpr() const { return Expr; } /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEVPredicate *P) { - return P->getKind() == P_Compare; + return P->getKind() == P_Expr; } }; @@ -393,7 +382,7 @@ /// Implementation of the SCEVPredicate interface const SCEVAddRecExpr *getExpr() const; - bool implies(const SCEVPredicate *N) const override; + bool implies(ScalarEvolution &SE, const SCEVPredicate *N) const override; void print(raw_ostream &OS, unsigned Depth = 0) const override; bool isAlwaysTrue() const override; @@ -429,7 +418,7 @@ /// Implementation of the SCEVPredicate interface bool isAlwaysTrue() const override; - bool implies(const SCEVPredicate *N) const override; + bool implies(ScalarEvolution &SE, const SCEVPredicate *N) const override; void print(raw_ostream &OS, unsigned Depth) const override; /// We estimate the complexity of a union predicate as the size number of @@ -1047,6 +1036,9 @@ Optional evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI); + /// Return true of A implies B. + bool isImpliedCond(const SCEV *B, const SCEV *A); + /// Test if the condition described by Pred, LHS, RHS is known to be true on /// every iteration of the loop of the recurrency LHS. bool isKnownOnEveryIteration(ICmpInst::Predicate Pred, Index: llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h =================================================================== --- llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -293,9 +293,9 @@ Value *expandCodeForPredicate(const SCEVPredicate *Pred, Instruction *Loc); /// A specialized variant of expandCodeForPredicate, handling the case when - /// we are expanding code for a SCEVComparePredicate. - Value *expandComparePredicate(const SCEVComparePredicate *Pred, - Instruction *Loc); + /// we are expanding code for a SCEVExprPredicate. + Value *expandExprPredicate(const SCEVExprPredicate *Pred, + Instruction *Loc); /// Generates code that evaluates if the \p AR expression will overflow. Value *generateOverflowCheck(const SCEVAddRecExpr *AR, Instruction *Loc, Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -4468,6 +4468,10 @@ return getConstant( cast(ConstantExpr::getNot(VC->getValue()))); + if (const SCEVCompareExpr *C = dyn_cast(V)) + return getCompareExpr(ICmpInst::getInversePredicate(C->getPredicate()), + C->getLHS(), C->getRHS()); + // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y) if (const SCEVMinMaxExpr *MME = dyn_cast(V)) { auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) { @@ -5544,8 +5548,9 @@ return true; auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { - if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) && - !Preds->implies(SE.getEqualPredicate(Expr2, Expr1))) + if (Expr1 != Expr2 && + !Preds->implies(SE, SE.getEqualPredicate(Expr1, Expr2)) && + !Preds->implies(SE, SE.getEqualPredicate(Expr2, Expr1))) return false; return true; }; @@ -10407,6 +10412,15 @@ return None; } +bool ScalarEvolution::isImpliedCond(const SCEV *B, const SCEV *A) { + const auto *CmpA = dyn_cast(A); + const auto *CmpB = dyn_cast(B); + return CmpA && CmpB && + isImpliedCond(CmpB->getPredicate(), CmpB->getLHS(), CmpB->getRHS(), + CmpA->getPredicate(), CmpA->getLHS(), CmpA->getRHS()); +} + + bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS) { @@ -13701,18 +13715,18 @@ FoldingSetNodeID ID; assert(LHS->getType() == RHS->getType() && "Type mismatch between LHS and RHS"); - // Unique this node based on the arguments - ID.AddInteger(SCEVPredicate::P_Compare); - ID.AddInteger(Pred); - ID.AddPointer(LHS); - ID.AddPointer(RHS); + + auto *Expr = getCompareExpr(Pred, LHS, RHS); + + ID.AddInteger(SCEVPredicate::P_Expr); + ID.AddPointer(Expr); void *IP = nullptr; if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) return S; - SCEVComparePredicate *Eq = new (SCEVAllocator) - SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS); - UniquePreds.InsertNode(Eq, IP); - return Eq; + auto *P = new (SCEVAllocator) + SCEVExprPredicate(ID.Intern(SCEVAllocator), Expr); + UniquePreds.InsertNode(P, IP); + return P; } const SCEVPredicate *ScalarEvolution::getWrapPredicate( @@ -13754,16 +13768,21 @@ const SCEV *visitUnknown(const SCEVUnknown *Expr) { if (Pred) { + auto MatchEqCompare = [&](const SCEVPredicate *Pred) -> const SCEV* { + if (const auto *IPred = dyn_cast(Pred)) + if (const auto *Cmp = dyn_cast(IPred->getExpr())) + if (Cmp->getLHS() == Expr && + Cmp->getPredicate() == ICmpInst::ICMP_EQ) + return Cmp->getRHS(); + return nullptr; + }; + if (auto *U = dyn_cast(Pred)) { for (auto *Pred : U->getPredicates()) - if (const auto *IPred = dyn_cast(Pred)) - if (IPred->getLHS() == Expr && - IPred->getPredicate() == ICmpInst::ICMP_EQ) - return IPred->getRHS(); - } else if (const auto *IPred = dyn_cast(Pred)) { - if (IPred->getLHS() == Expr && - IPred->getPredicate() == ICmpInst::ICMP_EQ) - return IPred->getRHS(); + if (auto *Res = MatchEqCompare(Pred)) + return Res; + } else if (auto *Res = MatchEqCompare(Pred)) { + return Res; } } return convertToAddRecWithPreds(Expr); @@ -13810,7 +13829,7 @@ bool addOverflowAssumption(const SCEVPredicate *P) { if (!NewPreds) { // Check if we've already made this assumption. - return Pred && Pred->implies(P); + return Pred && Pred->implies(SE, P); } NewPreds->insert(P); return true; @@ -13883,36 +13902,25 @@ SCEVPredicateKind Kind) : FastID(ID), Kind(Kind) {} -SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID, - const ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) - : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) { - assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match"); - assert(LHS != RHS && "LHS and RHS are the same SCEV"); +SCEVExprPredicate::SCEVExprPredicate(const FoldingSetNodeIDRef ID, + const SCEV *Expr) + : SCEVPredicate(ID, P_Expr), Expr(Expr) { + assert(Expr->getType()->isIntegerTy(1) && "must be a boolean"); } -bool SCEVComparePredicate::implies(const SCEVPredicate *N) const { - const auto *Op = dyn_cast(N); - +bool SCEVExprPredicate::implies(ScalarEvolution &SE, + const SCEVPredicate *N) const { + const auto *Op = dyn_cast(N); if (!Op) return false; - if (Pred != ICmpInst::ICMP_EQ) - return false; - - return Op->LHS == LHS && Op->RHS == RHS; + return SE.isImpliedCond(Op->Expr, this->Expr); } -bool SCEVComparePredicate::isAlwaysTrue() const { return false; } - -void SCEVComparePredicate::print(raw_ostream &OS, unsigned Depth) const { - if (Pred == ICmpInst::ICMP_EQ) - OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; - else - OS.indent(Depth) << "Compare predicate: " << *LHS - << " " << CmpInst::getPredicateName(Pred) << ") " - << *RHS << "\n"; +bool SCEVExprPredicate::isAlwaysTrue() const { return Expr->isOne(); } +void SCEVExprPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << "Expr predicate: " << *Expr << "\n"; } SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, @@ -13922,7 +13930,8 @@ const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; } -bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { +bool SCEVWrapPredicate::implies(ScalarEvolution &SE, + const SCEVPredicate *N) const { const auto *Op = dyn_cast(N); return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; @@ -13980,13 +13989,14 @@ [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); } -bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { +bool SCEVUnionPredicate::implies(ScalarEvolution &SE, + const SCEVPredicate *N) const { if (const auto *Set = dyn_cast(N)) return all_of(Set->Preds, - [this](const SCEVPredicate *I) { return this->implies(I); }); + [&](const SCEVPredicate *I) { return this->implies(SE, I); }); return any_of(Preds, - [N](const SCEVPredicate *I) { return I->implies(N); }); + [&](const SCEVPredicate *I) { return I->implies(SE, N); }); } void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { @@ -14051,7 +14061,7 @@ } void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { - if (Preds->implies(&Pred)) + if (Preds->implies(SE, &Pred)) return; auto &OldPreds = Preds->getPredicates(); Index: llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp =================================================================== --- llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -2487,8 +2487,8 @@ switch (Pred->getKind()) { case SCEVPredicate::P_Union: return expandUnionPredicate(cast(Pred), IP); - case SCEVPredicate::P_Compare: - return expandComparePredicate(cast(Pred), IP); + case SCEVPredicate::P_Expr: + return expandExprPredicate(cast(Pred), IP); case SCEVPredicate::P_Wrap: { auto *AddRecPred = cast(Pred); return expandWrapPredicate(AddRecPred, IP); @@ -2497,17 +2497,13 @@ llvm_unreachable("Unknown SCEV predicate type"); } -Value *SCEVExpander::expandComparePredicate(const SCEVComparePredicate *Pred, - Instruction *IP) { - Value *Expr0 = - expandCodeForImpl(Pred->getLHS(), Pred->getLHS()->getType(), IP, false); - Value *Expr1 = - expandCodeForImpl(Pred->getRHS(), Pred->getRHS()->getType(), IP, false); - - Builder.SetInsertPoint(IP); - auto InvPred = ICmpInst::getInversePredicate(Pred->getPredicate()); - auto *I = Builder.CreateICmp(InvPred, Expr0, Expr1, "ident.check"); - return I; +Value *SCEVExpander::expandExprPredicate(const SCEVExprPredicate *Pred, + Instruction *IP) { + // WARNING: We're implicitly negating the expression here. This only works + // because we're currently only using SCEVExprPredicate for leaf predicates. + // When we remove SCEVUnionExpr, we need to adjust the leaf handling. + auto *S = SE.getNotSCEV(Pred->getExpr()); + return expandCodeForImpl(S, S->getType(), IP, false); } Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, Index: llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll =================================================================== --- llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll +++ llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll @@ -19,7 +19,8 @@ ; CHECK-EMPTY: ; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop. ; CHECK-NEXT: SCEV assumptions: -; CHECK-NEXT: Equal predicate: %stride == 1 +; CHECK-NEXT: Expr predicate: (%stride eq 1) + ; CHECK-EMPTY: ; CHECK-NEXT: Expressions re-written: ; CHECK-NEXT: [PSE] %gep.A = getelementptr inbounds i32, i32* %A, i64 %mul: @@ -63,7 +64,7 @@ ; CHECK-EMPTY: ; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop. ; CHECK-NEXT: SCEV assumptions: -; CHECK-NEXT: Equal predicate: %stride == 1 +; CHECK-NEXT: Expr predicate: (%stride eq 1) ; CHECK-EMPTY: ; CHECK-NEXT: Expressions re-written: ; CHECK-NEXT: [PSE] %gep.A = getelementptr inbounds { i32, i8 }, { i32, i8 }* %A, i64 %mul: @@ -110,8 +111,8 @@ ; CHECK-EMPTY: ; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop. ; CHECK-NEXT: SCEV assumptions: -; CHECK-NEXT: Equal predicate: %stride.2 == 1 -; CHECK-NEXT: Equal predicate: %stride.1 == 1 +; CHECK-NEXT: Expr predicate: (%stride.2 eq 1) +; CHECK-NEXT: Expr predicate: (%stride.1 eq 1) ; CHECK-EMPTY: ; CHECK-NEXT: Expressions re-written: ; CHECK-NEXT: [PSE] %gep.A = getelementptr inbounds i32, i32* %A, i64 %mul: Index: llvm/test/Transforms/LoopDistribute/symbolic-stride.ll =================================================================== --- llvm/test/Transforms/LoopDistribute/symbolic-stride.ll +++ llvm/test/Transforms/LoopDistribute/symbolic-stride.ll @@ -24,8 +24,8 @@ ; DEFAULT-NEXT: entry: ; DEFAULT-NEXT: br label [[FOR_BODY_LVER_CHECK:%.*]] ; DEFAULT: for.body.lver.check: -; DEFAULT-NEXT: [[IDENT_CHECK:%.*]] = icmp ne i64 [[STRIDE:%.*]], 1 -; DEFAULT-NEXT: br i1 [[IDENT_CHECK]], label [[FOR_BODY_PH_LVER_ORIG:%.*]], label [[FOR_BODY_PH_LDIST1:%.*]] +; DEFAULT-NEXT: [[TMP0:%.*]] = icmp ne i64 [[STRIDE:%.*]], 1 +; DEFAULT-NEXT: br i1 [[TMP0]], label [[FOR_BODY_PH_LVER_ORIG:%.*]], label [[FOR_BODY_PH_LDIST1:%.*]] ; DEFAULT: for.body.ph.lver.orig: ; DEFAULT-NEXT: br label [[FOR_BODY_LVER_ORIG:%.*]] ; DEFAULT: for.body.lver.orig: Index: llvm/test/Transforms/LoopLoadElim/symbolic-stride.ll =================================================================== --- llvm/test/Transforms/LoopLoadElim/symbolic-stride.ll +++ llvm/test/Transforms/LoopLoadElim/symbolic-stride.ll @@ -20,11 +20,10 @@ ; ; ; -; ; DEFAULT-LABEL: @f( ; DEFAULT-NEXT: for.body.lver.check: -; DEFAULT-NEXT: [[IDENT_CHECK:%.*]] = icmp ne i64 [[STRIDE:%.*]], 1 -; DEFAULT-NEXT: br i1 [[IDENT_CHECK]], label [[FOR_BODY_PH_LVER_ORIG:%.*]], label [[FOR_BODY_PH:%.*]] +; DEFAULT-NEXT: [[TMP0:%.*]] = icmp ne i64 [[STRIDE:%.*]], 1 +; DEFAULT-NEXT: br i1 [[TMP0]], label [[FOR_BODY_PH_LVER_ORIG:%.*]], label [[FOR_BODY_PH:%.*]] ; DEFAULT: for.body.ph.lver.orig: ; DEFAULT-NEXT: br label [[FOR_BODY_LVER_ORIG:%.*]] ; DEFAULT: for.body.lver.orig: @@ -85,8 +84,8 @@ ; ; THRESHOLD-LABEL: @f( ; THRESHOLD-NEXT: for.body.lver.check: -; THRESHOLD-NEXT: [[IDENT_CHECK:%.*]] = icmp ne i64 [[STRIDE:%.*]], 1 -; THRESHOLD-NEXT: br i1 [[IDENT_CHECK]], label [[FOR_BODY_PH_LVER_ORIG:%.*]], label [[FOR_BODY_PH:%.*]] +; THRESHOLD-NEXT: [[TMP0:%.*]] = icmp ne i64 [[STRIDE:%.*]], 1 +; THRESHOLD-NEXT: br i1 [[TMP0]], label [[FOR_BODY_PH_LVER_ORIG:%.*]], label [[FOR_BODY_PH:%.*]] ; THRESHOLD: for.body.ph.lver.orig: ; THRESHOLD-NEXT: br label [[FOR_BODY_LVER_ORIG:%.*]] ; THRESHOLD: for.body.lver.orig: @@ -159,8 +158,8 @@ ; ; DEFAULT-LABEL: @f_struct( ; DEFAULT-NEXT: for.body.lver.check: -; DEFAULT-NEXT: [[IDENT_CHECK:%.*]] = icmp ne i64 [[STRIDE:%.*]], 1 -; DEFAULT-NEXT: br i1 [[IDENT_CHECK]], label [[FOR_BODY_PH_LVER_ORIG:%.*]], label [[FOR_BODY_PH:%.*]] +; DEFAULT-NEXT: [[TMP0:%.*]] = icmp ne i64 [[STRIDE:%.*]], 1 +; DEFAULT-NEXT: br i1 [[TMP0]], label [[FOR_BODY_PH_LVER_ORIG:%.*]], label [[FOR_BODY_PH:%.*]] ; DEFAULT: for.body.ph.lver.orig: ; DEFAULT-NEXT: br label [[FOR_BODY_LVER_ORIG:%.*]] ; DEFAULT: for.body.lver.orig: @@ -230,8 +229,8 @@ ; ; THRESHOLD-LABEL: @f_struct( ; THRESHOLD-NEXT: for.body.lver.check: -; THRESHOLD-NEXT: [[IDENT_CHECK:%.*]] = icmp ne i64 [[STRIDE:%.*]], 1 -; THRESHOLD-NEXT: br i1 [[IDENT_CHECK]], label [[FOR_BODY_PH_LVER_ORIG:%.*]], label [[FOR_BODY_PH:%.*]] +; THRESHOLD-NEXT: [[TMP0:%.*]] = icmp ne i64 [[STRIDE:%.*]], 1 +; THRESHOLD-NEXT: br i1 [[TMP0]], label [[FOR_BODY_PH_LVER_ORIG:%.*]], label [[FOR_BODY_PH:%.*]] ; THRESHOLD: for.body.ph.lver.orig: ; THRESHOLD-NEXT: br label [[FOR_BODY_LVER_ORIG:%.*]] ; THRESHOLD: for.body.lver.orig: @@ -320,10 +319,10 @@ ; ; DEFAULT-LABEL: @two_strides( ; DEFAULT-NEXT: for.body.lver.check: -; DEFAULT-NEXT: [[IDENT_CHECK:%.*]] = icmp ne i64 [[STRIDE_2:%.*]], 1 -; DEFAULT-NEXT: [[IDENT_CHECK1:%.*]] = icmp ne i64 [[STRIDE_1:%.*]], 1 -; DEFAULT-NEXT: [[TMP0:%.*]] = or i1 [[IDENT_CHECK]], [[IDENT_CHECK1]] -; DEFAULT-NEXT: br i1 [[TMP0]], label [[FOR_BODY_PH_LVER_ORIG:%.*]], label [[FOR_BODY_PH:%.*]] +; DEFAULT-NEXT: [[TMP0:%.*]] = icmp ne i64 [[STRIDE_2:%.*]], 1 +; DEFAULT-NEXT: [[TMP1:%.*]] = icmp ne i64 [[STRIDE_1:%.*]], 1 +; DEFAULT-NEXT: [[TMP2:%.*]] = or i1 [[TMP0]], [[TMP1]] +; DEFAULT-NEXT: br i1 [[TMP2]], label [[FOR_BODY_PH_LVER_ORIG:%.*]], label [[FOR_BODY_PH:%.*]] ; DEFAULT: for.body.ph.lver.orig: ; DEFAULT-NEXT: br label [[FOR_BODY_LVER_ORIG:%.*]] ; DEFAULT: for.body.lver.orig: @@ -357,10 +356,10 @@ ; DEFAULT-NEXT: [[ARRAYIDX_NEXT:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 [[MUL_2]] ; DEFAULT-NEXT: store i32 [[ADD]], i32* [[ARRAYIDX_NEXT]], align 4 ; DEFAULT-NEXT: [[EXITCOND:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], [[N]] -; DEFAULT-NEXT: br i1 [[EXITCOND]], label [[FOR_END_LOOPEXIT2:%.*]], label [[FOR_BODY]] +; DEFAULT-NEXT: br i1 [[EXITCOND]], label [[FOR_END_LOOPEXIT1:%.*]], label [[FOR_BODY]] ; DEFAULT: for.end.loopexit: ; DEFAULT-NEXT: br label [[FOR_END:%.*]] -; DEFAULT: for.end.loopexit2: +; DEFAULT: for.end.loopexit1: ; DEFAULT-NEXT: br label [[FOR_END]] ; DEFAULT: for.end: ; DEFAULT-NEXT: ret void