Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -33,6 +33,7 @@ class ScalarEvolution; class Loop; class SCEV; +class SCEVPredicateSet; /// Optimization analysis message produced during vectorization. Messages inform /// the user why vectorization did not occur. @@ -177,10 +178,10 @@ const SmallVectorImpl &Instrs) const; }; - MemoryDepChecker(ScalarEvolution *Se, const Loop *L) + MemoryDepChecker(ScalarEvolution *Se, const Loop *L, SCEVPredicateSet &Pred) : SE(Se), InnermostLoop(L), AccessIdx(0), ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true), - RecordInterestingDependences(true) {} + RecordInterestingDependences(true), Pred(Pred) {} /// \brief Register the location (instructions are given increasing numbers) /// of a write access. @@ -290,6 +291,9 @@ /// \brief Check whether the data dependence could prevent store-load /// forwarding. bool couldPreventStoreLoadForward(unsigned Distance, unsigned TypeByteSize); + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Pred; }; /// \brief Holds information about the memory runtime legality checks to verify @@ -332,7 +336,8 @@ /// Insert a pointer and calculate the start and end SCEVs. void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, - unsigned ASId, const ValueToValueMap &Strides); + unsigned ASId, const ValueToValueMap &Strides, + SCEVPredicateSet &Pred); /// \brief No run-time memory checking is necessary. bool empty() const { return Pointers.empty(); } @@ -538,6 +543,9 @@ return StoreToLoopInvariantAddress; } + /// The SCEV predicate containing all the SCEV-related assumptions. + std::unique_ptr Pred; + private: /// \brief Analyze the loop. Substitute symbolic strides using Strides. void analyzeLoop(const ValueToValueMap &Strides); @@ -564,6 +572,9 @@ DominatorTree *DT; LoopInfo *LI; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet ScPredicates; + unsigned NumLoads; unsigned NumStores; @@ -593,10 +604,30 @@ const ValueToValueMap &PtrToStride, Value *Ptr, Value *OrigPtr = nullptr); +///\brief Return the SCEV of a value having with all assumptions applied. This +/// will replace symbolic strides according to \p PtrToStride and apply any +/// existing SCEV assumptions contained in \p Preds. +const SCEV *rewriteSCEV(ScalarEvolution *SE, const ValueToValueMap &PtrToStride, + Value *Ptr, Value *OrigPtr, const Loop *L, + SCEVPredicateSet &Preds); + +///\brief Try and add a minimal set of assumptions that will cause the +/// re-written SCEV of \p Ptr to be an AddRecExpr. If successful we will +/// return a modified AddRecExpr and add any assumptions made to \p Preds. +/// Otherwise, we will make no new assumption and return the same result as +/// rewriteSCEV. +const SCEV *convertSCEVToAddRec(ScalarEvolution *SE, + const ValueToValueMap &PtrToStride, Value *Ptr, + Value *OrigPtr, const Loop *L, + SCEVPredicateSet &Preds); + /// \brief Check the stride of the pointer and ensure that it does not wrap in -/// the address space. +/// the address space. If \p MakeAssumptions is true, we will try to add +/// SCEV assumptions as necessary to \Pred in order to return true. If we +/// cannot return true, \p Pred will remain unchanged. int isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap); + const ValueToValueMap &StridesMap, SCEVPredicateSet &Pred, + bool MakeAssumptions); /// \brief This analysis provides dependence information for the memory accesses /// of a loop. @@ -621,7 +652,7 @@ /// /// If the client speculates (and then issues run-time checks) for the values /// of symbolic strides, \p Strides provides the mapping (see - /// replaceSymbolicStrideSCEV). If there is no cached result available run + /// replaceSymbolicStrideSCEV). If there is no cached result available run /// the analysis. const LoopAccessInfo &getInfo(Loop *L, const ValueToValueMap &Strides); Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -51,6 +51,8 @@ class SCEVUnknown; class SCEVAddRecExpr; class SCEV; + class SCEVExpander; + template<> struct FoldingSetTrait; /// SCEV - This class represents an analyzed expression in the program. These @@ -171,6 +173,131 @@ static bool classof(const SCEV *S); }; + enum SCEVPredicateTypes { pAddRecOverflow, pSet }; + + //===--------------------------------------------------------------------===// + /// SCEVPredicate - This class represents an assumption made using SCEV + /// expressions which can be checked at run-time. + /// + class SCEVPredicate { + protected: + unsigned short SCEVPredicateType; + + public: + SCEVPredicate(unsigned short Type); + unsigned short getType() const { return SCEVPredicateType; } + /// Returns true if the predicate is always true. This means that no + /// assumptions were made and nothing needs to be checked at run-time. + virtual bool isAlwaysTrue() const = 0; + /// Return true if we consider this to be always false or if we've + /// given up on this set of assumptions (for example because of the + /// high cost of checking at run-time). + virtual bool isAlwaysFalse() const = 0; + /// Returns true if this predicate implies \p N. + virtual bool contains(const SCEVPredicate *N) const = 0; + /// Prints a textual representation of this predicate. + virtual void print(raw_ostream &OS, unsigned Depth) const = 0; + + /// Generates a run-time check for this predicate. + virtual Value *generateCheck(Instruction *Loc, ScalarEvolution *SE, + const DataLayout *DL, SCEVExpander &Exp) = 0; + }; + + //===--------------------------------------------------------------------===// + /// SCEVAddRecOverflowPredicate - This class represents an assumption + /// made using on an AddRec expression. Given an affine AddRec expression + /// (a,+,b), we assume that it has nsw or nuw flags. + class SCEVAddRecOverflowPredicate : public SCEVPredicate { + const SCEV *AR; + SCEV::NoWrapFlags Flags; + + public: + SCEVAddRecOverflowPredicate() + : SCEVPredicate(pAddRecOverflow), AR(nullptr), + Flags(SCEV::FlagAnyWrap) {} + SCEVAddRecOverflowPredicate(const SCEV *AR, SCEV::NoWrapFlags Flags); + + /// Returns the set assumed no overflow flags. + SCEV::NoWrapFlags getFlags() const { return Flags; } + /// Add an assumption of no overflow for \p AddedFlags. + void addFlags(SCEV::NoWrapFlags AddedFlags); + /// Returns the AddRec expression that we've made assumptions for. + const SCEV *getExpr() const { return AR; } + + /// Implementation of the SCEVPredicate interface + bool contains(const SCEVPredicate *N) const override; + void print(raw_ostream &OS, unsigned Depth) const override; + bool isAlwaysTrue() const override; + bool isAlwaysFalse() const override; + Value *generateCheck(Instruction *Loc, ScalarEvolution *SE, + const DataLayout *DL, SCEVExpander &Exp) override; + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVPredicate *P) { + return P->getType() == pAddRecOverflow; + } + }; + + //===--------------------------------------------------------------------===// + /// SCEVPredicateSet - This class represents a composition of other + /// SCEV predicates, and is the class that most clients will interact with. + /// + class SCEVPredicateSet : public SCEVPredicate { + private: + /// Flag used to track if this predicate set is invalid. + bool Never; + /// Storage for different predicates that make up this Predicate Set. + SmallVector AddRecOverflows; + /// Vector with references to all predicates in this set. + SmallVector Preds; + + public: + SCEVPredicateSet(); + /// The copy constructor. + SCEVPredicateSet(const SCEVPredicateSet &Old); + /// Adds a predicate to this predicate set. + void add(const SCEVPredicate *N); + + /// Generates a run-time check for all the contained predicates. + /// This is a wrapper around generateCheck, and provides an interface + /// similar to other run-time checks used for versioning. + std::pair + generateGuardCond(Instruction *Loc, ScalarEvolution *SE); + + /// Implementation of the SCEVPredicate interface + bool isAlwaysTrue() const override; + bool isAlwaysFalse() const override; + bool contains(const SCEVPredicate *N) const override; + void print(raw_ostream &OS, unsigned Depth) const; + Value *generateCheck(Instruction *Loc, ScalarEvolution *SE, + const DataLayout *DL, SCEVExpander &Exp) override; + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVPredicate *P) { + return P->getType() == pSet; + } + + /// The copy operator. + const SCEVPredicateSet &operator=(const SCEVPredicateSet &RHS) { + Never = RHS.Never; + AddRecOverflows = RHS.AddRecOverflows; + Preds.clear(); + for (unsigned II = 0; II < AddRecOverflows.size(); ++II) { + Preds.push_back(&AddRecOverflows[II]); + } + assert(Preds.size() == RHS.Preds.size() && "Wrong Preds size after copy"); + return *this; + } + }; + + /// Associates a SCEV predicate to a SCEV. + struct AssumptionResult { + const SCEV *Start; + const SCEV *Res; + SCEVPredicateSet Pred; + AssumptionResult(const SCEV *Start) : Start(Start), Res(nullptr) {} + }; + /// The main scalar evolution driver. Because client code (intentionally) /// can't do much with the SCEV objects directly, they must ask this class /// for services. @@ -269,9 +396,14 @@ const SCEV *Exact; const SCEV *Max; + /// A predicate set guard for this ExitLimit. The result is only + /// valid if this predicate evaluates to 'true' at run-time. + SCEVPredicateSet Pred; + /*implicit*/ ExitLimit(const SCEV *E) : Exact(E), Max(E) {} - ExitLimit(const SCEV *E, const SCEV *M) : Exact(E), Max(M) {} + ExitLimit(const SCEV *E, const SCEV *M, SCEVPredicateSet &P) + : Exact(E), Max(M), Pred(P) {} /// hasAnyInfo - Test whether this ExitLimit contains any computed /// information, or whether it's all SCEVCouldNotCompute values. @@ -286,7 +418,8 @@ struct ExitNotTakenInfo { AssertingVH ExitingBlock; const SCEV *ExactNotTaken; - PointerIntPair NextExit; + PointerIntPair NextExit; + SCEVPredicateSet Pred; ExitNotTakenInfo() : ExitingBlock(nullptr), ExactNotTaken(nullptr) {} @@ -322,8 +455,9 @@ /// Initialize BackedgeTakenInfo from a list of exact exit counts. BackedgeTakenInfo( - SmallVectorImpl< std::pair > &ExitCounts, - bool Complete, const SCEV *MaxCount); + SmallVectorImpl> &ExitCounts, + SmallVectorImpl &ExitPreds, bool Complete, + const SCEV *MaxCount); /// hasAnyInfo - Test whether this BackedgeTakenInfo contains any /// computed information, or whether it's all SCEVCouldNotCompute @@ -333,11 +467,20 @@ } /// getExact - Return an expression indicating the exact backedge-taken - /// count of the loop if it is known, or SCEVCouldNotCompute - /// otherwise. This is the number of times the loop header can be - /// guaranteed to execute, minus one. + /// count of the loop if it is known and always correct (independent + /// of any assumptions that should be checked at run-time), or + /// SCEVCouldNotCompute otherwise. This is the number of times the + /// loop header can be guaranteed to execute, minus one. const SCEV *getExact(ScalarEvolution *SE) const; + /// getGuardedExact - Return an expression indicating the exact + /// backedge-taken count of the loop if it is known, or + /// SCEVCouldNotCompute otherwise. Returns the SCEV predicates that + /// need to be checked at run-time in order for this answer to be valid + /// in \p Predicates. This is the number of times the loop header can be + /// guaranteed to execute, minus one. + const SCEV *getGuardedExact(ScalarEvolution *SE, + SCEVPredicateSet &Predicates) const; /// getExact - Return the number of times this loop exit may fall through /// to the back edge, or SCEVCouldNotCompute. The loop is guaranteed not /// to exit via this block before this number of iterations, but may exit @@ -462,11 +605,10 @@ /// ComputeExitLimitFromICmp - Compute the number of times the backedge of /// the specified loop will execute if its exit condition were a conditional /// branch of the ICmpInst ExitCond, TBB, and FBB. - ExitLimit ComputeExitLimitFromICmp(const Loop *L, - ICmpInst *ExitCond, - BasicBlock *TBB, - BasicBlock *FBB, - bool IsSubExpr); + ExitLimit ComputeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, + BasicBlock *TBB, BasicBlock *FBB, + bool IsSubExpr, + bool UseAssumptions = false); /// ComputeExitLimitFromSingleExitSwitch - Compute the number of times the /// backedge of the specified loop will execute if its exit condition were a @@ -495,21 +637,25 @@ /// HowFarToZero - Return the number of times an exit condition comparing /// the specified value to zero will execute. If not computable, return /// CouldNotCompute. - ExitLimit HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr); + ExitLimit HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr, + bool UseAssumptions = false); /// HowFarToNonZero - Return the number of times an exit condition checking /// the specified value for nonzero will execute. If not computable, return /// CouldNotCompute. - ExitLimit HowFarToNonZero(const SCEV *V, const Loop *L); + ExitLimit HowFarToNonZero(const SCEV *V, const Loop *L, + bool UseAssumptions = false); /// HowManyLessThans - Return the number of times an exit condition /// containing the specified less-than comparison will execute. If not /// computable, return CouldNotCompute. isSigned specifies whether the /// less-than is signed. - ExitLimit HowManyLessThans(const SCEV *LHS, const SCEV *RHS, - const Loop *L, bool isSigned, bool IsSubExpr); + ExitLimit HowManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, + bool isSigned, bool IsSubExpr, + bool UseAssumptions = false); ExitLimit HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, - const Loop *L, bool isSigned, bool IsSubExpr); + const Loop *L, bool isSigned, bool IsSubExpr, + bool UseAssumptions = false); /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB /// (which may not be an immediate predecessor) which has exactly one @@ -852,6 +998,9 @@ /// const SCEV *getBackedgeTakenCount(const Loop *L); + const SCEV *getGuardedBackedgeTakenCount(const Loop *L, + SCEVPredicateSet &Predicates); + /// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except /// return the least SCEV value that is known never to be less than the /// actual backedge taken count. @@ -1069,6 +1218,17 @@ SmallVectorImpl &Sizes, const SCEV *ElementSize); + /// Re-writes the SCEV according to the Predicates in \p Preds, by + /// applying overflow assumptions and sinking sext/zext expressions. + const SCEV *rewriteUsingPredicate(const SCEV *Scev, const Loop *L, + SCEVPredicateSet &A); + + /// Tries to convert a SCEV into an AddRecExpr by making overflow + /// assumptions and sinking SCEV nodes. If unsuccessful, we will return + /// a nullptr in the Ret field. If succesful, the predicate set of the + /// answer must be checked at run-time in order for the answer to be + /// valid. + AssumptionResult getAddRecWithRTChecks(const SCEV *S, const Loop *L); private: /// Compute the backedge taken count knowing the interval difference, the /// stride and presence of the equality in the comparison. Index: include/llvm/Analysis/ScalarEvolutionExpressions.h =================================================================== --- include/llvm/Analysis/ScalarEvolutionExpressions.h +++ include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -23,6 +23,7 @@ class ConstantInt; class ConstantRange; class DominatorTree; + class SCEVExpander; enum SCEVTypes { // These should be ordered in terms of increasing complexity to make the @@ -745,12 +746,11 @@ LoopToScevMapT ⤅ }; -/// Applies the Map (Loop -> SCEV) to the given Scev. -static inline const SCEV *apply(const SCEV *Scev, LoopToScevMapT &Map, - ScalarEvolution &SE) { - return SCEVApplyRewriter::rewrite(Scev, Map, SE); -} - -} + /// Applies the Map (Loop -> SCEV) to the given Scev. + static inline const SCEV *apply(const SCEV *Scev, LoopToScevMapT &Map, + ScalarEvolution &SE) { + return SCEVApplyRewriter::rewrite(Scev, Map, SE); + } + } #endif Index: include/llvm/Transforms/Utils/LoopVersioning.h =================================================================== --- include/llvm/Transforms/Utils/LoopVersioning.h +++ include/llvm/Transforms/Utils/LoopVersioning.h @@ -23,6 +23,7 @@ class Loop; class LoopAccessInfo; class LoopInfo; +class ScalarEvolution; /// \brief This class emits a version of the loop where run-time checks ensure /// that may-alias pointers can't overlap. @@ -35,12 +36,12 @@ /// as input. It uses runtime check provided by user. LoopVersioning(SmallVector Checks, const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, - DominatorTree *DT); + DominatorTree *DT, ScalarEvolution *SE); /// \brief Expects LoopAccessInfo, Loop, LoopInfo, DominatorTree as input. /// It uses default runtime check provided by LoopAccessInfo. LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L, LoopInfo *LI, - DominatorTree *DT); + DominatorTree *DT, ScalarEvolution *SE); /// \brief Performs the CFG manipulation part of versioning the loop including /// the DominatorTree and LoopInfo updates. @@ -93,6 +94,7 @@ const LoopAccessInfo &LAI; LoopInfo *LI; DominatorTree *DT; + ScalarEvolution *SE; }; } Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -119,15 +119,45 @@ return SE->getSCEV(Ptr); } +const SCEV *llvm::rewriteSCEV(ScalarEvolution *SE, + const ValueToValueMap &PtrToStride, Value *Ptr, + Value *OrigPtr, const Loop *L, + SCEVPredicateSet &Preds) { + + const SCEV *Ret = replaceSymbolicStrideSCEV(SE, PtrToStride, Ptr, OrigPtr); + Ret = SE->rewriteUsingPredicate(Ret, L, Preds); + return Ret; +} + +const SCEV *llvm::convertSCEVToAddRec(ScalarEvolution *SE, + const ValueToValueMap &PtrToStride, + Value *Ptr, Value *OrigPtr, const Loop *L, + SCEVPredicateSet &Preds) { + + const SCEV *Ret = rewriteSCEV(SE, PtrToStride, Ptr, OrigPtr, L, Preds); + if (dyn_cast(Ret)) + return Ret; + + AssumptionResult R = SE->getAddRecWithRTChecks(Ret, L); + R.Pred.add(&Preds); + // Only commit to the new predicates if we've succeeded. + if (R.Res && !R.Pred.isAlwaysFalse()) { + Preds = R.Pred; + return R.Res; + } + return Ret; +} + void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, unsigned ASId, - const ValueToValueMap &Strides) { + const ValueToValueMap &Strides, + SCEVPredicateSet &Pred) { // Get the stride replaced scev. - const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const SCEV *Sc = rewriteSCEV(SE, Strides, Ptr, nullptr, Lp, Pred); const SCEVAddRecExpr *AR = dyn_cast(Sc); assert(AR && "Invalid addrec expression"); - const SCEV *Ex = SE->getBackedgeTakenCount(Lp); + const SCEV *Ex = SE->getGuardedBackedgeTakenCount(Lp, Pred); const SCEV *ScStart = AR->getStart(); const SCEV *ScEnd = AR->evaluateAtIteration(Ex, *SE); const SCEV *Step = AR->getStepRecurrence(*SE); @@ -417,9 +447,9 @@ typedef SmallPtrSet MemAccessInfoSet; AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI, - MemoryDepChecker::DepCandidates &DA) - : DL(Dl), AST(*AA), LI(LI), DepCands(DA), - IsRTCheckAnalysisNeeded(false) {} + MemoryDepChecker::DepCandidates &DA, SCEVPredicateSet &Pred) + : DL(Dl), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false), + Pred(Pred) {} /// \brief Register a load and whether it is only read from. void addLoad(MemoryLocation &Loc, bool IsReadOnly) { @@ -504,14 +534,18 @@ /// (i.e. ShouldRetryWithRuntimeCheck), isDependencyCheckNeeded is cleared /// while this remains set if we have potentially dependent accesses. bool IsRTCheckAnalysisNeeded; + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Pred; }; } // end anonymous namespace /// \brief Check whether a pointer can participate in a runtime bounds check. static bool hasComputableBounds(ScalarEvolution *SE, - const ValueToValueMap &Strides, Value *Ptr) { - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const ValueToValueMap &Strides, Value *Ptr, + Loop *L, SCEVPredicateSet &Pred) { + const SCEV *PtrScev = rewriteSCEV(SE, Strides, Ptr, nullptr, L, Pred); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR) return false; @@ -554,11 +588,21 @@ else ++NumReadPtrChecks; - if (hasComputableBounds(SE, StridesMap, Ptr) && - // When we run after a failing dependency check we have to make sure - // we don't have wrapping pointers. - (!ShouldCheckStride || - isStridedPtr(SE, Ptr, TheLoop, StridesMap) == 1)) { + bool Bounded = hasComputableBounds(SE, StridesMap, Ptr, TheLoop, Pred); + if (!Bounded) { + convertSCEVToAddRec(SE, StridesMap, Ptr, nullptr, TheLoop, Pred); + Bounded = hasComputableBounds(SE, StridesMap, Ptr, TheLoop, Pred); + } + + bool ValidStride = true; + if (ShouldCheckStride) { + ValidStride = + (isStridedPtr(SE, Ptr, TheLoop, StridesMap, Pred, true) == 1); + } + + // When we run after a failing dependency check we have to make sure + // we don't have wrapping pointers. + if (Bounded && ValidStride) { // The id of the dependence set. unsigned DepId; @@ -572,7 +616,7 @@ // Each access has its own dependence set. DepId = RunningDepId++; - RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap); + RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, Pred); DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); } else { @@ -803,7 +847,8 @@ /// \brief Check whether the access through \p Ptr has a constant stride. int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap) { + const ValueToValueMap &StridesMap, + SCEVPredicateSet &Pred, bool MakeAssumptions) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); @@ -815,16 +860,23 @@ return 0; } - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Ptr); + const SCEV *PtrScev = rewriteSCEV(SE, StridesMap, Ptr, nullptr, Lp, Pred); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR) { - DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " - << *Ptr << " SCEV: " << *PtrScev << "\n"); - return 0; + // It's not an AddRecExpr. Try to force an AddRecExpr by making + // assumptions which can be checked at run-time. + const SCEV *Retry = + convertSCEVToAddRec(SE, StridesMap, Ptr, nullptr, Lp, Pred); + AR = dyn_cast(Retry); + if (!AR) { + DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " << *Ptr + << " SCEV: " << *PtrScev << "\n"); + return 0; + } } - // The accesss function must stride over the innermost loop. + // The access function must stride over the innermost loop. if (Lp != AR->getLoop()) { DEBUG(dbgs() << "LAA: Bad stride - Not striding over innermost loop " << *Ptr << " SCEV: " << *PtrScev << "\n"); @@ -1026,11 +1078,15 @@ BPtr->getType()->getPointerAddressSpace()) return Dependence::Unknown; - const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, APtr); - const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, BPtr); - - int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides); - int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides); + const SCEV *AScev = + rewriteSCEV(SE, Strides, APtr, nullptr, InnermostLoop, Pred); + const SCEV *BScev = + rewriteSCEV(SE, Strides, BPtr, nullptr, InnermostLoop, Pred); + + // Make assumptions here, otherwise we're guaranteed to end up with + // an unknown dependence. + int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides, Pred, true); + int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides, Pred, true); const SCEV *Src = AScev; const SCEV *Sink = BScev; @@ -1057,7 +1113,7 @@ // Need accesses with constant stride. We don't want to vectorize // "A[B[i]] += ..." and similar code or pointer arithmetic that could wrap in // the address space. - if (!StrideAPtr || !StrideBPtr || StrideAPtr != StrideBPtr){ + if (!StrideAPtr || !StrideBPtr || StrideAPtr != StrideBPtr) { DEBUG(dbgs() << "Pointer access with non-constant stride\n"); return Dependence::Unknown; } @@ -1325,10 +1381,11 @@ } // ScalarEvolution needs to be able to find the exit count. - const SCEV *ExitCount = SE->getBackedgeTakenCount(TheLoop); + const SCEV *ExitCount = SE->getGuardedBackedgeTakenCount(TheLoop, *Pred); + if (ExitCount == SE->getCouldNotCompute()) { - emitAnalysis(LoopAccessReport() << - "could not determine number of loop iterations"); + emitAnalysis(LoopAccessReport() + << "could not determine number of loop iterations"); DEBUG(dbgs() << "LAA: SCEV could not compute the loop exit count.\n"); return false; } @@ -1429,7 +1486,7 @@ MemoryDepChecker::DepCandidates DependentAccesses; AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(), - AA, LI, DependentAccesses); + AA, LI, DependentAccesses, *Pred); // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects // multiple times on the same object. If the ptr is accessed twice, once @@ -1480,7 +1537,8 @@ // read a few words, modify, and write a few words, and some of the // words may be written to the same address. bool IsReadOnlyPtr = false; - if (Seen.insert(Ptr).second || !isStridedPtr(SE, Ptr, TheLoop, Strides)) { + if (Seen.insert(Ptr).second || + !isStridedPtr(SE, Ptr, TheLoop, Strides, *Pred, false)) { ++NumReads; IsReadOnlyPtr = true; } @@ -1723,8 +1781,9 @@ const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, const ValueToValueMap &Strides) - : PtrRtChecking(SE), DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL), - TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), + : Pred(new SCEVPredicateSet), PtrRtChecking(SE), DepChecker(SE, L, *Pred), + TheLoop(L), SE(SE), DL(DL), TLI(TLI), + AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { if (canAnalyzeLoop()) @@ -1758,11 +1817,16 @@ OS.indent(Depth) << "Store to invariant address was " << (StoreToLoopInvariantAddress ? "" : "not ") << "found in loop.\n"; + + OS.indent(Depth) << "SCEV assumptions:\n"; + Pred->print(OS, Depth); } const LoopAccessInfo & LoopAccessAnalysis::getInfo(Loop *L, const ValueToValueMap &Strides) { - auto &LAI = LoopAccessInfoMap[L]; + DenseMap> &Map = LoopAccessInfoMap; + + auto &LAI = Map[L]; #ifndef NDEBUG assert((!LAI || LAI->NumSymbolicStrides == Strides.size()) && Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -67,6 +67,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -80,6 +81,8 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" @@ -109,6 +112,13 @@ "derived loop"), cl::init(100)); +static cl::opt +OverflowCheckThreshold("force-max-overflow-checks", cl::init(16), + cl::Hidden, + cl::desc("Don't create SCEV predicates with more than " + "this number of assumptions.")); + + // FIXME: Enable this with XDEBUG when the test suite is clean. static cl::opt VerifySCEV("verify-scev", @@ -4696,6 +4706,12 @@ return getBackedgeTakenInfo(L).getExact(this); } +const SCEV * +ScalarEvolution::getGuardedBackedgeTakenCount(const Loop *L, + SCEVPredicateSet &Predicates) { + return getBackedgeTakenInfo(L).getGuardedExact(this, Predicates); +} + /// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except /// return the least SCEV value that is known never to be less than the /// actual backedge taken count. @@ -4883,6 +4899,9 @@ assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); + if (ENT->Pred.isAlwaysFalse() || !ENT->Pred.isAlwaysTrue()) + return SE->getCouldNotCompute(); + if (!BECount) BECount = ENT->ExactNotTaken; else if (BECount != ENT->ExactNotTaken) @@ -4892,14 +4911,47 @@ return BECount; } +const SCEV *ScalarEvolution::BackedgeTakenInfo::getGuardedExact( + ScalarEvolution *SE, SCEVPredicateSet &Predicates) const { + // If any exits were not computable, the loop is not computable. + if (!ExitNotTaken.isCompleteList()) + return SE->getCouldNotCompute(); + + // We need exactly one computable exit. + if (!ExitNotTaken.ExitingBlock) + return SE->getCouldNotCompute(); + assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info"); + + const SCEV *BECount = nullptr; + SCEVPredicateSet Pred; + for (const ExitNotTakenInfo *ENT = &ExitNotTaken; ENT != nullptr; + ENT = ENT->getNextExit()) { + + assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); + + if (!BECount) { + BECount = ENT->ExactNotTaken; + } else if (BECount != ENT->ExactNotTaken) { + return SE->getCouldNotCompute(); + } + Predicates.add(&(ENT->Pred)); + } + assert(BECount && "Invalid not taken count for loop exit"); + + if (Predicates.isAlwaysFalse()) + return SE->getCouldNotCompute(); + + return BECount; +} + /// getExact - Get the exact not taken count for this loop exit. const SCEV * ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, ScalarEvolution *SE) const { - for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != nullptr; ENT = ENT->getNextExit()) { + for (const ExitNotTakenInfo *ENT = &ExitNotTaken; ENT != nullptr; + ENT = ENT->getNextExit()) { - if (ENT->ExitingBlock == ExitingBlock) + if (ENT->ExitingBlock == ExitingBlock && ENT->Pred.isAlwaysTrue()) return ENT->ExactNotTaken; } return SE->getCouldNotCompute(); @@ -4908,6 +4960,11 @@ /// getMax - Get the max backedge taken count for the loop. const SCEV * ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { + for (const ExitNotTakenInfo *ENT = &ExitNotTaken; ENT != nullptr; + ENT = ENT->getNextExit()) { + if (!ENT->Pred.isAlwaysTrue()) + return SE->getCouldNotCompute(); + } return Max ? Max : SE->getCouldNotCompute(); } @@ -4933,8 +4990,10 @@ /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( - SmallVectorImpl< std::pair > &ExitCounts, - bool Complete, const SCEV *MaxCount) : Max(MaxCount) { + SmallVectorImpl> &ExitCounts, + SmallVectorImpl &ExitPreds, bool Complete, + const SCEV *MaxCount) + : Max(MaxCount) { if (!Complete) ExitNotTaken.setIncomplete(); @@ -4944,7 +5003,10 @@ ExitNotTaken.ExitingBlock = ExitCounts[0].first; ExitNotTaken.ExactNotTaken = ExitCounts[0].second; - if (NumExits == 1) return; + ExitNotTaken.Pred = *ExitPreds[0]; + + if (NumExits == 1) + return; // Handle the rare case of multiple computable exits. ExitNotTakenInfo *ENT = new ExitNotTakenInfo[NumExits-1]; @@ -4954,6 +5016,7 @@ PrevENT->setNextExit(ENT); ENT->ExitingBlock = ExitCounts[i].first; ENT->ExactNotTaken = ExitCounts[i].second; + ENT->Pred = *ExitPreds[i]; } } @@ -4972,6 +5035,7 @@ L->getExitingBlocks(ExitingBlocks); SmallVector, 4> ExitCounts; + SmallVector ExitCountPreds; bool CouldComputeBECount = true; BasicBlock *Latch = L->getLoopLatch(); // may be NULL. const SCEV *MustExitMaxBECount = nullptr; @@ -4979,6 +5043,7 @@ // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts // and compute maxBECount. + // Do a union of all the predicates here. for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { BasicBlock *ExitBB = ExitingBlocks[i]; ExitLimit EL = ComputeExitLimit(L, ExitBB); @@ -4989,8 +5054,10 @@ // We couldn't compute an exact value for this exit, so // we won't be able to compute an exact value for the loop. CouldComputeBECount = false; - else + else { ExitCounts.push_back(std::make_pair(ExitBB, EL.Exact)); + ExitCountPreds.push_back(&EL.Pred); + } // 2. Derive the loop's MaxBECount from each exit's max number of // non-exiting iterations. Partition the loop exits into two kinds: @@ -5018,9 +5085,12 @@ } } } - const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount : - (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); - return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount); + const SCEV *MaxBECount = + MustExitMaxBECount + ? MustExitMaxBECount + : (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); + return BackedgeTakenInfo(ExitCounts, ExitCountPreds, CouldComputeBECount, + MaxBECount); } /// ComputeExitLimit - Compute the number of times the backedge of the specified @@ -5153,7 +5223,10 @@ BECount = EL0.Exact; } - return ExitLimit(BECount, MaxBECount); + SCEVPredicateSet NP; + NP.add(&EL0.Pred); + NP.add(&EL1.Pred); + return ExitLimit(BECount, MaxBECount, NP); } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. @@ -5188,7 +5261,10 @@ BECount = EL0.Exact; } - return ExitLimit(BECount, MaxBECount); + SCEVPredicateSet NP; + NP.add(&EL0.Pred); + NP.add(&EL1.Pred); + return ExitLimit(BECount, MaxBECount, NP); } } @@ -5217,12 +5293,9 @@ /// ComputeExitLimitFromICmp - Compute the number of times the /// backedge of the specified loop will execute if its exit condition /// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB. -ScalarEvolution::ExitLimit -ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, - ICmpInst *ExitCond, - BasicBlock *TBB, - BasicBlock *FBB, - bool ControlsExit) { +ScalarEvolution::ExitLimit ScalarEvolution::ComputeExitLimitFromICmp( + const Loop *L, ICmpInst *ExitCond, BasicBlock *TBB, BasicBlock *FBB, + bool ControlsExit, bool UseAssumptions) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Cond; @@ -5272,30 +5345,37 @@ } switch (Cond) { - case ICmpInst::ICMP_NE: { // while (X != Y) + case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); - if (EL.hasAnyInfo()) return EL; + ExitLimit EL = + HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, UseAssumptions); + if (EL.hasAnyInfo()) + return EL; break; } - case ICmpInst::ICMP_EQ: { // while (X == Y) + case ICmpInst::ICMP_EQ: { // while (X == Y) // Convert to: while (X-Y == 0) - ExitLimit EL = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); - if (EL.hasAnyInfo()) return EL; + ExitLimit EL = HowFarToNonZero(getMinusSCEV(LHS, RHS), L, UseAssumptions); + if (EL.hasAnyInfo()) + return EL; break; } case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_ULT: { // while (X < Y) + case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = Cond == ICmpInst::ICMP_SLT; - ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit); - if (EL.hasAnyInfo()) return EL; + ExitLimit EL = + HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, UseAssumptions); + if (EL.hasAnyInfo()) + return EL; break; } case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_UGT: { // while (X > Y) + case ICmpInst::ICMP_UGT: { // while (X > Y) bool IsSigned = Cond == ICmpInst::ICMP_SGT; - ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit); - if (EL.hasAnyInfo()) return EL; + ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit, + UseAssumptions); + if (EL.hasAnyInfo()) + return EL; break; } default: @@ -5309,7 +5389,14 @@ #endif break; } - return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); + ExitLimit EL = ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); + + if (EL.hasAnyInfo() || UseAssumptions) + return EL; + + // We could not prove what the exit limit is without making + // assumptions. Try to compute it using assumptions. + return ComputeExitLimitFromICmp(L, ExitCond, TBB, FBB, ControlsExit, true); } ScalarEvolution::ExitLimit @@ -6198,8 +6285,12 @@ /// now expressed as a single expression, V = x-y. So the exit test is /// effectively V != 0. We know and take advantage of the fact that this /// expression only being used in a comparison by zero context. -ScalarEvolution::ExitLimit -ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { +ScalarEvolution::ExitLimit ScalarEvolution::HowFarToZero(const SCEV *V, + const Loop *L, + bool ControlsExit, + bool UseAssumptions) { + SCEVPredicateSet P; + // If the value is a constant if (const SCEVConstant *C = dyn_cast(V)) { // If the value is already zero, the branch will execute zero times. @@ -6208,6 +6299,20 @@ } const SCEVAddRecExpr *AddRec = dyn_cast(V); + if ((!AddRec) && UseAssumptions) { + // Try to make this a chrec using runtime assumptions. + //AssumptionResult R = removeOverflowsWithAssumptions(V, L, this); + AssumptionResult R = getAddRecWithRTChecks(V, L); + if (!R.Res) + return getCouldNotCompute(); + if (R.Pred.isAlwaysFalse()) + return getCouldNotCompute(); + AddRec = dyn_cast(R.Res); + if (!AddRec) + return getCouldNotCompute(); + P.add(&R.Pred); + } + if (!AddRec || AddRec->getLoop() != L) return getCouldNotCompute(); @@ -6236,7 +6341,7 @@ // should not accept a root of 2. const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); if (Val->isZero()) - return R1; // We found a quadratic root! + return ExitLimit(R1, R1, P); // We found a quadratic root! } } return getCouldNotCompute(); @@ -6291,9 +6396,9 @@ ? getConstant(APInt::getMinValue(CR.getBitWidth())) : getConstant(APInt::getMaxValue(CR.getBitWidth())); else - MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() - : -CR.getUnsignedMin()); - return ExitLimit(Distance, MaxBECount); + MaxBECount = + getConstant(CountDown ? CR.getUnsignedMax() : -CR.getUnsignedMin()); + return ExitLimit(Distance, MaxBECount, P); } // As a special case, handle the instance where Step is a positive power of @@ -6306,8 +6411,10 @@ // also returns true if StepV is maximally negative (eg, INT_MIN), but that // case is not handled as this code is guarded by !CountDown. if (StepV.isPowerOf2() && - GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) - return getUDivExactExpr(Distance, Step); + GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) { + const SCEV *E = getUDivExactExpr(Distance, Step); + return ExitLimit(E, E, P); + } } // If the condition controls loop exit (the loop exits only if the expression @@ -6318,14 +6425,15 @@ if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) { const SCEV *Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact); + return ExitLimit(Exact, Exact, P); } // Then, try to solve the above equation provided that Start is constant. - if (const SCEVConstant *StartC = dyn_cast(Start)) - return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), - -StartC->getValue()->getValue(), - *this); + if (const SCEVConstant *StartC = dyn_cast(Start)) { + const SCEV *E = SolveLinEquationWithOverflow( + StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); + return ExitLimit(E, E, P); + } return getCouldNotCompute(); } @@ -6333,7 +6441,8 @@ /// specified value for nonzero will execute. If not computable, return /// CouldNotCompute ScalarEvolution::ExitLimit -ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { +ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L, + bool UseAssumptions) { // Loops that look like: while (X == 0) are very strange indeed. We don't // handle them yet except for the trivial case. This could be expanded in the // future as needed. @@ -7529,17 +7638,33 @@ ScalarEvolution::ExitLimit ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsExit) { + bool ControlsExit, bool UseAssumptions) { + SCEVPredicateSet P; + // We handle only IV < Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); const SCEVAddRecExpr *IV = dyn_cast(LHS); + if (!IV && UseAssumptions) { + // Try to make this a chrec using runtime assumptions. + AssumptionResult R = getAddRecWithRTChecks(LHS, L); + if (!R.Res) + return getCouldNotCompute(); + if (R.Pred.isAlwaysFalse()) + return getCouldNotCompute(); + IV = dyn_cast(R.Res); + if (!IV) + return getCouldNotCompute(); + P.add(&R.Pred); + } + // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) return getCouldNotCompute(); + // FIXME: we can assume NoWrap here if necessary and check at runtime. bool NoWrap = ControlsExit && IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); @@ -7603,18 +7728,32 @@ if (isa(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount, P); } ScalarEvolution::ExitLimit ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsExit) { + bool ControlsExit, bool UseAssumptions) { + SCEVPredicateSet P; + // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); const SCEVAddRecExpr *IV = dyn_cast(LHS); + if (!IV && UseAssumptions) { + // Try to make this a chrec using runtime assumptions. + AssumptionResult R = getAddRecWithRTChecks(LHS, L); + if (!R.Res) + return getCouldNotCompute(); + if (R.Pred.isAlwaysFalse()) + return getCouldNotCompute(); + IV = dyn_cast(R.Res); + if (!IV) + return getCouldNotCompute(); + P.add(&R.Pred); + } // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) @@ -7685,7 +7824,7 @@ if (isa(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount, P); } /// getNumIterationsInRange - Return the number of iterations of this loop that @@ -8825,3 +8964,453 @@ AU.addRequiredTransitive(); AU.addRequiredTransitive(); } + +static Value *generateOverflowCheck(const SCEVAddRecExpr *AR, Instruction *Loc, + bool Signed, ScalarEvolution *SE, + const DataLayout *DL, SCEVExpander &Exp) { + Module *M = Loc->getParent()->getParent()->getParent(); + IRBuilder<> OFBuilder(Loc); + Value *AddF, *MulF; + if (Signed) { + AddF = Intrinsic::getDeclaration(M, Intrinsic::sadd_with_overflow, + AR->getType()); + MulF = Intrinsic::getDeclaration(M, Intrinsic::smul_with_overflow, + AR->getType()); + } else { + AddF = Intrinsic::getDeclaration(M, Intrinsic::uadd_with_overflow, + AR->getType()); + MulF = Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, + AR->getType()); + } + Value *Start; + Value *Stride; + + SCEVPredicateSet MP; + const SCEV *ExitCount = SE->getGuardedBackedgeTakenCount(AR->getLoop(), MP); + + unsigned DstBits = AR->getType()->getPrimitiveSizeInBits(); + unsigned SrcBits = ExitCount->getType()->getPrimitiveSizeInBits(); + + if (SrcBits < DstBits) { + // We need to extend + if (Signed) + ExitCount = SE->getNoopOrSignExtend(ExitCount, AR->getType()); + else + ExitCount = SE->getNoopOrZeroExtend(ExitCount, AR->getType()); + } + + assert(ExitCount != SE->getCouldNotCompute() && "Invalid loop count"); + Value *TripCount = Exp.expandCodeFor(ExitCount, ExitCount->getType(), Loc); + Value *TripCountCheck = nullptr; + + // We might need to truncate TripCount + // If this is the case, we need to make sure that this is legal. + if (SrcBits > DstBits) { + APInt CmpMaxValue = Signed ? APInt::getSignedMaxValue(DstBits).sext(SrcBits) + : APInt::getMaxValue(DstBits).zext(SrcBits); + // The min value only makes sense for signed checks. + + ConstantInt *CTMax = ConstantInt::get(M->getContext(), CmpMaxValue); + CmpInst::Predicate P = Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; + TripCountCheck = OFBuilder.CreateICmp(P, TripCount, CTMax); + + if (Signed) { + APInt CmpMinValue = APInt::getSignedMinValue(DstBits).sext(SrcBits); + ConstantInt *CTMin = ConstantInt::get(M->getContext(), CmpMinValue); + Value *MinCheck = + OFBuilder.CreateICmp(ICmpInst::ICMP_SLT, TripCount, CTMin); + TripCountCheck = OFBuilder.CreateOr(TripCountCheck, MinCheck); + } + + TripCount = OFBuilder.CreateTrunc(TripCount, AR->getType()); + } + + // We need to truncate or extend TripCount to the type used by the SCEV + // Extension is not a problem. + Start = Exp.expandCodeFor(AR->getStart(), AR->getStart()->getType(), Loc); + + // This is an affine expression + Stride = + Exp.expandCodeFor(AR->getOperand(1), AR->getOperand(1)->getType(), Loc); + + CallInst *Mul = OFBuilder.CreateCall(MulF, {Stride, TripCount}, "mul"); + Value *MulV = OFBuilder.CreateExtractValue(Mul, 0, "mul.result"); + Value *OfMul = OFBuilder.CreateExtractValue(Mul, 1, "mul.overflow"); + CallInst *Add = OFBuilder.CreateCall(AddF, {MulV, Start}, "uadd"); + Value *OfAdd = OFBuilder.CreateExtractValue(Add, 1, "add.overflow"); + Value *Overflow = OFBuilder.CreateOr(OfMul, OfAdd, "overflow"); + + if (TripCountCheck) { + Overflow = OFBuilder.CreateOr(Overflow, TripCountCheck); + } + + return Overflow; +} + +static Instruction *getFirstInst(Instruction *FirstInst, Value *V, + Instruction *Loc) { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; +} + +/// Removes overflows and records the assumptions that were made. +struct SCEVOverflowRewriter + : public SCEVVisitor { +public: + SCEVPredicateSet &P; + + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE, SCEVPredicateSet &A, + bool Assume) { + SCEVOverflowRewriter Rewriter(L, SE, A, Assume); + return Rewriter.visit(Scev); + } + + bool addOverflowAssumption(const SCEV *S, SCEV::NoWrapFlags AddedFlags) { + SCEVAddRecOverflowPredicate A(S, AddedFlags); + if (!MakeAssumptions) { + // Check if we've already made this assumption. + if (P.contains(&A)) + return true; + return false; + } + P.add(&A); + return true; + } + + SCEVOverflowRewriter(const Loop *L, ScalarEvolution &S, SCEVPredicateSet &P, + bool MakeAssumptions) + : P(P), SE(S), L(L), MakeAssumptions(MakeAssumptions) {} + + const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } + + const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getTruncateExpr(Operand, Expr->getType()); + } + + // We should only need to add assumptions when encountering the + // sext/zext expressions, as other expressions will fold into + // the AddRecExprs. + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nuw + // flag. Add the nuw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + // We would also like to add the NUW flag here. We add the NUW + // flag to the assumption set but cannot set it on the expression + // because it would pollute ScalarEvolution's cache. + if (addOverflowAssumption(AR, SCEV::FlagNUW)) + return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), + SE.getZeroExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getZeroExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nsw + // flag. Add the nsw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + // We would also like to add the NUW flag here. We add the NUW + // flag to the assumption set but cannot set it on the expression + // because it would pollute ScalarEvolution's cache. + if (addOverflowAssumption(AR, SCEV::FlagNSW)) + return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getSignExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getAddExpr(Operands); + } + + const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getMulExpr(Operands); + } + + const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { + return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS())); + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + + const Loop *L = Expr->getLoop(); + return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); + } + + const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getSMaxExpr(Operands); + } + + const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getUMaxExpr(Operands); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } + + const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return Expr; + } + +private: + ScalarEvolution &SE; + const Loop *L; + bool MakeAssumptions; +}; + +const SCEV * +ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, const Loop *L, + SCEVPredicateSet &Pred) { + return SCEVOverflowRewriter::rewrite(Scev, L, *this, Pred, false); +} + +AssumptionResult +ScalarEvolution::getAddRecWithRTChecks(const SCEV *S, const Loop *L) { + AssumptionResult Result(S); + const SCEV *Ret = SCEVOverflowRewriter::rewrite(S, L, *this, + Result.Pred, true); + if (dyn_cast(Ret) || + dyn_cast(Ret)) { + Result.Res = Ret; + } + return Result; +} + +//// SCEV predicates +SCEVPredicate::SCEVPredicate(unsigned short Type) : SCEVPredicateType(Type) {} + +SCEVAddRecOverflowPredicate::SCEVAddRecOverflowPredicate( + const SCEV *AR, SCEV::NoWrapFlags Flags) + : SCEVPredicate(pAddRecOverflow), AR(AR), Flags(Flags) { + assert(dyn_cast(AR) && + "Can only create a" + "SCEVAddRecOverflowPredicate using a SCEVAddRecExpr"); +} + +bool SCEVAddRecOverflowPredicate::contains(const SCEVPredicate *N) const { + const SCEVAddRecOverflowPredicate *OP = + dyn_cast(N); + + if (!OP) + return false; + + if (OP->getExpr() != AR) + return false; + + if ((OP->getFlags() & getFlags()) != OP->getFlags()) + return false; + + return true; +} + +bool SCEVAddRecOverflowPredicate::isAlwaysTrue() const { + const SCEVAddRecExpr *A = static_cast(AR); + return ScalarEvolution::clearFlags(Flags, A->getNoWrapFlags()) == + SCEV::FlagAnyWrap; +} + +bool SCEVAddRecOverflowPredicate::isAlwaysFalse() const { return false; } + +Value *SCEVAddRecOverflowPredicate::generateCheck(Instruction *Loc, + ScalarEvolution *SE, + const DataLayout *DL, + SCEVExpander &Exp) { + IRBuilder<> OFBuilder(Loc); + const SCEVAddRecExpr *A = static_cast(AR); + Value *OverflowRuntimeCheck = nullptr; + + if (Flags & SCEV::FlagNUW) { + // Add a check for NUW + Value *Overflow = generateOverflowCheck(A, Loc, false, SE, DL, Exp); + + if (!OverflowRuntimeCheck) + OverflowRuntimeCheck = Overflow; + else + OverflowRuntimeCheck = OFBuilder.CreateOr(OverflowRuntimeCheck, Overflow); + } + if (Flags & SCEV::FlagNSW) { + // Add a check for NSW + Value *Overflow = generateOverflowCheck(A, Loc, true, SE, DL, Exp); + + if (!OverflowRuntimeCheck) + OverflowRuntimeCheck = Overflow; + else + OverflowRuntimeCheck = OFBuilder.CreateOr(OverflowRuntimeCheck, Overflow); + } + return OverflowRuntimeCheck; +} + +void SCEVAddRecOverflowPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << *getExpr() << " Added Flags: "; + if (SCEV::FlagNUW & getFlags()) + OS << ""; + if (SCEV::FlagNSW & getFlags()) + OS << ""; + OS << "\n"; +} + +SCEVPredicateSet::SCEVPredicateSet() : SCEVPredicate(pSet), Never(false) {} + +SCEVPredicateSet::SCEVPredicateSet(const SCEVPredicateSet &Old) + : SCEVPredicateSet() { + this->Never = Old.Never; + if (Never) + return; + AddRecOverflows = Old.AddRecOverflows; + for (unsigned II = 0; II < AddRecOverflows.size(); ++II) { + Preds.push_back(&AddRecOverflows[II]); + } +} + +bool SCEVPredicateSet::isAlwaysTrue() const { + if (Never) + return false; + + for (auto II = Preds.begin(), EE = Preds.end(); II != EE; ++II) { + const SCEVPredicate *OA = *II; + if (!OA->isAlwaysTrue()) + return false; + } + + return true; +} + +bool SCEVPredicateSet::isAlwaysFalse() const { return Never; } + +std::pair +SCEVPredicateSet::generateGuardCond(Instruction *Loc, ScalarEvolution *SE) { + Instruction *tnullptr = nullptr; + + assert(!Never && "Cannot generate a runtime check on " + "a predicate with the Never flag set"); + + if (isAlwaysTrue()) + return std::pair(tnullptr, tnullptr); + + IRBuilder<> OFBuilder(Loc); + Instruction *FirstInst = nullptr; + Module *M = Loc->getParent()->getParent()->getParent(); + const DataLayout &DL = M->getDataLayout(); + SCEVExpander Exp(*SE, DL, "start"); + + Value *Check = generateCheck(Loc, SE, &DL, Exp); + + if (!Check) + return std::make_pair(nullptr, nullptr); + + Instruction *CheckInst = + BinaryOperator::CreateOr(Check, ConstantInt::getFalse(M->getContext())); + OFBuilder.Insert(CheckInst, "scev.check"); + + FirstInst = getFirstInst(FirstInst, CheckInst, Loc); + return std::make_pair(FirstInst, CheckInst); +} + +Value *SCEVPredicateSet::generateCheck(Instruction *Loc, ScalarEvolution *SE, + const DataLayout *DL, + SCEVExpander &Exp) { + + IRBuilder<> OFBuilder(Loc); + Value *AllCheck = nullptr; + + // Loop over all checks in this set. + for (auto II = Preds.begin(), EE = Preds.end(); II != EE; ++II) { + SCEVPredicate *OA = *II; + + if (OA->isAlwaysTrue()) + continue; + + Value *CheckResult = OA->generateCheck(Loc, SE, DL, Exp); + + if (!AllCheck) + AllCheck = CheckResult; + else + AllCheck = OFBuilder.CreateOr(AllCheck, CheckResult); + } + + return AllCheck; +} + +bool SCEVPredicateSet::contains(const SCEVPredicate *N) const { + if (Never) + return false; + + if (const SCEVPredicateSet *Set = dyn_cast(N)) { + for (auto II = Set->Preds.begin(), EE = Set->Preds.end(); II != EE; ++II) { + if (!contains(*II)) + return false; + } + return true; + } + for (auto II = Preds.begin(), EE = Preds.end(); II != EE; ++II) { + if ((*II)->contains(N)) + return true; + } + return false; +} + +void SCEVPredicateSet::print(raw_ostream &OS, unsigned Depth) const { + for (auto II = Preds.begin(), EE = Preds.end(); II != EE; ++II) + (*II)->print(OS, Depth); +} + +void SCEVPredicateSet::add(const SCEVPredicate *N) { + if (Preds.size() > OverflowCheckThreshold || N->isAlwaysFalse()) { + Never = true; + return; + } + + if (const SCEVAddRecOverflowPredicate *OP = + dyn_cast(N)) { + + const SCEVAddRecExpr *AR = + static_cast(OP->getExpr()); + for (unsigned II = 0, EE = AddRecOverflows.size(); II < EE; ++II) { + if (AddRecOverflows[II].getExpr() == AR) { + AddRecOverflows[II].addFlags(OP->getFlags()); + return; + } + } + AddRecOverflows.push_back(*OP); + Preds.push_back(&AddRecOverflows.back()); + + } else if (const SCEVPredicateSet *Set = + dyn_cast(N)) { + for (auto II = Set->Preds.begin(), EE = Set->Preds.end(); II != EE; ++II) { + add(*II); + } + } else + llvm_unreachable("Unknown SCEV predicate type!"); +} + +void SCEVAddRecOverflowPredicate::addFlags(SCEV::NoWrapFlags AddedFlags) { + Flags = ScalarEvolution::setFlags(Flags, AddedFlags); +} Index: lib/Transforms/Scalar/LoopDistribute.cpp =================================================================== --- lib/Transforms/Scalar/LoopDistribute.cpp +++ lib/Transforms/Scalar/LoopDistribute.cpp @@ -595,6 +595,7 @@ LI = &getAnalysis().getLoopInfo(); LAA = &getAnalysis(); DT = &getAnalysis().getDomTree(); + SE = &getAnalysis().getSE(); // Build up a worklist of inner-loops to vectorize. This is necessary as the // act of distributing a loop creates new loops and can invalidate iterators @@ -617,6 +618,7 @@ } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); AU.addRequired(); AU.addPreserved(); AU.addRequired(); @@ -790,10 +792,10 @@ const auto &AllChecks = RtPtrChecking->getChecks(); auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition, RtPtrChecking); - if (!Checks.empty()) { + if ((!LAI.Pred->isAlwaysTrue()) || !Checks.empty()) { DEBUG(dbgs() << "\nPointers:\n"); DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks)); - LoopVersioning LVer(std::move(Checks), LAI, L, LI, DT); + LoopVersioning LVer(std::move(Checks), LAI, L, LI, DT, SE); LVer.versionLoop(); LVer.addPHINodes(DefsUsedOutside); } @@ -821,6 +823,7 @@ LoopInfo *LI; LoopAccessAnalysis *LAA; DominatorTree *DT; + ScalarEvolution *SE; }; } // anonymous namespace @@ -831,6 +834,7 @@ INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopAccessAnalysis) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_END(LoopDistribute, LDIST_NAME, ldist_name, false, false) namespace llvm { Index: lib/Transforms/Utils/LoopVersioning.cpp =================================================================== --- lib/Transforms/Utils/LoopVersioning.cpp +++ lib/Transforms/Utils/LoopVersioning.cpp @@ -24,18 +24,20 @@ LoopVersioning::LoopVersioning( SmallVector Checks, - const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, DominatorTree *DT) + const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, DominatorTree *DT, + ScalarEvolution *SE) : VersionedLoop(L), NonVersionedLoop(nullptr), Checks(std::move(Checks)), - LAI(LAI), LI(LI), DT(DT) { + LAI(LAI), LI(LI), DT(DT), SE(SE) { assert(L->getExitBlock() && "No single exit block"); assert(L->getLoopPreheader() && "No preheader"); } LoopVersioning::LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L, - LoopInfo *LI, DominatorTree *DT) + LoopInfo *LI, DominatorTree *DT, + ScalarEvolution *SE) : VersionedLoop(L), NonVersionedLoop(nullptr), Checks(LAInfo.getRuntimePointerChecking()->getChecks()), LAI(LAInfo), - LI(LI), DT(DT) { + LI(LI), DT(DT), SE(SE) { assert(L->getExitBlock() && "No single exit block"); assert(L->getLoopPreheader() && "No preheader"); } @@ -43,11 +45,26 @@ void LoopVersioning::versionLoop() { Instruction *FirstCheckInst; Instruction *MemRuntimeCheck; + Instruction *OverflowRuntimeCheck; + Instruction *RuntimeCheck = nullptr; + // Add the memcheck in the original preheader (this is empty initially). BasicBlock *MemCheckBB = VersionedLoop->getLoopPreheader(); std::tie(FirstCheckInst, MemRuntimeCheck) = LAI.addRuntimeChecks(MemCheckBB->getTerminator(), Checks); assert(MemRuntimeCheck && "called even though needsAnyChecking = false"); + std::tie(FirstCheckInst, OverflowRuntimeCheck) = + LAI.Pred->generateGuardCond(MemCheckBB->getTerminator(), SE); + + if (MemRuntimeCheck && OverflowRuntimeCheck) { + RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck, + OverflowRuntimeCheck, "ldist.safe"); + RuntimeCheck->insertBefore(MemCheckBB->getTerminator()); + } else + RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : OverflowRuntimeCheck; + + assert(RuntimeCheck && "called even though we don't need " + "any runtime checks"); // Rename the block to make the IR more readable. MemCheckBB->setName(VersionedLoop->getHeader()->getName() + ".lver.memcheck"); @@ -70,8 +87,7 @@ // Insert the conditional branch based on the result of the memchecks. Instruction *OrigTerm = MemCheckBB->getTerminator(); BranchInst::Create(NonVersionedLoop->getLoopPreheader(), - VersionedLoop->getLoopPreheader(), MemRuntimeCheck, - OrigTerm); + VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm); OrigTerm->eraseFromParent(); // The loops merge in the original exit block. This is now dominated by the Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -265,11 +265,11 @@ InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, unsigned VecWidth, - unsigned UnrollFactor) + unsigned UnrollFactor, SCEVPredicateSet &Pred) : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), TLI(TLI), TTI(TTI), VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()), Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor), - Legal(nullptr), AddedSafetyChecks(false) {} + Legal(nullptr), AddedSafetyChecks(false), Pred(Pred) {} // Perform the actual loop widening (vectorization). void vectorize(LoopVectorizationLegality *L) { @@ -309,6 +309,10 @@ /// pair as (first, last). std::pair addStrideCheck(Instruction *Loc); + // Adds code to check the overflow assumptions made by SCEV + std::pair + addRuntimeOverflowChecks(Instruction *Loc); + /// Create an empty loop, based on the loop ranges of the old loop. void createEmptyLoop(); /// Copy and widen the instructions from the old loop. @@ -473,14 +477,19 @@ // Record whether runtime check is added. bool AddedSafetyChecks; + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Pred; }; class InnerLoopUnroller : public InnerLoopVectorizer { public: InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, unsigned UnrollFactor) - : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor) {} + const TargetTransformInfo *TTI, unsigned UnrollFactor, + SCEVPredicateSet &Pred) + : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor, + Pred) {} private: void scalarizeInstruction(Instruction *Instr, @@ -700,8 +709,9 @@ /// between the member and the group in a map. class InterleavedAccessInfo { public: - InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT) - : SE(SE), TheLoop(L), DT(DT) {} + InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT, + SCEVPredicateSet &Pred) + : SE(SE), TheLoop(L), DT(DT), Pred(Pred) {} ~InterleavedAccessInfo() { SmallSet DelSet; @@ -734,6 +744,8 @@ ScalarEvolution *SE; Loop *TheLoop; DominatorTree *DT; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Pred; /// Holds the relationships between the members and the interleave group. DenseMap InterleaveGroupMap; @@ -1080,11 +1092,13 @@ Function *F, const TargetTransformInfo *TTI, LoopAccessAnalysis *LAA, LoopVectorizationRequirements *R, - const LoopVectorizeHints *H) + const LoopVectorizeHints *H, + SCEVPredicateSet &Pred) : NumPredStores(0), TheLoop(L), SE(SE), TLI(TLI), TheFunction(F), - TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(SE, L, DT), - Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), - Requirements(R), Hints(H) {} + TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), + InterleaveInfo(SE, L, DT, Pred), Induction(nullptr), + WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), + Hints(H), Pred(Pred) {} /// This enum represents the kinds of inductions that we support. enum InductionKind { @@ -1361,7 +1375,10 @@ /// While vectorizing these instructions we have to generate a /// call to the appropriate masked intrinsic - SmallPtrSet MaskedOp; + SmallPtrSet MaskedOp; + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Pred; }; /// LoopVectorizationCostModel - estimates the expected speedups due to @@ -1377,9 +1394,10 @@ LoopVectorizationLegality *Legal, const TargetTransformInfo &TTI, const TargetLibraryInfo *TLI, AssumptionCache *AC, - const Function *F, const LoopVectorizeHints *Hints) + const Function *F, const LoopVectorizeHints *Hints, + SCEVPredicateSet &Pred) : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), - TheFunction(F), Hints(Hints) { + TheFunction(F), Hints(Hints), Pred(Pred) { CodeMetrics::collectEphemeralValues(L, AC, EphValues); } @@ -1468,6 +1486,9 @@ const Function *TheFunction; // Loop Vectorize Hint. const LoopVectorizeHints *Hints; + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Pred; }; /// \brief This holds vectorization requirements that must be verified late in @@ -1702,11 +1723,12 @@ return false; } } + SCEVPredicateSet Pred; // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements; LoopVectorizationLegality LVL(L, SE, DT, TLI, AA, F, TTI, LAA, - &Requirements, &Hints); + &Requirements, &Hints, Pred); if (!LVL.canVectorize()) { DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); emitMissedWarning(F, L, Hints); @@ -1714,7 +1736,8 @@ } // Use the cost model. - LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, AC, F, &Hints); + LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, AC, F, &Hints, + Pred); // Check the function attributes to find out if this function should be // optimized for size. @@ -1824,7 +1847,7 @@ assert(IC > 1 && "interleave count should not be 1 or 0"); // If we decided that it is not legal to vectorize the loop then // interleave it. - InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC); + InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC, Pred); Unroller.vectorize(&LVL); emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, L->getStartLoc(), @@ -1832,7 +1855,7 @@ Twine(IC) + ")"); } else { // If we decided that it is *legal* to vectorize the loop then do it. - InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC); + InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC, Pred); LB.vectorize(&LVL); ++LoopsVectorized; @@ -1978,9 +2001,10 @@ // We can emit wide load/stores only if the last non-zero index is the // induction variable. const SCEV *Last = nullptr; - if (!Strides.count(Gep)) + if (!Strides.count(Gep)) { Last = SE->getSCEV(Gep->getOperand(InductionOperand)); - else { + Last = SE->rewriteUsingPredicate(Last, TheLoop, Pred); + } else { // Because of the multiplication by a stride we can have a s/zext cast. // We are going to replace this stride by 1 so the cast is safe to ignore. // @@ -1990,23 +2014,41 @@ // %idxprom = zext i32 %mul to i64 << Safe cast. // %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom // - Last = replaceSymbolicStrideSCEV(SE, Strides, - Gep->getOperand(InductionOperand), Gep); + Last = rewriteSCEV(SE, Strides, Gep->getOperand(InductionOperand), Gep, + TheLoop, Pred); if (const SCEVCastExpr *C = dyn_cast(Last)) Last = (C->getSCEVType() == scSignExtend || C->getSCEVType() == scZeroExtend) ? C->getOperand() : Last; } + + SCEVPredicateSet P = Pred; + if (!dyn_cast(Last)) { + // Attermpt to add new SCEV assumptions to Last in order to + // get an AddRecExpr. + AssumptionResult R = SE->getAddRecWithRTChecks(Last, TheLoop); + R.Pred.add(&Pred); + + if (R.Res && !R.Pred.isAlwaysFalse()) { + Last = R.Res; + P = R.Pred; + } + } + if (const SCEVAddRecExpr *AR = dyn_cast(Last)) { const SCEV *Step = AR->getStepRecurrence(*SE); // The memory is consecutive because the last index is consecutive // and all other indices are loop invariant. - if (Step->isOne()) + if (Step->isOne()) { + Pred = P; return 1; - if (Step->isAllOnesValue()) + } + if (Step->isAllOnesValue()) { + Pred = P; return -1; + } } return 0; @@ -2659,7 +2701,7 @@ Type *IdxTy = Legal->getWidestInductionType(); // Find the loop boundaries. - const SCEV *ExitCount = SE->getBackedgeTakenCount(OrigLoop); + const SCEV *ExitCount = SE->getGuardedBackedgeTakenCount(OrigLoop, Pred); assert(ExitCount != SE->getCouldNotCompute() && "Invalid loop count"); // The exit count might have the type of i64 while the phi is i32. This can @@ -2850,6 +2892,29 @@ VectorPH = NewVectorPH; } + // Generate runtime checks for any SCEV assumptions that we've made. + Instruction *OFCheck; + std::tie(FirstCheckInst, OFCheck) = + Pred.generateGuardCond(VectorPH->getTerminator(), SE); + if (OFCheck) { + AddedSafetyChecks = true; + // Create a new block containing the scev check. + VectorPH->setName("vector.scevcheck"); + NewVectorPH = + VectorPH->splitBasicBlock(VectorPH->getTerminator(), "vector.ph"); + + if (ParentLoop) + ParentLoop->addBasicBlockToLoop(NewVectorPH, *LI); + LoopBypassBlocks.push_back(VectorPH); + + // Replace the branch into the scev check block with a conditional branch + // for the "few elements case". + ReplaceInstWithInst(VectorPH->getTerminator(), + BranchInst::Create(MiddleBlock, NewVectorPH, OFCheck)); + + VectorPH = NewVectorPH; + } + // We are going to resume the execution of the scalar loop. // Go over all of the induction variables that we found and fix the // PHIs that are left in the scalar version of the loop. @@ -4007,10 +4072,10 @@ } // ScalarEvolution needs to be able to find the exit count. - const SCEV *ExitCount = SE->getBackedgeTakenCount(TheLoop); + const SCEV *ExitCount = SE->getGuardedBackedgeTakenCount(TheLoop, Pred); if (ExitCount == SE->getCouldNotCompute()) { - emitAnalysis(VectorizationReport() << - "could not determine number of loop iterations"); + emitAnalysis(VectorizationReport() + << "could not determine number of loop iterations"); DEBUG(dbgs() << "LV: SCEV could not compute the loop exit count.\n"); return false; } @@ -4346,6 +4411,15 @@ } Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); + Pred.add(&*LAI->Pred); + + if (Pred.isAlwaysFalse()) { + emitAnalysis(VectorizationReport() + << "Too many SCEV assuptions need to be made and checked " + << "at runtime"); + DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n"); + return false; + } return true; } @@ -4474,7 +4548,7 @@ StoreInst *SI = dyn_cast(I); Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); - int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides); + int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides, Pred, false); // The factor of the corresponding interleave group. unsigned Factor = std::abs(Stride); Index: test/Transforms/LoopDistribute/distribute-with-overflows.ll =================================================================== --- /dev/null +++ test/Transforms/LoopDistribute/distribute-with-overflows.ll @@ -0,0 +1,102 @@ +; RUN: opt -mtriple=aarch64--linux-gnueabi -basicaa -loop-distribute -verify-loop-info -verify-dom-info -S \ +; RUN: < %s | FileCheck %s + +; RUN: opt -basicaa -loop-distribute -loop-vectorize -force-vector-width=4 \ +; RUN: -verify-loop-info -verify-dom-info -S < %s | \ +; RUN: FileCheck --check-prefix=VECTORIZE %s + +; The memcheck version of basic.ll with overflows - the induction variables can +; overflow. We should distribute and vectorize the second part of this loop with +; 5 memchecks (A+1 x {C, D, E} + C x {A, B}) +; +; for (i = 0; i < n; i++) { +; A[i + 1] = A[i] * B[i]; +; ------------------------------- +; C[i] = D[i] * E[i]; +; } + +@B = common global i32* null, align 8 +@A = common global i32* null, align 8 +@C = common global i32* null, align 8 +@D = common global i32* null, align 8 +@E = common global i32* null, align 8 + +define void @f(i64 %n) { +entry: + %a = load i32*, i32** @A, align 8 + %b = load i32*, i32** @B, align 8 + %c = load i32*, i32** @C, align 8 + %d = load i32*, i32** @D, align 8 + %e = load i32*, i32** @E, align 8 + br label %for.body + +; We have two compares for each array overlap check which is a total of 10 +; compares. +; +; CHECK: for.body.lver.memcheck: + +; CHECK: icmp ugt i64 %{{[a-zA-Z0-9]+}}, 4294967295 + +; CHECK: %ldist.safe = or i1 %memcheck.conflict, %scev.check +; CHECK: br i1 %ldist.safe, label %for.body.ph.lver.orig, label %for.body.ph.ldist1 + + +; The non-distributed loop that the memchecks fall back on. + +; CHECK: for.body.ph.lver.orig: +; CHECK: br label %for.body.lver.orig +; CHECK: for.body.lver.orig: +; CHECK: br i1 %exitcond.lver.orig, label %for.end, label %for.body.lver.orig + +; Verify the two distributed loops. + +; CHECK: for.body.ph.ldist1: +; CHECK: br label %for.body.ldist1 +; CHECK: for.body.ldist1: +; CHECK: %mulA.ldist1 = mul i32 %loadB.ldist1, %loadA.ldist1 +; CHECK: br i1 %exitcond.ldist1, label %for.body.ph, label %for.body.ldist1 + +; CHECK: for.body.ph: +; CHECK: br label %for.body +; CHECK: for.body: +; CHECK: %mulC = mul i32 %loadD, %loadE +; CHECK: for.end: + + +; VECTORIZE: mul <4 x i32> + +for.body: ; preds = %for.body, %entry + %ind = phi i32 [ 0, %entry ], [ %add, %for.body ] + %ind_ext = zext i32 %ind to i64 + + %arrayidxA = getelementptr inbounds i32, i32* %a, i64 %ind_ext + %loadA = load i32, i32* %arrayidxA, align 4 + + %arrayidxB = getelementptr inbounds i32, i32* %b, i64 %ind_ext + %loadB = load i32, i32* %arrayidxB, align 4 + + %mulA = mul i32 %loadB, %loadA + + %add = add i32 %ind, 1 + %add_ext = zext i32 %add to i64 + + %arrayidxA_plus_4 = getelementptr inbounds i32, i32* %a, i64 %add_ext + store i32 %mulA, i32* %arrayidxA_plus_4, align 4 + + %arrayidxD = getelementptr inbounds i32, i32* %d, i64 %ind_ext + %loadD = load i32, i32* %arrayidxD, align 4 + + %arrayidxE = getelementptr inbounds i32, i32* %e, i64 %ind_ext + %loadE = load i32, i32* %arrayidxE, align 4 + + %mulC = mul i32 %loadD, %loadE + + %arrayidxC = getelementptr inbounds i32, i32* %c, i64 %ind_ext + store i32 %mulC, i32* %arrayidxC, align 4 + + %exitcond = icmp eq i64 %add_ext, %n + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + ret void +} Index: test/Transforms/LoopVectorize/safegep.ll =================================================================== --- test/Transforms/LoopVectorize/safegep.ll +++ test/Transforms/LoopVectorize/safegep.ll @@ -9,8 +9,10 @@ ; PR16592 ; CHECK-LABEL: @safe( +; CHECK-LABEL-NOT: vector.overflowcheck ; CHECK: <4 x float> + define void @safe(float* %A, float* %B, float %K) { entry: br label %"" Index: test/Transforms/LoopVectorize/scev-overflow-check.ll =================================================================== --- /dev/null +++ test/Transforms/LoopVectorize/scev-overflow-check.ll @@ -0,0 +1,125 @@ +; RUN: opt -mtriple=aarch64--linux-gnueabi -loop-vectorize < %s -S | FileCheck %s + +; CHECK-LABEL: test0 +define void @test0(i32* %A, + i32* %B, + i32* %C, i32 %N) { +entry: + %cmp13 = icmp eq i32 %N, 0 + br i1 %cmp13, label %for.end, label %for.body.preheader + +; If N is greater then 65535, this would loop forever. +; CHECK: icmp ugt i32 %N, 65535 + +for.body.preheader: + br label %for.body + +for.body: + %indvars.iv = phi i16 [ %indvars.next, %for.body ], [ 0, %for.body.preheader ] + %indvars.next = add i16 %indvars.iv, 1 + %indvars.ext = zext i16 %indvars.iv to i32 + + %arrayidx = getelementptr inbounds i32, i32* %B, i32 %indvars.ext + %0 = load i32, i32* %arrayidx, align 4 + %arrayidx3 = getelementptr inbounds i32, i32* %C, i32 %indvars.ext + %1 = load i32, i32* %arrayidx3, align 4 + + %mul4 = mul i32 %1, %0 + + %arrayidx7 = getelementptr inbounds i32, i32* %A, i32 %indvars.ext + store i32 %mul4, i32* %arrayidx7, align 4 + + %exitcond = icmp eq i32 %indvars.ext, %N + br i1 %exitcond, label %for.end.loopexit, label %for.body + +for.end.loopexit: + br label %for.end + +for.end: + ret void +} + +; CHECK-LABEL: test1 +define void @test1(i32* %A, + i32* %B, + i32* %C, i32 %N, i32 %Offset) { +entry: + %cmp13 = icmp eq i32 %N, 0 + br i1 %cmp13, label %for.end, label %for.body.preheader + +; Because of the GEPs, we need check that Offset + N does not overflow. +; CHECK: [[MUL0:%[a-zA-Z_0-9.]+]] = call { i32, i1 } @llvm.smul.with.overflow.i32(i32 1, i32 %N) +; CHECK: [[MUL1:%[a-zA-Z_0-9.]+]] = extractvalue { i32, i1 } [[MUL0]], 0 +; CHECK: call { i32, i1 } @llvm.sadd.with.overflow.i32(i32 [[MUL1]], i32 %Offset) + +for.body.preheader: + br label %for.body + +for.body: + %indvars.iv = phi i16 [ %indvars.next, %for.body ], [ 0, %for.body.preheader ] + %indvars.next = add i16 %indvars.iv, 1 + + %indvars.ext = zext i16 %indvars.iv to i32 + %indvars.access = add i32 %Offset, %indvars.ext + + %arrayidx = getelementptr inbounds i32, i32* %B, i32 %indvars.access + %0 = load i32, i32* %arrayidx, align 4 + %arrayidx3 = getelementptr inbounds i32, i32* %C, i32 %indvars.access + %1 = load i32, i32* %arrayidx3, align 4 + + %mul4 = mul i32 %1, %0 + + %arrayidx7 = getelementptr inbounds i32, i32* %A, i32 %indvars.access + store i32 %mul4, i32* %arrayidx7, align 4 + + %exitcond = icmp eq i32 %indvars.ext, %N + br i1 %exitcond, label %for.end.loopexit, label %for.body + +for.end.loopexit: + br label %for.end + +for.end: + ret void +} + +; CHECK-LABEL: test2 +define void @test2(i32* %A, + i32* %B, + i32* %C, i32 %N, i32 %Offset) { +entry: + %cmp13 = icmp eq i32 %N, 0 + br i1 %cmp13, label %for.end, label %for.body.preheader + +; CHECK: icmp sgt i32 %N, 32767 +; CHECK: icmp slt i32 %N, -32768 + +for.body.preheader: + br label %for.body + +for.body: + %indvars.iv = phi i16 [ %indvars.next, %for.body ], [ 0, %for.body.preheader ] + %indvars.next = add i16 %indvars.iv, 1 + + %indvars.ext = sext i16 %indvars.iv to i32 + %indvars.access = add i32 %Offset, %indvars.ext + + %arrayidx = getelementptr inbounds i32, i32* %B, i32 %indvars.access + %0 = load i32, i32* %arrayidx, align 4 + %arrayidx3 = getelementptr inbounds i32, i32* %C, i32 %indvars.access + %1 = load i32, i32* %arrayidx3, align 4 + + %mul4 = add i32 %1, %0 + + %arrayidx7 = getelementptr inbounds i32, i32* %A, i32 %indvars.access + store i32 %mul4, i32* %arrayidx7, align 4 + + %exitcond = icmp eq i32 %indvars.ext, %N + br i1 %exitcond, label %for.end.loopexit, label %for.body + +for.end.loopexit: + br label %for.end + +for.end: + ret void +} + Index: test/Transforms/LoopVectorize/version-mem-access.ll =================================================================== --- test/Transforms/LoopVectorize/version-mem-access.ll +++ test/Transforms/LoopVectorize/version-mem-access.ll @@ -16,11 +16,12 @@ %cmp13 = icmp eq i32 %N, 0 br i1 %cmp13, label %for.end, label %for.body.preheader +; We don't need to check the symbolic stride for B, we can assume instead +; that {0,+,BStride} will not overflow. + ; CHECK-DAG: icmp ne i64 %AStride, 1 -; CHECK-DAG: icmp ne i32 %BStride, 1 ; CHECK-DAG: icmp ne i64 %CStride, 1 ; CHECK: or -; CHECK: or ; CHECK: br ; CHECK: vector.body @@ -56,11 +57,11 @@ } ; We used to crash on this function because we removed the fptosi cast when -; replacing the symbolic stride '%conv'. -; PR18480 +; replacing the symbolic stride '%conv' (PR18480). However, replacing the +; symbolic stride is no longer required since we can do an overflow check. ; CHECK-LABEL: fn1 -; CHECK: load <2 x double> +; CHECK: store <2 x double> define void @fn1(double* noalias %x, double* noalias %c, double %a) { entry: