Index: llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolution.h +++ llvm/include/llvm/Analysis/ScalarEvolution.h @@ -643,6 +643,8 @@ bool Sequential = false); const SCEV *getUMinExpr(SmallVectorImpl &Operands, bool Sequential = false); + const SCEV *getCompareExpr(ICmpInst::Predicate P, const SCEV *RHS, + const SCEV *LHS); const SCEV *getUnknown(Value *V); const SCEV *getCouldNotCompute(); Index: llvm/include/llvm/Analysis/ScalarEvolutionDivision.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolutionDivision.h +++ llvm/include/llvm/Analysis/ScalarEvolutionDivision.h @@ -43,6 +43,7 @@ void visitSMinExpr(const SCEVSMinExpr *Numerator) {} void visitUMinExpr(const SCEVUMinExpr *Numerator) {} void visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Numerator) {} + void visitCompareExpr(const SCEVCompareExpr *) {} void visitUnknown(const SCEVUnknown *Numerator) {} void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} Index: llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -51,6 +51,7 @@ scSMinExpr, scSequentialUMinExpr, scPtrToInt, + scCompareExpr, scUnknown, scCouldNotCompute }; @@ -564,6 +565,48 @@ } }; +/// This class represents a icmp comparing two SCEVS. Currently, this node +/// does not appear in generic SCEV expressions. It is only used to support +/// PredicatedScalarEvolution. +class SCEVCompareExpr : public SCEV { + friend class ScalarEvolution; + + const ICmpInst::Predicate Pred; + std::array Operands; + + SCEVCompareExpr(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) + : SCEV(ID, scCompareExpr, computeExpressionSize({LHS, RHS})), Pred(Pred) { + Operands[0] = LHS; + Operands[1] = RHS; + } + +public: + ICmpInst::Predicate getPredicate() const { return Pred; } + const SCEV *getLHS() const { return Operands[0]; } + const SCEV *getRHS() const { return Operands[1]; } + size_t getNumOperands() const { return 2; } + const SCEV *getOperand(unsigned i) const { + assert((i == 0 || i == 1) && "Operand index out of range!"); + return i == 0 ? getLHS() : getRHS(); + } + + using op_iterator = std::array::const_iterator; + using op_range = iterator_range; + op_range operands() const { + return make_range(Operands.begin(), Operands.end()); + } + + Type *getType() const { + return IntegerType::getInt1Ty(getRHS()->getType()->getContext()); + } + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const SCEV *S) { + return S->getSCEVType() == scCompareExpr; + } +}; + /// This means that we are dealing with an entirely unknown SCEV /// value, and only represent it as its LLVM Value. This is the /// "bottom" value for the analysis. @@ -642,6 +685,8 @@ case scSequentialUMinExpr: return ((SC *)this) ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S); + case scCompareExpr: + return ((SC *)this)->visitCompareExpr((const SCEVCompareExpr *)S); case scUnknown: return ((SC *)this)->visitUnknown((const SCEVUnknown *)S); case scCouldNotCompute: @@ -697,6 +742,7 @@ case scSMinExpr: case scUMinExpr: case scSequentialUMinExpr: + case scCompareExpr: case scAddRecExpr: for (const auto *Op : cast(S)->operands()) push(Op); @@ -892,6 +938,17 @@ return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true); } + const SCEV *visitCompareExpr(const SCEVCompareExpr *Expr) { + SmallVector Operands; + bool Changed = false; + for (auto *Op : Expr->operands()) { + Operands.push_back(((SC *)this)->visit(Op)); + Changed |= Op != Operands.back(); + } + return !Changed ? Expr : SE.getCompareExpr(Expr->getPredicate(), + Operands[0], Operands[1]); + } + const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { Index: llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h =================================================================== --- llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -488,6 +488,8 @@ Value *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S); + Value *visitCompareExpr(const SCEVCompareExpr *Expr); + Value *visitUnknown(const SCEVUnknown *S) { return S->getValue(); } void rememberInstruction(Value *I); Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -346,6 +346,13 @@ OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")"; return; } + case scCompareExpr: { + const SCEVCompareExpr *Cmp = cast(this); + OS << "(" << *Cmp->getLHS() << " " + << CmpInst::getPredicateName(Cmp->getPredicate()) + << " " << *Cmp->getRHS() << ")"; + return; + } case scUnknown: { const SCEVUnknown *U = cast(this); Type *AllocTy; @@ -402,6 +409,8 @@ return cast(this)->getType(); case scUDivExpr: return cast(this)->getType(); + case scCompareExpr: + return cast(this)->getType(); case scUnknown: return cast(this)->getType(); case scCouldNotCompute: @@ -816,6 +825,21 @@ EqCacheSCEV.unionSets(LHS, RHS); return X; } + case scCompareExpr: { + auto *LC = cast(LHS); + auto *RC = cast(RHS); + + // Lexicographically compare udiv expressions. + auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(), + RC->getLHS(), DT, Depth + 1); + if (X != 0) + return X; + X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(), + RC->getRHS(), DT, Depth + 1); + if (X == 0) + EqCacheSCEV.unionSets(LHS, RHS); + return X; + } case scPtrToInt: case scTruncate: @@ -3995,6 +4019,8 @@ return visitAnyMinMaxExpr(Expr); } + RetVal visitCompareExpr(const SCEVCompareExpr *Expr) { return Expr; } + RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; } RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; } @@ -4124,6 +4150,32 @@ : getMinMaxExpr(scUMinExpr, Ops); } +const SCEV *ScalarEvolution::getCompareExpr(ICmpInst::Predicate Pred, + const SCEV *LHS, + const SCEV *RHS) { + FoldingSetNodeID ID; + ID.AddInteger(scCompareExpr); + ID.AddInteger(Pred); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + void *IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; + + // TODO: Add various simplification rules here by moving/borrowing + // existing inference code. + + // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs + // changes). Make sure we get a new one. + IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = new (SCEVAllocator) SCEVCompareExpr(ID.Intern(SCEVAllocator), + Pred, LHS, RHS); + UniqueSCEVs.InsertNode(S, IP); + registerUser(S, {LHS, RHS}); + return S; +} + const SCEV * ScalarEvolution::getSizeOfScalableVectorExpr(Type *IntTy, ScalableVectorType *ScalableTy) { @@ -5748,6 +5800,7 @@ case scUMinExpr: case scSMinExpr: case scSequentialUMinExpr: + case scCompareExpr: // These expressions are available if their operand(s) is/are. return true; @@ -9228,6 +9281,7 @@ case scSMinExpr: case scUMinExpr: case scSequentialUMinExpr: + case scCompareExpr: return nullptr; // TODO: smax, umax, smin, umax, umin_seq. } llvm_unreachable("Unknown SCEV kind!"); @@ -13070,6 +13124,17 @@ return (LD == LoopInvariant && RD == LoopInvariant) ? LoopInvariant : LoopComputable; } + case scCompareExpr: { + auto *UDiv = cast(S); + LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L); + if (LD == LoopVariant) + return LoopVariant; + LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L); + if (RD == LoopVariant) + return LoopVariant; + return (LD == LoopInvariant && RD == LoopInvariant) ? + LoopInvariant : LoopComputable; + } case scUnknown: // All non-instruction values are loop invariant. All instructions are loop // invariant if they are not contained in the specified loop. @@ -13163,6 +13228,18 @@ return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ? ProperlyDominatesBlock : DominatesBlock; } + case scCompareExpr: { + const SCEVCompareExpr *UDiv = cast(S); + const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS(); + BlockDisposition LD = getBlockDisposition(LHS, BB); + if (LD == DoesNotDominateBlock) + return DoesNotDominateBlock; + BlockDisposition RD = getBlockDisposition(RHS, BB); + if (RD == DoesNotDominateBlock) + return DoesNotDominateBlock; + return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ? + ProperlyDominatesBlock : DominatesBlock; + } case scUnknown: if (Instruction *I = dyn_cast(cast(S)->getValue())) { Index: llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp =================================================================== --- llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -1821,6 +1821,11 @@ return Builder.CreateSelect(AnyOpIsZero, SaturationPoint, NaiveUMin); } +Value *SCEVExpander::visitCompareExpr(const SCEVCompareExpr *Expr) { + return Builder.CreateICmp(Expr->getPredicate(), expand(Expr->getLHS()), + expand(Expr->getRHS())); +} + Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *IP, bool Root) { setInsertPoint(IP); @@ -2320,6 +2325,9 @@ } break; } + case scCompareExpr: + Cost += CmpSelCost(Instruction::ICmp, 2, 0, 1); + break; case scAddRecExpr: { // In this polynominal, we may have some zero operands, and we shouldn't // really charge for those. So how many non-zero coeffients are there? @@ -2420,6 +2428,11 @@ costAndCollectOperands(WorkItem, TTI, CostKind, Worklist); return false; // Will answer upon next entry into this function. } + case scCompareExpr: { + Cost += costAndCollectOperands(WorkItem, TTI, CostKind, + Worklist); + return false; // Will answer upon next entry into this function. + } case scUDivExpr: { // UDivExpr is very likely a UDiv that ScalarEvolution's HowFarToZero or // HowManyLessThans produced to compute a precise expression, rather than a