Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -32,6 +32,7 @@ class ScalarEvolution; class Loop; class SCEV; +class SCEVUnionPredicate; /// Optimization analysis message produced during vectorization. Messages inform /// the user why vectorization did not occur. @@ -176,10 +177,11 @@ const SmallVectorImpl &Instrs) const; }; - MemoryDepChecker(ScalarEvolution *Se, const Loop *L) + MemoryDepChecker(ScalarEvolution *Se, const Loop *L, + SCEVUnionPredicate &Preds) : SE(Se), InnermostLoop(L), AccessIdx(0), ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true), - RecordInterestingDependences(true) {} + RecordInterestingDependences(true), Preds(Preds) {} /// \brief Register the location (instructions are given increasing numbers) /// of a write access. @@ -289,6 +291,9 @@ /// \brief Check whether the data dependence could prevent store-load /// forwarding. bool couldPreventStoreLoadForward(unsigned Distance, unsigned TypeByteSize); + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVUnionPredicate &Preds; }; /// \brief Holds information about the memory runtime legality checks to verify @@ -331,7 +336,8 @@ /// Insert a pointer and calculate the start and end SCEVs. void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, - unsigned ASId, const ValueToValueMap &Strides); + unsigned ASId, const ValueToValueMap &Strides, + SCEVUnionPredicate &Preds); /// \brief No run-time memory checking is necessary. bool empty() const { return Pointers.empty(); } @@ -537,6 +543,9 @@ return StoreToLoopInvariantAddress; } + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVUnionPredicate Preds; + private: /// \brief Analyze the loop. Substitute symbolic strides using Strides. void analyzeLoop(const ValueToValueMap &Strides); @@ -590,12 +599,13 @@ /// stride as collected by LoopVectorizationLegality::collectStridedAccess. const SCEV *replaceSymbolicStrideSCEV(ScalarEvolution *SE, const ValueToValueMap &PtrToStride, - Value *Ptr, Value *OrigPtr = nullptr); + SCEVUnionPredicate &Preds, Value *Ptr, + Value *OrigPtr = nullptr); /// \brief Check the stride of the pointer and ensure that it does not wrap in /// the address space. int isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap); + const ValueToValueMap &StridesMap, SCEVUnionPredicate &Preds); /// \brief This analysis provides dependence information for the memory accesses /// of a loop. @@ -622,7 +632,8 @@ /// of symbolic strides, \p Strides provides the mapping (see /// replaceSymbolicStrideSCEV). If there is no cached result available run /// the analysis. - const LoopAccessInfo &getInfo(Loop *L, const ValueToValueMap &Strides); + const LoopAccessInfo &getInfo(Loop *L, const ValueToValueMap &Strides, + const unsigned MaxSCEVPredicates = 0); void releaseMemory() override { // Invalidate the cache when the pass is freed. Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -50,8 +50,13 @@ class Operator; class SCEVUnknown; class SCEVAddRecExpr; + class SCEVConstant; class SCEV; + class SCEVExpander; + class SCEVPredicate; + template<> struct FoldingSetTrait; + template<> struct FoldingSetTrait; /// This class represents an analyzed expression in the program. These are /// opaque objects that the client is not allowed to do much with directly. @@ -166,6 +171,167 @@ static bool classof(const SCEV *S); }; + //===--------------------------------------------------------------------===// + /// SCEVPredicate - This class represents an assumption made using SCEV + /// expressions which can be checked at run-time. + /// + class SCEVPredicate : public FoldingSetNode { + friend struct FoldingSetTrait; + + /// A reference to an Interned FoldingSetNodeID for this node. The + /// ScalarEvolution's BumpPtrAllocator holds the data. + FoldingSetNodeIDRef FastID; + + protected: + unsigned short SCEVPredicateType; + enum SCEVPredicateTypes { PSET, PEQUAL }; + + public: + SCEVPredicate(const FoldingSetNodeIDRef ID, unsigned short Type); + + virtual ~SCEVPredicate() {} + + unsigned short getType() const { return SCEVPredicateType; } + + /// \brief Returns the estimated complexity of this predicate. + /// This is roughly measured in the number of run-time checks required. + virtual unsigned getComplexity() { return 1; } + + /// \brief Returns true if the predicate is always true. This means that no + /// assumptions were made and nothing needs to be checked at run-time. + virtual bool isAlwaysTrue() const = 0; + + /// \brief Returns true if this predicate implies \p N. + virtual bool implies(const SCEVPredicate *N) const = 0; + + /// \brief Prints a textual representation of this predicate. + virtual void print(raw_ostream &OS, unsigned Depth) const = 0; + + /// \brief Generates a run-time check for this predicate. The return value + /// of this check indicates if the predicate is not true. + /// In case no checks are needed nullptr is returned. + virtual Value *generateCheck(Instruction *Loc, ScalarEvolution &SE, + const DataLayout &DL, + SCEVExpander &Exp) const = 0; + /// \brief Returns the SCEV to which this predicate applies, or nullptr + /// if this is a SCEVUnionPredicate. + virtual const SCEV *getExpr() const = 0; + }; + + // Specialize FoldingSetTrait for SCEVPredicate to avoid needing to compute + // temporary FoldingSetNodeID values. + template <> + struct FoldingSetTrait + : DefaultFoldingSetTrait { + + static void Profile(const SCEVPredicate &X, FoldingSetNodeID &ID) { + ID = X.FastID; + } + + static bool Equals(const SCEVPredicate &X, const FoldingSetNodeID &ID, + unsigned IDHash, FoldingSetNodeID &TempID) { + return ID == X.FastID; + } + static unsigned ComputeHash(const SCEVPredicate &X, + FoldingSetNodeID &TempID) { + return X.FastID.ComputeHash(); + } + }; + + //===--------------------------------------------------------------------===// + /// SCEVEqualPredicate - This class represents an assumption that two SCEV + /// expressions are equal, and this can be checked at run-time. We assume + /// that the left hand side is a SCEVUnknown and the right hand side a + /// constant. + /// + class SCEVEqualPredicate : public SCEVPredicate { + /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a + /// constant. + const SCEVUnknown *LHS; + const SCEVConstant *RHS; + + public: + SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS, + const SCEVConstant *RHS); + + /// Implementation of the SCEVPredicate interface + bool implies(const SCEVPredicate *N) const override; + void print(raw_ostream &OS, unsigned Depth) const override; + bool isAlwaysTrue() const override; + const SCEV *getExpr() const; + + Value *generateCheck(Instruction *Loc, ScalarEvolution &SE, + const DataLayout &DL, + SCEVExpander &Exp) const override; + + /// \brief returns the left hand side of the equality. + const SCEVUnknown *getLHS() const { return LHS; } + + /// \brief returns the right hand side of the equality. + const SCEVConstant *getRHS() const { return RHS; } + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVPredicate *P) { + return P->getType() == PEQUAL; + } + }; + + //===--------------------------------------------------------------------===// + /// SCEVUnionPredicate - This class represents a composition of other + /// SCEV predicates, and is the class that most clients will interact with. + /// This is equivalent to a logical "AND" of all the predicates in the union. + class SCEVUnionPredicate : public SCEVPredicate { + private: + typedef std::map> + PredicateMap; + + /// Vector with references to all predicates in this union. + SmallVector Preds; + /// Maps SCEVs to predicates for quick look-ups. + PredicateMap SCEVToPreds; + + public: + SCEVUnionPredicate(); + + /// \brief Adds a predicate to this union. + void add(const SCEVPredicate *N); + + /// \brief Generates a run-time check for all the contained predicates + /// using \p Loc as the instruction inserton point. + /// This is a wrapper around generateCheck, and provides an interface + /// similar to other run-time checks used for versioning. Like the + /// generateCheck method, the returned values indicate if any of its + /// sub-predicates are not true. It will return a pair formed by the + /// first and last instructions in the check. If no instructions are + /// generated we will return a pair of nullptr. + std::pair + generateGuardCond(Instruction *Loc, ScalarEvolution &SE); + + /// \brief Returns a reference to a vector containing all predicates + /// which apply to \p Expr. + SmallVectorImpl * + getPredicatesForExpr(const SCEV *Expr); + + /// Implementation of the SCEVPredicate interface + bool isAlwaysTrue() const override; + bool implies(const SCEVPredicate *N) const override; + void print(raw_ostream &OS, unsigned Depth) const; + const SCEV *getExpr() const override; + + Value *generateCheck(Instruction *Loc, ScalarEvolution &SE, + const DataLayout &DL, + SCEVExpander &Exp) const override; + + /// \brief We estimate the complexity of a union predicate as the size + /// number of predicates in the union. + unsigned getComplexity() override { return Preds.size(); } + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVPredicate *P) { + return P->getType() == PSET; + } + }; + /// The main scalar evolution driver. Because client code (intentionally) /// can't do much with the SCEV objects directly, they must ask this class /// for services. @@ -1076,6 +1242,12 @@ SmallVectorImpl &Sizes, const SCEV *ElementSize); + const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS, + const SCEVConstant *RHS); + + /// Re-writes the SCEV according to the Predicates in \p Preds. + const SCEV *rewriteUsingPredicate(const SCEV *Scev, SCEVUnionPredicate &A); + private: /// Compute the backedge taken count knowing the interval difference, the /// stride and presence of the equality in the comparison. @@ -1096,6 +1268,7 @@ private: FoldingSet UniqueSCEVs; + FoldingSet UniquePreds; BumpPtrAllocator SCEVAllocator; /// The head of a linked list of all SCEVUnknown values that have been Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -89,8 +89,8 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE, const ValueToValueMap &PtrToStride, + SCEVUnionPredicate &Preds, Value *Ptr, Value *OrigPtr) { - const SCEV *OrigSCEV = SE->getSCEV(Ptr); // If there is an entry in the map return the SCEV of the pointer with the @@ -108,26 +108,31 @@ ValueToValueMap RewriteMap; RewriteMap[StrideVal] = One; - const SCEV *ByOne = - SCEVParameterRewriter::rewrite(OrigSCEV, *SE, RewriteMap, true); + const auto *U = static_cast(SE->getSCEV(StrideVal)); + const auto *CT = static_cast( + SE->getConstant(StrideVal->getType(), 1, true)); + + Preds.add(SE->getEqualPredicate(U, CT)); + + const SCEV *ByOne = SE->rewriteUsingPredicate(OrigSCEV, Preds); DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne << "\n"); return ByOne; } // Otherwise, just return the SCEV of the original pointer. - return SE->getSCEV(Ptr); + return OrigSCEV; } void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, unsigned ASId, - const ValueToValueMap &Strides) { + const ValueToValueMap &Strides, + SCEVUnionPredicate &Preds) { // Get the stride replaced scev. - const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr); const SCEVAddRecExpr *AR = dyn_cast(Sc); assert(AR && "Invalid addrec expression"); const SCEV *Ex = SE->getBackedgeTakenCount(Lp); - const SCEV *ScStart = AR->getStart(); const SCEV *ScEnd = AR->evaluateAtIteration(Ex, *SE); const SCEV *Step = AR->getStepRecurrence(*SE); @@ -417,9 +422,9 @@ typedef SmallPtrSet MemAccessInfoSet; AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI, - MemoryDepChecker::DepCandidates &DA) - : DL(Dl), AST(*AA), LI(LI), DepCands(DA), - IsRTCheckAnalysisNeeded(false) {} + MemoryDepChecker::DepCandidates &DA, SCEVUnionPredicate &Preds) + : DL(Dl), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false), + Preds(Preds) {} /// \brief Register a load and whether it is only read from. void addLoad(MemoryLocation &Loc, bool IsReadOnly) { @@ -504,14 +509,18 @@ /// (i.e. ShouldRetryWithRuntimeCheck), isDependencyCheckNeeded is cleared /// while this remains set if we have potentially dependent accesses. bool IsRTCheckAnalysisNeeded; + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVUnionPredicate &Preds; }; } // end anonymous namespace /// \brief Check whether a pointer can participate in a runtime bounds check. static bool hasComputableBounds(ScalarEvolution *SE, - const ValueToValueMap &Strides, Value *Ptr) { - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const ValueToValueMap &Strides, Value *Ptr, + Loop *L, SCEVUnionPredicate &Preds) { + const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR) return false; @@ -554,11 +563,11 @@ else ++NumReadPtrChecks; - if (hasComputableBounds(SE, StridesMap, Ptr) && + if (hasComputableBounds(SE, StridesMap, Ptr, TheLoop, Preds) && // When we run after a failing dependency check we have to make sure // we don't have wrapping pointers. (!ShouldCheckStride || - isStridedPtr(SE, Ptr, TheLoop, StridesMap) == 1)) { + isStridedPtr(SE, Ptr, TheLoop, StridesMap, Preds) == 1)) { // The id of the dependence set. unsigned DepId; @@ -572,7 +581,7 @@ // Each access has its own dependence set. DepId = RunningDepId++; - RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap); + RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, Preds); DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); } else { @@ -803,7 +812,8 @@ /// \brief Check whether the access through \p Ptr has a constant stride. int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap) { + const ValueToValueMap &StridesMap, + SCEVUnionPredicate &Preds) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); @@ -815,7 +825,7 @@ return 0; } - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Ptr); + const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Preds, Ptr); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR) { @@ -1026,11 +1036,11 @@ BPtr->getType()->getPointerAddressSpace()) return Dependence::Unknown; - const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, APtr); - const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, BPtr); + const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, APtr); + const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, BPtr); - int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides); - int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides); + int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides, Preds); + int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides, Preds); const SCEV *Src = AScev; const SCEV *Sink = BScev; @@ -1429,7 +1439,7 @@ MemoryDepChecker::DepCandidates DependentAccesses; AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(), - AA, LI, DependentAccesses); + AA, LI, DependentAccesses, Preds); // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects // multiple times on the same object. If the ptr is accessed twice, once @@ -1480,7 +1490,8 @@ // read a few words, modify, and write a few words, and some of the // words may be written to the same address. bool IsReadOnlyPtr = false; - if (Seen.insert(Ptr).second || !isStridedPtr(SE, Ptr, TheLoop, Strides)) { + if (Seen.insert(Ptr).second || + !isStridedPtr(SE, Ptr, TheLoop, Strides, Preds)) { ++NumReads; IsReadOnlyPtr = true; } @@ -1728,7 +1739,7 @@ const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, const ValueToValueMap &Strides) - : PtrRtChecking(SE), DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL), + : PtrRtChecking(SE), DepChecker(SE, L, Preds), TheLoop(L), SE(SE), DL(DL), TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { @@ -1763,10 +1774,14 @@ OS.indent(Depth) << "Store to invariant address was " << (StoreToLoopInvariantAddress ? "" : "not ") << "found in loop.\n"; + + OS.indent(Depth) << "SCEV assumptions:\n"; + Preds.print(OS, Depth); } const LoopAccessInfo & -LoopAccessAnalysis::getInfo(Loop *L, const ValueToValueMap &Strides) { +LoopAccessAnalysis::getInfo(Loop *L, const ValueToValueMap &Strides, + const unsigned MaxSCEVPredicates) { auto &LAI = LoopAccessInfoMap[L]; #ifndef NDEBUG @@ -1776,8 +1791,8 @@ if (!LAI) { const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - LAI = llvm::make_unique(L, SE, DL, TLI, AA, DT, LI, - Strides); + LAI = + llvm::make_unique(L, SE, DL, TLI, AA, DT, LI, Strides); #ifndef NDEBUG LAI->NumSymbolicStrides = Strides.size(); #endif Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -67,6 +67,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -9238,3 +9239,202 @@ AU.addRequiredTransitive(); AU.addRequiredTransitive(); } + +static Instruction *getFirstInst(Instruction *FirstInst, Value *V, + Instruction *Loc) { + if (FirstInst) + return FirstInst; + if (auto I = dyn_cast(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; +} + +const SCEVPredicate * +ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, + const SCEVConstant *RHS) { + FoldingSetNodeID ID; + // Unique this node based on the arguments + ID.AddPointer(LHS); + ID.AddPointer(RHS); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + SCEVEqualPredicate *Eq = new (SCEVAllocator) + SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS); + UniquePreds.InsertNode(Eq, IP); + return Eq; +} + +class SCEVPredicateRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, + SCEVUnionPredicate &A) { + SCEVPredicateRewriter Rewriter(SE, A); + return Rewriter.visit(Scev); + } + + SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P) + : SCEVRewriteVisitor(SE), P(P) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + auto ExprPreds = P.getPredicatesForExpr(Expr); + if (ExprPreds) + for (auto *Pred : *ExprPreds) + if (const SCEVEqualPredicate *IPred = + dyn_cast(Pred)) { + if (IPred->getLHS() == Expr) + return IPred->getRHS(); + } + + return Expr; + } + +private: + SCEVUnionPredicate &P; +}; + +const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, + SCEVUnionPredicate &Preds) { + return SCEVPredicateRewriter::rewrite(Scev, *this, Preds); +} + +//// SCEV predicates +SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, unsigned short Type) + : FastID(ID), SCEVPredicateType(Type) {} + +SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, + const SCEVUnknown *LHS, + const SCEVConstant *RHS) + : SCEVPredicate(ID, PEQUAL), LHS(LHS), RHS(RHS) {} + +bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast(N); + + if (!Op) + return false; + + return Op->LHS == LHS && Op->RHS == RHS; +} + +bool SCEVEqualPredicate::isAlwaysTrue() const { return false; } + +Value *SCEVEqualPredicate::generateCheck(Instruction *Loc, ScalarEvolution &SE, + const DataLayout &DL, + SCEVExpander &Exp) const { + IRBuilder<> Builder(Loc); + + Value *Expr0 = Exp.expandCodeFor(LHS, LHS->getType(), Loc); + Value *Expr1 = Exp.expandCodeFor(RHS, RHS->getType(), Loc); + + return Builder.CreateICmpNE(Expr0, Expr1, "ident.check"); +} + +const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; } + +void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; +} + +/// Union predicates don't get cached so create a dummy set ID for it. +SCEVUnionPredicate::SCEVUnionPredicate() + : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), PSET) {} + +bool SCEVUnionPredicate::isAlwaysTrue() const { + return std::all_of(Preds.begin(), Preds.end(), + [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); +} + +std::pair +SCEVUnionPredicate::generateGuardCond(Instruction *Loc, ScalarEvolution &SE) { + IRBuilder<> GuardBuilder(Loc); + Instruction *FirstInst = nullptr; + Module *M = Loc->getParent()->getParent()->getParent(); + const DataLayout &DL = M->getDataLayout(); + SCEVExpander Exp(SE, DL, "start"); + + Value *Check = generateCheck(Loc, SE, DL, Exp); + + if (!Check) + return std::make_pair(nullptr, nullptr); + + // IRBuilder might fold the checks to constant expressions. Make sure we have + // at least one instruction by performing an OR with False. + Instruction *CheckInst = + BinaryOperator::CreateOr(Check, ConstantInt::getFalse(M->getContext())); + GuardBuilder.Insert(CheckInst, "scev.check"); + + FirstInst = getFirstInst(FirstInst, CheckInst, Loc); + return std::make_pair(FirstInst, CheckInst); +} + +SmallVectorImpl * +SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { + auto I = SCEVToPreds.find(Expr); + if (I == SCEVToPreds.end()) + return nullptr; + return &(I->second); +} + +Value *SCEVUnionPredicate::generateCheck(Instruction *Loc, ScalarEvolution &SE, + const DataLayout &DL, + SCEVExpander &Exp) const { + IRBuilder<> CheckBuilder(Loc); + Value *AllCheck = nullptr; + + // Loop over all checks in this set. + for (auto Pred : Preds) { + if (Pred->isAlwaysTrue()) + continue; + + Value *CheckResult = Pred->generateCheck(Loc, SE, DL, Exp); + + if (!AllCheck) + AllCheck = CheckResult; + else + AllCheck = CheckBuilder.CreateOr(AllCheck, CheckResult); + } + + return AllCheck; +} + +bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { + if (const auto *Set = dyn_cast(N)) + return std::all_of( + Set->Preds.begin(), Set->Preds.end(), + [this](const SCEVPredicate *I) { return this->implies(I); }); + + auto ScevPredsIt = SCEVToPreds.find(N->getExpr()); + if (ScevPredsIt == SCEVToPreds.end()) + return false; + auto &SCEVPreds = ScevPredsIt->second; + + return std::any_of(SCEVPreds.begin(), SCEVPreds.end(), + [N](const SCEVPredicate *I) { return I->implies(N); }); +} + +const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; } + +void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { + for (auto Pred : Preds) + Pred->print(OS, Depth); +} + +void SCEVUnionPredicate::add(const SCEVPredicate *N) { + if (const auto *Set = dyn_cast(N)) { + for (auto Pred : Set->Preds) + add(Pred); + return; + } + + if (implies(N)) + return; + + const SCEV *Key = N->getExpr(); + assert(Key && "Only SCEVUnionPredicate doesn't have an " + " associated expression!"); + + auto &Predicates = SCEVToPreds[Key]; + Predicates.push_back(N); + + Preds.push_back(N); +} Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -221,6 +221,15 @@ cl::desc("The maximum allowed number of runtime memory checks with a " "vectorize(enable) pragma.")); +static cl::opt VectorizeSCEVCheckThreshold( + "vectorize-scev-check-threshold", cl::init(16), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed.")); + +static cl::opt PragmaVectorizeSCEVCheckThreshold( + "pragma-vectorize-scev-check-threshold", cl::init(128), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed with a " + "vectorize(enable) pragma")); + namespace { // Forward declarations. @@ -272,12 +281,12 @@ InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, unsigned VecWidth, - unsigned UnrollFactor) + unsigned UnrollFactor, SCEVUnionPredicate &Preds) : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), TLI(TLI), TTI(TTI), VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()), Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor), TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr), - AddedSafetyChecks(false) {} + AddedSafetyChecks(false), Preds(Preds) {} // Perform the actual loop widening (vectorization). void vectorize(LoopVectorizationLegality *L) { @@ -309,12 +318,6 @@ typedef DenseMap, VectorParts> EdgeMaskCache; - /// \brief Add checks for strides that were assumed to be 1. - /// - /// Returns the last check instruction and the first check instruction in the - /// pair as (first, last). - std::pair addStrideCheck(Instruction *Loc); - /// Create an empty loop, based on the loop ranges of the old loop. void createEmptyLoop(); /// Create a new induction variable inside L. @@ -395,11 +398,11 @@ void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass); /// Emit a bypass check to see if the vector trip count is nonzero. void emitVectorLoopEnteredCheck(Loop *L, BasicBlock *Bypass); - /// Emit bypass checks to check if strides we've assumed to be one really are. - void emitStrideChecks(Loop *L, BasicBlock *Bypass); + + void emitSCEVChecks(Loop *L, BasicBlock *Bypass); /// Emit bypass checks to check any memory assumptions we may have made. void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass); - + /// This is a helper class that holds the vectorizer state. It maps scalar /// instructions to vector instructions. When the code is 'unrolled' then /// then a single scalar value is mapped to multiple vector parts. The parts @@ -503,14 +506,19 @@ // Record whether runtime check is added. bool AddedSafetyChecks; + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVUnionPredicate &Preds; }; class InnerLoopUnroller : public InnerLoopVectorizer { public: InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, unsigned UnrollFactor) - : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor) {} + const TargetTransformInfo *TTI, unsigned UnrollFactor, + SCEVUnionPredicate &Preds) + : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor, + Preds) {} private: void scalarizeInstruction(Instruction *Instr, @@ -731,8 +739,9 @@ /// between the member and the group in a map. class InterleavedAccessInfo { public: - InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT) - : SE(SE), TheLoop(L), DT(DT) {} + InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT, + SCEVUnionPredicate &Preds) + : SE(SE), TheLoop(L), DT(DT), Preds(Preds) {} ~InterleavedAccessInfo() { SmallSet DelSet; @@ -766,6 +775,9 @@ Loop *TheLoop; DominatorTree *DT; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVUnionPredicate &Preds; + /// Holds the relationships between the members and the interleave group. DenseMap InterleaveGroupMap; @@ -1128,11 +1140,14 @@ Function *F, const TargetTransformInfo *TTI, LoopAccessAnalysis *LAA, LoopVectorizationRequirements *R, - const LoopVectorizeHints *H) + const LoopVectorizeHints *H, + unsigned SCEVCheckThreshold, + SCEVUnionPredicate &Preds) : NumPredStores(0), TheLoop(L), SE(SE), TLI(TLI), TheFunction(F), - TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(SE, L, DT), - Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), - Requirements(R), Hints(H) {} + TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), + InterleaveInfo(SE, L, DT, Preds), Induction(nullptr), + WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), Hints(H), + SCEVCheckThreshold(SCEVCheckThreshold), Preds(Preds) {} /// ReductionList contains the reduction descriptors for all /// of the reductions that were found in the loop. @@ -1331,7 +1346,11 @@ /// While vectorizing these instructions we have to generate a /// call to the appropriate masked intrinsic - SmallPtrSet MaskedOp; + SmallPtrSet MaskedOp; + + unsigned SCEVCheckThreshold; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVUnionPredicate &Preds; }; /// LoopVectorizationCostModel - estimates the expected speedups due to @@ -1348,9 +1367,11 @@ const TargetTransformInfo &TTI, const TargetLibraryInfo *TLI, AssumptionCache *AC, const Function *F, const LoopVectorizeHints *Hints, - SmallPtrSetImpl &ValuesToIgnore) + SmallPtrSetImpl &ValuesToIgnore, + SCEVUnionPredicate &Preds) : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), - TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {} + TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore), + Preds(Preds) {} /// Information about vectorization costs struct VectorizationFactor { @@ -1436,6 +1457,8 @@ const LoopVectorizeHints *Hints; // Values to ignore in the cost model. const SmallPtrSetImpl &ValuesToIgnore; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVUnionPredicate &Preds; }; /// \brief This holds vectorization requirements that must be verified late in @@ -1666,10 +1689,16 @@ } } + unsigned SCEVThreshold = VectorizeSCEVCheckThreshold; + if (Hints.getForce() == LoopVectorizeHints::FK_Enabled) + SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; + + SCEVUnionPredicate Preds; + // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements; LoopVectorizationLegality LVL(L, SE, DT, TLI, AA, F, TTI, LAA, - &Requirements, &Hints); + &Requirements, &Hints, SCEVThreshold, Preds); if (!LVL.canVectorize()) { DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); emitMissedWarning(F, L, Hints); @@ -1688,7 +1717,7 @@ // Use the cost model. LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, AC, F, &Hints, - ValuesToIgnore); + ValuesToIgnore, Preds); // Check the function attributes to find out if this function should be // optimized for size. @@ -1799,7 +1828,7 @@ assert(IC > 1 && "interleave count should not be 1 or 0"); // If we decided that it is not legal to vectorize the loop then // interleave it. - InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC); + InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC, Preds); Unroller.vectorize(&LVL); emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(), @@ -1807,7 +1836,7 @@ Twine(IC) + ")"); } else { // If we decided that it is *legal* to vectorize the loop then do it. - InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC); + InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC, Preds); LB.vectorize(&LVL); ++LoopsVectorized; @@ -1967,7 +1996,7 @@ // %idxprom = zext i32 %mul to i64 << Safe cast. // %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom // - Last = replaceSymbolicStrideSCEV(SE, Strides, + Last = replaceSymbolicStrideSCEV(SE, Strides, Preds, Gep->getOperand(InductionOperand), Gep); if (const SCEVCastExpr *C = dyn_cast(Last)) Last = @@ -2525,56 +2554,8 @@ } } -static Instruction *getFirstInst(Instruction *FirstInst, Value *V, - Instruction *Loc) { - if (FirstInst) - return FirstInst; - if (Instruction *I = dyn_cast(V)) - return I->getParent() == Loc->getParent() ? I : nullptr; - return nullptr; -} - -std::pair -InnerLoopVectorizer::addStrideCheck(Instruction *Loc) { - Instruction *tnullptr = nullptr; - if (!Legal->mustCheckStrides()) - return std::pair(tnullptr, tnullptr); - - IRBuilder<> ChkBuilder(Loc); - - // Emit checks. - Value *Check = nullptr; - Instruction *FirstInst = nullptr; - for (SmallPtrSet::iterator SI = Legal->strides_begin(), - SE = Legal->strides_end(); - SI != SE; ++SI) { - Value *Ptr = stripIntegerCast(*SI); - Value *C = ChkBuilder.CreateICmpNE(Ptr, ConstantInt::get(Ptr->getType(), 1), - "stride.chk"); - // Store the first instruction we create. - FirstInst = getFirstInst(FirstInst, C, Loc); - if (Check) - Check = ChkBuilder.CreateOr(Check, C); - else - Check = C; - } - - // We have to do this trickery because the IRBuilder might fold the check to a - // constant expression in which case there is no Instruction anchored in a - // the block. - LLVMContext &Ctx = Loc->getContext(); - Instruction *TheCheck = - BinaryOperator::CreateAnd(Check, ConstantInt::getTrue(Ctx)); - ChkBuilder.Insert(TheCheck, "stride.not.one"); - FirstInst = getFirstInst(FirstInst, TheCheck, Loc); - - return std::make_pair(FirstInst, TheCheck); -} - -PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, - Value *Start, - Value *End, - Value *Step, +PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start, + Value *End, Value *Step, Instruction *DL) { BasicBlock *Header = L->getHeader(); BasicBlock *Latch = L->getLoopLatch(); @@ -2709,26 +2690,26 @@ LoopBypassBlocks.push_back(BB); } -void InnerLoopVectorizer::emitStrideChecks(Loop *L, - BasicBlock *Bypass) { +void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { BasicBlock *BB = L->getLoopPreheader(); - + // Generate the code to check that the strides we assumed to be one are really // one. We want the new basic block to start at the first instruction in a // sequence of instructions that form a check. - Instruction *StrideCheck; + Instruction *SCEVCheck; Instruction *FirstCheckInst; - std::tie(FirstCheckInst, StrideCheck) = addStrideCheck(BB->getTerminator()); - if (!StrideCheck) + std::tie(FirstCheckInst, SCEVCheck) = + Preds.generateGuardCond(BB->getTerminator(), *SE); + if (!SCEVCheck) return; // Create a new block containing the stride check. - BB->setName("vector.stridecheck"); + BB->setName("vector.scevcheck"); auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); if (L->getParentLoop()) L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); ReplaceInstWithInst(BB->getTerminator(), - BranchInst::Create(Bypass, NewBB, StrideCheck)); + BranchInst::Create(Bypass, NewBB, SCEVCheck)); LoopBypassBlocks.push_back(BB); AddedSafetyChecks = true; } @@ -2848,10 +2829,10 @@ // Now, compare the new count to zero. If it is zero skip the vector loop and // jump to the scalar loop. emitVectorLoopEnteredCheck(Lp, ScalarPH); - // Generate the code to check that the strides we assumed to be one are really - // one. We want the new basic block to start at the first instruction in a - // sequence of instructions that form a check. - emitStrideChecks(Lp, ScalarPH); + // Generate the code to check any assumptions that we've made for SCEV + // expressions. + emitSCEVChecks(Lp, ScalarPH); + // Generate the code that checks in runtime if arrays overlap. We put the // checks into a separate block to make the more common case of few elements // faster. @@ -3987,7 +3968,15 @@ // Analyze interleaved memory accesses. if (UseInterleaved) - InterleaveInfo.analyzeInterleaving(Strides); + InterleaveInfo.analyzeInterleaving(Strides); + + if (Preds.getComplexity() > SCEVCheckThreshold) { + emitAnalysis(VectorizationReport() + << "Too many SCEV assuptions need to be made and checked " + << "at runtime"); + DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n"); + return false; + } // Okay! We can vectorize. At this point we don't have any other mem analysis // which may limit our maximum vectorization factor, so just return true with @@ -4293,6 +4282,7 @@ } Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); + Preds.add(&LAI->Preds); return true; } @@ -4407,7 +4397,7 @@ StoreInst *SI = dyn_cast(I); Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); - int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides); + int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides, Preds); // The factor of the corresponding interleave group. unsigned Factor = std::abs(Stride); @@ -4416,7 +4406,7 @@ if (Factor < 2 || Factor > MaxInterleaveGroupFactor) continue; - const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr); PointerType *PtrTy = dyn_cast(Ptr->getType()); unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType());