diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -635,8 +635,10 @@ const SCEV *getUMaxExpr(SmallVectorImpl &Operands); const SCEV *getSMinExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getSMinExpr(SmallVectorImpl &Operands); - const SCEV *getUMinExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getUMinExpr(SmallVectorImpl &Operands); + const SCEV *getUMinExpr(const SCEV *LHS, const SCEV *RHS, + bool PoisonSafe = false); + const SCEV *getUMinExpr(SmallVectorImpl &Operands, + bool PoisonSafe = false); const SCEV *getUnknown(Value *V); const SCEV *getCouldNotCompute(); @@ -728,11 +730,13 @@ /// Promote the operands to the wider of the types using zero-extension, and /// then perform a umin operation with them. - const SCEV *getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS); + const SCEV *getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, + bool PoisonSafe = false); /// Promote the operands to the wider of the types using zero-extension, and /// then perform a umin operation with them. N-ary function. - const SCEV *getUMinFromMismatchedTypes(SmallVectorImpl &Ops); + const SCEV *getUMinFromMismatchedTypes(SmallVectorImpl &Ops, + bool PoisonSafe = false); /// Transitively follow the chain of pointer-type operands until reaching a /// SCEV that does not have a single pointer operand. This returns a diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h b/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h --- a/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h @@ -42,6 +42,7 @@ void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} void visitSMinExpr(const SCEVSMinExpr *Numerator) {} void visitUMinExpr(const SCEVUMinExpr *Numerator) {} + void visitSafeUMinExpr(const SCEVSafeUMinExpr *Numerator) {} void visitUnknown(const SCEVUnknown *Numerator) {} void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h --- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -40,7 +40,7 @@ // folders simpler. scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr, scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, scUMinExpr, scSMinExpr, - scPtrToInt, scUnknown, scCouldNotCompute + scPtrToInt, scSafeUMinExpr, scUnknown, scCouldNotCompute }; /// This class represents a constant integer value. @@ -231,6 +231,7 @@ return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr || + S->getSCEVType() == scSafeUMinExpr || S->getSCEVType() == scAddRecExpr; } }; @@ -247,7 +248,8 @@ static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || - S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr; + S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr || + S->getSCEVType() == scSafeUMinExpr; } /// Set flags for a non-recurrence without clearing previously set flags. @@ -432,7 +434,7 @@ static bool isMinMaxType(enum SCEVTypes T) { return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr || - T == scUMinExpr; + T == scUMinExpr || T == scSafeUMinExpr; } protected: @@ -524,6 +526,47 @@ } }; + /// This node is the base class poison-safe min/max selections. + class SCEVSafeMinMaxExpr : public SCEVMinMaxExpr { + friend class ScalarEvolution; + + static bool isSafeMinMaxType(enum SCEVTypes T) { + return T == scSafeUMinExpr; + } + + protected: + /// Note: Constructing subclasses via this constructor is allowed + SCEVSafeMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, + const SCEV *const *O, size_t N) + : SCEVMinMaxExpr(ID, T, O, N) { + assert(isSafeMinMaxType(T)); + // Min and max never overflow + setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); + } + + public: + Type *getType() const { return getOperand(0)->getType(); } + + static bool classof(const SCEV *S) { + return isSafeMinMaxType(S->getSCEVType()); + } + }; + + /// This class represents a poison-safe unsigned minimum selection. + class SCEVSafeUMinExpr : public SCEVSafeMinMaxExpr { + friend class ScalarEvolution; + + SCEVSafeUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, + size_t N) + : SCEVSafeMinMaxExpr(ID, scSafeUMinExpr, O, N) {} + + public: + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const SCEV *S) { + return S->getSCEVType() == scSafeUMinExpr; + } + }; + /// This means that we are dealing with an entirely unknown SCEV /// value, and only represent it as its LLVM Value. This is the /// "bottom" value for the analysis. @@ -602,6 +645,8 @@ return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S); case scUMinExpr: return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S); + case scSafeUMinExpr: + return ((SC *)this)->visitSafeUMinExpr((const SCEVSafeUMinExpr *)S); case scUnknown: return ((SC*)this)->visitUnknown((const SCEVUnknown*)S); case scCouldNotCompute: @@ -657,6 +702,7 @@ case scUMaxExpr: case scSMinExpr: case scUMinExpr: + case scSafeUMinExpr: case scAddRecExpr: for (const auto *Op : cast(S)->operands()) push(Op); @@ -845,6 +891,16 @@ return !Changed ? Expr : SE.getUMinExpr(Operands); } + const SCEV *visitSafeUMinExpr(const SCEVSafeUMinExpr *Expr) { + SmallVector Operands; + bool Changed = false; + for (auto *Op : Expr->operands()) { + Operands.push_back(((SC *)this)->visit(Op)); + Changed |= Op != Operands.back(); + } + return !Changed ? Expr : SE.getUMinExpr(Operands, /*PoisonSafe=*/true); + } + const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h --- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -450,6 +450,14 @@ /// Determine the most "relevant" loop for the given SCEV. const Loop *getRelevantLoop(const SCEV *); + Value *expandSMaxExpr(const SCEVNAryExpr *S); + + Value *expandUMaxExpr(const SCEVNAryExpr *S); + + Value *expandSMinExpr(const SCEVNAryExpr *S); + + Value *expandUMinExpr(const SCEVNAryExpr *S); + Value *visitConstant(const SCEVConstant *S) { return S->getValue(); } Value *visitPtrToIntExpr(const SCEVPtrToIntExpr *S); @@ -476,6 +484,8 @@ Value *visitUMinExpr(const SCEVUMinExpr *S); + Value *visitSafeUMinExpr(const SCEVSafeUMinExpr *S); + Value *visitUnknown(const SCEVUnknown *S) { return S->getValue(); } void rememberInstruction(Value *I); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -301,7 +301,8 @@ case scUMaxExpr: case scSMaxExpr: case scUMinExpr: - case scSMinExpr: { + case scSMinExpr: + case scSafeUMinExpr: { const SCEVNAryExpr *NAry = cast(this); const char *OpStr = nullptr; switch (NAry->getSCEVType()) { @@ -315,6 +316,9 @@ case scSMinExpr: OpStr = " smin "; break; + case scSafeUMinExpr: + OpStr = " umin_safe "; + break; default: llvm_unreachable("There are no other nary expression types."); } @@ -391,6 +395,7 @@ case scSMaxExpr: case scUMinExpr: case scSMinExpr: + case scSafeUMinExpr: return cast(this)->getType(); case scAddExpr: return cast(this)->getType(); @@ -774,7 +779,8 @@ case scSMaxExpr: case scUMaxExpr: case scSMinExpr: - case scUMinExpr: { + case scUMinExpr: + case scSafeUMinExpr: { const SCEVNAryExpr *LC = cast(LHS); const SCEVNAryExpr *RC = cast(RHS); @@ -3757,7 +3763,7 @@ return APIntOps::smin(LHS, RHS); else if (Kind == scUMaxExpr) return APIntOps::umax(LHS, RHS); - else if (Kind == scUMinExpr) + else if (Kind == scUMinExpr || Kind == scSafeUMinExpr) return APIntOps::umin(LHS, RHS); llvm_unreachable("Unknown SCEV min/max opcode"); }; @@ -3885,14 +3891,15 @@ return getMinMaxExpr(scSMinExpr, Ops); } -const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, - const SCEV *RHS) { +const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS, + bool PoisonSafe) { SmallVector Ops = { LHS, RHS }; - return getUMinExpr(Ops); + return getUMinExpr(Ops, PoisonSafe); } -const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl &Ops) { - return getMinMaxExpr(scUMinExpr, Ops); +const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl &Ops, + bool PoisonSafe) { + return getMinMaxExpr(PoisonSafe ? scSafeUMinExpr : scUMinExpr, Ops); } const SCEV * @@ -4375,13 +4382,15 @@ } const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, - const SCEV *RHS) { + const SCEV *RHS, + bool PoisonSafe) { SmallVector Ops = { LHS, RHS }; - return getUMinFromMismatchedTypes(Ops); + return getUMinFromMismatchedTypes(Ops, PoisonSafe); } -const SCEV *ScalarEvolution::getUMinFromMismatchedTypes( - SmallVectorImpl &Ops) { +const SCEV * +ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl &Ops, + bool PoisonSafe) { assert(!Ops.empty() && "At least one operand must be!"); // Trivial case. if (Ops.size() == 1) @@ -4402,7 +4411,7 @@ PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType)); // Generate umin. - return getUMinExpr(PromotedOps); + return getUMinExpr(PromotedOps, PoisonSafe); } const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { @@ -5513,6 +5522,7 @@ case scSMaxExpr: case scUMinExpr: case scSMinExpr: + case scSafeUMinExpr: // These expressions are available if their operand(s) is/are. return true; @@ -6092,6 +6102,14 @@ ConservativeResult.intersectWith(X, RangeType)); } + if (const SCEVSafeUMinExpr *UMin = dyn_cast(S)) { + ConstantRange X = getRangeRef(UMin->getOperand(0), SignHint); + for (unsigned i = 1, e = UMin->getNumOperands(); i != e; ++i) + X = X.umin(getRangeRef(UMin->getOperand(i), SignHint)); + return setRange(UMin, SignHint, + ConservativeResult.intersectWith(X, RangeType)); + } + if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint); ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint); @@ -8161,22 +8179,22 @@ // umin(EL0.ExactNotTaken, EL1.ExactNotTaken) is unsafe in general. // To see the detailed examples, please see // test/Analysis/ScalarEvolution/exit-count-select.ll - bool PoisonSafe = isa(ExitCond); - if (!PoisonSafe) + bool NeverPoison = isa(ExitCond); + if (!NeverPoison) // Even if ExitCond is select, we can safely derive BECount using both // EL0 and EL1 in these cases: // (1) EL0.ExactNotTaken is non-zero // (2) EL1.ExactNotTaken is non-poison // (3) EL0.ExactNotTaken is zero (BECount should be simply zero and // it cannot be umin(0, ..)) - // The PoisonSafe assignment below is simplified and the assertion after + // The NeverPoison assignment below is simplified and the assertion after // BECount calculation fully guarantees the condition (3). - PoisonSafe = isa(EL0.ExactNotTaken) || - isa(EL1.ExactNotTaken); + NeverPoison = isa(EL0.ExactNotTaken) || + isa(EL1.ExactNotTaken); if (EL0.ExactNotTaken != getCouldNotCompute() && - EL1.ExactNotTaken != getCouldNotCompute() && PoisonSafe) { - BECount = - getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); + EL1.ExactNotTaken != getCouldNotCompute()) { + BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken, + /*PoisonSafe=*/!NeverPoison); // If EL0.ExactNotTaken was zero and ExitCond was a short-circuit form, // it should have been simplified to zero (see the condition (3) above) @@ -8977,7 +8995,8 @@ case scUMaxExpr: case scSMinExpr: case scUMinExpr: - return nullptr; // TODO: smax, umax, smin, umax. + case scSafeUMinExpr: + return nullptr; // TODO: smax, umax, smin, umax, umin_safe. } llvm_unreachable("Unknown SCEV kind!"); } @@ -11359,6 +11378,7 @@ case ICmpInst::ICMP_ULE: return // min(A, ...) <= A + // FIXME: what about umin_safe? IsMinMaxConsistingOf(LHS, RHS) || // A <= max(A, ...) IsMinMaxConsistingOf(RHS, LHS); @@ -12759,7 +12779,8 @@ case scUMaxExpr: case scSMaxExpr: case scUMinExpr: - case scSMinExpr: { + case scSMinExpr: + case scSafeUMinExpr: { bool HasVarying = false; for (auto *Op : cast(S)->operands()) { LoopDisposition D = getLoopDisposition(Op, L); @@ -12849,7 +12870,8 @@ case scUMaxExpr: case scSMaxExpr: case scUMinExpr: - case scSMinExpr: { + case scSMinExpr: + case scSafeUMinExpr: { const SCEVNAryExpr *NAry = cast(S); bool Proper = true; for (const SCEV *NAryOp : NAry->operands()) { diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -1671,7 +1671,7 @@ return Builder.CreateSExt(V, Ty); } -Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) { +Value *SCEVExpander::expandSMaxExpr(const SCEVNAryExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands()-2; i >= 0; --i) { @@ -1700,7 +1700,7 @@ return LHS; } -Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) { +Value *SCEVExpander::expandUMaxExpr(const SCEVNAryExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands()-2; i >= 0; --i) { @@ -1729,7 +1729,7 @@ return LHS; } -Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) { +Value *SCEVExpander::expandSMinExpr(const SCEVNAryExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands() - 1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands() - 2; i >= 0; --i) { @@ -1758,7 +1758,7 @@ return LHS; } -Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) { +Value *SCEVExpander::expandUMinExpr(const SCEVNAryExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands() - 1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands() - 2; i >= 0; --i) { @@ -1787,6 +1787,39 @@ return LHS; } +Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) { + return expandSMaxExpr(S); +} + +Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) { + return expandUMaxExpr(S); +} + +Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) { + return expandSMinExpr(S); +} + +Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) { + return expandUMinExpr(S); +} + +Value *SCEVExpander::visitSafeUMinExpr(const SCEVSafeUMinExpr *S) { + SmallVector Ops; + for (const SCEV *Op : S->operands()) + Ops.emplace_back(expand(Op)); + + Value *Boundary = Constant::getNullValue(S->getType()); + + SmallVector OpIsZero; + for (Value *Op : ArrayRef(Ops).drop_back()) + OpIsZero.emplace_back(Builder.CreateICmpEQ(Op, Boundary)); + + Value *AnyOpIsZero = Builder.CreateOr(OpIsZero); + + Value *NaiveUMin = expandUMinExpr(S); + return Builder.CreateSelect(AnyOpIsZero, Boundary, NaiveUMin); +} + Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *IP, bool Root) { setInsertPoint(IP); @@ -2271,10 +2304,26 @@ case scSMaxExpr: case scUMaxExpr: case scSMinExpr: - case scUMinExpr: { + case scUMinExpr: + case scSafeUMinExpr: { // FIXME: should this ask the cost for Intrinsic's? + // The reduction tree. Cost += CmpSelCost(Instruction::ICmp, S->getNumOperands() - 1, 0, 1); Cost += CmpSelCost(Instruction::Select, S->getNumOperands() - 1, 0, 2); + switch (S->getSCEVType()) { + case scSafeUMinExpr: { + // The safety net against poison. + // FIXME: this is broken. + Cost += CmpSelCost(Instruction::ICmp, S->getNumOperands() - 1, 0, 0); + Cost += ArithCost(Instruction::Or, + S->getNumOperands() > 2 ? S->getNumOperands() - 2 : 0); + Cost += CmpSelCost(Instruction::Select, 1, 0, 1); + break; + } + default: + assert(!isa(S) && "Unhandled SCEV expression type?"); + break; + } break; } case scAddRecExpr: { @@ -2399,7 +2448,8 @@ case scUMaxExpr: case scSMaxExpr: case scUMinExpr: - case scSMinExpr: { + case scSMinExpr: + case scSafeUMinExpr: { assert(cast(S)->getNumOperands() > 1 && "Nary expr should have more than 1 operand."); // The simple nary expr will require one less op (or pair of ops) diff --git a/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll b/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll --- a/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll +++ b/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll @@ -1,21 +1,21 @@ ; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py ; RUN: opt -disable-output "-passes=print" %s 2>&1 | FileCheck %s -; exact-not-taken cannot be umin(n, m) because it is possible for (n, m) to be (0, poison) -; https://alive2.llvm.org/ce/z/NsP9ue define i32 @logical_and_2ops(i32 %n, i32 %m) { ; CHECK-LABEL: 'logical_and_2ops' ; CHECK-NEXT: Classifying expressions for: @logical_and_2ops ; CHECK-NEXT: %i = phi i32 [ 0, %entry ], [ %i.next, %loop ] -; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%n umin_safe %m) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %i.next = add i32 %i, 1 -; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%n umin_safe %m)) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %cond = select i1 %cond_p0, i1 %cond_p1, i1 false ; CHECK-NEXT: --> %cond U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: Determining loop execution counts for: @logical_and_2ops -; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. +; CHECK-NEXT: Loop %loop: backedge-taken count is (%n umin_safe %m) ; CHECK-NEXT: Loop %loop: max backedge-taken count is -1 -; CHECK-NEXT: Loop %loop: Unpredictable predicated backedge-taken count. +; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (%n umin_safe %m) +; CHECK-NEXT: Predicates: +; CHECK: Loop %loop: Trip multiple is 1 ; entry: br label %loop @@ -30,21 +30,21 @@ ret i32 %i } -; exact-not-taken cannot be umin(n, m) because it is possible for (n, m) to be (0, poison) -; https://alive2.llvm.org/ce/z/ApRitq define i32 @logical_or_2ops(i32 %n, i32 %m) { ; CHECK-LABEL: 'logical_or_2ops' ; CHECK-NEXT: Classifying expressions for: @logical_or_2ops ; CHECK-NEXT: %i = phi i32 [ 0, %entry ], [ %i.next, %loop ] -; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%n umin_safe %m) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %i.next = add i32 %i, 1 -; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%n umin_safe %m)) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %cond = select i1 %cond_p0, i1 true, i1 %cond_p1 ; CHECK-NEXT: --> %cond U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: Determining loop execution counts for: @logical_or_2ops -; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. +; CHECK-NEXT: Loop %loop: backedge-taken count is (%n umin_safe %m) ; CHECK-NEXT: Loop %loop: max backedge-taken count is -1 -; CHECK-NEXT: Loop %loop: Unpredictable predicated backedge-taken count. +; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (%n umin_safe %m) +; CHECK-NEXT: Predicates: +; CHECK: Loop %loop: Trip multiple is 1 ; entry: br label %loop @@ -63,17 +63,19 @@ ; CHECK-LABEL: 'logical_and_3ops' ; CHECK-NEXT: Classifying expressions for: @logical_and_3ops ; CHECK-NEXT: %i = phi i32 [ 0, %entry ], [ %i.next, %loop ] -; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%n umin_safe %m umin_safe %k) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %i.next = add i32 %i, 1 -; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%n umin_safe %m umin_safe %k)) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %cond_p3 = select i1 %cond_p0, i1 %cond_p1, i1 false ; CHECK-NEXT: --> %cond_p3 U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: %cond = select i1 %cond_p3, i1 %cond_p2, i1 false ; CHECK-NEXT: --> %cond U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: Determining loop execution counts for: @logical_and_3ops -; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. +; CHECK-NEXT: Loop %loop: backedge-taken count is (%n umin_safe %m umin_safe %k) ; CHECK-NEXT: Loop %loop: max backedge-taken count is -1 -; CHECK-NEXT: Loop %loop: Unpredictable predicated backedge-taken count. +; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (%n umin_safe %m umin_safe %k) +; CHECK-NEXT: Predicates: +; CHECK: Loop %loop: Trip multiple is 1 ; entry: br label %loop @@ -94,17 +96,19 @@ ; CHECK-LABEL: 'logical_or_3ops' ; CHECK-NEXT: Classifying expressions for: @logical_or_3ops ; CHECK-NEXT: %i = phi i32 [ 0, %entry ], [ %i.next, %loop ] -; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%n umin_safe %m umin_safe %k) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %i.next = add i32 %i, 1 -; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%n umin_safe %m umin_safe %k)) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %cond_p3 = select i1 %cond_p0, i1 true, i1 %cond_p1 ; CHECK-NEXT: --> %cond_p3 U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: %cond = select i1 %cond_p3, i1 true, i1 %cond_p2 ; CHECK-NEXT: --> %cond U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: Determining loop execution counts for: @logical_or_3ops -; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. +; CHECK-NEXT: Loop %loop: backedge-taken count is (%n umin_safe %m umin_safe %k) ; CHECK-NEXT: Loop %loop: max backedge-taken count is -1 -; CHECK-NEXT: Loop %loop: Unpredictable predicated backedge-taken count. +; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (%n umin_safe %m umin_safe %k) +; CHECK-NEXT: Predicates: +; CHECK: Loop %loop: Trip multiple is 1 ; entry: br label %loop diff --git a/llvm/test/Transforms/IndVarSimplify/exit-count-select.ll b/llvm/test/Transforms/IndVarSimplify/exit-count-select.ll --- a/llvm/test/Transforms/IndVarSimplify/exit-count-select.ll +++ b/llvm/test/Transforms/IndVarSimplify/exit-count-select.ll @@ -4,17 +4,14 @@ define i32 @logical_and_2ops(i32 %n, i32 %m) { ; CHECK-LABEL: @logical_and_2ops( ; CHECK-NEXT: entry: +; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[M:%.*]], i32 [[N:%.*]]) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[I_NEXT:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[I_NEXT]] = add i32 [[I]], 1 -; CHECK-NEXT: [[COND_P0:%.*]] = icmp ult i32 [[I]], [[N:%.*]] -; CHECK-NEXT: [[COND_P1:%.*]] = icmp ult i32 [[I]], [[M:%.*]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[COND_P0]], i1 [[COND_P1]], i1 false -; CHECK-NEXT: br i1 [[COND]], label [[LOOP]], label [[EXIT:%.*]] +; CHECK-NEXT: br i1 false, label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[I_LCSSA:%.*]] = phi i32 [ [[I]], [[LOOP]] ] -; CHECK-NEXT: ret i32 [[I_LCSSA]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[N]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[TMP0]], i32 0, i32 [[UMIN]] +; CHECK-NEXT: ret i32 [[TMP1]] ; entry: br label %loop @@ -32,17 +29,14 @@ define i32 @logical_or_2ops(i32 %n, i32 %m) { ; CHECK-LABEL: @logical_or_2ops( ; CHECK-NEXT: entry: +; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[M:%.*]], i32 [[N:%.*]]) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[I_NEXT:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[I_NEXT]] = add i32 [[I]], 1 -; CHECK-NEXT: [[COND_P0:%.*]] = icmp uge i32 [[I]], [[N:%.*]] -; CHECK-NEXT: [[COND_P1:%.*]] = icmp uge i32 [[I]], [[M:%.*]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[COND_P0]], i1 true, i1 [[COND_P1]] -; CHECK-NEXT: br i1 [[COND]], label [[EXIT:%.*]], label [[LOOP]] +; CHECK-NEXT: br i1 true, label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: [[I_LCSSA:%.*]] = phi i32 [ [[I]], [[LOOP]] ] -; CHECK-NEXT: ret i32 [[I_LCSSA]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[N]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[TMP0]], i32 0, i32 [[UMIN]] +; CHECK-NEXT: ret i32 [[TMP1]] ; entry: br label %loop @@ -60,19 +54,17 @@ define i32 @logical_and_3ops(i32 %n, i32 %m, i32 %k) { ; CHECK-LABEL: @logical_and_3ops( ; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[M:%.*]], 0 +; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[K:%.*]], i32 [[M]]) +; CHECK-NEXT: [[UMIN1:%.*]] = call i32 @llvm.umin.i32(i32 [[UMIN]], i32 [[N:%.*]]) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[I_NEXT:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[I_NEXT]] = add i32 [[I]], 1 -; CHECK-NEXT: [[COND_P0:%.*]] = icmp ult i32 [[I]], [[N:%.*]] -; CHECK-NEXT: [[COND_P1:%.*]] = icmp ult i32 [[I]], [[M:%.*]] -; CHECK-NEXT: [[COND_P2:%.*]] = icmp ult i32 [[I]], [[K:%.*]] -; CHECK-NEXT: [[COND_P3:%.*]] = select i1 [[COND_P0]], i1 [[COND_P1]], i1 false -; CHECK-NEXT: [[COND:%.*]] = select i1 [[COND_P3]], i1 [[COND_P2]], i1 false -; CHECK-NEXT: br i1 [[COND]], label [[LOOP]], label [[EXIT:%.*]] +; CHECK-NEXT: br i1 false, label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[I_LCSSA:%.*]] = phi i32 [ [[I]], [[LOOP]] ] -; CHECK-NEXT: ret i32 [[I_LCSSA]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[N]], 0 +; CHECK-NEXT: [[TMP2:%.*]] = or i1 [[TMP1]], [[TMP0]] +; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], i32 0, i32 [[UMIN1]] +; CHECK-NEXT: ret i32 [[TMP3]] ; entry: br label %loop @@ -92,19 +84,17 @@ define i32 @logical_or_3ops(i32 %n, i32 %m, i32 %k) { ; CHECK-LABEL: @logical_or_3ops( ; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[M:%.*]], 0 +; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[K:%.*]], i32 [[M]]) +; CHECK-NEXT: [[UMIN1:%.*]] = call i32 @llvm.umin.i32(i32 [[UMIN]], i32 [[N:%.*]]) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[I_NEXT:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[I_NEXT]] = add i32 [[I]], 1 -; CHECK-NEXT: [[COND_P0:%.*]] = icmp uge i32 [[I]], [[N:%.*]] -; CHECK-NEXT: [[COND_P1:%.*]] = icmp uge i32 [[I]], [[M:%.*]] -; CHECK-NEXT: [[COND_P2:%.*]] = icmp uge i32 [[I]], [[K:%.*]] -; CHECK-NEXT: [[COND_P3:%.*]] = select i1 [[COND_P0]], i1 true, i1 [[COND_P1]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[COND_P3]], i1 true, i1 [[COND_P2]] -; CHECK-NEXT: br i1 [[COND]], label [[EXIT:%.*]], label [[LOOP]] +; CHECK-NEXT: br i1 true, label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: [[I_LCSSA:%.*]] = phi i32 [ [[I]], [[LOOP]] ] -; CHECK-NEXT: ret i32 [[I_LCSSA]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[N]], 0 +; CHECK-NEXT: [[TMP2:%.*]] = or i1 [[TMP1]], [[TMP0]] +; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], i32 0, i32 [[UMIN1]] +; CHECK-NEXT: ret i32 [[TMP3]] ; entry: br label %loop diff --git a/polly/include/polly/Support/SCEVAffinator.h b/polly/include/polly/Support/SCEVAffinator.h --- a/polly/include/polly/Support/SCEVAffinator.h +++ b/polly/include/polly/Support/SCEVAffinator.h @@ -111,6 +111,7 @@ PWACtx visitSMinExpr(const llvm::SCEVSMinExpr *E); PWACtx visitUMaxExpr(const llvm::SCEVUMaxExpr *E); PWACtx visitUMinExpr(const llvm::SCEVUMinExpr *E); + PWACtx visitSafeUMinExpr(const llvm::SCEVSafeUMinExpr *E); PWACtx visitUnknown(const llvm::SCEVUnknown *E); PWACtx visitSDivInstruction(llvm::Instruction *SDiv); PWACtx visitSRemInstruction(llvm::Instruction *SRem); diff --git a/polly/lib/Support/SCEVAffinator.cpp b/polly/lib/Support/SCEVAffinator.cpp --- a/polly/lib/Support/SCEVAffinator.cpp +++ b/polly/lib/Support/SCEVAffinator.cpp @@ -465,6 +465,10 @@ llvm_unreachable("SCEVUMinExpr not yet supported"); } +PWACtx SCEVAffinator::visitSafeUMinExpr(const SCEVSafeUMinExpr *Expr) { + llvm_unreachable("SCEVSafeUMinExpr not yet supported"); +} + PWACtx SCEVAffinator::visitUDivExpr(const SCEVUDivExpr *Expr) { // The handling of unsigned division is basically the same as for signed // division, except the interpretation of the operands. As the divisor diff --git a/polly/lib/Support/SCEVValidator.cpp b/polly/lib/Support/SCEVValidator.cpp --- a/polly/lib/Support/SCEVValidator.cpp +++ b/polly/lib/Support/SCEVValidator.cpp @@ -332,6 +332,22 @@ return ValidatorResult(SCEVType::PARAM, Expr); } + class ValidatorResult visitSafeUMinExpr(const SCEVSafeUMinExpr *Expr) { + // We do not support unsigned min operations. If 'Expr' is constant during + // Scop execution we treat this as a parameter, otherwise we bail out. + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { + ValidatorResult Op = visit(Expr->getOperand(i)); + + if (!Op.isConstant()) { + LLVM_DEBUG( + dbgs() << "INVALID: SafeUMinExpr has a non-constant operand"); + return ValidatorResult(SCEVType::INVALID); + } + } + + return ValidatorResult(SCEVType::PARAM, Expr); + } + ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) { if (R->contains(I)) { LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction " diff --git a/polly/lib/Support/ScopHelper.cpp b/polly/lib/Support/ScopHelper.cpp --- a/polly/lib/Support/ScopHelper.cpp +++ b/polly/lib/Support/ScopHelper.cpp @@ -391,6 +391,12 @@ NewOps.push_back(visit(Op)); return SE.getSMinExpr(NewOps); } + const SCEV *visitSafeUMinExpr(const SCEVSafeUMinExpr *E) { + SmallVector NewOps; + for (const SCEV *Op : E->operands()) + NewOps.push_back(visit(Op)); + return SE.getUMinExpr(NewOps, /*PoisonSafe=*/true); + } const SCEV *visitAddRecExpr(const SCEVAddRecExpr *E) { SmallVector NewOps; for (const SCEV *Op : E->operands())