Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -32,6 +32,7 @@ class ScalarEvolution; class Loop; class SCEV; +class SCEVPredicateSet; /// Optimization analysis message produced during vectorization. Messages inform /// the user why vectorization did not occur. @@ -176,10 +177,12 @@ const SmallVectorImpl &Instrs) const; }; - MemoryDepChecker(ScalarEvolution *Se, const Loop *L) + MemoryDepChecker(ScalarEvolution *Se, const Loop *L, + unsigned SCEVCheckThreshold, SCEVPredicateSet &Pred) : SE(Se), InnermostLoop(L), AccessIdx(0), ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true), - RecordInterestingDependences(true) {} + RecordInterestingDependences(true), + SCEVCheckThreshold(SCEVCheckThreshold), Preds(Preds) {} /// \brief Register the location (instructions are given increasing numbers) /// of a write access. @@ -289,6 +292,11 @@ /// \brief Check whether the data dependence could prevent store-load /// forwarding. bool couldPreventStoreLoadForward(unsigned Distance, unsigned TypeByteSize); + + /// Indicates the maximum complexity of the SCEV predicate. + unsigned SCEVCheckThreshold; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Preds; }; /// \brief Holds information about the memory runtime legality checks to verify @@ -331,7 +339,8 @@ /// Insert a pointer and calculate the start and end SCEVs. void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, - unsigned ASId, const ValueToValueMap &Strides); + unsigned ASId, const ValueToValueMap &Strides, + unsigned SCEVCheckThreshold, SCEVPredicateSet &Pred); /// \brief No run-time memory checking is necessary. bool empty() const { return Pointers.empty(); } @@ -461,7 +470,8 @@ LoopAccessInfo(Loop *L, ScalarEvolution *SE, const DataLayout &DL, const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, - const ValueToValueMap &Strides); + const ValueToValueMap &Strides, + const unsigned SCEVComplexity = 16); /// Return true we can analyze the memory accesses in the loop and there are /// no memory dependence cycles. @@ -537,6 +547,13 @@ return StoreToLoopInvariantAddress; } + /// Indicates the maximum complexity of the SCEV predicate that the + /// analysis is allowed to create. We expect all clients of the analysis + /// to use the same value. + unsigned SCEVCheckThreshold; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet Preds; + private: /// \brief Analyze the loop. Substitute symbolic strides using Strides. void analyzeLoop(const ValueToValueMap &Strides); @@ -590,12 +607,15 @@ /// stride as collected by LoopVectorizationLegality::collectStridedAccess. const SCEV *replaceSymbolicStrideSCEV(ScalarEvolution *SE, const ValueToValueMap &PtrToStride, + unsigned PredicateThreshold, + SCEVPredicateSet &Pred, const Loop *L, Value *Ptr, Value *OrigPtr = nullptr); /// \brief Check the stride of the pointer and ensure that it does not wrap in /// the address space. int isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap); + const ValueToValueMap &StridesMap, + unsigned SCEVCheckThreshold, SCEVPredicateSet &Pred); /// \brief This analysis provides dependence information for the memory accesses /// of a loop. @@ -622,7 +642,8 @@ /// of symbolic strides, \p Strides provides the mapping (see /// replaceSymbolicStrideSCEV). If there is no cached result available run /// the analysis. - const LoopAccessInfo &getInfo(Loop *L, const ValueToValueMap &Strides); + const LoopAccessInfo &getInfo(Loop *L, const ValueToValueMap &Strides, + const unsigned MaxSCEVPredicates = 0); void releaseMemory() override { // Invalidate the cache when the pass is freed. Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -51,6 +51,8 @@ class SCEVUnknown; class SCEVAddRecExpr; class SCEV; + class SCEVExpander; + template<> struct FoldingSetTrait; /// This class represents an analyzed expression in the program. These are @@ -166,6 +168,116 @@ static bool classof(const SCEV *S); }; + //===--------------------------------------------------------------------===// + /// SCEVPredicate - This class represents an assumption made using SCEV + /// expressions which can be checked at run-time. + /// + class SCEVPredicate { + protected: + unsigned short SCEVPredicateType; + enum SCEVPredicateTypes { PSET, PEQUAL }; + + public: + SCEVPredicate(unsigned short Type); + virtual ~SCEVPredicate() {} + unsigned short getType() const { return SCEVPredicateType; } + // Returns the estimated complexity of this predicate. + // This is roughly measured in the number of run-time checks required. + virtual unsigned getComplexity() { return 1; } + /// Returns true if the predicate is always true. This means that no + /// assumptions were made and nothing needs to be checked at run-time. + virtual bool isAlwaysTrue() const = 0; + /// Returns true if this predicate implies \p N. + virtual bool contains(const SCEVPredicate *N) const = 0; + /// Prints a textual representation of this predicate. + virtual void print(raw_ostream &OS, unsigned Depth) const = 0; + /// Generates a run-time check for this predicate. + virtual Value *generateCheck(Instruction *Loc, ScalarEvolution *SE, + const DataLayout *DL, SCEVExpander &Exp) = 0; + }; + + //===--------------------------------------------------------------------===// + /// SCEVEqualPredicate - This class represents an assumption that two SCEV + /// expressions are equal, and this can be checked at run-time. We assume# + /// that the right hand side is a SCEVUnknown. + /// + class SCEVEqualPredicate : public SCEVPredicate { + // We assume that E0 == E1 + const SCEV *E0; + const SCEV *E1; + + public: + SCEVEqualPredicate() + : SCEVPredicate(PEQUAL), E0(nullptr), E1(nullptr) {} + SCEVEqualPredicate(const SCEV *E0, const SCEV *E1); + + /// Implementation of the SCEVPredicate interface + bool contains(const SCEVPredicate *N) const override; + void print(raw_ostream &OS, unsigned Depth) const override; + bool isAlwaysTrue() const override; + + const SCEV *getLHS() { return E0; } + const SCEV *getRHS() { return E1; } + + Value *generateCheck(Instruction *Loc, ScalarEvolution *SE, + const DataLayout *DL, SCEVExpander &Exp) override; + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVPredicate *P) { + return P->getType() == PEQUAL; + } + }; + + //===--------------------------------------------------------------------===// + /// SCEVPredicateSet - This class represents a composition of other + /// SCEV predicates, and is the class that most clients will interact with. + /// + class SCEVPredicateSet : public SCEVPredicate { + protected: + /// Storage for different predicates that make up this Predicate Set. + SmallVector IdPreds; + + public: + SCEVPredicateSet(); + /// The copy constructor. + SCEVPredicateSet(const SCEVPredicateSet &Old); + /// Adds a predicate to this predicate set. + void add(const SCEVPredicate *N); + + /// Vector with references to all predicates in this set. + SmallVector Preds; + + /// Generates a run-time check for all the contained predicates. + /// This is a wrapper around generateCheck, and provides an interface + /// similar to other run-time checks used for versioning. + std::pair + generateGuardCond(Instruction *Loc, ScalarEvolution *SE); + + /// Implementation of the SCEVPredicate interface + bool isAlwaysTrue() const override; + bool contains(const SCEVPredicate *N) const override; + void print(raw_ostream &OS, unsigned Depth) const; + Value *generateCheck(Instruction *Loc, ScalarEvolution *SE, + const DataLayout *DL, SCEVExpander &Exp) override; + + /// Because we hold a set of predicates we need to override this method. + unsigned getComplexity() override {return Preds.size();} + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVPredicate *P) { + return P->getType() == PSET; + } + + /// The copy operator. + const SCEVPredicateSet &operator=(const SCEVPredicateSet &RHS) { + IdPreds = RHS.IdPreds; + Preds.clear(); + for (unsigned II = 0; II < IdPreds.size(); ++II) + Preds.push_back(&IdPreds[II]); + assert(Preds.size() == RHS.Preds.size() && "Wrong Preds size after copy"); + return *this; + } + }; + /// The main scalar evolution driver. Because client code (intentionally) /// can't do much with the SCEV objects directly, they must ask this class /// for services. @@ -1054,6 +1166,10 @@ SmallVectorImpl &Sizes, const SCEV *ElementSize); + /// Re-writes the SCEV according to the Predicates in \p Preds. + const SCEV *rewriteUsingPredicate(const SCEV *Scev, const Loop *L, + SCEVPredicateSet &A); + private: /// Compute the backedge taken count knowing the interval difference, the /// stride and presence of the equality in the comparison. Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -15,6 +15,7 @@ #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DiagnosticInfo.h" @@ -89,8 +90,9 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE, const ValueToValueMap &PtrToStride, + unsigned SCEVThreshold, + SCEVPredicateSet &Pred, const Loop *L, Value *Ptr, Value *OrigPtr) { - const SCEV *OrigSCEV = SE->getSCEV(Ptr); // If there is an entry in the map return the SCEV of the pointer with the @@ -108,26 +110,36 @@ ValueToValueMap RewriteMap; RewriteMap[StrideVal] = One; - const SCEV *ByOne = - SCEVParameterRewriter::rewrite(OrigSCEV, *SE, RewriteMap, true); + SCEVPredicateSet P(Pred); + const SCEVEqualPredicate SI(SE->getSCEV(StrideVal), + SE->getConstant(StrideVal->getType(), 1, true)); + P.add(&SI); + + // Check that we haven't created too many predicates. + if (P.getComplexity() <= SCEVThreshold) + Pred = P; + + const SCEV *ByOne = SE->rewriteUsingPredicate(OrigSCEV, L, Pred); DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne << "\n"); return ByOne; } // Otherwise, just return the SCEV of the original pointer. - return SE->getSCEV(Ptr); + return OrigSCEV; } void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, unsigned ASId, - const ValueToValueMap &Strides) { + const ValueToValueMap &Strides, + unsigned SCEVCheckThreshold, + SCEVPredicateSet &Pred) { // Get the stride replaced scev. - const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, SCEVCheckThreshold, + Pred, Lp, Ptr); const SCEVAddRecExpr *AR = dyn_cast(Sc); assert(AR && "Invalid addrec expression"); const SCEV *Ex = SE->getBackedgeTakenCount(Lp); - const SCEV *ScStart = AR->getStart(); const SCEV *ScEnd = AR->evaluateAtIteration(Ex, *SE); const SCEV *Step = AR->getStepRecurrence(*SE); @@ -417,9 +429,10 @@ typedef SmallPtrSet MemAccessInfoSet; AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI, - MemoryDepChecker::DepCandidates &DA) - : DL(Dl), AST(*AA), LI(LI), DepCands(DA), - IsRTCheckAnalysisNeeded(false) {} + MemoryDepChecker::DepCandidates &DA, + unsigned SCEVCheckThreshold, SCEVPredicateSet &Preds) + : DL(Dl), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false), + SCEVCheckThreshold(SCEVCheckThreshold), Preds(Preds) {} /// \brief Register a load and whether it is only read from. void addLoad(MemoryLocation &Loc, bool IsReadOnly) { @@ -504,14 +517,23 @@ /// (i.e. ShouldRetryWithRuntimeCheck), isDependencyCheckNeeded is cleared /// while this remains set if we have potentially dependent accesses. bool IsRTCheckAnalysisNeeded; + + /// Indicates the maximum complexity of the SCEV predicate. + unsigned SCEVCheckThreshold; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Preds; }; } // end anonymous namespace /// \brief Check whether a pointer can participate in a runtime bounds check. static bool hasComputableBounds(ScalarEvolution *SE, - const ValueToValueMap &Strides, Value *Ptr) { - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const ValueToValueMap &Strides, Value *Ptr, + Loop *L, unsigned SCEVCheckThreshold, + SCEVPredicateSet &Pred) { + const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, + SCEVCheckThreshold, Pred, + L, Ptr); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR) return false; @@ -554,11 +576,13 @@ else ++NumReadPtrChecks; - if (hasComputableBounds(SE, StridesMap, Ptr) && + if (hasComputableBounds(SE, StridesMap, Ptr, TheLoop, + SCEVCheckThreshold, Preds) && // When we run after a failing dependency check we have to make sure // we don't have wrapping pointers. (!ShouldCheckStride || - isStridedPtr(SE, Ptr, TheLoop, StridesMap) == 1)) { + isStridedPtr(SE, Ptr, TheLoop, StridesMap, + SCEVCheckThreshold, Preds) == 1)) { // The id of the dependence set. unsigned DepId; @@ -572,7 +596,8 @@ // Each access has its own dependence set. DepId = RunningDepId++; - RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap); + RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, + SCEVCheckThreshold, Preds); DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); } else { @@ -803,7 +828,8 @@ /// \brief Check whether the access through \p Ptr has a constant stride. int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap) { + const ValueToValueMap &StridesMap, + unsigned SCEVCheckThreshold, SCEVPredicateSet &Pred) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); @@ -815,7 +841,9 @@ return 0; } - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Ptr); + const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, + SCEVCheckThreshold, Pred, + Lp, Ptr); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR) { @@ -1026,11 +1054,17 @@ BPtr->getType()->getPointerAddressSpace()) return Dependence::Unknown; - const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, APtr); - const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, BPtr); - - int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides); - int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides); + const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, + SCEVCheckThreshold, Preds, + InnermostLoop, APtr); + const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, + SCEVCheckThreshold, Preds, + InnermostLoop, BPtr); + + int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides, + SCEVCheckThreshold, Preds); + int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides, + SCEVCheckThreshold, Preds); const SCEV *Src = AScev; const SCEV *Sink = BScev; @@ -1429,7 +1463,8 @@ MemoryDepChecker::DepCandidates DependentAccesses; AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(), - AA, LI, DependentAccesses); + AA, LI, DependentAccesses, SCEVCheckThreshold, + Preds); // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects // multiple times on the same object. If the ptr is accessed twice, once @@ -1480,7 +1515,8 @@ // read a few words, modify, and write a few words, and some of the // words may be written to the same address. bool IsReadOnlyPtr = false; - if (Seen.insert(Ptr).second || !isStridedPtr(SE, Ptr, TheLoop, Strides)) { + if (Seen.insert(Ptr).second || + !isStridedPtr(SE, Ptr, TheLoop, Strides, SCEVCheckThreshold, Preds)) { ++NumReads; IsReadOnlyPtr = true; } @@ -1727,9 +1763,12 @@ const DataLayout &DL, const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, - const ValueToValueMap &Strides) - : PtrRtChecking(SE), DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL), - TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), + const ValueToValueMap &Strides, + const unsigned MaxSCEVComplexity) + : SCEVCheckThreshold(MaxSCEVComplexity), + PtrRtChecking(SE), DepChecker(SE, L, MaxSCEVComplexity, Preds), + TheLoop(L), SE(SE), DL(DL), TLI(TLI), + AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { if (canAnalyzeLoop()) @@ -1763,21 +1802,27 @@ OS.indent(Depth) << "Store to invariant address was " << (StoreToLoopInvariantAddress ? "" : "not ") << "found in loop.\n"; + + OS.indent(Depth) << "SCEV assumptions:\n"; + Preds.print(OS, Depth); } const LoopAccessInfo & -LoopAccessAnalysis::getInfo(Loop *L, const ValueToValueMap &Strides) { +LoopAccessAnalysis::getInfo(Loop *L, const ValueToValueMap &Strides, + const unsigned MaxSCEVPredicates) { auto &LAI = LoopAccessInfoMap[L]; #ifndef NDEBUG assert((!LAI || LAI->NumSymbolicStrides == Strides.size()) && "Symbolic strides changed for loop"); + assert((!LAI || LAI->SCEVCheckThreshold == MaxSCEVPredicates) && + "Number of maximum SCEV predicates changed for the loop"); #endif if (!LAI) { const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); LAI = llvm::make_unique(L, SE, DL, TLI, AA, DT, LI, - Strides); + Strides, MaxSCEVPredicates); #ifndef NDEBUG LAI->NumSymbolicStrides = Strides.size(); #endif Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -67,6 +67,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -80,6 +81,8 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" @@ -9052,3 +9055,244 @@ AU.addRequiredTransitive(); AU.addRequiredTransitive(); } + +static Instruction *getFirstInst(Instruction *FirstInst, Value *V, + Instruction *Loc) { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; +} + +struct SCEVPredicateRewriter + : public SCEVVisitor { +public: + SCEVPredicateSet &P; + + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE, SCEVPredicateSet &A) { + SCEVPredicateRewriter Rewriter(L, SE, A); + return Rewriter.visit(Scev); + } + + SCEVPredicateRewriter(const Loop *L, ScalarEvolution &S, SCEVPredicateSet &P) + : P(P), SE(S), L(L) {} + + const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } + + const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getTruncateExpr(Operand, Expr->getType()); + } + + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getZeroExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getSignExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getAddExpr(Operands); + } + + const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getMulExpr(Operands); + } + + const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { + return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS())); + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + + const Loop *L = Expr->getLoop(); + return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); + } + + const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getSMaxExpr(Operands); + } + + const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getUMaxExpr(Operands); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + for (SCEVPredicate *Pred : P.Preds) { + if (SCEVEqualPredicate *IPred = dyn_cast(Pred)) { + if (IPred->getLHS() == Expr) + return IPred->getRHS(); + } + } + return Expr; + } + + const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return Expr; + } + +private: + ScalarEvolution &SE; + const Loop *L; +}; + +const SCEV * +ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, const Loop *L, + SCEVPredicateSet &Pred) { + return SCEVPredicateRewriter::rewrite(Scev, L, *this, Pred); +} + +//// SCEV predicates +SCEVPredicate::SCEVPredicate(unsigned short Type) : SCEVPredicateType(Type) {} + +SCEVEqualPredicate::SCEVEqualPredicate( + const SCEV *E0, const SCEV *E1) + : SCEVPredicate(PEQUAL), E0(E0), E1(E1) {} + +bool SCEVEqualPredicate::contains(const SCEVPredicate *N) const { + const SCEVEqualPredicate *Op = + dyn_cast(N); + + if (!Op) + return false; + + if ((Op->E0 == E0 && Op->E1 == E1) || + (Op->E0 == E1 && Op->E1 == E0)) + return true; + + return false; +} + +bool SCEVEqualPredicate::isAlwaysTrue() const { + return E0 == E1; +} + +Value *SCEVEqualPredicate::generateCheck(Instruction *Loc, + ScalarEvolution *SE, + const DataLayout *DL, + SCEVExpander &Exp) { + IRBuilder<> Builder(Loc); + + Value *Expr0 = Exp.expandCodeFor(E0, E0->getType(), Loc); + Value *Expr1 = Exp.expandCodeFor(E1, E1->getType(), Loc); + + Value *C = Builder.CreateICmpNE(Expr0, Expr1, "ident.check"); + return C; +} + +void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << "Equal predicate: " << *E0 << " == " << *E1 << "\n"; +} + +SCEVPredicateSet::SCEVPredicateSet() : SCEVPredicate(PSET) {} + +SCEVPredicateSet::SCEVPredicateSet(const SCEVPredicateSet &Old) + : SCEVPredicateSet() { + IdPreds = Old.IdPreds; + for (auto Pred : IdPreds) + Preds.push_back(&Pred); +} + +bool SCEVPredicateSet::isAlwaysTrue() const { + return std::all_of(Preds.begin(), Preds.end(), + [](SCEVPredicate *I){return I->isAlwaysTrue();}); +} + +std::pair +SCEVPredicateSet::generateGuardCond(Instruction *Loc, ScalarEvolution *SE) { + if (isAlwaysTrue()) + return std::pair(nullptr, nullptr); + + IRBuilder<> GuardBuilder(Loc); + Instruction *FirstInst = nullptr; + Module *M = Loc->getParent()->getParent()->getParent(); + const DataLayout &DL = M->getDataLayout(); + SCEVExpander Exp(*SE, DL, "start"); + + Value *Check = generateCheck(Loc, SE, &DL, Exp); + + if (!Check) + return std::make_pair(nullptr, nullptr); + + Instruction *CheckInst = + BinaryOperator::CreateOr(Check, ConstantInt::getFalse(M->getContext())); + GuardBuilder.Insert(CheckInst, "scev.check"); + + FirstInst = getFirstInst(FirstInst, CheckInst, Loc); + return std::make_pair(FirstInst, CheckInst); +} + +Value *SCEVPredicateSet::generateCheck(Instruction *Loc, ScalarEvolution *SE, + const DataLayout *DL, + SCEVExpander &Exp) { + + IRBuilder<> CheckBuilder(Loc); + Value *AllCheck = nullptr; + + // Loop over all checks in this set. + for (auto Pred : Preds) { + if (Pred->isAlwaysTrue()) + continue; + + Value *CheckResult = Pred->generateCheck(Loc, SE, DL, Exp); + + if (!AllCheck) + AllCheck = CheckResult; + else + AllCheck = CheckBuilder.CreateOr(AllCheck, CheckResult); + } + + return AllCheck; +} + +bool SCEVPredicateSet::contains(const SCEVPredicate *N) const { + if (const SCEVPredicateSet *Set = dyn_cast(N)) + return std::all_of(Set->Preds.begin(), Set->Preds.end(), + [this](SCEVPredicate *I){return this->contains(I);}); + return std::any_of(Preds.begin(), Preds.end(), + [N](SCEVPredicate *I){return I->contains(N);}); +} + +void SCEVPredicateSet::print(raw_ostream &OS, unsigned Depth) const { + for (auto Pred : Preds) + Pred->print(OS, Depth); +} + +void SCEVPredicateSet::add(const SCEVPredicate *N) { + if (const SCEVEqualPredicate *EP = + dyn_cast(N)) { + for (auto Pred : IdPreds) + if (Pred.contains(EP)) + return; + IdPreds.push_back(*EP); + Preds.push_back(&IdPreds.back()); + return; + } + if (const SCEVPredicateSet *Set = + dyn_cast(N)) { + for (auto Pred : Set->Preds) + add(Pred); + return; + } + llvm_unreachable("Unknown SCEV predicate type!"); +} Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -221,6 +221,15 @@ cl::desc("The maximum allowed number of runtime memory checks with a " "vectorize(enable) pragma.")); +static cl::opt VectorizeSCEVCheckThreshold( + "vectorize-scev-check-threshold", cl::init(16), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed.")); + +static cl::opt PragmaVectorizeSCEVCheckThreshold( + "pragma-vectorize-scev-check-threshold", cl::init(128), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed with a " + "vectorize(enable) pragma")); + namespace { // Forward declarations. @@ -272,12 +281,12 @@ InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, unsigned VecWidth, - unsigned UnrollFactor) + unsigned UnrollFactor, SCEVPredicateSet &Preds) : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), TLI(TLI), TTI(TTI), VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()), Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor), TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr), - AddedSafetyChecks(false) {} + AddedSafetyChecks(false), Preds(Preds) {} // Perform the actual loop widening (vectorization). void vectorize(LoopVectorizationLegality *L) { @@ -309,12 +318,6 @@ typedef DenseMap, VectorParts> EdgeMaskCache; - /// \brief Add checks for strides that were assumed to be 1. - /// - /// Returns the last check instruction and the first check instruction in the - /// pair as (first, last). - std::pair addStrideCheck(Instruction *Loc); - /// Create an empty loop, based on the loop ranges of the old loop. void createEmptyLoop(); /// Create a new induction variable inside L. @@ -395,8 +398,8 @@ void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass); /// Emit a bypass check to see if the vector trip count is nonzero. void emitVectorLoopEnteredCheck(Loop *L, BasicBlock *Bypass); - /// Emit bypass checks to check if strides we've assumed to be one really are. - void emitStrideChecks(Loop *L, BasicBlock *Bypass); + + void emitSCEVChecks(Loop *L, BasicBlock *Bypass); /// Emit bypass checks to check any memory assumptions we may have made. void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass); @@ -503,14 +506,19 @@ // Record whether runtime check is added. bool AddedSafetyChecks; + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Preds; }; class InnerLoopUnroller : public InnerLoopVectorizer { public: InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, unsigned UnrollFactor) - : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor) {} + const TargetTransformInfo *TTI, unsigned UnrollFactor, + SCEVPredicateSet &Preds) + : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor, + Preds) {} private: void scalarizeInstruction(Instruction *Instr, @@ -731,8 +739,10 @@ /// between the member and the group in a map. class InterleavedAccessInfo { public: - InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT) - : SE(SE), TheLoop(L), DT(DT) {} + InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT, + unsigned SCEVCheckThreshold, SCEVPredicateSet &Preds) + : SE(SE), TheLoop(L), DT(DT), SCEVCheckThreshold(SCEVCheckThreshold), + Preds(Preds) {} ~InterleavedAccessInfo() { SmallSet DelSet; @@ -765,6 +775,10 @@ ScalarEvolution *SE; Loop *TheLoop; DominatorTree *DT; + /// Indicates the maximum complexity of the SCEV predicate. + unsigned SCEVCheckThreshold; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Preds; /// Holds the relationships between the members and the interleave group. DenseMap InterleaveGroupMap; @@ -1128,11 +1142,15 @@ Function *F, const TargetTransformInfo *TTI, LoopAccessAnalysis *LAA, LoopVectorizationRequirements *R, - const LoopVectorizeHints *H) + const LoopVectorizeHints *H, + unsigned SCEVCheckThreshold, + SCEVPredicateSet &Preds) : NumPredStores(0), TheLoop(L), SE(SE), TLI(TLI), TheFunction(F), - TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(SE, L, DT), + TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), + InterleaveInfo(SE, L, DT, SCEVCheckThreshold, Preds), Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), - Requirements(R), Hints(H) {} + Requirements(R), Hints(H), SCEVCheckThreshold(SCEVCheckThreshold), + Preds(Preds) {} /// ReductionList contains the reduction descriptors for all /// of the reductions that were found in the loop. @@ -1331,7 +1349,11 @@ /// While vectorizing these instructions we have to generate a /// call to the appropriate masked intrinsic - SmallPtrSet MaskedOp; + SmallPtrSet MaskedOp; + + unsigned SCEVCheckThreshold; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Preds; }; /// LoopVectorizationCostModel - estimates the expected speedups due to @@ -1348,9 +1370,11 @@ const TargetTransformInfo &TTI, const TargetLibraryInfo *TLI, AssumptionCache *AC, const Function *F, const LoopVectorizeHints *Hints, - SmallPtrSetImpl &ValuesToIgnore) + SmallPtrSetImpl &ValuesToIgnore, + SCEVPredicateSet &Preds) : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), - TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {} + TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore), + Preds(Preds) {} /// Information about vectorization costs struct VectorizationFactor { @@ -1436,6 +1460,8 @@ const LoopVectorizeHints *Hints; // Values to ignore in the cost model. const SmallPtrSetImpl &ValuesToIgnore; + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVPredicateSet &Preds; }; /// \brief This holds vectorization requirements that must be verified late in @@ -1666,10 +1692,16 @@ } } + unsigned SCEVThreshold = VectorizeSCEVCheckThreshold; + if (Hints.getForce() == LoopVectorizeHints::FK_Enabled) + SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; + + SCEVPredicateSet Preds; + // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements; LoopVectorizationLegality LVL(L, SE, DT, TLI, AA, F, TTI, LAA, - &Requirements, &Hints); + &Requirements, &Hints, SCEVThreshold, Preds); if (!LVL.canVectorize()) { DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); emitMissedWarning(F, L, Hints); @@ -1688,7 +1720,7 @@ // Use the cost model. LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, AC, F, &Hints, - ValuesToIgnore); + ValuesToIgnore, Preds); // Check the function attributes to find out if this function should be // optimized for size. @@ -1799,7 +1831,7 @@ assert(IC > 1 && "interleave count should not be 1 or 0"); // If we decided that it is not legal to vectorize the loop then // interleave it. - InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC); + InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC, Preds); Unroller.vectorize(&LVL); emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(), @@ -1807,7 +1839,7 @@ Twine(IC) + ")"); } else { // If we decided that it is *legal* to vectorize the loop then do it. - InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC); + InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC, Preds); LB.vectorize(&LVL); ++LoopsVectorized; @@ -1967,7 +1999,8 @@ // %idxprom = zext i32 %mul to i64 << Safe cast. // %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom // - Last = replaceSymbolicStrideSCEV(SE, Strides, + Last = replaceSymbolicStrideSCEV(SE, Strides, SCEVCheckThreshold, Preds, + TheLoop, Gep->getOperand(InductionOperand), Gep); if (const SCEVCastExpr *C = dyn_cast(Last)) Last = @@ -2525,52 +2558,6 @@ } } -static Instruction *getFirstInst(Instruction *FirstInst, Value *V, - Instruction *Loc) { - if (FirstInst) - return FirstInst; - if (Instruction *I = dyn_cast(V)) - return I->getParent() == Loc->getParent() ? I : nullptr; - return nullptr; -} - -std::pair -InnerLoopVectorizer::addStrideCheck(Instruction *Loc) { - Instruction *tnullptr = nullptr; - if (!Legal->mustCheckStrides()) - return std::pair(tnullptr, tnullptr); - - IRBuilder<> ChkBuilder(Loc); - - // Emit checks. - Value *Check = nullptr; - Instruction *FirstInst = nullptr; - for (SmallPtrSet::iterator SI = Legal->strides_begin(), - SE = Legal->strides_end(); - SI != SE; ++SI) { - Value *Ptr = stripIntegerCast(*SI); - Value *C = ChkBuilder.CreateICmpNE(Ptr, ConstantInt::get(Ptr->getType(), 1), - "stride.chk"); - // Store the first instruction we create. - FirstInst = getFirstInst(FirstInst, C, Loc); - if (Check) - Check = ChkBuilder.CreateOr(Check, C); - else - Check = C; - } - - // We have to do this trickery because the IRBuilder might fold the check to a - // constant expression in which case there is no Instruction anchored in a - // the block. - LLVMContext &Ctx = Loc->getContext(); - Instruction *TheCheck = - BinaryOperator::CreateAnd(Check, ConstantInt::getTrue(Ctx)); - ChkBuilder.Insert(TheCheck, "stride.not.one"); - FirstInst = getFirstInst(FirstInst, TheCheck, Loc); - - return std::make_pair(FirstInst, TheCheck); -} - PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start, Value *End, @@ -2709,26 +2696,26 @@ LoopBypassBlocks.push_back(BB); } -void InnerLoopVectorizer::emitStrideChecks(Loop *L, - BasicBlock *Bypass) { +void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { BasicBlock *BB = L->getLoopPreheader(); - + // Generate the code to check that the strides we assumed to be one are really // one. We want the new basic block to start at the first instruction in a // sequence of instructions that form a check. - Instruction *StrideCheck; + Instruction *SCEVCheck; Instruction *FirstCheckInst; - std::tie(FirstCheckInst, StrideCheck) = addStrideCheck(BB->getTerminator()); - if (!StrideCheck) + std::tie(FirstCheckInst, SCEVCheck) = + Preds.generateGuardCond(BB->getTerminator(), SE); + if (!SCEVCheck) return; // Create a new block containing the stride check. - BB->setName("vector.stridecheck"); + BB->setName("vector.scevcheck"); auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); if (L->getParentLoop()) L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); ReplaceInstWithInst(BB->getTerminator(), - BranchInst::Create(Bypass, NewBB, StrideCheck)); + BranchInst::Create(Bypass, NewBB, SCEVCheck)); LoopBypassBlocks.push_back(BB); AddedSafetyChecks = true; } @@ -2848,10 +2835,10 @@ // Now, compare the new count to zero. If it is zero skip the vector loop and // jump to the scalar loop. emitVectorLoopEnteredCheck(Lp, ScalarPH); - // Generate the code to check that the strides we assumed to be one are really - // one. We want the new basic block to start at the first instruction in a - // sequence of instructions that form a check. - emitStrideChecks(Lp, ScalarPH); + // Generate the code to check any assumptions that we've made for SCEV + // expressions. + emitSCEVChecks(Lp, ScalarPH); + // Generate the code that checks in runtime if arrays overlap. We put the // checks into a separate block to make the more common case of few elements // faster. @@ -4277,7 +4264,7 @@ } bool LoopVectorizationLegality::canVectorizeMemory() { - LAI = &LAA->getInfo(TheLoop, Strides); + LAI = &LAA->getInfo(TheLoop, Strides, SCEVCheckThreshold); auto &OptionalReport = LAI->getReport(); if (OptionalReport) emitAnalysis(VectorizationReport(*OptionalReport)); @@ -4293,6 +4280,16 @@ } Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); + Preds.add(&LAI->Preds); + + if (Preds.getComplexity() > SCEVCheckThreshold) { + emitAnalysis(VectorizationReport() + << "Too many SCEV assuptions need to be made and checked " + << "at runtime"); + DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n"); + + return false; + } return true; } @@ -4407,7 +4404,8 @@ StoreInst *SI = dyn_cast(I); Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); - int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides); + int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides, SCEVCheckThreshold, + Preds); // The factor of the corresponding interleave group. unsigned Factor = std::abs(Stride); @@ -4416,7 +4414,9 @@ if (Factor < 2 || Factor > MaxInterleaveGroupFactor) continue; - const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, + SCEVCheckThreshold, + Preds, TheLoop, Ptr); PointerType *PtrTy = dyn_cast(Ptr->getType()); unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType());