Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -656,8 +656,10 @@ /// /// If necessary this method will version the stride of the pointer according /// to \p PtrToStride and therefore add a new predicate to \p Preds. +/// The \p Assume parameter indicates if we are allowed to make additional +/// run-time assumptions. int isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap); + const ValueToValueMap &StridesMap, bool Assume = false); /// \brief This analysis provides dependence information for the memory accesses /// of a loop. Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -179,7 +179,7 @@ FoldingSetNodeIDRef FastID; public: - enum SCEVPredicateKind { P_Union, P_Equal }; + enum SCEVPredicateKind { P_Union, P_Equal, P_AddRecOverflow }; protected: SCEVPredicateKind Kind; @@ -269,6 +269,49 @@ } }; + /// SCEVWrapPredicate - This class represents an assumption + /// made on an AddRec expression. Given an affine AddRec expression + /// {a,+,b}, we assume that it has nsw or nuw flags. + class SCEVWrapPredicate final : public SCEVPredicate { + public: + /// Similar to SCEV::NoWrapFlags, but with slightly different semantics + /// for FlagNUW. The increment is considered to be signed, and a + b + /// (where b is the increment) is considered to wrap if: + /// a + b != zext(a) + sext(b) + /// + /// In the integer domain this is equivalent to 0 <= a + b < 2^n + /// + /// FlagNSW has identical semantics with SCEV::FlagNSW + enum NoWrapFlags { + FlagAnyWrap = 0, // No guarantee. + FlagNUW = (1 << 0), // No signed wrap. + FlagNSW = (1 << 1), // No unsigned wrap. + NoWrapMask = (1 << 2) - 1 + }; + + private: + const SCEVAddRecExpr *AR; + NoWrapFlags Flags; + + public: + SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, + NoWrapFlags Flags); + + /// \brief Returns the set assumed no overflow flags. + NoWrapFlags getFlags() const { return Flags; } + + /// Implementation of the SCEVPredicate interface + const SCEV *getExpr() const override; + bool implies(const SCEVPredicate *N) const override; + void print(raw_ostream &OS, unsigned Depth = 0) const override; + bool isAlwaysTrue() const override; + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVPredicate *P) { + return P->getKind() == P_AddRecOverflow; + } + }; + /// SCEVUnionPredicate - This class represents a composition of other /// SCEV predicates, and is the class that most clients will interact with. /// This is equivalent to a logical "AND" of all the predicates in the union. @@ -1248,8 +1291,18 @@ const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS, const SCEVConstant *RHS); + const SCEVPredicate * + getAddRecOverflowPredicate(const SCEVAddRecExpr *AR, + SCEVWrapPredicate::NoWrapFlags AddedFlags); + /// Re-writes the SCEV according to the Predicates in \p Preds. - const SCEV *rewriteUsingPredicate(const SCEV *Scev, SCEVUnionPredicate &A); + const SCEV *rewriteUsingPredicate(const SCEV *Scev, const Loop *L, + SCEVUnionPredicate &A); + /// Tries and convert the \p Scev expression to an AddRec expression, + /// adding additional predicates to \p Preds as required. + const SCEV *convertSCEVToAddRecWithPredicates(const SCEV *Scev, + const Loop *L, + SCEVUnionPredicate &Preds); private: /// Compute the backedge taken count knowing the interval difference, the @@ -1340,7 +1393,7 @@ /// - lowers the number of expression rewrites. class PredicatedScalarEvolution { public: - PredicatedScalarEvolution(ScalarEvolution &SE); + PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L); const SCEVUnionPredicate &getUnionPredicate() const; /// \brief Returns the SCEV expression of V, in the context of the current /// SCEV predicate. @@ -1350,6 +1403,14 @@ const SCEV *getSCEV(Value *V); /// \brief Adds a new predicate. void addPredicate(const SCEVPredicate &Pred); + /// \brief Attempts to produce an AddRecExpr for V by adding additional + /// SCEV predicates. + const SCEV *getAsAddRec(Value *V); + /// \brief Proves that V doesn't overflow by adding SCEV predicate. + void setNoOverflow(Value *V, SCEVWrapPredicate::NoWrapFlags Flags); + /// \brief Returns true if we've proved that V doesn't wrap by means of a + /// SCEV predicate. + bool hasNoOverflow(Value *V, SCEVWrapPredicate::NoWrapFlags Flags); /// \brief Returns the ScalarEvolution analysis used. ScalarEvolution *getSE() const { return &SE; } @@ -1366,8 +1427,12 @@ /// rewrites, we will rewrite the previous result instead of the original /// SCEV. DenseMap RewriteMap; + /// Records what NoWrap flags we've added to a Value *. + DenseMap FlagsMap; /// The ScalarEvolution analysis. ScalarEvolution &SE; + /// The analyzed Loop. + Loop &L; /// The SCEVPredicate that forms our context. We will rewrite all /// expressions assuming that this predicate true. SCEVUnionPredicate Preds; Index: include/llvm/Analysis/ScalarEvolutionExpander.h =================================================================== --- include/llvm/Analysis/ScalarEvolutionExpander.h +++ include/llvm/Analysis/ScalarEvolutionExpander.h @@ -162,6 +162,16 @@ Value *expandEqualPredicate(const SCEVEqualPredicate *Pred, Instruction *Loc); + /// \brief Generates code that evaluates if the \p AR expression will + /// overflow. + Value *generateOverflowCheck(const SCEVAddRecExpr *AR, Instruction *Loc, + bool Signed); + + /// \brief A specialized variant of expandCodeForPredicate, handling the + /// case when we are expanding code for a SCEVWrapPredicate. + Value *expandAddRecOverflowPredicate(const SCEVWrapPredicate *P, + Instruction *Loc); + /// \brief A specialized variant of expandCodeForPredicate, handling the /// case when we are expanding code for a SCEVUnionPredicate. Value *expandUnionPredicate(const SCEVUnionPredicate *Pred, Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -773,7 +773,7 @@ /// \brief Return true if an AddRec pointer \p Ptr is unsigned non-wrapping, /// i.e. monotonically increasing/decreasing. static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, - ScalarEvolution *SE, const Loop *L) { + PredicatedScalarEvolution &PSE, const Loop *L) { // FIXME: This should probably only return true for NUW. if (AR->getNoWrapFlags(SCEV::NoWrapMask)) return true; @@ -809,7 +809,7 @@ // Assume constant for other the operand so that the AddRec can be // easily found. isa(OBO->getOperand(1))) { - auto *OpScev = SE->getSCEV(OBO->getOperand(0)); + auto *OpScev = PSE.getSCEV(OBO->getOperand(0)); if (auto *OpAR = dyn_cast(OpScev)) return OpAR->getLoop() == L && OpAR->getNoWrapFlags(SCEV::FlagNSW); @@ -820,31 +820,35 @@ /// \brief Check whether the access through \p Ptr has a constant stride. int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, - const Loop *Lp, const ValueToValueMap &StridesMap) { + const Loop *Lp, const ValueToValueMap &StridesMap, + bool Assume) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); // Make sure that the pointer does not point to aggregate types. auto *PtrTy = cast(Ty); if (PtrTy->getElementType()->isAggregateType()) { - DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type" - << *Ptr << "\n"); + DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type" << *Ptr + << "\n"); return 0; } const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); + if (Assume && !AR) + AR = dyn_cast(PSE.getAsAddRec(Ptr)); + if (!AR) { - DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " - << *Ptr << " SCEV: " << *PtrScev << "\n"); + DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " << *Ptr + << " SCEV: " << *PtrScev << "\n"); return 0; } // The accesss function must stride over the innermost loop. if (Lp != AR->getLoop()) { - DEBUG(dbgs() << "LAA: Bad stride - Not striding over innermost loop " << - *Ptr << " SCEV: " << *PtrScev << "\n"); + DEBUG(dbgs() << "LAA: Bad stride - Not striding over innermost loop " + << *Ptr << " SCEV: " << *PtrScev << "\n"); } // The address calculation must not wrap. Otherwise, a dependence could be @@ -855,12 +859,21 @@ // to access the pointer value "0" which is undefined behavior in address // space 0, therefore we can also vectorize this case. bool IsInBoundsGEP = isInBoundsGep(Ptr); - bool IsNoWrapAddRec = isNoWrapAddRec(Ptr, AR, PSE.getSE(), Lp); + bool IsNoWrapAddRec = PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::FlagNUW) || + isNoWrapAddRec(Ptr, AR, PSE, Lp); bool IsInAddressSpaceZero = PtrTy->getAddressSpace() == 0; if (!IsNoWrapAddRec && !IsInBoundsGEP && !IsInAddressSpaceZero) { - DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " - << *Ptr << " SCEV: " << *PtrScev << "\n"); - return 0; + if (Assume) { + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::FlagNUW); + IsNoWrapAddRec = true; + DEBUG(dbgs() << "LAA: Pointer may wrap in the address space " << *Ptr + << " SCEV: " << *PtrScev + << " added an overflow assumption\n"); + } else { + DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " + << *Ptr << " SCEV: " << *PtrScev << "\n"); + return 0; + } } // Check the step is constant. @@ -894,8 +907,13 @@ // know we can't "wrap around the address space". In case of address space // zero we know that this won't happen without triggering undefined behavior. if (!IsNoWrapAddRec && (IsInBoundsGEP || IsInAddressSpaceZero) && - Stride != 1 && Stride != -1) - return 0; + Stride != 1 && Stride != -1) { + if (Assume) { + // We can avoid this case by adding a run-time check. + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::FlagNUW); + } else + return 0; + } return Stride; } @@ -1050,8 +1068,8 @@ const SCEV *AScev = replaceSymbolicStrideSCEV(PSE, Strides, APtr); const SCEV *BScev = replaceSymbolicStrideSCEV(PSE, Strides, BPtr); - int StrideAPtr = isStridedPtr(PSE, APtr, InnermostLoop, Strides); - int StrideBPtr = isStridedPtr(PSE, BPtr, InnermostLoop, Strides); + int StrideAPtr = isStridedPtr(PSE, APtr, InnermostLoop, Strides, true); + int StrideBPtr = isStridedPtr(PSE, BPtr, InnermostLoop, Strides, true); const SCEV *Src = AScev; const SCEV *Sink = BScev; @@ -1750,7 +1768,7 @@ const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, const ValueToValueMap &Strides) - : PSE(*SE), PtrRtChecking(SE), DepChecker(PSE, L), TheLoop(L), DL(DL), + : PSE(*SE, *L), PtrRtChecking(SE), DepChecker(PSE, L), TheLoop(L), DL(DL), TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -9591,17 +9591,36 @@ return Eq; } +const SCEVPredicate *ScalarEvolution::getAddRecOverflowPredicate( + const SCEVAddRecExpr *AR, SCEVWrapPredicate::NoWrapFlags AddedFlags) { + FoldingSetNodeID ID; + // Unique this node based on the arguments + ID.AddInteger(SCEVPredicate::P_AddRecOverflow); + ID.AddPointer(AR); + ID.AddInteger(AddedFlags); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + auto *OF = new (SCEVAllocator) + SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags); + UniquePreds.InsertNode(OF, IP); + return OF; +} + namespace { + class SCEVPredicateRewriter : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, - SCEVUnionPredicate &A) { - SCEVPredicateRewriter Rewriter(SE, A); + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE, SCEVUnionPredicate &A, + bool Assume) { + SCEVPredicateRewriter Rewriter(L, SE, A, Assume); return Rewriter.visit(Scev); } - SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P) - : SCEVRewriteVisitor(SE), P(P) {} + SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, + SCEVUnionPredicate &P, bool Assume) + : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { auto ExprPreds = P.getPredicatesForExpr(Expr); @@ -9613,14 +9632,73 @@ return Expr; } + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nuw + // flag. Add the nuw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + // We would also like to add the NUW flag here. We add the NUW + // flag to the assumption set but cannot set it on the expression + // because it would pollute ScalarEvolution's cache. + if (addOverflowAssumption(AR, SCEVWrapPredicate::FlagNUW)) + return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getZeroExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nsw + // flag. Add the nsw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + // We would also like to add the NUW flag here. We add the NUW + // flag to the assumption set but cannot set it on the expression + // because it would pollute ScalarEvolution's cache. + if (addOverflowAssumption(AR, SCEVWrapPredicate::FlagNSW)) + return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getSignExtendExpr(Operand, Expr->getType()); + } + private: + bool addOverflowAssumption(const SCEVAddRecExpr *AR, + SCEVWrapPredicate::NoWrapFlags AddedFlags) { + auto *A = SE.getAddRecOverflowPredicate(AR, AddedFlags); + if (!Assume) { + // Check if we've already made this assumption. + if (P.implies(A)) + return true; + return false; + } + P.add(A); + return true; + } + SCEVUnionPredicate &P; + const Loop *L; + bool Assume; }; } // end anonymous namespace const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, + const Loop *L, SCEVUnionPredicate &Preds) { - return SCEVPredicateRewriter::rewrite(Scev, *this, Preds); + return SCEVPredicateRewriter::rewrite(Scev, L, *this, Preds, false); +} + +const SCEV *ScalarEvolution::convertSCEVToAddRecWithPredicates( + const SCEV *Scev, const Loop *L, SCEVUnionPredicate &Preds) { + return SCEVPredicateRewriter::rewrite(Scev, L, *this, Preds, true); } /// SCEV predicates @@ -9652,6 +9730,35 @@ OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; } +SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, + const SCEVAddRecExpr *AR, + NoWrapFlags Flags) + : SCEVPredicate(ID, P_AddRecOverflow), AR(AR), Flags(Flags) {} + +const SCEV *SCEVWrapPredicate::getExpr() const { return AR; } + +bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast(N); + + if (!Op) + return false; + + return Op->AR == AR && Op->Flags == Flags; +} + +bool SCEVWrapPredicate::isAlwaysTrue() const { + return (Flags & ~AR->getNoWrapFlags()) == FlagAnyWrap; +} + +void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << *getExpr() << " Added Flags: "; + if (SCEVWrapPredicate::FlagNUW & getFlags()) + OS << ""; + if (SCEVWrapPredicate::FlagNSW & getFlags()) + OS << ""; + OS << "\n"; +} + /// Union predicates don't get cached so create a dummy set ID for it. SCEVUnionPredicate::SCEVUnionPredicate() : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} @@ -9708,8 +9815,9 @@ Preds.push_back(N); } -PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE) - : SE(SE), Generation(0) {} +PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, + Loop &L) + : SE(SE), L(L), Generation(0) {} const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { const SCEV *Expr = SE.getSCEV(V); @@ -9724,7 +9832,7 @@ if (Entry.second) Expr = Entry.second; - const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds); + const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds); Entry = {Generation, NewSCEV}; return NewSCEV; @@ -9746,7 +9854,47 @@ if (++Generation == 0) { for (auto &II : RewriteMap) { const SCEV *Rewritten = II.second.second; - II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)}; + II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)}; } } } + +void PredicatedScalarEvolution::setNoOverflow( + Value *V, SCEVWrapPredicate::NoWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast(Expr); + addPredicate(*SE.getAddRecOverflowPredicate(AR, Flags)); + + auto II = FlagsMap.insert({V, Flags}); + if (!II.second) + II.first->second = + (SCEVWrapPredicate::NoWrapFlags)(Flags | II.first->second); +} + +bool PredicatedScalarEvolution::hasNoOverflow( + Value *V, SCEVWrapPredicate::NoWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = static_cast(Expr); + + // The NSW flag has the same meaning for both ScalarEvolution and + // SCEVWrapPredicate. + if (ScalarEvolution::setFlags(AR->getNoWrapFlags(), SCEV::FlagNSW) == + AR->getNoWrapFlags()) + Flags = + (SCEVWrapPredicate::NoWrapFlags)(Flags & ~SCEVWrapPredicate::FlagNSW); + + auto II = FlagsMap.find(V); + + if (II != FlagsMap.end()) + Flags = (SCEVWrapPredicate::NoWrapFlags)(Flags & ~II->second); + + return Flags == SCEVWrapPredicate::FlagAnyWrap; +} + +const SCEV *PredicatedScalarEvolution::getAsAddRec(Value *V) { + const SCEV *Expr = this->getSCEV(V); + const SCEV *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds); + updateGeneration(); + RewriteMap[SE.getSCEV(V)] = {Generation, New}; + return New; +} Index: lib/Analysis/ScalarEvolutionExpander.cpp =================================================================== --- lib/Analysis/ScalarEvolutionExpander.cpp +++ lib/Analysis/ScalarEvolutionExpander.cpp @@ -1947,6 +1947,10 @@ return expandUnionPredicate(cast(Pred), IP); case SCEVPredicate::P_Equal: return expandEqualPredicate(cast(Pred), IP); + case SCEVPredicate::P_AddRecOverflow: { + auto *AddRecPred = cast(Pred); + return expandAddRecOverflowPredicate(AddRecPred, IP); + } } llvm_unreachable("Unknown SCEV predicate type"); } @@ -1961,6 +1965,63 @@ return I; } +Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, + Instruction *Loc, bool Signed) { + assert(AR->isAffine() && "Cannot generate RT check for " + "non-affine expression"); + + const SCEV *ExitCount = SE.getBackedgeTakenCount(AR->getLoop()); + const SCEV *Step = AR->getStepRecurrence(SE); + const SCEV *Start = AR->getStart(); + + unsigned DstBits = SE.getTypeSizeInBits(AR->getType()); + unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType()); + unsigned MaxBits = 2 * std::max(DstBits, SrcBits); + + auto *TripCount = SE.getTruncateOrZeroExtend(ExitCount, AR->getType()); + IntegerType *MaxTy = IntegerType::get(Loc->getContext(), MaxBits); + + assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count"); + + const auto *ExtendedTripCount = SE.getZeroExtendExpr(ExitCount, MaxTy); + const auto *ExtendedStep = SE.getSignExtendExpr(Step, MaxTy); + const auto *ExtendedStart = Signed ? SE.getSignExtendExpr(Start, MaxTy) + : SE.getZeroExtendExpr(Start, MaxTy); + + const SCEV *End = SE.getAddExpr(Start, SE.getMulExpr(TripCount, Step)); + const SCEV *RHS = Signed ? SE.getSignExtendExpr(End, MaxTy) + : SE.getZeroExtendExpr(End, MaxTy); + + const SCEV *LHS = SE.getAddExpr( + ExtendedStart, SE.getMulExpr(ExtendedTripCount, ExtendedStep)); + + // Do all SCEV expansions now. + Value *LHSVal = expandCodeFor(LHS, MaxTy, Loc); + Value *RHSVal = expandCodeFor(RHS, MaxTy, Loc); + + Builder.SetInsertPoint(Loc); + + return Builder.CreateICmp(ICmpInst::ICMP_NE, RHSVal, LHSVal); +} + +Value * +SCEVExpander::expandAddRecOverflowPredicate(const SCEVWrapPredicate *Pred, + Instruction *IP) { + const auto *A = static_cast(Pred->getExpr()); + auto *BoolType = IntegerType::get(IP->getContext(), 1); + Value *Check = ConstantInt::getNullValue(BoolType); + + // Add a check for NUW + if (Pred->getFlags() & SCEVWrapPredicate::FlagNUW) + Check = generateOverflowCheck(A, IP, false); + + // Add a check for NSW + if (Pred->getFlags() & SCEVWrapPredicate::FlagNSW) + Check = Builder.CreateOr(generateOverflowCheck(A, IP, true), Check); + + return Check; +} + Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union, Instruction *IP) { auto *BoolType = IntegerType::get(IP->getContext(), 1); Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1745,7 +1745,7 @@ } } - PredicatedScalarEvolution PSE(*SE); + PredicatedScalarEvolution PSE(*SE, *L); // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements; Index: test/Transforms/LoopVectorize/same-base-access.ll =================================================================== --- test/Transforms/LoopVectorize/same-base-access.ll +++ test/Transforms/LoopVectorize/same-base-access.ll @@ -62,11 +62,9 @@ } - -; We don't vectorize this function because A[i*7] is scalarized, and the -; different scalars can in theory wrap around and overwrite other scalar -; elements. At the moment we only allow read/write access to arrays -; that are consecutive. +; A[i*7] is scalarized, and the different scalars can in theory wrap +; around and overwrite other scalar elements. However we can still +; vectorize because we can version the loop to avoid this case. ; ; void foo(int *a) { ; for (int i=0; i<256; ++i) { @@ -78,7 +76,7 @@ ; } ; CHECK-LABEL: @func2( -; CHECK-NOT: <4 x i32> +; CHECK: <4 x i32> ; CHECK: ret define i32 @func2(i32* nocapture %a) nounwind uwtable ssp { br label %1