diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -629,14 +629,18 @@ const SCEV *getAbsExpr(const SCEV *Op, bool IsNSW); const SCEV *getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl &Operands); + const SCEV *getSaturatingMinMaxExpr(SCEVTypes Kind, + SmallVectorImpl &Operands); const SCEV *getSMaxExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getSMaxExpr(SmallVectorImpl &Operands); const SCEV *getUMaxExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getUMaxExpr(SmallVectorImpl &Operands); const SCEV *getSMinExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getSMinExpr(SmallVectorImpl &Operands); - const SCEV *getUMinExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getUMinExpr(SmallVectorImpl &Operands); + const SCEV *getUMinExpr(const SCEV *LHS, const SCEV *RHS, + bool Saturating = false); + const SCEV *getUMinExpr(SmallVectorImpl &Operands, + bool Saturating = false); const SCEV *getUnknown(Value *V); const SCEV *getCouldNotCompute(); @@ -728,11 +732,13 @@ /// Promote the operands to the wider of the types using zero-extension, and /// then perform a umin operation with them. - const SCEV *getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS); + const SCEV *getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, + bool Saturating = false); /// Promote the operands to the wider of the types using zero-extension, and /// then perform a umin operation with them. N-ary function. - const SCEV *getUMinFromMismatchedTypes(SmallVectorImpl &Ops); + const SCEV *getUMinFromMismatchedTypes(SmallVectorImpl &Ops, + bool Saturating = false); /// Transitively follow the chain of pointer-type operands until reaching a /// SCEV that does not have a single pointer operand. This returns a diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h b/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h --- a/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h @@ -42,6 +42,7 @@ void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} void visitSMinExpr(const SCEVSMinExpr *Numerator) {} void visitUMinExpr(const SCEVUMinExpr *Numerator) {} + void visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Numerator) {} void visitUnknown(const SCEVUnknown *Numerator) {} void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h --- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -35,34 +35,45 @@ class Loop; class Type; - enum SCEVTypes : unsigned short { - // These should be ordered in terms of increasing complexity to make the - // folders simpler. - scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr, - scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, scUMinExpr, scSMinExpr, - scPtrToInt, scUnknown, scCouldNotCompute - }; +enum SCEVTypes : unsigned short { + // These should be ordered in terms of increasing complexity to make the + // folders simpler. + scConstant, + scTruncate, + scZeroExtend, + scSignExtend, + scAddExpr, + scMulExpr, + scUDivExpr, + scAddRecExpr, + scUMaxExpr, + scSMaxExpr, + scUMinExpr, + scSMinExpr, + scPtrToInt, + scSequentialUMinExpr, + scUnknown, + scCouldNotCompute +}; - /// This class represents a constant integer value. - class SCEVConstant : public SCEV { - friend class ScalarEvolution; +/// This class represents a constant integer value. +class SCEVConstant : public SCEV { + friend class ScalarEvolution; - ConstantInt *V; + ConstantInt *V; - SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) : - SCEV(ID, scConstant, 1), V(v) {} + SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) + : SCEV(ID, scConstant, 1), V(v) {} - public: - ConstantInt *getValue() const { return V; } - const APInt &getAPInt() const { return getValue()->getValue(); } +public: + ConstantInt *getValue() const { return V; } + const APInt &getAPInt() const { return getValue()->getValue(); } - Type *getType() const { return V->getType(); } + Type *getType() const { return V->getType(); } - /// Methods for support type inquiry through isa, cast, and dyn_cast: - static bool classof(const SCEV *S) { - return S->getSCEVType() == scConstant; - } - }; + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; } +}; inline unsigned short computeExpressionSize(ArrayRef Args) { APInt Size(16, 1); @@ -231,6 +242,7 @@ return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr || + S->getSCEVType() == scSequentialUMinExpr || S->getSCEVType() == scAddRecExpr; } }; @@ -524,6 +536,54 @@ } }; + /// This node is the base class for sequential/in-order min/max selections. + /// Note that their fundamental difference from SCEVMinMaxExpr's is that they + /// are early-returning upon reaching saturation point. + /// I.e. given `0 umin_seq poison`, the result will be `0`, + /// while the result of `0 umin poison` is `poison`. + class SCEVSequentialMinMaxExpr : public SCEVNAryExpr { + friend class ScalarEvolution; + + static bool isSequentialMinMaxType(enum SCEVTypes T) { + return T == scSequentialUMinExpr; + } + + /// Set flags for a non-recurrence without clearing previously set flags. + void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } + + protected: + /// Note: Constructing subclasses via this constructor is allowed + SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, + const SCEV *const *O, size_t N) + : SCEVNAryExpr(ID, T, O, N) { + assert(isSequentialMinMaxType(T)); + // Min and max never overflow + setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); + } + + public: + Type *getType() const { return getOperand(0)->getType(); } + + static bool classof(const SCEV *S) { + return isSequentialMinMaxType(S->getSCEVType()); + } + }; + + /// This class represents a sequential/in-order unsigned minimum selection. + class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr { + friend class ScalarEvolution; + + SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, + size_t N) + : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {} + + public: + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const SCEV *S) { + return S->getSCEVType() == scSequentialUMinExpr; + } + }; + /// This means that we are dealing with an entirely unknown SCEV /// value, and only represent it as its LLVM Value. This is the /// "bottom" value for the analysis. @@ -602,6 +662,9 @@ return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S); case scUMinExpr: return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S); + case scSequentialUMinExpr: + return ((SC *)this) + ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S); case scUnknown: return ((SC*)this)->visitUnknown((const SCEVUnknown*)S); case scCouldNotCompute: @@ -657,6 +720,7 @@ case scUMaxExpr: case scSMinExpr: case scUMinExpr: + case scSequentialUMinExpr: case scAddRecExpr: for (const auto *Op : cast(S)->operands()) push(Op); @@ -845,6 +909,16 @@ return !Changed ? Expr : SE.getUMinExpr(Operands); } + const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { + SmallVector Operands; + bool Changed = false; + for (auto *Op : Expr->operands()) { + Operands.push_back(((SC *)this)->visit(Op)); + Changed |= Op != Operands.back(); + } + return !Changed ? Expr : SE.getUMinExpr(Operands, /*Saturating=*/true); + } + const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h --- a/llvm/include/llvm/IR/IRBuilder.h +++ b/llvm/include/llvm/IR/IRBuilder.h @@ -1571,6 +1571,15 @@ Cond2, Name); } + // NOTE: this is sequential, non-commutative reduction! + Value *CreateLogicalOr(ArrayRef Ops) { + assert(!Ops.empty()); + Value *Accum = Ops[0]; + for (unsigned i = 1; i < Ops.size(); i++) + Accum = CreateLogicalOr(Accum, Ops[i]); + return Accum; + } + CallInst *CreateConstrainedFPBinOp( Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource = nullptr, const Twine &Name = "", MDNode *FPMathTag = nullptr, diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h --- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -450,6 +450,14 @@ /// Determine the most "relevant" loop for the given SCEV. const Loop *getRelevantLoop(const SCEV *); + Value *expandSMaxExpr(const SCEVNAryExpr *S); + + Value *expandUMaxExpr(const SCEVNAryExpr *S); + + Value *expandSMinExpr(const SCEVNAryExpr *S); + + Value *expandUMinExpr(const SCEVNAryExpr *S); + Value *visitConstant(const SCEVConstant *S) { return S->getValue(); } Value *visitPtrToIntExpr(const SCEVPtrToIntExpr *S); @@ -476,6 +484,8 @@ Value *visitUMinExpr(const SCEVUMinExpr *S); + Value *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S); + Value *visitUnknown(const SCEVUnknown *S) { return S->getValue(); } void rememberInstruction(Value *I); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -301,7 +301,8 @@ case scUMaxExpr: case scSMaxExpr: case scUMinExpr: - case scSMinExpr: { + case scSMinExpr: + case scSequentialUMinExpr: { const SCEVNAryExpr *NAry = cast(this); const char *OpStr = nullptr; switch (NAry->getSCEVType()) { @@ -315,6 +316,9 @@ case scSMinExpr: OpStr = " smin "; break; + case scSequentialUMinExpr: + OpStr = " umin_seq "; + break; default: llvm_unreachable("There are no other nary expression types."); } @@ -392,6 +396,8 @@ case scUMinExpr: case scSMinExpr: return cast(this)->getType(); + case scSequentialUMinExpr: + return cast(this)->getType(); case scAddExpr: return cast(this)->getType(); case scUDivExpr: @@ -774,7 +780,8 @@ case scSMaxExpr: case scUMaxExpr: case scSMinExpr: - case scUMinExpr: { + case scUMinExpr: + case scSequentialUMinExpr: { const SCEVNAryExpr *LC = cast(LHS); const SCEVNAryExpr *RC = cast(RHS); @@ -3721,6 +3728,7 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl &Ops) { + assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!"); assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG @@ -3857,6 +3865,75 @@ return S; } +const SCEV * +ScalarEvolution::getSaturatingMinMaxExpr(SCEVTypes Kind, + SmallVectorImpl &Ops) { + assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) && + "Not a SCEVSequentialMinMaxExpr!"); + assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!"); + if (Ops.size() == 1) + return Ops[0]; +#ifndef NDEBUG + Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); + for (unsigned i = 1, e = Ops.size(); i != e; ++i) { + assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && + "Operand types don't match!"); + assert(Ops[0]->getType()->isPointerTy() == + Ops[i]->getType()->isPointerTy() && + "min/max should be consistently pointerish"); + } +#endif + + // Note that SCEVSequentialMinMaxExpr is *NOT* commutative, + // so we can *NOT* do any kind of sorting of the expressions! + + // Check if we have created the same expression before. + if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) + return S; + + // FIXME: there are *some* simplifications that we can do here. + + // Check to see if one of the operands is of the same kind. If so, expand its + // operands onto our operand list, and recurse to simplify. + { + unsigned Idx = 0; + bool DeletedAny = false; + while (Idx < Ops.size()) { + if (Ops[Idx]->getSCEVType() != Kind) { + ++Idx; + continue; + } + const auto *SMME = cast(Ops[Idx]); + Ops.erase(Ops.begin() + Idx); + Ops.insert(Ops.begin() + Idx, SMME->op_begin(), SMME->op_end()); + DeletedAny = true; + } + + if (DeletedAny) + return getSaturatingMinMaxExpr(Kind, Ops); + } + + // Okay, it looks like we really DO need an expr. Check to see if we + // already have one, otherwise create a new one. + FoldingSetNodeID ID; + ID.AddInteger(Kind); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + void *IP = nullptr; + const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP); + if (ExistingSCEV) + return ExistingSCEV; + + const SCEV **O = SCEVAllocator.Allocate(Ops.size()); + std::uninitialized_copy(Ops.begin(), Ops.end(), O); + SCEV *S = new (SCEVAllocator) + SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size()); + + UniqueSCEVs.InsertNode(S, IP); + registerUser(S, Ops); + return S; +} + const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { SmallVector Ops = {LHS, RHS}; return getSMaxExpr(Ops); @@ -3885,14 +3962,16 @@ return getMinMaxExpr(scSMinExpr, Ops); } -const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, - const SCEV *RHS) { +const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS, + bool Saturating) { SmallVector Ops = { LHS, RHS }; - return getUMinExpr(Ops); + return getUMinExpr(Ops, Saturating); } -const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl &Ops) { - return getMinMaxExpr(scUMinExpr, Ops); +const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl &Ops, + bool Saturating) { + return Saturating ? getSaturatingMinMaxExpr(scSequentialUMinExpr, Ops) + : getMinMaxExpr(scUMinExpr, Ops); } const SCEV * @@ -4375,13 +4454,15 @@ } const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, - const SCEV *RHS) { + const SCEV *RHS, + bool Saturating) { SmallVector Ops = { LHS, RHS }; - return getUMinFromMismatchedTypes(Ops); + return getUMinFromMismatchedTypes(Ops, Saturating); } -const SCEV *ScalarEvolution::getUMinFromMismatchedTypes( - SmallVectorImpl &Ops) { +const SCEV * +ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl &Ops, + bool Saturating) { assert(!Ops.empty() && "At least one operand must be!"); // Trivial case. if (Ops.size() == 1) @@ -4402,7 +4483,7 @@ PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType)); // Generate umin. - return getUMinExpr(PromotedOps); + return getUMinExpr(PromotedOps, Saturating); } const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { @@ -5513,6 +5594,7 @@ case scSMaxExpr: case scUMinExpr: case scSMinExpr: + case scSequentialUMinExpr: // These expressions are available if their operand(s) is/are. return true; @@ -6060,7 +6142,7 @@ ConservativeResult.intersectWith(X, RangeType)); } - if (isa(S)) { + if (isa(S) || isa(S)) { Intrinsic::ID ID; switch (S->getSCEVType()) { case scUMaxExpr: @@ -6070,13 +6152,14 @@ ID = Intrinsic::smax; break; case scUMinExpr: + case scSequentialUMinExpr: ID = Intrinsic::umin; break; case scSMinExpr: ID = Intrinsic::smin; break; default: - llvm_unreachable("Unknown SCEVMinMaxExpr."); + llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr."); } const auto *NAry = cast(S); @@ -8169,9 +8252,9 @@ PoisonSafe = isa(EL0.ExactNotTaken) || isa(EL1.ExactNotTaken); if (EL0.ExactNotTaken != getCouldNotCompute() && - EL1.ExactNotTaken != getCouldNotCompute() && PoisonSafe) { - BECount = - getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); + EL1.ExactNotTaken != getCouldNotCompute()) { + BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken, + /*Saturating=*/!PoisonSafe); // If EL0.ExactNotTaken was zero and ExitCond was a short-circuit form, // it should have been simplified to zero (see the condition (3) above) @@ -8972,7 +9055,8 @@ case scUMaxExpr: case scSMinExpr: case scUMinExpr: - return nullptr; // TODO: smax, umax, smin, umax. + case scSequentialUMinExpr: + return nullptr; // TODO: smax, umax, smin, umax, umin_seq. } llvm_unreachable("Unknown SCEV kind!"); } @@ -11354,6 +11438,7 @@ case ICmpInst::ICMP_ULE: return // min(A, ...) <= A + // FIXME: what about umin_seq? IsMinMaxConsistingOf(LHS, RHS) || // A <= max(A, ...) IsMinMaxConsistingOf(RHS, LHS); @@ -12754,7 +12839,8 @@ case scUMaxExpr: case scSMaxExpr: case scUMinExpr: - case scSMinExpr: { + case scSMinExpr: + case scSequentialUMinExpr: { bool HasVarying = false; for (auto *Op : cast(S)->operands()) { LoopDisposition D = getLoopDisposition(Op, L); @@ -12844,7 +12930,8 @@ case scUMaxExpr: case scSMaxExpr: case scUMinExpr: - case scSMinExpr: { + case scSMinExpr: + case scSequentialUMinExpr: { const SCEVNAryExpr *NAry = cast(S); bool Proper = true; for (const SCEV *NAryOp : NAry->operands()) { diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -1671,7 +1671,7 @@ return Builder.CreateSExt(V, Ty); } -Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) { +Value *SCEVExpander::expandSMaxExpr(const SCEVNAryExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands()-2; i >= 0; --i) { @@ -1700,7 +1700,7 @@ return LHS; } -Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) { +Value *SCEVExpander::expandUMaxExpr(const SCEVNAryExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands()-2; i >= 0; --i) { @@ -1729,7 +1729,7 @@ return LHS; } -Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) { +Value *SCEVExpander::expandSMinExpr(const SCEVNAryExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands() - 1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands() - 2; i >= 0; --i) { @@ -1758,7 +1758,7 @@ return LHS; } -Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) { +Value *SCEVExpander::expandUMinExpr(const SCEVNAryExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands() - 1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands() - 2; i >= 0; --i) { @@ -1787,6 +1787,40 @@ return LHS; } +Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) { + return expandSMaxExpr(S); +} + +Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) { + return expandUMaxExpr(S); +} + +Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) { + return expandSMinExpr(S); +} + +Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) { + return expandUMinExpr(S); +} + +Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) { + SmallVector Ops; + for (const SCEV *Op : S->operands()) + Ops.emplace_back(expand(Op)); + + Value *SaturationPoint = + MinMaxIntrinsic::getSaturationPoint(Intrinsic::umin, S->getType()); + + SmallVector OpIsZero; + for (Value *Op : ArrayRef(Ops).drop_back()) + OpIsZero.emplace_back(Builder.CreateICmpEQ(Op, SaturationPoint)); + + Value *AnyOpIsZero = Builder.CreateLogicalOr(OpIsZero); + + Value *NaiveUMin = expandUMinExpr(S); + return Builder.CreateSelect(AnyOpIsZero, SaturationPoint, NaiveUMin); +} + Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *IP, bool Root) { setInsertPoint(IP); @@ -2271,10 +2305,27 @@ case scSMaxExpr: case scUMaxExpr: case scSMinExpr: - case scUMinExpr: { + case scUMinExpr: + case scSequentialUMinExpr: { // FIXME: should this ask the cost for Intrinsic's? + // The reduction tree. Cost += CmpSelCost(Instruction::ICmp, S->getNumOperands() - 1, 0, 1); Cost += CmpSelCost(Instruction::Select, S->getNumOperands() - 1, 0, 2); + switch (S->getSCEVType()) { + case scSequentialUMinExpr: { + // The safety net against poison. + // FIXME: this is broken. + Cost += CmpSelCost(Instruction::ICmp, S->getNumOperands() - 1, 0, 0); + Cost += ArithCost(Instruction::Or, + S->getNumOperands() > 2 ? S->getNumOperands() - 2 : 0); + Cost += CmpSelCost(Instruction::Select, 1, 0, 1); + break; + } + default: + assert(!isa(S) && + "Unhandled SCEV expression type?"); + break; + } break; } case scAddRecExpr: { @@ -2399,7 +2450,8 @@ case scUMaxExpr: case scSMaxExpr: case scUMinExpr: - case scSMinExpr: { + case scSMinExpr: + case scSequentialUMinExpr: { assert(cast(S)->getNumOperands() > 1 && "Nary expr should have more than 1 operand."); // The simple nary expr will require one less op (or pair of ops) diff --git a/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll b/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll --- a/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll +++ b/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll @@ -1,21 +1,21 @@ ; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py ; RUN: opt -disable-output "-passes=print" %s 2>&1 | FileCheck %s -; exact-not-taken cannot be umin(n, m) because it is possible for (n, m) to be (0, poison) -; https://alive2.llvm.org/ce/z/NsP9ue define i32 @logical_and_2ops(i32 %n, i32 %m) { ; CHECK-LABEL: 'logical_and_2ops' ; CHECK-NEXT: Classifying expressions for: @logical_and_2ops ; CHECK-NEXT: %i = phi i32 [ 0, %entry ], [ %i.next, %loop ] -; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%n umin_seq %m) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %i.next = add i32 %i, 1 -; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%n umin_seq %m)) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %cond = select i1 %cond_p0, i1 %cond_p1, i1 false ; CHECK-NEXT: --> %cond U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: Determining loop execution counts for: @logical_and_2ops -; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. +; CHECK-NEXT: Loop %loop: backedge-taken count is (%n umin_seq %m) ; CHECK-NEXT: Loop %loop: max backedge-taken count is -1 -; CHECK-NEXT: Loop %loop: Unpredictable predicated backedge-taken count. +; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (%n umin_seq %m) +; CHECK-NEXT: Predicates: +; CHECK: Loop %loop: Trip multiple is 1 ; entry: br label %loop @@ -30,21 +30,21 @@ ret i32 %i } -; exact-not-taken cannot be umin(n, m) because it is possible for (n, m) to be (0, poison) -; https://alive2.llvm.org/ce/z/ApRitq define i32 @logical_or_2ops(i32 %n, i32 %m) { ; CHECK-LABEL: 'logical_or_2ops' ; CHECK-NEXT: Classifying expressions for: @logical_or_2ops ; CHECK-NEXT: %i = phi i32 [ 0, %entry ], [ %i.next, %loop ] -; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%n umin_seq %m) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %i.next = add i32 %i, 1 -; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%n umin_seq %m)) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %cond = select i1 %cond_p0, i1 true, i1 %cond_p1 ; CHECK-NEXT: --> %cond U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: Determining loop execution counts for: @logical_or_2ops -; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. +; CHECK-NEXT: Loop %loop: backedge-taken count is (%n umin_seq %m) ; CHECK-NEXT: Loop %loop: max backedge-taken count is -1 -; CHECK-NEXT: Loop %loop: Unpredictable predicated backedge-taken count. +; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (%n umin_seq %m) +; CHECK-NEXT: Predicates: +; CHECK: Loop %loop: Trip multiple is 1 ; entry: br label %loop @@ -63,17 +63,19 @@ ; CHECK-LABEL: 'logical_and_3ops' ; CHECK-NEXT: Classifying expressions for: @logical_and_3ops ; CHECK-NEXT: %i = phi i32 [ 0, %entry ], [ %i.next, %loop ] -; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%n umin_seq %m umin_seq %k) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %i.next = add i32 %i, 1 -; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%n umin_seq %m umin_seq %k)) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %cond_p3 = select i1 %cond_p0, i1 %cond_p1, i1 false ; CHECK-NEXT: --> %cond_p3 U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: %cond = select i1 %cond_p3, i1 %cond_p2, i1 false ; CHECK-NEXT: --> %cond U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: Determining loop execution counts for: @logical_and_3ops -; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. +; CHECK-NEXT: Loop %loop: backedge-taken count is (%n umin_seq %m umin_seq %k) ; CHECK-NEXT: Loop %loop: max backedge-taken count is -1 -; CHECK-NEXT: Loop %loop: Unpredictable predicated backedge-taken count. +; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (%n umin_seq %m umin_seq %k) +; CHECK-NEXT: Predicates: +; CHECK: Loop %loop: Trip multiple is 1 ; entry: br label %loop @@ -94,17 +96,19 @@ ; CHECK-LABEL: 'logical_or_3ops' ; CHECK-NEXT: Classifying expressions for: @logical_or_3ops ; CHECK-NEXT: %i = phi i32 [ 0, %entry ], [ %i.next, %loop ] -; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%n umin_seq %m umin_seq %k) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %i.next = add i32 %i, 1 -; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%n umin_seq %m umin_seq %k)) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %cond_p3 = select i1 %cond_p0, i1 true, i1 %cond_p1 ; CHECK-NEXT: --> %cond_p3 U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: %cond = select i1 %cond_p3, i1 true, i1 %cond_p2 ; CHECK-NEXT: --> %cond U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: Determining loop execution counts for: @logical_or_3ops -; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. +; CHECK-NEXT: Loop %loop: backedge-taken count is (%n umin_seq %m umin_seq %k) ; CHECK-NEXT: Loop %loop: max backedge-taken count is -1 -; CHECK-NEXT: Loop %loop: Unpredictable predicated backedge-taken count. +; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (%n umin_seq %m umin_seq %k) +; CHECK-NEXT: Predicates: +; CHECK: Loop %loop: Trip multiple is 1 ; entry: br label %loop diff --git a/llvm/test/Transforms/IndVarSimplify/exit-count-select.ll b/llvm/test/Transforms/IndVarSimplify/exit-count-select.ll --- a/llvm/test/Transforms/IndVarSimplify/exit-count-select.ll +++ b/llvm/test/Transforms/IndVarSimplify/exit-count-select.ll @@ -4,17 +4,14 @@ define i32 @logical_and_2ops(i32 %n, i32 %m) { ; CHECK-LABEL: @logical_and_2ops( ; CHECK-NEXT: entry: +; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[M:%.*]], i32 [[N:%.*]]) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[I_NEXT:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[I_NEXT]] = add i32 [[I]], 1 -; CHECK-NEXT: [[COND_P0:%.*]] = icmp ult i32 [[I]], [[N:%.*]] -; CHECK-NEXT: [[COND_P1:%.*]] = icmp ult i32 [[I]], [[M:%.*]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[COND_P0]], i1 [[COND_P1]], i1 false -; CHECK-NEXT: br i1 [[COND]], label [[LOOP]], label [[EXIT:%.*]] +; CHECK-NEXT: br i1 false, label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[I_LCSSA:%.*]] = phi i32 [ [[I]], [[LOOP]] ] -; CHECK-NEXT: ret i32 [[I_LCSSA]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[N]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[TMP0]], i32 0, i32 [[UMIN]] +; CHECK-NEXT: ret i32 [[TMP1]] ; entry: br label %loop @@ -32,17 +29,14 @@ define i32 @logical_or_2ops(i32 %n, i32 %m) { ; CHECK-LABEL: @logical_or_2ops( ; CHECK-NEXT: entry: +; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[M:%.*]], i32 [[N:%.*]]) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[I_NEXT:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[I_NEXT]] = add i32 [[I]], 1 -; CHECK-NEXT: [[COND_P0:%.*]] = icmp uge i32 [[I]], [[N:%.*]] -; CHECK-NEXT: [[COND_P1:%.*]] = icmp uge i32 [[I]], [[M:%.*]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[COND_P0]], i1 true, i1 [[COND_P1]] -; CHECK-NEXT: br i1 [[COND]], label [[EXIT:%.*]], label [[LOOP]] +; CHECK-NEXT: br i1 true, label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: [[I_LCSSA:%.*]] = phi i32 [ [[I]], [[LOOP]] ] -; CHECK-NEXT: ret i32 [[I_LCSSA]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[N]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[TMP0]], i32 0, i32 [[UMIN]] +; CHECK-NEXT: ret i32 [[TMP1]] ; entry: br label %loop @@ -60,19 +54,17 @@ define i32 @logical_and_3ops(i32 %n, i32 %m, i32 %k) { ; CHECK-LABEL: @logical_and_3ops( ; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[M:%.*]], 0 +; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[K:%.*]], i32 [[M]]) +; CHECK-NEXT: [[UMIN1:%.*]] = call i32 @llvm.umin.i32(i32 [[UMIN]], i32 [[N:%.*]]) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[I_NEXT:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[I_NEXT]] = add i32 [[I]], 1 -; CHECK-NEXT: [[COND_P0:%.*]] = icmp ult i32 [[I]], [[N:%.*]] -; CHECK-NEXT: [[COND_P1:%.*]] = icmp ult i32 [[I]], [[M:%.*]] -; CHECK-NEXT: [[COND_P2:%.*]] = icmp ult i32 [[I]], [[K:%.*]] -; CHECK-NEXT: [[COND_P3:%.*]] = select i1 [[COND_P0]], i1 [[COND_P1]], i1 false -; CHECK-NEXT: [[COND:%.*]] = select i1 [[COND_P3]], i1 [[COND_P2]], i1 false -; CHECK-NEXT: br i1 [[COND]], label [[LOOP]], label [[EXIT:%.*]] +; CHECK-NEXT: br i1 false, label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[I_LCSSA:%.*]] = phi i32 [ [[I]], [[LOOP]] ] -; CHECK-NEXT: ret i32 [[I_LCSSA]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[N]], 0 +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i1 true, i1 [[TMP0]] +; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], i32 0, i32 [[UMIN1]] +; CHECK-NEXT: ret i32 [[TMP3]] ; entry: br label %loop @@ -92,19 +84,17 @@ define i32 @logical_or_3ops(i32 %n, i32 %m, i32 %k) { ; CHECK-LABEL: @logical_or_3ops( ; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[M:%.*]], 0 +; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[K:%.*]], i32 [[M]]) +; CHECK-NEXT: [[UMIN1:%.*]] = call i32 @llvm.umin.i32(i32 [[UMIN]], i32 [[N:%.*]]) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[I_NEXT:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[I_NEXT]] = add i32 [[I]], 1 -; CHECK-NEXT: [[COND_P0:%.*]] = icmp uge i32 [[I]], [[N:%.*]] -; CHECK-NEXT: [[COND_P1:%.*]] = icmp uge i32 [[I]], [[M:%.*]] -; CHECK-NEXT: [[COND_P2:%.*]] = icmp uge i32 [[I]], [[K:%.*]] -; CHECK-NEXT: [[COND_P3:%.*]] = select i1 [[COND_P0]], i1 true, i1 [[COND_P1]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[COND_P3]], i1 true, i1 [[COND_P2]] -; CHECK-NEXT: br i1 [[COND]], label [[EXIT:%.*]], label [[LOOP]] +; CHECK-NEXT: br i1 true, label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: [[I_LCSSA:%.*]] = phi i32 [ [[I]], [[LOOP]] ] -; CHECK-NEXT: ret i32 [[I_LCSSA]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[N]], 0 +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i1 true, i1 [[TMP0]] +; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], i32 0, i32 [[UMIN1]] +; CHECK-NEXT: ret i32 [[TMP3]] ; entry: br label %loop diff --git a/polly/include/polly/Support/SCEVAffinator.h b/polly/include/polly/Support/SCEVAffinator.h --- a/polly/include/polly/Support/SCEVAffinator.h +++ b/polly/include/polly/Support/SCEVAffinator.h @@ -111,6 +111,7 @@ PWACtx visitSMinExpr(const llvm::SCEVSMinExpr *E); PWACtx visitUMaxExpr(const llvm::SCEVUMaxExpr *E); PWACtx visitUMinExpr(const llvm::SCEVUMinExpr *E); + PWACtx visitSequentialUMinExpr(const llvm::SCEVSequentialUMinExpr *E); PWACtx visitUnknown(const llvm::SCEVUnknown *E); PWACtx visitSDivInstruction(llvm::Instruction *SDiv); PWACtx visitSRemInstruction(llvm::Instruction *SRem); diff --git a/polly/lib/Support/SCEVAffinator.cpp b/polly/lib/Support/SCEVAffinator.cpp --- a/polly/lib/Support/SCEVAffinator.cpp +++ b/polly/lib/Support/SCEVAffinator.cpp @@ -465,6 +465,11 @@ llvm_unreachable("SCEVUMinExpr not yet supported"); } +PWACtx +SCEVAffinator::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { + llvm_unreachable("SCEVSequentialUMinExpr not yet supported"); +} + PWACtx SCEVAffinator::visitUDivExpr(const SCEVUDivExpr *Expr) { // The handling of unsigned division is basically the same as for signed // division, except the interpretation of the operands. As the divisor diff --git a/polly/lib/Support/SCEVValidator.cpp b/polly/lib/Support/SCEVValidator.cpp --- a/polly/lib/Support/SCEVValidator.cpp +++ b/polly/lib/Support/SCEVValidator.cpp @@ -332,6 +332,23 @@ return ValidatorResult(SCEVType::PARAM, Expr); } + class ValidatorResult + visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { + // We do not support unsigned min operations. If 'Expr' is constant during + // Scop execution we treat this as a parameter, otherwise we bail out. + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { + ValidatorResult Op = visit(Expr->getOperand(i)); + + if (!Op.isConstant()) { + LLVM_DEBUG( + dbgs() << "INVALID: SafeUMinExpr has a non-constant operand"); + return ValidatorResult(SCEVType::INVALID); + } + } + + return ValidatorResult(SCEVType::PARAM, Expr); + } + ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) { if (R->contains(I)) { LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction " diff --git a/polly/lib/Support/ScopHelper.cpp b/polly/lib/Support/ScopHelper.cpp --- a/polly/lib/Support/ScopHelper.cpp +++ b/polly/lib/Support/ScopHelper.cpp @@ -391,6 +391,12 @@ NewOps.push_back(visit(Op)); return SE.getSMinExpr(NewOps); } + const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *E) { + SmallVector NewOps; + for (const SCEV *Op : E->operands()) + NewOps.push_back(visit(Op)); + return SE.getUMinExpr(NewOps, /*Saturating=*/true); + } const SCEV *visitAddRecExpr(const SCEVAddRecExpr *E) { SmallVector NewOps; for (const SCEV *Op : E->operands())