Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -32,6 +32,7 @@ class ScalarEvolution; class Loop; class SCEV; +class SCEVUnionPredicate; /// Optimization analysis message produced during vectorization. Messages inform /// the user why vectorization did not occur. @@ -176,10 +177,11 @@ const SmallVectorImpl &Instrs) const; }; - MemoryDepChecker(ScalarEvolution *Se, const Loop *L) + MemoryDepChecker(ScalarEvolution *Se, const Loop *L, + SCEVUnionPredicate &Preds) : SE(Se), InnermostLoop(L), AccessIdx(0), ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true), - RecordInterestingDependences(true) {} + RecordInterestingDependences(true), Preds(Preds) {} /// \brief Register the location (instructions are given increasing numbers) /// of a write access. @@ -289,6 +291,15 @@ /// \brief Check whether the data dependence could prevent store-load /// forwarding. bool couldPreventStoreLoadForward(unsigned Distance, unsigned TypeByteSize); + + /// The SCEV predicate containing all the SCEV-related assumptions. + /// The dependence checker needs this in order to convert SCEVs of pointers + /// to more accurate expressions in the context of existing assumptions. + /// We also need this in case assumptions about SCEV expressions need to + /// be made in order to avoid unknown dependences. For example we might + /// assume a unit stride for a pointer in order to prove that a memory access + /// is strided and doesn't wrap. + SCEVUnionPredicate &Preds; }; /// \brief Holds information about the memory runtime legality checks to verify @@ -330,8 +341,13 @@ } /// Insert a pointer and calculate the start and end SCEVs. + /// \p We need Preds in order to compute the SCEV expression of the pointer + /// according to the assumptions that we've made during the analysis. + /// The method might also version the pointer stride according to \p Strides, + /// and change \p Preds. void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, - unsigned ASId, const ValueToValueMap &Strides); + unsigned ASId, const ValueToValueMap &Strides, + SCEVUnionPredicate &Preds); /// \brief No run-time memory checking is necessary. bool empty() const { return Pointers.empty(); } @@ -537,6 +553,15 @@ return StoreToLoopInvariantAddress; } + /// The SCEV predicate contains all the SCEV-related assumptions. + /// The is used to keep track of the minimal set of assumptions on SCEV + /// expressions that the analysis needs to make in order to return a + /// meaningful result. All SCEV expressions during the analysis should be + /// re-written (and therefore simplified) according to Preds. + /// A user of LoopAccessAnalysis will need to emit the runtime checks + /// associated with this predicate. + SCEVUnionPredicate Preds; + private: /// \brief Analyze the loop. Substitute symbolic strides using Strides. void analyzeLoop(const ValueToValueMap &Strides); @@ -583,19 +608,26 @@ Value *stripIntegerCast(Value *V); ///\brief Return the SCEV corresponding to a pointer with the symbolic stride -///replaced with constant one. +/// replaced with constant one, assuming \p Preds is true. +/// +/// If necessary this method will version the stride of the pointer according +/// to \p PtrToStride and therefore add a new predicate to \p Preds. /// /// If \p OrigPtr is not null, use it to look up the stride value instead of \p /// Ptr. \p PtrToStride provides the mapping between the pointer value and its /// stride as collected by LoopVectorizationLegality::collectStridedAccess. const SCEV *replaceSymbolicStrideSCEV(ScalarEvolution *SE, const ValueToValueMap &PtrToStride, - Value *Ptr, Value *OrigPtr = nullptr); + SCEVUnionPredicate &Preds, Value *Ptr, + Value *OrigPtr = nullptr); /// \brief Check the stride of the pointer and ensure that it does not wrap in -/// the address space. +/// the address space, assuming \p Preds is true. +/// +/// If necessary this method will version the stride of the pointer according +/// to \p PtrToStride and therefore add a new predicate to \p Preds. int isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap); + const ValueToValueMap &StridesMap, SCEVUnionPredicate &Preds); /// \brief This analysis provides dependence information for the memory accesses /// of a loop. Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -48,10 +48,15 @@ class Loop; class LoopInfo; class Operator; - class SCEVUnknown; - class SCEVAddRecExpr; class SCEV; - template<> struct FoldingSetTrait; + class SCEVAddRecExpr; + class SCEVConstant; + class SCEVExpander; + class SCEVPredicate; + class SCEVUnknown; + + template <> struct FoldingSetTrait; + template <> struct FoldingSetTrait; /// This class represents an analyzed expression in the program. These are /// opaque objects that the client is not allowed to do much with directly. @@ -164,6 +169,148 @@ static bool classof(const SCEV *S); }; + /// SCEVPredicate - This class represents an assumption made using SCEV + /// expressions which can be checked at run-time. + class SCEVPredicate : public FoldingSetNode { + friend struct FoldingSetTrait; + + /// A reference to an Interned FoldingSetNodeID for this node. The + /// ScalarEvolution's BumpPtrAllocator holds the data. + FoldingSetNodeIDRef FastID; + + public: + enum SCEVPredicateKind { P_Union, P_Equal }; + + protected: + SCEVPredicateKind Kind; + + public: + SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind); + + virtual ~SCEVPredicate() {} + + SCEVPredicateKind getKind() const { return Kind; } + + /// \brief Returns the estimated complexity of this predicate. + /// This is roughly measured in the number of run-time checks required. + virtual unsigned getComplexity() { return 1; } + + /// \brief Returns true if the predicate is always true. This means that no + /// assumptions were made and nothing needs to be checked at run-time. + virtual bool isAlwaysTrue() const = 0; + + /// \brief Returns true if this predicate implies \p N. + virtual bool implies(const SCEVPredicate *N) const = 0; + + /// \brief Prints a textual representation of this predicate with an + /// indentation of \p Depth. + virtual void print(raw_ostream &OS, unsigned Depth = 0) const = 0; + + /// \brief Returns the SCEV to which this predicate applies, or nullptr + /// if this is a SCEVUnionPredicate. + virtual const SCEV *getExpr() const = 0; + }; + + inline raw_ostream &operator<<(raw_ostream &OS, const SCEVPredicate &P) { + P.print(OS); + return OS; + } + + // Specialize FoldingSetTrait for SCEVPredicate to avoid needing to compute + // temporary FoldingSetNodeID values. + template <> + struct FoldingSetTrait + : DefaultFoldingSetTrait { + + static void Profile(const SCEVPredicate &X, FoldingSetNodeID &ID) { + ID = X.FastID; + } + + static bool Equals(const SCEVPredicate &X, const FoldingSetNodeID &ID, + unsigned IDHash, FoldingSetNodeID &TempID) { + return ID == X.FastID; + } + static unsigned ComputeHash(const SCEVPredicate &X, + FoldingSetNodeID &TempID) { + return X.FastID.ComputeHash(); + } + }; + + /// SCEVEqualPredicate - This class represents an assumption that two SCEV + /// expressions are equal, and this can be checked at run-time. We assume + /// that the left hand side is a SCEVUnknown and the right hand side a + /// constant. + class SCEVEqualPredicate : public SCEVPredicate { + /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a + /// constant. + const SCEVUnknown *LHS; + const SCEVConstant *RHS; + + public: + SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS, + const SCEVConstant *RHS); + + /// Implementation of the SCEVPredicate interface + bool implies(const SCEVPredicate *N) const override; + void print(raw_ostream &OS, unsigned Depth = 0) const override; + bool isAlwaysTrue() const override; + const SCEV *getExpr() const; + + /// \brief Returns the left hand side of the equality. + const SCEVUnknown *getLHS() const { return LHS; } + + /// \brief Returns the right hand side of the equality. + const SCEVConstant *getRHS() const { return RHS; } + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVPredicate *P) { + return P->getKind() == P_Equal; + } + }; + + /// SCEVUnionPredicate - This class represents a composition of other + /// SCEV predicates, and is the class that most clients will interact with. + /// This is equivalent to a logical "AND" of all the predicates in the union. + class SCEVUnionPredicate : public SCEVPredicate { + private: + typedef DenseMap> + PredicateMap; + + /// Vector with references to all predicates in this union. + SmallVector Preds; + /// Maps SCEVs to predicates for quick look-ups. + PredicateMap SCEVToPreds; + + public: + SCEVUnionPredicate(); + + const SmallVectorImpl &getPredicates() const { + return Preds; + } + + /// \brief Adds a predicate to this union. + void add(const SCEVPredicate *N); + + /// \brief Returns a reference to a vector containing all predicates + /// which apply to \p Expr. + ArrayRef getPredicatesForExpr(const SCEV *Expr); + + /// Implementation of the SCEVPredicate interface + bool isAlwaysTrue() const override; + bool implies(const SCEVPredicate *N) const override; + void print(raw_ostream &OS, unsigned Depth) const; + const SCEV *getExpr() const override; + + /// \brief We estimate the complexity of a union predicate as the size + /// number of predicates in the union. + unsigned getComplexity() override { return Preds.size(); } + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVPredicate *P) { + return P->getKind() == P_Union; + } + }; + /// The main scalar evolution driver. Because client code (intentionally) /// can't do much with the SCEV objects directly, they must ask this class /// for services. @@ -1097,6 +1244,12 @@ return F.getParent()->getDataLayout(); } + const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS, + const SCEVConstant *RHS); + + /// Re-writes the SCEV according to the Predicates in \p Preds. + const SCEV *rewriteUsingPredicate(const SCEV *Scev, SCEVUnionPredicate &A); + private: /// Compute the backedge taken count knowing the interval difference, the /// stride and presence of the equality in the comparison. @@ -1117,6 +1270,7 @@ private: FoldingSet UniqueSCEVs; + FoldingSet UniquePreds; BumpPtrAllocator SCEVAllocator; /// The head of a linked list of all SCEVUnknown values that have been Index: include/llvm/Analysis/ScalarEvolutionExpander.h =================================================================== --- include/llvm/Analysis/ScalarEvolutionExpander.h +++ include/llvm/Analysis/ScalarEvolutionExpander.h @@ -151,6 +151,22 @@ /// block. Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I); + /// \brief Generates a code sequence that evaluates this predicate. + /// The inserted instructions will be at position \p Loc. + /// The result will be of type i1 and will have a value of 0 when the + /// predicate is false and 1 otherwise. + Value *expandCodeForPredicate(const SCEVPredicate *Pred, Instruction *Loc); + + /// \brief A specialized variant of expandCodeForPredicate, handling the + /// case when we are expanding code for a SCEVEqualPredicate. + Value *expandEqualPredicate(const SCEVEqualPredicate *Pred, + Instruction *Loc); + + /// \brief A specialized variant of expandCodeForPredicate, handling the + /// case when we are expanding code for a SCEVUnionPredicate. + Value *expandUnionPredicate(const SCEVUnionPredicate *Pred, + Instruction *Loc); + /// \brief Set the current IV increment loop and position. void setIVIncInsertPos(const Loop *L, Instruction *Pos) { assert(!CanonicalMode && Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -89,8 +89,8 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE, const ValueToValueMap &PtrToStride, + SCEVUnionPredicate &Preds, Value *Ptr, Value *OrigPtr) { - const SCEV *OrigSCEV = SE->getSCEV(Ptr); // If there is an entry in the map return the SCEV of the pointer with the @@ -108,22 +108,28 @@ ValueToValueMap RewriteMap; RewriteMap[StrideVal] = One; - const SCEV *ByOne = - SCEVParameterRewriter::rewrite(OrigSCEV, *SE, RewriteMap, true); + const auto *U = cast(SE->getSCEV(StrideVal)); + const auto *CT = + static_cast(SE->getOne(StrideVal->getType())); + + Preds.add(SE->getEqualPredicate(U, CT)); + + const SCEV *ByOne = SE->rewriteUsingPredicate(OrigSCEV, Preds); DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne << "\n"); return ByOne; } // Otherwise, just return the SCEV of the original pointer. - return SE->getSCEV(Ptr); + return OrigSCEV; } void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, unsigned ASId, - const ValueToValueMap &Strides) { + const ValueToValueMap &Strides, + SCEVUnionPredicate &Preds) { // Get the stride replaced scev. - const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr); const SCEVAddRecExpr *AR = dyn_cast(Sc); assert(AR && "Invalid addrec expression"); const SCEV *Ex = SE->getBackedgeTakenCount(Lp); @@ -417,9 +423,9 @@ typedef SmallPtrSet MemAccessInfoSet; AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI, - MemoryDepChecker::DepCandidates &DA) - : DL(Dl), AST(*AA), LI(LI), DepCands(DA), - IsRTCheckAnalysisNeeded(false) {} + MemoryDepChecker::DepCandidates &DA, SCEVUnionPredicate &Preds) + : DL(Dl), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false), + Preds(Preds) {} /// \brief Register a load and whether it is only read from. void addLoad(MemoryLocation &Loc, bool IsReadOnly) { @@ -504,14 +510,18 @@ /// (i.e. ShouldRetryWithRuntimeCheck), isDependencyCheckNeeded is cleared /// while this remains set if we have potentially dependent accesses. bool IsRTCheckAnalysisNeeded; + + /// The SCEV predicate containing all the SCEV-related assumptions. + SCEVUnionPredicate &Preds; }; } // end anonymous namespace /// \brief Check whether a pointer can participate in a runtime bounds check. static bool hasComputableBounds(ScalarEvolution *SE, - const ValueToValueMap &Strides, Value *Ptr) { - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const ValueToValueMap &Strides, Value *Ptr, + Loop *L, SCEVUnionPredicate &Preds) { + const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR) return false; @@ -554,11 +564,11 @@ else ++NumReadPtrChecks; - if (hasComputableBounds(SE, StridesMap, Ptr) && + if (hasComputableBounds(SE, StridesMap, Ptr, TheLoop, Preds) && // When we run after a failing dependency check we have to make sure // we don't have wrapping pointers. (!ShouldCheckStride || - isStridedPtr(SE, Ptr, TheLoop, StridesMap) == 1)) { + isStridedPtr(SE, Ptr, TheLoop, StridesMap, Preds) == 1)) { // The id of the dependence set. unsigned DepId; @@ -572,7 +582,7 @@ // Each access has its own dependence set. DepId = RunningDepId++; - RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap); + RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, Preds); DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); } else { @@ -803,7 +813,8 @@ /// \brief Check whether the access through \p Ptr has a constant stride. int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap) { + const ValueToValueMap &StridesMap, + SCEVUnionPredicate &Preds) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); @@ -815,7 +826,7 @@ return 0; } - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Ptr); + const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Preds, Ptr); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR) { @@ -1026,11 +1037,11 @@ BPtr->getType()->getPointerAddressSpace()) return Dependence::Unknown; - const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, APtr); - const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, BPtr); + const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, APtr); + const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, BPtr); - int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides); - int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides); + int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides, Preds); + int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides, Preds); const SCEV *Src = AScev; const SCEV *Sink = BScev; @@ -1429,7 +1440,7 @@ MemoryDepChecker::DepCandidates DependentAccesses; AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(), - AA, LI, DependentAccesses); + AA, LI, DependentAccesses, Preds); // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects // multiple times on the same object. If the ptr is accessed twice, once @@ -1480,7 +1491,8 @@ // read a few words, modify, and write a few words, and some of the // words may be written to the same address. bool IsReadOnlyPtr = false; - if (Seen.insert(Ptr).second || !isStridedPtr(SE, Ptr, TheLoop, Strides)) { + if (Seen.insert(Ptr).second || + !isStridedPtr(SE, Ptr, TheLoop, Strides, Preds)) { ++NumReads; IsReadOnlyPtr = true; } @@ -1728,7 +1740,7 @@ const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, const ValueToValueMap &Strides) - : PtrRtChecking(SE), DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL), + : PtrRtChecking(SE), DepChecker(SE, L, Preds), TheLoop(L), SE(SE), DL(DL), TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { @@ -1763,6 +1775,9 @@ OS.indent(Depth) << "Store to invariant address was " << (StoreToLoopInvariantAddress ? "" : "not ") << "found in loop.\n"; + + OS.indent(Depth) << "SCEV assumptions:\n"; + Preds.print(OS, Depth); } const LoopAccessInfo & @@ -1776,8 +1791,8 @@ if (!LAI) { const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - LAI = llvm::make_unique(L, SE, DL, TLI, AA, DT, LI, - Strides); + LAI = + llvm::make_unique(L, SE, DL, TLI, AA, DT, LI, Strides); #ifndef NDEBUG LAI->NumSymbolicStrides = Strides.size(); #endif Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -8878,6 +8878,7 @@ UnsignedRanges(std::move(Arg.UnsignedRanges)), SignedRanges(std::move(Arg.SignedRanges)), UniqueSCEVs(std::move(Arg.UniqueSCEVs)), + UniquePreds(std::move(Arg.UniquePreds)), SCEVAllocator(std::move(Arg.SCEVAllocator)), FirstUnknown(Arg.FirstUnknown) { Arg.FirstUnknown = nullptr; @@ -9381,3 +9382,134 @@ AU.addRequiredTransitive(); AU.addRequiredTransitive(); } + +const SCEVPredicate * +ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, + const SCEVConstant *RHS) { + FoldingSetNodeID ID; + // Unique this node based on the arguments + ID.AddInteger(SCEVPredicate::P_Equal); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + SCEVEqualPredicate *Eq = new (SCEVAllocator) + SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS); + UniquePreds.InsertNode(Eq, IP); + return Eq; +} + +class SCEVPredicateRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, + SCEVUnionPredicate &A) { + SCEVPredicateRewriter Rewriter(SE, A); + return Rewriter.visit(Scev); + } + + SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P) + : SCEVRewriteVisitor(SE), P(P) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + auto ExprPreds = P.getPredicatesForExpr(Expr); + for (auto *Pred : ExprPreds) + if (const auto *IPred = dyn_cast(Pred)) + if (IPred->getLHS() == Expr) + return IPred->getRHS(); + + return Expr; + } + +private: + SCEVUnionPredicate &P; +}; + +const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, + SCEVUnionPredicate &Preds) { + return SCEVPredicateRewriter::rewrite(Scev, *this, Preds); +} + +/// SCEV predicates +SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, + SCEVPredicateKind Kind) + : FastID(ID), Kind(Kind) {} + +SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, + const SCEVUnknown *LHS, + const SCEVConstant *RHS) + : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {} + +bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast(N); + + if (!Op) + return false; + + return Op->LHS == LHS && Op->RHS == RHS; +} + +bool SCEVEqualPredicate::isAlwaysTrue() const { return false; } + +const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; } + +void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; +} + +/// Union predicates don't get cached so create a dummy set ID for it. +SCEVUnionPredicate::SCEVUnionPredicate() + : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} + +bool SCEVUnionPredicate::isAlwaysTrue() const { + return std::all_of(Preds.begin(), Preds.end(), + [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); +} + +ArrayRef +SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { + auto I = SCEVToPreds.find(Expr); + if (I == SCEVToPreds.end()) + return ArrayRef(); + return I->second; +} + +bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { + if (const auto *Set = dyn_cast(N)) + return std::all_of( + Set->Preds.begin(), Set->Preds.end(), + [this](const SCEVPredicate *I) { return this->implies(I); }); + + auto ScevPredsIt = SCEVToPreds.find(N->getExpr()); + if (ScevPredsIt == SCEVToPreds.end()) + return false; + auto &SCEVPreds = ScevPredsIt->second; + + return std::any_of(SCEVPreds.begin(), SCEVPreds.end(), + [N](const SCEVPredicate *I) { return I->implies(N); }); +} + +const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; } + +void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { + for (auto Pred : Preds) + Pred->print(OS, Depth); +} + +void SCEVUnionPredicate::add(const SCEVPredicate *N) { + if (const auto *Set = dyn_cast(N)) { + for (auto Pred : Set->Preds) + add(Pred); + return; + } + + if (implies(N)) + return; + + const SCEV *Key = N->getExpr(); + assert(Key && "Only SCEVUnionPredicate doesn't have an " + " associated expression!"); + + SCEVToPreds[Key].push_back(N); + Preds.push_back(N); +} Index: lib/Analysis/ScalarEvolutionExpander.cpp =================================================================== --- lib/Analysis/ScalarEvolutionExpander.cpp +++ lib/Analysis/ScalarEvolutionExpander.cpp @@ -1944,6 +1944,43 @@ return false; } +Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, + Instruction *IP) { + assert(IP); + switch (Pred->getKind()) { + case SCEVPredicate::P_Union: + return expandUnionPredicate(cast(Pred), IP); + case SCEVPredicate::P_Equal: + return expandEqualPredicate(cast(Pred), IP); + } + llvm_unreachable("Unknown SCEV predicate type"); +} + +Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred, + Instruction *IP) { + Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP); + Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP); + + Builder.SetInsertPoint(IP); + auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check"); + return I; +} + +Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union, + Instruction *IP) { + auto *BoolType = IntegerType::get(IP->getContext(), 1); + Value *Check = ConstantInt::getNullValue(BoolType); + + // Loop over all checks in this set. + for (auto Pred : Union->getPredicates()) { + auto *NextCheck = expandCodeForPredicate(Pred, IP); + Builder.SetInsertPoint(IP); + Check = Builder.CreateOr(Check, NextCheck); + } + + return Check; +} + namespace { // Search for a SCEV subexpression that is not safe to expand. Any expression // that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -222,6 +222,15 @@ cl::desc("The maximum allowed number of runtime memory checks with a " "vectorize(enable) pragma.")); +static cl::opt VectorizeSCEVCheckThreshold( + "vectorize-scev-check-threshold", cl::init(16), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed.")); + +static cl::opt PragmaVectorizeSCEVCheckThreshold( + "pragma-vectorize-scev-check-threshold", cl::init(128), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed with a " + "vectorize(enable) pragma")); + namespace { // Forward declarations. @@ -273,12 +282,12 @@ InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, unsigned VecWidth, - unsigned UnrollFactor) + unsigned UnrollFactor, SCEVUnionPredicate &Preds) : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), TLI(TLI), TTI(TTI), VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()), Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor), TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr), - AddedSafetyChecks(false) {} + AddedSafetyChecks(false), Preds(Preds) {} // Perform the actual loop widening (vectorization). // MinimumBitWidths maps scalar integer values to the smallest bitwidth they @@ -315,12 +324,6 @@ typedef DenseMap, VectorParts> EdgeMaskCache; - /// \brief Add checks for strides that were assumed to be 1. - /// - /// Returns the last check instruction and the first check instruction in the - /// pair as (first, last). - std::pair addStrideCheck(Instruction *Loc); - /// Create an empty loop, based on the loop ranges of the old loop. void createEmptyLoop(); /// Create a new induction variable inside L. @@ -404,11 +407,12 @@ void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass); /// Emit a bypass check to see if the vector trip count is nonzero. void emitVectorLoopEnteredCheck(Loop *L, BasicBlock *Bypass); - /// Emit bypass checks to check if strides we've assumed to be one really are. - void emitStrideChecks(Loop *L, BasicBlock *Bypass); + /// Emit a bypass check to see if all of the SCEV assumptions we've + /// had to make are correct. + void emitSCEVChecks(Loop *L, BasicBlock *Bypass); /// Emit bypass checks to check any memory assumptions we may have made. void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass); - + /// This is a helper class that holds the vectorizer state. It maps scalar /// instructions to vector instructions. When the code is 'unrolled' then /// then a single scalar value is mapped to multiple vector parts. The parts @@ -516,14 +520,23 @@ // Record whether runtime check is added. bool AddedSafetyChecks; + + /// The SCEV predicate containing all the SCEV-related assumptions. + /// The predicate is used to simplify existing expressions in the + /// context of existing SCEV assumptions. Since legality checking is + /// not done here, we don't need to use this predicate to record + /// further assumptions. + SCEVUnionPredicate &Preds; }; class InnerLoopUnroller : public InnerLoopVectorizer { public: InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, unsigned UnrollFactor) - : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor) {} + const TargetTransformInfo *TTI, unsigned UnrollFactor, + SCEVUnionPredicate &Preds) + : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor, + Preds) {} private: void scalarizeInstruction(Instruction *Instr, @@ -744,8 +757,9 @@ /// between the member and the group in a map. class InterleavedAccessInfo { public: - InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT) - : SE(SE), TheLoop(L), DT(DT) {} + InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT, + SCEVUnionPredicate &Preds) + : SE(SE), TheLoop(L), DT(DT), Preds(Preds) {} ~InterleavedAccessInfo() { SmallSet DelSet; @@ -779,6 +793,13 @@ Loop *TheLoop; DominatorTree *DT; + /// The SCEV predicate containing all the SCEV-related assumptions. + /// The predicate is used to simplify SCEV expressions in the + /// context of existing SCEV assumptions. The interleaved access + /// analysis can also add new predicates (for example by versioning + /// strides of pointers). + SCEVUnionPredicate &Preds; + /// Holds the relationships between the members and the interleave group. DenseMap InterleaveGroupMap; @@ -1141,11 +1162,13 @@ Function *F, const TargetTransformInfo *TTI, LoopAccessAnalysis *LAA, LoopVectorizationRequirements *R, - const LoopVectorizeHints *H) + const LoopVectorizeHints *H, + SCEVUnionPredicate &Preds) : NumPredStores(0), TheLoop(L), SE(SE), TLI(TLI), TheFunction(F), - TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(SE, L, DT), - Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), - Requirements(R), Hints(H) {} + TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), + InterleaveInfo(SE, L, DT, Preds), Induction(nullptr), + WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), Hints(H), + Preds(Preds) {} /// ReductionList contains the reduction descriptors for all /// of the reductions that were found in the loop. @@ -1344,7 +1367,14 @@ /// While vectorizing these instructions we have to generate a /// call to the appropriate masked intrinsic - SmallPtrSet MaskedOp; + SmallPtrSet MaskedOp; + + /// The SCEV predicate containing all the SCEV-related assumptions. + /// The predicate is used to simplify SCEV expressions in the + /// context of existing SCEV assumptions. The analysis will also + /// add a minimal set of new predicates if this is required to + /// enable vectorization/unrolling. + SCEVUnionPredicate &Preds; }; /// LoopVectorizationCostModel - estimates the expected speedups due to @@ -1360,9 +1390,10 @@ LoopVectorizationLegality *Legal, const TargetTransformInfo &TTI, const TargetLibraryInfo *TLI, DemandedBits *DB, - AssumptionCache *AC, - const Function *F, const LoopVectorizeHints *Hints, - SmallPtrSetImpl &ValuesToIgnore) + AssumptionCache *AC, const Function *F, + const LoopVectorizeHints *Hints, + SmallPtrSetImpl &ValuesToIgnore, + SCEVUnionPredicate &Preds) : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB), TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {} @@ -1690,10 +1721,12 @@ } } + SCEVUnionPredicate Preds; + // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements; LoopVectorizationLegality LVL(L, SE, DT, TLI, AA, F, TTI, LAA, - &Requirements, &Hints); + &Requirements, &Hints, Preds); if (!LVL.canVectorize()) { DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); emitMissedWarning(F, L, Hints); @@ -1712,7 +1745,7 @@ // Use the cost model. LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, DB, AC, F, &Hints, - ValuesToIgnore); + ValuesToIgnore, Preds); // Check the function attributes to find out if this function should be // optimized for size. @@ -1823,7 +1856,7 @@ assert(IC > 1 && "interleave count should not be 1 or 0"); // If we decided that it is not legal to vectorize the loop then // interleave it. - InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC); + InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC, Preds); Unroller.vectorize(&LVL, CM.MinBWs); emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(), @@ -1831,7 +1864,7 @@ Twine(IC) + ")"); } else { // If we decided that it is *legal* to vectorize the loop then do it. - InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC); + InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC, Preds); LB.vectorize(&LVL, CM.MinBWs); ++LoopsVectorized; @@ -1992,7 +2025,7 @@ // %idxprom = zext i32 %mul to i64 << Safe cast. // %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom // - Last = replaceSymbolicStrideSCEV(SE, Strides, + Last = replaceSymbolicStrideSCEV(SE, Strides, Preds, Gep->getOperand(InductionOperand), Gep); if (const SCEVCastExpr *C = dyn_cast(Last)) Last = @@ -2551,56 +2584,8 @@ } } -static Instruction *getFirstInst(Instruction *FirstInst, Value *V, - Instruction *Loc) { - if (FirstInst) - return FirstInst; - if (Instruction *I = dyn_cast(V)) - return I->getParent() == Loc->getParent() ? I : nullptr; - return nullptr; -} - -std::pair -InnerLoopVectorizer::addStrideCheck(Instruction *Loc) { - Instruction *tnullptr = nullptr; - if (!Legal->mustCheckStrides()) - return std::pair(tnullptr, tnullptr); - - IRBuilder<> ChkBuilder(Loc); - - // Emit checks. - Value *Check = nullptr; - Instruction *FirstInst = nullptr; - for (SmallPtrSet::iterator SI = Legal->strides_begin(), - SE = Legal->strides_end(); - SI != SE; ++SI) { - Value *Ptr = stripIntegerCast(*SI); - Value *C = ChkBuilder.CreateICmpNE(Ptr, ConstantInt::get(Ptr->getType(), 1), - "stride.chk"); - // Store the first instruction we create. - FirstInst = getFirstInst(FirstInst, C, Loc); - if (Check) - Check = ChkBuilder.CreateOr(Check, C); - else - Check = C; - } - - // We have to do this trickery because the IRBuilder might fold the check to a - // constant expression in which case there is no Instruction anchored in a - // the block. - LLVMContext &Ctx = Loc->getContext(); - Instruction *TheCheck = - BinaryOperator::CreateAnd(Check, ConstantInt::getTrue(Ctx)); - ChkBuilder.Insert(TheCheck, "stride.not.one"); - FirstInst = getFirstInst(FirstInst, TheCheck, Loc); - - return std::make_pair(FirstInst, TheCheck); -} - -PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, - Value *Start, - Value *End, - Value *Step, +PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start, + Value *End, Value *Step, Instruction *DL) { BasicBlock *Header = L->getHeader(); BasicBlock *Latch = L->getLoopLatch(); @@ -2735,26 +2720,26 @@ LoopBypassBlocks.push_back(BB); } -void InnerLoopVectorizer::emitStrideChecks(Loop *L, - BasicBlock *Bypass) { +void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { BasicBlock *BB = L->getLoopPreheader(); - - // Generate the code to check that the strides we assumed to be one are really - // one. We want the new basic block to start at the first instruction in a + + // Generate the code to check that the SCEV assumptions that we made. + // We want the new basic block to start at the first instruction in a // sequence of instructions that form a check. - Instruction *StrideCheck; - Instruction *FirstCheckInst; - std::tie(FirstCheckInst, StrideCheck) = addStrideCheck(BB->getTerminator()); - if (!StrideCheck) - return; + SCEVExpander Exp(*SE, Bypass->getModule()->getDataLayout(), "scev.check"); + Value *SCEVCheck = Exp.expandCodeForPredicate(&Preds, BB->getTerminator()); + + if (auto *C = dyn_cast(SCEVCheck)) + if (C->isZero()) + return; // Create a new block containing the stride check. - BB->setName("vector.stridecheck"); + BB->setName("vector.scevcheck"); auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); if (L->getParentLoop()) L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); ReplaceInstWithInst(BB->getTerminator(), - BranchInst::Create(Bypass, NewBB, StrideCheck)); + BranchInst::Create(Bypass, NewBB, SCEVCheck)); LoopBypassBlocks.push_back(BB); AddedSafetyChecks = true; } @@ -2874,10 +2859,10 @@ // Now, compare the new count to zero. If it is zero skip the vector loop and // jump to the scalar loop. emitVectorLoopEnteredCheck(Lp, ScalarPH); - // Generate the code to check that the strides we assumed to be one are really - // one. We want the new basic block to start at the first instruction in a - // sequence of instructions that form a check. - emitStrideChecks(Lp, ScalarPH); + // Generate the code to check any assumptions that we've made for SCEV + // expressions. + emitSCEVChecks(Lp, ScalarPH); + // Generate the code that checks in runtime if arrays overlap. We put the // checks into a separate block to make the more common case of few elements // faster. @@ -4130,7 +4115,19 @@ // Analyze interleaved memory accesses. if (UseInterleaved) - InterleaveInfo.analyzeInterleaving(Strides); + InterleaveInfo.analyzeInterleaving(Strides); + + unsigned SCEVThreshold = VectorizeSCEVCheckThreshold; + if (Hints->getForce() == LoopVectorizeHints::FK_Enabled) + SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; + + if (Preds.getComplexity() > SCEVThreshold) { + emitAnalysis(VectorizationReport() + << "Too many SCEV assumptions need to be made and checked " + << "at runtime"); + DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n"); + return false; + } // Okay! We can vectorize. At this point we don't have any other mem analysis // which may limit our maximum vectorization factor, so just return true with @@ -4436,6 +4433,7 @@ } Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); + Preds.add(&LAI->Preds); return true; } @@ -4550,7 +4548,7 @@ StoreInst *SI = dyn_cast(I); Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); - int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides); + int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides, Preds); // The factor of the corresponding interleave group. unsigned Factor = std::abs(Stride); @@ -4559,7 +4557,7 @@ if (Factor < 2 || Factor > MaxInterleaveGroupFactor) continue; - const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Ptr); + const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr); PointerType *PtrTy = dyn_cast(Ptr->getType()); unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType());