Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -656,8 +656,10 @@ /// /// If necessary this method will version the stride of the pointer according /// to \p PtrToStride and therefore add a new predicate to \p Preds. +/// The \p Assume parameter indicates if we are allowed to make additional +/// run-time assumptions. int isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap); + const ValueToValueMap &StridesMap, bool Assume = false); /// \brief Returns true if the memory operations \p A and \p B are consecutive. /// This is a simple API that does not depend on the analysis pass. Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -30,6 +30,7 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/IR/ValueMap.h" #include "llvm/Pass.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/DataTypes.h" @@ -178,7 +179,7 @@ FoldingSetNodeIDRef FastID; public: - enum SCEVPredicateKind { P_Union, P_Equal }; + enum SCEVPredicateKind { P_Union, P_Equal, P_Wrap }; protected: SCEVPredicateKind Kind; @@ -268,6 +269,98 @@ } }; + /// SCEVWrapPredicate - This class represents an assumption + /// made on an AddRec expression. Given an affine AddRec expression + /// {a,+,b}, we assume that it has the nssw or nusw flags (defined + /// below). + class SCEVWrapPredicate final : public SCEVPredicate { + public: + /// Similar to SCEV::NoWrapFlags, but with slightly different semantics + /// for FlagNUSW. The increment is considered to be signed, and a + b + /// (where b is the increment) is considered to wrap if: + /// zext(a + b) != zext(a) + sext(b) + /// + /// If Signed is a function that takes an n-bit tuple and maps to the + /// integer domain as the tuples value interpreted as twos complement, + /// and Unsigned a function that takes an n-bit tuple and maps to the + /// integer domain as as the base two value of input tuple, then a + b + /// has IncrementNUSW iff: + /// + /// 0 <= Unsigned(a) + Signed(b) < 2^n + /// + /// The IncrementNSSW flag has identical semantics with SCEV::FlagNSW. + /// + /// Note that the IncrementNUSW flag is not commutative: if base + inc + /// has IncrementNUSW, then inc + base doesn't neccessarily have this + /// property. The reason for this is that this is used for sign/zero + /// extending affine AddRec SCEV expressions when a SCEVWrapPredicate is + /// assumed. A {base,+,inc} expression is already non-commutative with + /// regards to base and inc, since it is interpreted as: + /// (((base + inc) + inc) + inc) ... + enum IncrementWrapFlags { + IncrementAnyWrap = 0, // No guarantee. + IncrementNUSW = (1 << 0), // No unsigned with signed increment wrap. + IncrementNSSW = (1 << 1), // No signed with signed increment wrap + // (equivalent with SCEV::NSW) + IncrementNoWrapMask = (1 << 2) - 1 + }; + + /// Convenient IncrementWrapFlags manipulation methods. + static SCEVWrapPredicate::IncrementWrapFlags LLVM_ATTRIBUTE_UNUSED_RESULT + clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, + SCEVWrapPredicate::IncrementWrapFlags OffFlags) { + assert((Flags & IncrementNoWrapMask) == Flags && "Invalid flags value!"); + assert((OffFlags & IncrementNoWrapMask) == OffFlags && + "Invalid flags value!"); + return (SCEVWrapPredicate::IncrementWrapFlags)(Flags & ~OffFlags); + } + + static SCEVWrapPredicate::IncrementWrapFlags LLVM_ATTRIBUTE_UNUSED_RESULT + maskFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, int Mask) { + assert((Flags & IncrementNoWrapMask) == Flags && "Invalid flags value!"); + assert((Mask & IncrementNoWrapMask) == Mask && "Invalid mask value!"); + + return (SCEVWrapPredicate::IncrementWrapFlags)(Flags & Mask); + } + + static SCEVWrapPredicate::IncrementWrapFlags LLVM_ATTRIBUTE_UNUSED_RESULT + setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, + SCEVWrapPredicate::IncrementWrapFlags OnFlags) { + assert((Flags & IncrementNoWrapMask) == Flags && "Invalid flags value!"); + assert((OnFlags & IncrementNoWrapMask) == OnFlags && + "Invalid flags value!"); + + return (SCEVWrapPredicate::IncrementWrapFlags)(Flags | OnFlags); + } + + /// \brief Returns the set of SCEVWrapPredicate no wrap flags implied + /// by a SCEVAddRecExpr. + static SCEVWrapPredicate::IncrementWrapFlags + getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE); + + private: + const SCEVAddRecExpr *AR; + IncrementWrapFlags Flags; + + public: + explicit SCEVWrapPredicate(const FoldingSetNodeIDRef ID, + const SCEVAddRecExpr *AR, + IncrementWrapFlags Flags); + + /// \brief Returns the set assumed no overflow flags. + IncrementWrapFlags getFlags() const { return Flags; } + /// Implementation of the SCEVPredicate interface + const SCEV *getExpr() const override; + bool implies(const SCEVPredicate *N) const override; + void print(raw_ostream &OS, unsigned Depth = 0) const override; + bool isAlwaysTrue() const override; + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVPredicate *P) { + return P->getKind() == P_Wrap; + } + }; + /// SCEVUnionPredicate - This class represents a composition of other /// SCEV predicates, and is the class that most clients will interact with. /// This is equivalent to a logical "AND" of all the predicates in the union. @@ -1251,8 +1344,18 @@ const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS, const SCEVConstant *RHS); + const SCEVPredicate * + getWrapPredicate(const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags); + /// Re-writes the SCEV according to the Predicates in \p Preds. - const SCEV *rewriteUsingPredicate(const SCEV *Scev, SCEVUnionPredicate &A); + const SCEV *rewriteUsingPredicate(const SCEV *Scev, const Loop *L, + SCEVUnionPredicate &A); + /// Tries to convert the \p Scev expression to an AddRec expression, + /// adding additional predicates to \p Preds as required. + const SCEV *convertSCEVToAddRecWithPredicates(const SCEV *Scev, + const Loop *L, + SCEVUnionPredicate &Preds); private: /// Compute the backedge taken count knowing the interval difference, the @@ -1343,7 +1446,7 @@ /// - lowers the number of expression rewrites. class PredicatedScalarEvolution { public: - PredicatedScalarEvolution(ScalarEvolution &SE); + PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L); const SCEVUnionPredicate &getUnionPredicate() const; /// \brief Returns the SCEV expression of V, in the context of the current /// SCEV predicate. @@ -1353,9 +1456,18 @@ const SCEV *getSCEV(Value *V); /// \brief Adds a new predicate. void addPredicate(const SCEVPredicate &Pred); + /// \brief Attempts to produce an AddRecExpr for V by adding additional + /// SCEV predicates. + const SCEV *getAsAddRec(Value *V); + /// \brief Proves that V doesn't overflow by adding SCEV predicate. + void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags); + /// \brief Returns true if we've proved that V doesn't wrap by means of a + /// SCEV predicate. + bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags); /// \brief Returns the ScalarEvolution analysis used. ScalarEvolution *getSE() const { return &SE; } - + /// We need to explicitly define the copy constructor because of FlagsMap. + PredicatedScalarEvolution(const PredicatedScalarEvolution&); private: /// \brief Increments the version number of the predicate. /// This needs to be called every time the SCEV predicate changes. @@ -1369,8 +1481,12 @@ /// rewrites, we will rewrite the previous result instead of the original /// SCEV. DenseMap RewriteMap; + /// Records what NoWrap flags we've added to a Value *. + ValueMap FlagsMap; /// The ScalarEvolution analysis. ScalarEvolution &SE; + /// The analyzed Loop. + const Loop &L; /// The SCEVPredicate that forms our context. We will rewrite all /// expressions assuming that this predicate true. SCEVUnionPredicate Preds; Index: include/llvm/Analysis/ScalarEvolutionExpander.h =================================================================== --- include/llvm/Analysis/ScalarEvolutionExpander.h +++ include/llvm/Analysis/ScalarEvolutionExpander.h @@ -162,6 +162,15 @@ Value *expandEqualPredicate(const SCEVEqualPredicate *Pred, Instruction *Loc); + /// \brief Generates code that evaluates if the \p AR expression will + /// overflow. + Value *generateOverflowCheck(const SCEVAddRecExpr *AR, Instruction *Loc, + bool Signed); + + /// \brief A specialized variant of expandCodeForPredicate, handling the + /// case when we are expanding code for a SCEVWrapPredicate. + Value *expandWrapPredicate(const SCEVWrapPredicate *P, Instruction *Loc); + /// \brief A specialized variant of expandCodeForPredicate, handling the /// case when we are expanding code for a SCEVUnionPredicate. Value *expandUnionPredicate(const SCEVUnionPredicate *Pred, Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -773,7 +773,7 @@ /// \brief Return true if an AddRec pointer \p Ptr is unsigned non-wrapping, /// i.e. monotonically increasing/decreasing. static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, - ScalarEvolution *SE, const Loop *L) { + PredicatedScalarEvolution &PSE, const Loop *L) { // FIXME: This should probably only return true for NUW. if (AR->getNoWrapFlags(SCEV::NoWrapMask)) return true; @@ -809,7 +809,7 @@ // Assume constant for other the operand so that the AddRec can be // easily found. isa(OBO->getOperand(1))) { - auto *OpScev = SE->getSCEV(OBO->getOperand(0)); + auto *OpScev = PSE.getSCEV(OBO->getOperand(0)); if (auto *OpAR = dyn_cast(OpScev)) return OpAR->getLoop() == L && OpAR->getNoWrapFlags(SCEV::FlagNSW); @@ -820,24 +820,28 @@ /// \brief Check whether the access through \p Ptr has a constant stride. int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, - const Loop *Lp, const ValueToValueMap &StridesMap) { + const Loop *Lp, const ValueToValueMap &StridesMap, + bool Assume) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); // Make sure that the pointer does not point to aggregate types. auto *PtrTy = cast(Ty); if (PtrTy->getElementType()->isAggregateType()) { - DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type" - << *Ptr << "\n"); + DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type" << *Ptr + << "\n"); return 0; } const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); + if (Assume && !AR) + AR = dyn_cast(PSE.getAsAddRec(Ptr)); + if (!AR) { - DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " - << *Ptr << " SCEV: " << *PtrScev << "\n"); + DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " << *Ptr + << " SCEV: " << *PtrScev << "\n"); return 0; } @@ -856,12 +860,22 @@ // to access the pointer value "0" which is undefined behavior in address // space 0, therefore we can also vectorize this case. bool IsInBoundsGEP = isInBoundsGep(Ptr); - bool IsNoWrapAddRec = isNoWrapAddRec(Ptr, AR, PSE.getSE(), Lp); + bool IsNoWrapAddRec = + PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW) || + isNoWrapAddRec(Ptr, AR, PSE, Lp); bool IsInAddressSpaceZero = PtrTy->getAddressSpace() == 0; if (!IsNoWrapAddRec && !IsInBoundsGEP && !IsInAddressSpaceZero) { - DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " - << *Ptr << " SCEV: " << *PtrScev << "\n"); - return 0; + if (Assume) { + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + IsNoWrapAddRec = true; + DEBUG(dbgs() << "LAA: Pointer may wrap in the address space " << *Ptr + << " SCEV: " << *PtrScev + << " added an overflow assumption\n"); + } else { + DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " + << *Ptr << " SCEV: " << *PtrScev << "\n"); + return 0; + } } // Check the step is constant. @@ -895,8 +909,17 @@ // know we can't "wrap around the address space". In case of address space // zero we know that this won't happen without triggering undefined behavior. if (!IsNoWrapAddRec && (IsInBoundsGEP || IsInAddressSpaceZero) && - Stride != 1 && Stride != -1) - return 0; + Stride != 1 && Stride != -1) { + if (Assume) { + // We can avoid this case by adding a run-time check. + DEBUG(dbgs() << "LAA: Non unit strided pointer which is not either " + << "inbouds or in address space 0 may wrap: " << *Ptr + << " SCEV: " << *PtrScev + << " added an overflow assumption\n"); + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + } else + return 0; + } return Stride; } @@ -1123,8 +1146,8 @@ const SCEV *AScev = replaceSymbolicStrideSCEV(PSE, Strides, APtr); const SCEV *BScev = replaceSymbolicStrideSCEV(PSE, Strides, BPtr); - int StrideAPtr = isStridedPtr(PSE, APtr, InnermostLoop, Strides); - int StrideBPtr = isStridedPtr(PSE, BPtr, InnermostLoop, Strides); + int StrideAPtr = isStridedPtr(PSE, APtr, InnermostLoop, Strides, true); + int StrideBPtr = isStridedPtr(PSE, BPtr, InnermostLoop, Strides, true); const SCEV *Src = AScev; const SCEV *Sink = BScev; @@ -1824,7 +1847,7 @@ const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, const ValueToValueMap &Strides) - : PSE(*SE), PtrRtChecking(SE), DepChecker(PSE, L), TheLoop(L), DL(DL), + : PSE(*SE, *L), PtrRtChecking(SE), DepChecker(PSE, L), TheLoop(L), DL(DL), TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -9545,17 +9545,40 @@ return Eq; } +const SCEVPredicate *ScalarEvolution::getWrapPredicate( + const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + FoldingSetNodeID ID; + // Unique this node based on the arguments + ID.AddInteger(SCEVPredicate::P_Wrap); + ID.AddPointer(AR); + ID.AddInteger(AddedFlags); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + auto *OF = new (SCEVAllocator) + SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags); + UniquePreds.InsertNode(OF, IP); + return OF; +} + namespace { + class SCEVPredicateRewriter : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, - SCEVUnionPredicate &A) { - SCEVPredicateRewriter Rewriter(SE, A); + // Rewrites Scev in the context of a loop L and the predicate A. + // If Assume is true, rewrite is free to add further predicates to A + // such that the result will be an AddRecExpr. + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE, SCEVUnionPredicate &A, + bool Assume) { + SCEVPredicateRewriter Rewriter(L, SE, A, Assume); return Rewriter.visit(Scev); } - SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P) - : SCEVRewriteVisitor(SE), P(P) {} + SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, + SCEVUnionPredicate &P, bool Assume) + : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { auto ExprPreds = P.getPredicatesForExpr(Expr); @@ -9567,14 +9590,67 @@ return Expr; } + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nuw + // flag. Add the nusw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) + return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getZeroExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nsw + // flag. Add the nssw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) + return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getSignExtendExpr(Operand, Expr->getType()); + } + private: + bool addOverflowAssumption(const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + auto *A = SE.getWrapPredicate(AR, AddedFlags); + if (!Assume) { + // Check if we've already made this assumption. + if (P.implies(A)) + return true; + return false; + } + P.add(A); + return true; + } + SCEVUnionPredicate &P; + const Loop *L; + bool Assume; }; } // end anonymous namespace const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, + const Loop *L, SCEVUnionPredicate &Preds) { - return SCEVPredicateRewriter::rewrite(Scev, *this, Preds); + return SCEVPredicateRewriter::rewrite(Scev, L, *this, Preds, false); +} + +const SCEV *ScalarEvolution::convertSCEVToAddRecWithPredicates( + const SCEV *Scev, const Loop *L, SCEVUnionPredicate &Preds) { + return SCEVPredicateRewriter::rewrite(Scev, L, *this, Preds, true); } /// SCEV predicates @@ -9604,6 +9680,59 @@ OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; } +SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, + const SCEVAddRecExpr *AR, + IncrementWrapFlags Flags) + : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {} + +const SCEV *SCEVWrapPredicate::getExpr() const { return AR; } + +bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast(N); + + return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; +} + +bool SCEVWrapPredicate::isAlwaysTrue() const { + SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags(); + IncrementWrapFlags IFlags = Flags; + + if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags) + IFlags = clearFlags(IFlags, IncrementNSSW); + + return IFlags == IncrementAnyWrap; +} + +void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << *getExpr() << " Added Flags: "; + if (SCEVWrapPredicate::IncrementNUSW & getFlags()) + OS << ""; + if (SCEVWrapPredicate::IncrementNSSW & getFlags()) + OS << ""; + OS << "\n"; +} + +SCEVWrapPredicate::IncrementWrapFlags +SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, + ScalarEvolution &SE) { + IncrementWrapFlags ImpliedFlags = IncrementAnyWrap; + SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags(); + + // We can safely transfer the NSW flag as NSSW. + if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags) + ImpliedFlags = IncrementNSSW; + + if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) { + // If the increment is positive, the SCEV NUW flag will also imply the + // WrapPredicate NUSW flag. + if (const auto *Step = dyn_cast(AR->getStepRecurrence(SE))) + if (Step->getValue()->getValue().isNonNegative()) + ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW); + } + + return ImpliedFlags; +} + /// Union predicates don't get cached so create a dummy set ID for it. SCEVUnionPredicate::SCEVUnionPredicate() : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} @@ -9660,8 +9789,9 @@ Preds.push_back(N); } -PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE) - : SE(SE), Generation(0) {} +PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, + Loop &L) + : SE(SE), L(L), Generation(0) {} const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { const SCEV *Expr = SE.getSCEV(V); @@ -9676,7 +9806,7 @@ if (Entry.second) Expr = Entry.second; - const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds); + const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds); Entry = {Generation, NewSCEV}; return NewSCEV; @@ -9698,7 +9828,54 @@ if (++Generation == 0) { for (auto &II : RewriteMap) { const SCEV *Rewritten = II.second.second; - II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)}; + II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)}; } } } + +void PredicatedScalarEvolution::setNoOverflow( + Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast(Expr); + + auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); + + // Clear the statically implied flags. + Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags); + addPredicate(*SE.getWrapPredicate(AR, Flags)); + + auto II = FlagsMap.insert({V, Flags}); + if (!II.second) + II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second); +} + +bool PredicatedScalarEvolution::hasNoOverflow( + Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast(Expr); + + Flags = SCEVWrapPredicate::clearFlags( + Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE)); + + auto II = FlagsMap.find(V); + + if (II != FlagsMap.end()) + Flags = SCEVWrapPredicate::clearFlags(Flags, II->second); + + return Flags == SCEVWrapPredicate::IncrementAnyWrap; +} + +const SCEV *PredicatedScalarEvolution::getAsAddRec(Value *V) { + const SCEV *Expr = this->getSCEV(V); + const SCEV *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds); + updateGeneration(); + RewriteMap[SE.getSCEV(V)] = {Generation, New}; + return New; +} + +PredicatedScalarEvolution:: +PredicatedScalarEvolution(const PredicatedScalarEvolution &Init) : + RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds) { + for (auto I = Init.FlagsMap.begin(), E = Init.FlagsMap.end(); I != E; ++I) + FlagsMap.insert(*I); +} Index: lib/Analysis/ScalarEvolutionExpander.cpp =================================================================== --- lib/Analysis/ScalarEvolutionExpander.cpp +++ lib/Analysis/ScalarEvolutionExpander.cpp @@ -1940,6 +1940,10 @@ return expandUnionPredicate(cast(Pred), IP); case SCEVPredicate::P_Equal: return expandEqualPredicate(cast(Pred), IP); + case SCEVPredicate::P_Wrap: { + auto *AddRecPred = cast(Pred); + return expandWrapPredicate(AddRecPred, IP); + } } llvm_unreachable("Unknown SCEV predicate type"); } @@ -1954,6 +1958,70 @@ return I; } +Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, + Instruction *Loc, bool Signed) { + assert(AR->isAffine() && "Cannot generate RT check for " + "non-affine expression"); + + const SCEV *ExitCount = SE.getBackedgeTakenCount(AR->getLoop()); + const SCEV *Step = AR->getStepRecurrence(SE); + const SCEV *Start = AR->getStart(); + + unsigned DstBits = SE.getTypeSizeInBits(AR->getType()); + unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType()); + unsigned MaxBits = 2 * std::max(DstBits, SrcBits); + + auto *TripCount = SE.getTruncateOrZeroExtend(ExitCount, AR->getType()); + IntegerType *MaxTy = IntegerType::get(Loc->getContext(), MaxBits); + + assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count"); + + const auto *ExtendedTripCount = SE.getZeroExtendExpr(ExitCount, MaxTy); + const auto *ExtendedStep = SE.getSignExtendExpr(Step, MaxTy); + const auto *ExtendedStart = Signed ? SE.getSignExtendExpr(Start, MaxTy) + : SE.getZeroExtendExpr(Start, MaxTy); + + const SCEV *End = SE.getAddExpr(Start, SE.getMulExpr(TripCount, Step)); + const SCEV *RHS = Signed ? SE.getSignExtendExpr(End, MaxTy) + : SE.getZeroExtendExpr(End, MaxTy); + + const SCEV *LHS = SE.getAddExpr( + ExtendedStart, SE.getMulExpr(ExtendedTripCount, ExtendedStep)); + + // Do all SCEV expansions now. + Value *LHSVal = expandCodeFor(LHS, MaxTy, Loc); + Value *RHSVal = expandCodeFor(RHS, MaxTy, Loc); + + Builder.SetInsertPoint(Loc); + + return Builder.CreateICmp(ICmpInst::ICMP_NE, RHSVal, LHSVal); +} + +Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred, + Instruction *IP) { + const auto *A = cast(Pred->getExpr()); + Value *NSSWCheck = nullptr, *NUSWCheck = nullptr; + + // Add a check for NUSW + if (Pred->getFlags() & SCEVWrapPredicate::IncrementNUSW) + NUSWCheck = generateOverflowCheck(A, IP, false); + + // Add a check for NSSW + if (Pred->getFlags() & SCEVWrapPredicate::IncrementNSSW) + NSSWCheck = generateOverflowCheck(A, IP, true); + + if (NUSWCheck && NSSWCheck) + return Builder.CreateOr(NUSWCheck, NSSWCheck); + + if (NUSWCheck) + return NUSWCheck; + + if (NSSWCheck) + return NSSWCheck; + + return ConstantInt::getFalse(IP->getContext()); +} + Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union, Instruction *IP) { auto *BoolType = IntegerType::get(IP->getContext(), 1); Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1751,7 +1751,7 @@ } } - PredicatedScalarEvolution PSE(*SE); + PredicatedScalarEvolution PSE(*SE, *L); // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements; Index: test/Analysis/LoopAccessAnalysis/wrapping-pointer-versioning.ll =================================================================== --- /dev/null +++ test/Analysis/LoopAccessAnalysis/wrapping-pointer-versioning.ll @@ -0,0 +1,241 @@ +; RUN: opt -basicaa -loop-accesses -analyze < %s | FileCheck %s -check-prefix=LAA +; RUN: opt -loop-vectorize -force-vector-interleave=1 -force-vector-width=4 -S < %s | FileCheck %s -check-prefix=LV + +target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" + +; For this loop: +; unsigned index = 0; +; for (int i = 0; i < n; i++) { +; A[2 * index] = A[2 * index] + B[i]; +; index++; +; } +; +; SCEV is unable to prove that A[2 * i] does not overflow. +; +; Analyzing the IR does not help us because the GEPs are not +; affine AddRecExprs. However, we can turn them into AddRecExprs +; using SCEV Predicates. +; +; Once we have an affine expression we need to add an additional NUSW +; to check that the pointers don't wrap since the GEPs are not +; inbound. + +; LAA-LABEL: f1 +; LAA: Memory dependences are safe{{$}} +; LAA: SCEV assumptions: +; LAA-NEXT: {0,+,2}<%for.body> Added Flags: +; LAA-NEXT: {%a,+,4}<%for.body> Added Flags: + +; The expression for %mul_ext as analyzed by SCEV is +; (zext i32 {0,+,2}<%for.body> to i64) +; We have added the nusw flag to turn this expression into i64 {0,+,2}<%for.body> + +; LV-LABEL: f1 +; LV-LABEL: vector.scevcheck +; LV: [[PredCheck0:%[^ ]*]] = icmp ne i128 +; LV: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]] +; LV: [[PredCheck1:%[^ ]*]] = icmp ne i128 +; LV: [[FinalCheck:%[^ ]*]] = or i1 [[Or0]], [[PredCheck1]] +; LV: br i1 [[FinalCheck]], label %scalar.ph, label %vector.ph +define void @f1(i16* noalias %a, + i16* noalias %b, i64 %N) { +entry: + br label %for.body + +for.body: ; preds = %for.body, %entry + %ind = phi i64 [ 0, %entry ], [ %inc, %for.body ] + %ind1 = phi i32 [ 0, %entry ], [ %inc1, %for.body ] + + %mul = mul i32 %ind1, 2 + %mul_ext = zext i32 %mul to i64 + + %arrayidxA = getelementptr i16, i16* %a, i64 %mul_ext + %loadA = load i16, i16* %arrayidxA, align 2 + + %arrayidxB = getelementptr i16, i16* %b, i64 %ind + %loadB = load i16, i16* %arrayidxB, align 2 + + %add = mul i16 %loadA, %loadB + + store i16 %add, i16* %arrayidxA, align 2 + + %inc = add nuw nsw i64 %ind, 1 + %inc1 = add i32 %ind1, 1 + + %exitcond = icmp eq i64 %inc, %N + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + ret void +} + +; For this loop: +; unsigned index = n; +; for (int i = 0; i < n; i++) { +; A[2 * index] = A[2 * index] + B[i]; +; index--; +; } +; +; the SCEV expression for 2 * index is not an AddRecExpr +; (and implictly not affine). However, we are able to make assumptions +; that will turn the expression into an affine one and continue the +; analysis. +; +; Once we have an affine expression we need to add an additional NUSW +; to check that the pointers don't wrap since the GEPs are not +; inbounds. +; +; This loop has a negative stride for A, and the nusw flag is required in +; order to properly extend the increment from i32 -4 to i64 -4. + +; LAA-LABEL: f2 +; LAA: Memory dependences are safe{{$}} +; LAA: SCEV assumptions: +; LAA-NEXT: {(2 * (trunc i64 %N to i32)),+,-2}<%for.body> Added Flags: +; LAA-NEXT: {((2 * (zext i32 (2 * (trunc i64 %N to i32)) to i64)) + %a),+,-4}<%for.body> Added Flags: + +; The expression for %mul_ext as analyzed by SCEV is +; (zext i32 {(2 * (trunc i64 %N to i32)),+,-2}<%for.body> to i64) +; We have added the nusw flag to turn this expression into i64 {zext i32 (2 * (trunc i64 %N to i32)) to i64,+,-2}<%for.body> + +; LV-LABEL: f2 +; LV-LABEL: vector.scevcheck +; LV: [[PredCheck0:%[^ ]*]] = icmp ne i128 +; LV: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]] +; LV: [[PredCheck1:%[^ ]*]] = icmp ne i128 +; LV: [[FinalCheck:%[^ ]*]] = or i1 [[Or0]], [[PredCheck1]] +; LV: br i1 [[FinalCheck]], label %scalar.ph, label %vector.ph +define void @f2(i16* noalias %a, + i16* noalias %b, i64 %N) { +entry: + %TruncN = trunc i64 %N to i32 + br label %for.body + +for.body: ; preds = %for.body, %entry + %ind = phi i64 [ 0, %entry ], [ %inc, %for.body ] + %ind1 = phi i32 [ %TruncN, %entry ], [ %dec, %for.body ] + + %mul = mul i32 %ind1, 2 + %mul_ext = zext i32 %mul to i64 + + %arrayidxA = getelementptr i16, i16* %a, i64 %mul_ext + %loadA = load i16, i16* %arrayidxA, align 2 + + %arrayidxB = getelementptr i16, i16* %b, i64 %ind + %loadB = load i16, i16* %arrayidxB, align 2 + + %add = mul i16 %loadA, %loadB + + store i16 %add, i16* %arrayidxA, align 2 + + %inc = add nuw nsw i64 %ind, 1 + %dec = sub i32 %ind1, 1 + + %exitcond = icmp eq i64 %inc, %N + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + ret void +} + +; We replicate the tests above, but this time sign extend 2 * index instead +; of zero extending it. + +; LAA-LABEL: f3 +; LAA: Memory dependences are safe{{$}} +; LAA: SCEV assumptions: +; LAA-NEXT: {0,+,2}<%for.body> Added Flags: +; LAA-NEXT: {%a,+,4}<%for.body> Added Flags: + +; The expression for %mul_ext as analyzed by SCEV is +; i64 (sext i32 {0,+,2}<%for.body> to i64) +; We have added the nssw flag to turn this expression into i64 {0,+,2}<%for.body> + +; LV-LABEL: f3 +; LV-LABEL: vector.scevcheck +; LV: [[PredCheck0:%[^ ]*]] = icmp ne i128 +; LV: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]] +; LV: [[PredCheck1:%[^ ]*]] = icmp ne i128 +; LV: [[FinalCheck:%[^ ]*]] = or i1 [[Or0]], [[PredCheck1]] +; LV: br i1 [[FinalCheck]], label %scalar.ph, label %vector.ph +define void @f3(i16* noalias %a, + i16* noalias %b, i64 %N) { +entry: + br label %for.body + +for.body: ; preds = %for.body, %entry + %ind = phi i64 [ 0, %entry ], [ %inc, %for.body ] + %ind1 = phi i32 [ 0, %entry ], [ %inc1, %for.body ] + + %mul = mul i32 %ind1, 2 + %mul_ext = sext i32 %mul to i64 + + %arrayidxA = getelementptr i16, i16* %a, i64 %mul_ext + %loadA = load i16, i16* %arrayidxA, align 2 + + %arrayidxB = getelementptr i16, i16* %b, i64 %ind + %loadB = load i16, i16* %arrayidxB, align 2 + + %add = mul i16 %loadA, %loadB + + store i16 %add, i16* %arrayidxA, align 2 + + %inc = add nuw nsw i64 %ind, 1 + %inc1 = add i32 %ind1, 1 + + %exitcond = icmp eq i64 %inc, %N + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + ret void +} + +; LAA-LABEL: f4 +; LAA: Memory dependences are safe{{$}} +; LAA: SCEV assumptions: +; LAA-NEXT: {(2 * (trunc i64 %N to i32)),+,-2}<%for.body> Added Flags: +; LAA-NEXT: {((2 * (sext i32 (2 * (trunc i64 %N to i32)) to i64)) + %a),+,-4}<%for.body> Added Flags: + +; The expression for %mul_ext as analyzed by SCEV is +; i64 (sext i32 {(2 * (trunc i64 %N to i32)),+,-2}<%for.body> to i64) +; We have added the nssw flag to turn this expression into i64 {sext i32 (2 * (trunc i64 %N to i32)) to i64,+,-2}<%for.body> + +; LV-LABEL: f4 +; LV-LABEL: vector.scevcheck +; LV: [[PredCheck0:%[^ ]*]] = icmp ne i128 +; LV: [[Or0:%[^ ]*]] = or i1 false, [[PredCheck0]] +; LV: [[PredCheck1:%[^ ]*]] = icmp ne i128 +; LV: [[FinalCheck:%[^ ]*]] = or i1 [[Or0]], [[PredCheck1]] +; LV: br i1 [[FinalCheck]], label %scalar.ph, label %vector.ph +define void @f4(i16* noalias %a, + i16* noalias %b, i64 %N) { +entry: + %TruncN = trunc i64 %N to i32 + br label %for.body + +for.body: ; preds = %for.body, %entry + %ind = phi i64 [ 0, %entry ], [ %inc, %for.body ] + %ind1 = phi i32 [ %TruncN, %entry ], [ %dec, %for.body ] + + %mul = mul i32 %ind1, 2 + %mul_ext = sext i32 %mul to i64 + + %arrayidxA = getelementptr i16, i16* %a, i64 %mul_ext + %loadA = load i16, i16* %arrayidxA, align 2 + + %arrayidxB = getelementptr i16, i16* %b, i64 %ind + %loadB = load i16, i16* %arrayidxB, align 2 + + %add = mul i16 %loadA, %loadB + + store i16 %add, i16* %arrayidxA, align 2 + + %inc = add nuw nsw i64 %ind, 1 + %dec = sub i32 %ind1, 1 + + %exitcond = icmp eq i64 %inc, %N + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + ret void +} Index: test/Transforms/LoopVectorize/same-base-access.ll =================================================================== --- test/Transforms/LoopVectorize/same-base-access.ll +++ test/Transforms/LoopVectorize/same-base-access.ll @@ -62,11 +62,9 @@ } - -; We don't vectorize this function because A[i*7] is scalarized, and the -; different scalars can in theory wrap around and overwrite other scalar -; elements. At the moment we only allow read/write access to arrays -; that are consecutive. +; A[i*7] is scalarized, and the different scalars can in theory wrap +; around and overwrite other scalar elements. However we can still +; vectorize because we can version the loop to avoid this case. ; ; void foo(int *a) { ; for (int i=0; i<256; ++i) { @@ -78,7 +76,7 @@ ; } ; CHECK-LABEL: @func2( -; CHECK-NOT: <4 x i32> +; CHECK: <4 x i32> ; CHECK: ret define i32 @func2(i32* nocapture %a) nounwind uwtable ssp { br label %1