Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -32,6 +32,7 @@ class ScalarEvolution; class Loop; class SCEV; +class SCEVPredicateSet; /// Optimization analysis message produced during vectorization. Messages inform /// the user why vectorization did not occur. @@ -176,10 +177,10 @@ const SmallVectorImpl &Instrs) const; }; - MemoryDepChecker(ScalarEvolution *Se, const Loop *L) + MemoryDepChecker(ScalarEvolution *Se, const Loop *L, SCEVPredicateSet &Pred) : SE(Se), InnermostLoop(L), AccessIdx(0), ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true), - RecordInterestingDependences(true) {} + RecordInterestingDependences(true), Pred(Pred) {} /// \brief Register the location (instructions are given increasing numbers) /// of a write access. @@ -289,6 +290,9 @@ /// \brief Check whether the data dependence could prevent store-load /// forwarding. bool couldPreventStoreLoadForward(unsigned Distance, unsigned TypeByteSize); + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Pred; }; /// \brief Holds information about the memory runtime legality checks to verify @@ -331,7 +335,8 @@ /// Insert a pointer and calculate the start and end SCEVs. void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, - unsigned ASId, const ValueToValueMap &Strides); + unsigned ASId, const ValueToValueMap &Strides, + SCEVPredicateSet &Pred); /// \brief No run-time memory checking is necessary. bool empty() const { return Pointers.empty(); } @@ -537,6 +542,9 @@ return StoreToLoopInvariantAddress; } + /// The SCEV predicate containing all the SCEV-related assumptions. + std::unique_ptr Pred; + private: /// \brief Analyze the loop. Substitute symbolic strides using Strides. void analyzeLoop(const ValueToValueMap &Strides); @@ -563,6 +571,9 @@ DominatorTree *DT; LoopInfo *LI; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet ScPredicates; + unsigned NumLoads; unsigned NumStores; @@ -590,12 +601,13 @@ /// stride as collected by LoopVectorizationLegality::collectStridedAccess. const SCEV *replaceSymbolicStrideSCEV(ScalarEvolution *SE, const ValueToValueMap &PtrToStride, + SCEVPredicateSet &Pred, const Loop *L, Value *Ptr, Value *OrigPtr = nullptr); /// \brief Check the stride of the pointer and ensure that it does not wrap in /// the address space. int isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap); + const ValueToValueMap &StridesMap, SCEVPredicateSet &Pred); /// \brief This analysis provides dependence information for the memory accesses /// of a loop. Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -51,6 +51,8 @@ class SCEVUnknown; class SCEVAddRecExpr; class SCEV; + class SCEVExpander; + template<> struct FoldingSetTrait; /// SCEV - This class represents an analyzed expression in the program. These @@ -171,6 +173,118 @@ static bool classof(const SCEV *S); }; + enum SCEVPredicateTypes { pSet, pEqual }; + + //===--------------------------------------------------------------------===// + /// SCEVPredicate - This class represents an assumption made using SCEV + /// expressions which can be checked at run-time. + /// + class SCEVPredicate { + protected: + unsigned short SCEVPredicateType; + + public: + SCEVPredicate(unsigned short Type); + virtual ~SCEVPredicate() {} + unsigned short getType() const { return SCEVPredicateType; } + /// Returns true if the predicate is always true. This means that no + /// assumptions were made and nothing needs to be checked at run-time. + virtual bool isAlwaysTrue() const = 0; + /// Return true if we consider this to be always false or if we've + /// given up on this set of assumptions (for example because of the + /// high cost of checking at run-time). + virtual bool isAlwaysFalse() const = 0; + /// Returns true if this predicate implies \p N. + virtual bool contains(const SCEVPredicate *N) const = 0; + /// Prints a textual representation of this predicate. + virtual void print(raw_ostream &OS, unsigned Depth) const = 0; + + /// Generates a run-time check for this predicate. + virtual Value *generateCheck(Instruction *Loc, ScalarEvolution *SE, + const DataLayout *DL, SCEVExpander &Exp) = 0; + }; + + class SCEVEqualPredicate : public SCEVPredicate { + // We assume that E0 == E1 + const SCEV *E0; + const SCEV *E1; + + public: + SCEVEqualPredicate() + : SCEVPredicate(pEqual), E0(nullptr), E1(nullptr) {} + SCEVEqualPredicate(const SCEV *E0, const SCEV *E1); + + /// Implementation of the SCEVPredicate interface + bool contains(const SCEVPredicate *N) const override; + void print(raw_ostream &OS, unsigned Depth) const override; + bool isAlwaysTrue() const override; + bool isAlwaysFalse() const override; + + const SCEV *getLHS() { return E0; } + const SCEV *getRHS() { return E1; } + + Value *generateCheck(Instruction *Loc, ScalarEvolution *SE, + const DataLayout *DL, SCEVExpander &Exp) override; + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVPredicate *P) { + return P->getType() == pEqual; + } + }; + + //===--------------------------------------------------------------------===// + /// SCEVPredicateSet - This class represents a composition of other + /// SCEV predicates, and is the class that most clients will interact with. + /// + class SCEVPredicateSet : public SCEVPredicate { + protected: + /// Flag used to track if this predicate set is invalid. + bool Never; + /// Storage for different predicates that make up this Predicate Set. + //SmallVector AddRecOverflows; + SmallVector IdPreds; + + public: + SCEVPredicateSet(); + /// The copy constructor. + SCEVPredicateSet(const SCEVPredicateSet &Old); + /// Adds a predicate to this predicate set. + void add(const SCEVPredicate *N); + + /// Vector with references to all predicates in this set. + SmallVector Preds; + + /// Generates a run-time check for all the contained predicates. + /// This is a wrapper around generateCheck, and provides an interface + /// similar to other run-time checks used for versioning. + std::pair + generateGuardCond(Instruction *Loc, ScalarEvolution *SE); + + /// Implementation of the SCEVPredicate interface + bool isAlwaysTrue() const override; + bool isAlwaysFalse() const override; + bool contains(const SCEVPredicate *N) const override; + void print(raw_ostream &OS, unsigned Depth) const; + Value *generateCheck(Instruction *Loc, ScalarEvolution *SE, + const DataLayout *DL, SCEVExpander &Exp) override; + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVPredicate *P) { + return P->getType() == pSet; + } + + /// The copy operator. + const SCEVPredicateSet &operator=(const SCEVPredicateSet &RHS) { + Never = RHS.Never; + IdPreds = RHS.IdPreds; + Preds.clear(); + for (unsigned II = 0; II < IdPreds.size(); ++II) + Preds.push_back(&IdPreds[II]); + assert(Preds.size() == RHS.Preds.size() && "Wrong Preds size after copy"); + return *this; + } + }; + /// The main scalar evolution driver. Because client code (intentionally) /// can't do much with the SCEV objects directly, they must ask this class /// for services. @@ -1069,6 +1183,11 @@ SmallVectorImpl &Sizes, const SCEV *ElementSize); + /// Re-writes the SCEV according to the Predicates in \p Preds, by + /// applying overflow assumptions and sinking sext/zext expressions. + const SCEV *rewriteUsingPredicate(const SCEV *Scev, const Loop *L, + SCEVPredicateSet &A); + private: /// Compute the backedge taken count knowing the interval difference, the /// stride and presence of the equality in the comparison. Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -15,6 +15,7 @@ #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DiagnosticInfo.h" @@ -89,6 +90,7 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE, const ValueToValueMap &PtrToStride, + SCEVPredicateSet &Pred, const Loop *L, Value *Ptr, Value *OrigPtr) { const SCEV *OrigSCEV = SE->getSCEV(Ptr); @@ -104,15 +106,21 @@ StrideVal = stripIntegerCast(StrideVal); // Replace symbolic stride by one. - Value *One = ConstantInt::get(StrideVal->getType(), 1); + Value *VOne = ConstantInt::get(StrideVal->getType(), 1); + ConstantInt *One = static_cast(VOne); ValueToValueMap RewriteMap; RewriteMap[StrideVal] = One; - const SCEV *ByOne = - SCEVParameterRewriter::rewrite(OrigSCEV, *SE, RewriteMap, true); - DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne - << "\n"); - return ByOne; + SCEVPredicateSet P = Pred; + const SCEVEqualPredicate SI(SE->getSCEV(StrideVal), SE->getConstant(One)); + P.add(&SI); + if (!P.isAlwaysFalse()) { + const SCEV *ByOne = SE->rewriteUsingPredicate(OrigSCEV, L, P); + DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne + << "\n"); + Pred = P; + return ByOne; + } } // Otherwise, just return the SCEV of the original pointer. @@ -121,13 +129,13 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, unsigned ASId, - const ValueToValueMap &Strides) { + const ValueToValueMap &Strides, + SCEVPredicateSet &Pred) { // Get the stride replaced scev. - const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Pred, Lp, Ptr); const SCEVAddRecExpr *AR = dyn_cast(Sc); assert(AR && "Invalid addrec expression"); const SCEV *Ex = SE->getBackedgeTakenCount(Lp); - const SCEV *ScStart = AR->getStart(); const SCEV *ScEnd = AR->evaluateAtIteration(Ex, *SE); const SCEV *Step = AR->getStepRecurrence(*SE); @@ -417,9 +425,9 @@ typedef SmallPtrSet MemAccessInfoSet; AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI, - MemoryDepChecker::DepCandidates &DA) - : DL(Dl), AST(*AA), LI(LI), DepCands(DA), - IsRTCheckAnalysisNeeded(false) {} + MemoryDepChecker::DepCandidates &DA, SCEVPredicateSet &Pred) + : DL(Dl), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false), + Pred(Pred) {} /// \brief Register a load and whether it is only read from. void addLoad(MemoryLocation &Loc, bool IsReadOnly) { @@ -504,14 +512,18 @@ /// (i.e. ShouldRetryWithRuntimeCheck), isDependencyCheckNeeded is cleared /// while this remains set if we have potentially dependent accesses. bool IsRTCheckAnalysisNeeded; + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Pred; }; } // end anonymous namespace /// \brief Check whether a pointer can participate in a runtime bounds check. static bool hasComputableBounds(ScalarEvolution *SE, - const ValueToValueMap &Strides, Value *Ptr) { - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const ValueToValueMap &Strides, Value *Ptr, + Loop *L, SCEVPredicateSet &Pred) { + const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Pred, L, Ptr); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR) return false; @@ -554,11 +566,11 @@ else ++NumReadPtrChecks; - if (hasComputableBounds(SE, StridesMap, Ptr) && + if (hasComputableBounds(SE, StridesMap, Ptr, TheLoop, Pred) && // When we run after a failing dependency check we have to make sure // we don't have wrapping pointers. (!ShouldCheckStride || - isStridedPtr(SE, Ptr, TheLoop, StridesMap) == 1)) { + isStridedPtr(SE, Ptr, TheLoop, StridesMap, Pred) == 1)) { // The id of the dependence set. unsigned DepId; @@ -572,7 +584,7 @@ // Each access has its own dependence set. DepId = RunningDepId++; - RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap); + RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, Pred); DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); } else { @@ -803,7 +815,8 @@ /// \brief Check whether the access through \p Ptr has a constant stride. int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap) { + const ValueToValueMap &StridesMap, + SCEVPredicateSet &Pred) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); @@ -815,7 +828,8 @@ return 0; } - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Ptr); + const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Pred, + Lp, Ptr); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR) { @@ -1026,11 +1040,13 @@ BPtr->getType()->getPointerAddressSpace()) return Dependence::Unknown; - const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, APtr); - const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, BPtr); + const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, Pred, + InnermostLoop, APtr); + const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, Pred, + InnermostLoop, BPtr); - int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides); - int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides); + int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides, Pred); + int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides, Pred); const SCEV *Src = AScev; const SCEV *Sink = BScev; @@ -1429,7 +1445,7 @@ MemoryDepChecker::DepCandidates DependentAccesses; AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(), - AA, LI, DependentAccesses); + AA, LI, DependentAccesses, *Pred); // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects // multiple times on the same object. If the ptr is accessed twice, once @@ -1480,7 +1496,8 @@ // read a few words, modify, and write a few words, and some of the // words may be written to the same address. bool IsReadOnlyPtr = false; - if (Seen.insert(Ptr).second || !isStridedPtr(SE, Ptr, TheLoop, Strides)) { + if (Seen.insert(Ptr).second || + !isStridedPtr(SE, Ptr, TheLoop, Strides, *Pred)) { ++NumReads; IsReadOnlyPtr = true; } @@ -1728,8 +1745,9 @@ const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, const ValueToValueMap &Strides) - : PtrRtChecking(SE), DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL), - TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), + : Pred(new SCEVPredicateSet), PtrRtChecking(SE), DepChecker(SE, L, *Pred), + TheLoop(L), SE(SE), DL(DL), TLI(TLI), + AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { if (canAnalyzeLoop()) @@ -1763,6 +1781,9 @@ OS.indent(Depth) << "Store to invariant address was " << (StoreToLoopInvariantAddress ? "" : "not ") << "found in loop.\n"; + + OS.indent(Depth) << "SCEV assumptions:\n"; + Pred->print(OS, Depth); } const LoopAccessInfo & Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -67,6 +67,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -80,6 +81,8 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" @@ -109,6 +112,13 @@ "derived loop"), cl::init(100)); +static cl::opt +SCEVCheckThreshold("force-max-scev-checks", cl::init(16), + cl::Hidden, + cl::desc("Don't create SCEV predicates with more than " + "this number of assumptions.")); + + // FIXME: Enable this with XDEBUG when the test suite is clean. static cl::opt VerifySCEV("verify-scev", @@ -8888,3 +8898,284 @@ AU.addRequiredTransitive(); AU.addRequiredTransitive(); } + +static Instruction *getFirstInst(Instruction *FirstInst, Value *V, + Instruction *Loc) { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; +} + +struct SCEVPredicateRewriter + : public SCEVVisitor { +public: + SCEVPredicateSet &P; + + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE, SCEVPredicateSet &A) { + SCEVPredicateRewriter Rewriter(L, SE, A); + return Rewriter.visit(Scev); + } + + SCEVPredicateRewriter(const Loop *L, ScalarEvolution &S, SCEVPredicateSet &P) + : P(P), SE(S), L(L) {} + + const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } + + const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getTruncateExpr(Operand, Expr->getType()); + } + + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getZeroExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getSignExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getAddExpr(Operands); + } + + const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getMulExpr(Operands); + } + + const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { + return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS())); + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + + const Loop *L = Expr->getLoop(); + return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); + } + + const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getSMaxExpr(Operands); + } + + const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getUMaxExpr(Operands); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + for (SCEVPredicate *Pred : P.Preds) { + if (SCEVEqualPredicate *IPred = dyn_cast(Pred)) { + if (IPred->getLHS() == Expr) + return IPred->getRHS(); + } + } + return Expr; + } + + const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return Expr; + } + +private: + ScalarEvolution &SE; + const Loop *L; +}; + +const SCEV * +ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, const Loop *L, + SCEVPredicateSet &Pred) { + return SCEVPredicateRewriter::rewrite(Scev, L, *this, Pred); +} + +//// SCEV predicates +SCEVPredicate::SCEVPredicate(unsigned short Type) : SCEVPredicateType(Type) {} + +SCEVEqualPredicate::SCEVEqualPredicate( + const SCEV *E0, const SCEV *E1) + : SCEVPredicate(pEqual), E0(E0), E1(E1) {} + +bool SCEVEqualPredicate::contains(const SCEVPredicate *N) const { + const SCEVEqualPredicate *Op = + dyn_cast(N); + + if (!Op) + return false; + + if ((Op->E0 == E0 && Op->E1 == E1) || + (Op->E0 == E1 && Op->E1 == E0)) + return true; + + return false; +} + +bool SCEVEqualPredicate::isAlwaysTrue() const { + if (E0 == E1) return true; + return false; +} + +bool SCEVEqualPredicate::isAlwaysFalse() const { return false; } + +Value *SCEVEqualPredicate::generateCheck(Instruction *Loc, + ScalarEvolution *SE, + const DataLayout *DL, + SCEVExpander &Exp) { + IRBuilder<> Builder(Loc); + + Value *Expr0 = Exp.expandCodeFor(E0, E0->getType(), Loc); + Value *Expr1 = Exp.expandCodeFor(E1, E1->getType(), Loc); + + Value *C = Builder.CreateICmpNE(Expr0, Expr1, "ident.check"); + return C; +} + +void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << "Equal predicate: " << *E0 << " == " << *E1 << "\n"; +} + +SCEVPredicateSet::SCEVPredicateSet() : SCEVPredicate(pSet), Never(false) {} + +SCEVPredicateSet::SCEVPredicateSet(const SCEVPredicateSet &Old) + : SCEVPredicateSet() { + this->Never = Old.Never; + if (Never) + return; + IdPreds = Old.IdPreds; + for (unsigned II = 0; II < IdPreds.size(); ++II) + Preds.push_back(&IdPreds[II]); +} + +bool SCEVPredicateSet::isAlwaysTrue() const { + if (Never) + return false; + + for (auto II = Preds.begin(), EE = Preds.end(); II != EE; ++II) { + const SCEVPredicate *OA = *II; + if (!OA->isAlwaysTrue()) + return false; + } + + return true; +} + +bool SCEVPredicateSet::isAlwaysFalse() const { return Never; } + +std::pair +SCEVPredicateSet::generateGuardCond(Instruction *Loc, ScalarEvolution *SE) { + Instruction *tnullptr = nullptr; + + assert(!Never && "Cannot generate a runtime check on " + "a predicate with the Never flag set"); + + if (isAlwaysTrue()) + return std::pair(tnullptr, tnullptr); + + IRBuilder<> OFBuilder(Loc); + Instruction *FirstInst = nullptr; + Module *M = Loc->getParent()->getParent()->getParent(); + const DataLayout &DL = M->getDataLayout(); + SCEVExpander Exp(*SE, DL, "start"); + + Value *Check = generateCheck(Loc, SE, &DL, Exp); + + if (!Check) + return std::make_pair(nullptr, nullptr); + + Instruction *CheckInst = + BinaryOperator::CreateOr(Check, ConstantInt::getFalse(M->getContext())); + OFBuilder.Insert(CheckInst, "scev.check"); + + FirstInst = getFirstInst(FirstInst, CheckInst, Loc); + return std::make_pair(FirstInst, CheckInst); +} + +Value *SCEVPredicateSet::generateCheck(Instruction *Loc, ScalarEvolution *SE, + const DataLayout *DL, + SCEVExpander &Exp) { + + IRBuilder<> OFBuilder(Loc); + Value *AllCheck = nullptr; + + // Loop over all checks in this set. + for (auto II = Preds.begin(), EE = Preds.end(); II != EE; ++II) { + SCEVPredicate *OA = *II; + + if (OA->isAlwaysTrue()) + continue; + + Value *CheckResult = OA->generateCheck(Loc, SE, DL, Exp); + + if (!AllCheck) + AllCheck = CheckResult; + else + AllCheck = OFBuilder.CreateOr(AllCheck, CheckResult); + } + + return AllCheck; +} + +bool SCEVPredicateSet::contains(const SCEVPredicate *N) const { + if (Never) + return false; + + if (const SCEVPredicateSet *Set = dyn_cast(N)) { + for (auto II = Set->Preds.begin(), EE = Set->Preds.end(); II != EE; ++II) { + if (!contains(*II)) + return false; + } + return true; + } + for (auto II = Preds.begin(), EE = Preds.end(); II != EE; ++II) { + if ((*II)->contains(N)) + return true; + } + return false; +} + +void SCEVPredicateSet::print(raw_ostream &OS, unsigned Depth) const { + for (auto II = Preds.begin(), EE = Preds.end(); II != EE; ++II) + (*II)->print(OS, Depth); +} + +void SCEVPredicateSet::add(const SCEVPredicate *N) { + if (Preds.size() > SCEVCheckThreshold || N->isAlwaysFalse()) { + Never = true; + return; + } + if (const SCEVEqualPredicate *OP = + dyn_cast(N)) { + for (unsigned II = 0, EE = IdPreds.size(); II < EE; ++II) + if (IdPreds[II].contains(OP)) + return; + + IdPreds.push_back(*OP); + Preds.push_back(&IdPreds.back()); + + return; + } + if (const SCEVPredicateSet *Set = + dyn_cast(N)) { + for (auto II = Set->Preds.begin(), EE = Set->Preds.end(); II != EE; ++II) { + add(*II); + } + return; + } + llvm_unreachable("Unknown SCEV predicate type!"); +} Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -272,12 +272,12 @@ InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, unsigned VecWidth, - unsigned UnrollFactor) + unsigned UnrollFactor, SCEVPredicateSet &Pred) : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), TLI(TLI), TTI(TTI), VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()), Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor), TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr), - AddedSafetyChecks(false) {} + AddedSafetyChecks(false), Pred(Pred) {} // Perform the actual loop widening (vectorization). void vectorize(LoopVectorizationLegality *L) { @@ -315,6 +315,10 @@ /// pair as (first, last). std::pair addStrideCheck(Instruction *Loc); + // Adds code to check the overflow assumptions made by SCEV + std::pair + addRuntimeOverflowChecks(Instruction *Loc); + /// Create an empty loop, based on the loop ranges of the old loop. void createEmptyLoop(); /// Create a new induction variable inside L. @@ -395,8 +399,8 @@ void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass); /// Emit a bypass check to see if the vector trip count is nonzero. void emitVectorLoopEnteredCheck(Loop *L, BasicBlock *Bypass); - /// Emit bypass checks to check if strides we've assumed to be one really are. - void emitStrideChecks(Loop *L, BasicBlock *Bypass); + + void emitSCEVChecks(Loop *L, BasicBlock *Bypass); /// Emit bypass checks to check any memory assumptions we may have made. void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass); @@ -503,14 +507,19 @@ // Record whether runtime check is added. bool AddedSafetyChecks; + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Pred; }; class InnerLoopUnroller : public InnerLoopVectorizer { public: InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, unsigned UnrollFactor) - : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor) {} + const TargetTransformInfo *TTI, unsigned UnrollFactor, + SCEVPredicateSet &Pred) + : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor, + Pred) {} private: void scalarizeInstruction(Instruction *Instr, @@ -731,8 +740,9 @@ /// between the member and the group in a map. class InterleavedAccessInfo { public: - InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT) - : SE(SE), TheLoop(L), DT(DT) {} + InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT, + SCEVPredicateSet &Pred) + : SE(SE), TheLoop(L), DT(DT), Pred(Pred) {} ~InterleavedAccessInfo() { SmallSet DelSet; @@ -765,6 +775,8 @@ ScalarEvolution *SE; Loop *TheLoop; DominatorTree *DT; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Pred; /// Holds the relationships between the members and the interleave group. DenseMap InterleaveGroupMap; @@ -1128,11 +1140,13 @@ Function *F, const TargetTransformInfo *TTI, LoopAccessAnalysis *LAA, LoopVectorizationRequirements *R, - const LoopVectorizeHints *H) + const LoopVectorizeHints *H, + SCEVPredicateSet &Pred) : NumPredStores(0), TheLoop(L), SE(SE), TLI(TLI), TheFunction(F), - TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(SE, L, DT), - Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), - Requirements(R), Hints(H) {} + TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), + InterleaveInfo(SE, L, DT, Pred), Induction(nullptr), + WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), + Hints(H), Pred(Pred) {} /// ReductionList contains the reduction descriptors for all /// of the reductions that were found in the loop. @@ -1331,7 +1345,10 @@ /// While vectorizing these instructions we have to generate a /// call to the appropriate masked intrinsic - SmallPtrSet MaskedOp; + SmallPtrSet MaskedOp; + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Pred; }; /// LoopVectorizationCostModel - estimates the expected speedups due to @@ -1348,9 +1365,11 @@ const TargetTransformInfo &TTI, const TargetLibraryInfo *TLI, AssumptionCache *AC, const Function *F, const LoopVectorizeHints *Hints, - SmallPtrSetImpl &ValuesToIgnore) + SmallPtrSetImpl &ValuesToIgnore, + SCEVPredicateSet &Pred) : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), - TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {} + TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore), + Pred(Pred) {} /// Information about vectorization costs struct VectorizationFactor { @@ -1436,6 +1455,8 @@ const LoopVectorizeHints *Hints; // Values to ignore in the cost model. const SmallPtrSetImpl &ValuesToIgnore; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Pred; }; /// \brief This holds vectorization requirements that must be verified late in @@ -1665,11 +1686,12 @@ return false; } } + SCEVPredicateSet Pred; // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements; LoopVectorizationLegality LVL(L, SE, DT, TLI, AA, F, TTI, LAA, - &Requirements, &Hints); + &Requirements, &Hints, Pred); if (!LVL.canVectorize()) { DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); emitMissedWarning(F, L, Hints); @@ -1688,7 +1710,7 @@ // Use the cost model. LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, AC, F, &Hints, - ValuesToIgnore); + ValuesToIgnore, Pred); // Check the function attributes to find out if this function should be // optimized for size. @@ -1799,7 +1821,7 @@ assert(IC > 1 && "interleave count should not be 1 or 0"); // If we decided that it is not legal to vectorize the loop then // interleave it. - InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC); + InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC, Pred); Unroller.vectorize(&LVL); emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(), @@ -1807,7 +1829,7 @@ Twine(IC) + ")"); } else { // If we decided that it is *legal* to vectorize the loop then do it. - InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC); + InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC, Pred); LB.vectorize(&LVL); ++LoopsVectorized; @@ -1967,7 +1989,7 @@ // %idxprom = zext i32 %mul to i64 << Safe cast. // %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom // - Last = replaceSymbolicStrideSCEV(SE, Strides, + Last = replaceSymbolicStrideSCEV(SE, Strides, Pred, TheLoop, Gep->getOperand(InductionOperand), Gep); if (const SCEVCastExpr *C = dyn_cast(Last)) Last = @@ -2710,26 +2732,26 @@ LoopBypassBlocks.push_back(BB); } -void InnerLoopVectorizer::emitStrideChecks(Loop *L, - BasicBlock *Bypass) { +void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { BasicBlock *BB = L->getLoopPreheader(); - + // Generate the code to check that the strides we assumed to be one are really // one. We want the new basic block to start at the first instruction in a // sequence of instructions that form a check. - Instruction *StrideCheck; + Instruction *SCEVCheck; Instruction *FirstCheckInst; - std::tie(FirstCheckInst, StrideCheck) = addStrideCheck(BB->getTerminator()); - if (!StrideCheck) + std::tie(FirstCheckInst, SCEVCheck) = + Pred.generateGuardCond(BB->getTerminator(), SE); + if (!SCEVCheck) return; // Create a new block containing the stride check. - BB->setName("vector.stridecheck"); + BB->setName("vector.scevcheck"); auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); if (L->getParentLoop()) L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); ReplaceInstWithInst(BB->getTerminator(), - BranchInst::Create(Bypass, NewBB, StrideCheck)); + BranchInst::Create(Bypass, NewBB, SCEVCheck)); LoopBypassBlocks.push_back(BB); AddedSafetyChecks = true; } @@ -2849,10 +2871,10 @@ // Now, compare the new count to zero. If it is zero skip the vector loop and // jump to the scalar loop. emitVectorLoopEnteredCheck(Lp, ScalarPH); - // Generate the code to check that the strides we assumed to be one are really - // one. We want the new basic block to start at the first instruction in a - // sequence of instructions that form a check. - emitStrideChecks(Lp, ScalarPH); + // Generate the code to check any assumptions that we've made for SCEV + // expressions. + emitSCEVChecks(Lp, ScalarPH); + // Generate the code that checks in runtime if arrays overlap. We put the // checks into a separate block to make the more common case of few elements // faster. @@ -4292,6 +4314,15 @@ } Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); + Pred.add(&*LAI->Pred); + + if (Pred.isAlwaysFalse()) { + emitAnalysis(VectorizationReport() + << "Too many SCEV assuptions need to be made and checked " + << "at runtime"); + DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n"); + return false; + } return true; } @@ -4406,7 +4437,7 @@ StoreInst *SI = dyn_cast(I); Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); - int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides); + int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides, Pred); // The factor of the corresponding interleave group. unsigned Factor = std::abs(Stride); @@ -4415,7 +4446,7 @@ if (Factor < 2 || Factor > MaxInterleaveGroupFactor) continue; - const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Pred, TheLoop, Ptr); PointerType *PtrTy = dyn_cast(Ptr->getType()); unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType());