Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -486,6 +486,12 @@ /// LoopInfo &LI; + /// The generation count of the analysis. + /// + /// This is used by various parts of the analysis to track which SCEVs have + /// been lazily updated after some partial invalidation event. + uint64_t Generation; + /// This SCEV is used to represent unknown trip counts and things. std::unique_ptr CouldNotCompute; @@ -1045,9 +1051,15 @@ /// Return an existing SCEV for V if there is one, otherwise return nullptr. const SCEV *getExistingSCEV(Value *V); - /// Return false iff given SCEV contains a SCEVUnknown with NULL value- - /// pointer. - bool checkValidity(const SCEV *S) const; + /// Validate a SCEV's subgraph. + /// + /// This will ensure that the given SCEV and all SCEVs it transitively + /// references are valid. If any part of the subgraph is detected to be + /// invalid, this will remove those SCEVs from the analysis and return false. + /// Note that this may not remove *all* invalid SCEVs from the analysis, but + /// will definitively remove `S` if invalid and will remove others + /// optimistically. + bool validateSCEVSubgraph(const SCEV *S); /// Return true if `ExtendOpTy`({`Start`,+,`Step`}) can be proved to be /// equal to {`ExtendOpTy`(`Start`),+,`ExtendOpTy`(`Step`)}. This is @@ -1125,6 +1137,11 @@ /// Return the Value set from which the SCEV expr is generated. SetVector *getSCEVValues(const SCEV *S); + /// Increment the SCEV generation. + /// + /// This shifts the generation count forward. + void incrementSCEVGeneration(); + /// Erase Value from ValueExprMap and ExprValueMap. void eraseValueFromMap(Value *V); @@ -1618,6 +1635,13 @@ const SCEV *getOrCreateAddExpr(SmallVectorImpl &Ops, SCEV::NoWrapFlags Flags); + SCEV *lookupUniqueSCEV(const FoldingSetNodeID &ID, void *&IP) { + if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + if (validateSCEVSubgraph(S)) + return S; + return nullptr; + } + private: FoldingSet UniqueSCEVs; FoldingSet UniquePreds; Index: include/llvm/Analysis/ScalarEvolutionExpressions.h =================================================================== --- include/llvm/Analysis/ScalarEvolutionExpressions.h +++ include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -51,14 +51,64 @@ } }; + /// This is the base class for all expression classes. + /// + /// The common aspect to them is that they reference other SCEVs as part of + /// their definition and so we need to be able to invalidate them when part + /// of their expression graph changes. + class SCEVExpr : public SCEV { + private: + // This is mutable as it inherently doesn't make up part of the SCEV state. + mutable uint64_t Generation; + + public: + explicit SCEVExpr(const FoldingSetNodeIDRef ID, unsigned SCEVTy, + uint64_t CurrentGeneration) + : SCEV(ID, SCEVTy), Generation(CurrentGeneration) {} + + uint64_t getGeneration() const { return Generation; } + + void setGeneration(uint64_t NewGen) const { Generation = NewGen; } + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEV *S) { + switch (S->getSCEVType()) { + case scTruncate: + case scZeroExtend: + case scSignExtend: + case scAddExpr: + case scMulExpr: + case scSMaxExpr: + case scUMaxExpr: + case scAddRecExpr: + case scUDivExpr: + return true; + + case scUnknown: + case scConstant: + case scCouldNotCompute: + return false; + } + llvm_unreachable("Uncovered SCEV type!"); + } + +#ifndef NDEBUG + /// Helper to assert (in debug builds) that SCEVs have the current + /// gerenation. + static void verifyGeneration(ArrayRef SCEVs, + uint64_t CurrentGeneration); +#endif + }; + /// This is the base class for unary cast operator classes. - class SCEVCastExpr : public SCEV { + class SCEVCastExpr : public SCEVExpr { protected: const SCEV *Op; Type *Ty; SCEVCastExpr(const FoldingSetNodeIDRef ID, - unsigned SCEVTy, const SCEV *op, Type *ty); + unsigned SCEVTy, const SCEV *op, Type *ty, + uint64_t CurrentGeneration); public: const SCEV *getOperand() const { return Op; } @@ -78,7 +128,8 @@ friend class ScalarEvolution; SCEVTruncateExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, Type *ty); + const SCEV *op, Type *ty, + uint64_t CurrentGeneration); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: @@ -93,7 +144,8 @@ friend class ScalarEvolution; SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, Type *ty); + const SCEV *op, Type *ty, + uint64_t CurrentGeneration); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: @@ -108,7 +160,8 @@ friend class ScalarEvolution; SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, Type *ty); + const SCEV *op, Type *ty, + uint64_t CurrentGeneration); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: @@ -120,7 +173,7 @@ /// This node is a base class providing common functionality for /// n'ary operators. - class SCEVNAryExpr : public SCEV { + class SCEVNAryExpr : public SCEVExpr { protected: // Since SCEVs are immutable, ScalarEvolution allocates operand // arrays with its SCEVAllocator, so this class just needs a simple @@ -129,9 +182,13 @@ const SCEV *const *Operands; size_t NumOperands; - SCEVNAryExpr(const FoldingSetNodeIDRef ID, - enum SCEVTypes T, const SCEV *const *O, size_t N) - : SCEV(ID, T), Operands(O), NumOperands(N) {} + SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, + const SCEV *const *O, size_t N, uint64_t CurrentGeneration) + : SCEVExpr(ID, T, CurrentGeneration), Operands(O), NumOperands(N) { +#ifndef NDEBUG + verifyGeneration(makeArrayRef(Operands, NumOperands), CurrentGeneration); +#endif + } public: size_t getNumOperands() const { return NumOperands; } @@ -180,8 +237,9 @@ class SCEVCommutativeExpr : public SCEVNAryExpr { protected: SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, - enum SCEVTypes T, const SCEV *const *O, size_t N) - : SCEVNAryExpr(ID, T, O, N) {} + enum SCEVTypes T, const SCEV *const *O, size_t N, + uint64_t CurrentGeneration) + : SCEVNAryExpr(ID, T, O, N, CurrentGeneration) {} public: /// Methods for support type inquiry through isa, cast, and dyn_cast: @@ -204,8 +262,9 @@ friend class ScalarEvolution; SCEVAddExpr(const FoldingSetNodeIDRef ID, - const SCEV *const *O, size_t N) - : SCEVCommutativeExpr(ID, scAddExpr, O, N) { + const SCEV *const *O, size_t N, + uint64_t CurrentGeneration) + : SCEVCommutativeExpr(ID, scAddExpr, O, N, CurrentGeneration) { } public: @@ -228,8 +287,9 @@ friend class ScalarEvolution; SCEVMulExpr(const FoldingSetNodeIDRef ID, - const SCEV *const *O, size_t N) - : SCEVCommutativeExpr(ID, scMulExpr, O, N) { + const SCEV *const *O, size_t N, + uint64_t CurrentGeneration) + : SCEVCommutativeExpr(ID, scMulExpr, O, N, CurrentGeneration) { } public: @@ -241,13 +301,18 @@ /// This class represents a binary unsigned division operation. - class SCEVUDivExpr : public SCEV { + class SCEVUDivExpr : public SCEVExpr { friend class ScalarEvolution; const SCEV *LHS; const SCEV *RHS; - SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) - : SCEV(ID, scUDivExpr), LHS(lhs), RHS(rhs) {} + SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs, + uint64_t CurrentGeneration) + : SCEVExpr(ID, scUDivExpr, CurrentGeneration), LHS(lhs), RHS(rhs) { +#ifndef NDEBUG + verifyGeneration({LHS, RHS}, CurrentGeneration); +#endif + } public: const SCEV *getLHS() const { return LHS; } @@ -283,8 +348,9 @@ const Loop *L; SCEVAddRecExpr(const FoldingSetNodeIDRef ID, - const SCEV *const *O, size_t N, const Loop *l) - : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} + const SCEV *const *O, size_t N, const Loop *l, + uint64_t CurrentGeneration) + : SCEVNAryExpr(ID, scAddRecExpr, O, N, CurrentGeneration), L(l) {} public: const SCEV *getStart() const { return Operands[0]; } @@ -355,8 +421,9 @@ friend class ScalarEvolution; SCEVSMaxExpr(const FoldingSetNodeIDRef ID, - const SCEV *const *O, size_t N) - : SCEVCommutativeExpr(ID, scSMaxExpr, O, N) { + const SCEV *const *O, size_t N, + uint64_t CurrentGeneration) + : SCEVCommutativeExpr(ID, scSMaxExpr, O, N, CurrentGeneration) { // Max never overflows. setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); } @@ -374,8 +441,9 @@ friend class ScalarEvolution; SCEVUMaxExpr(const FoldingSetNodeIDRef ID, - const SCEV *const *O, size_t N) - : SCEVCommutativeExpr(ID, scUMaxExpr, O, N) { + const SCEV *const *O, size_t N, + uint64_t CurrentGeneration) + : SCEVCommutativeExpr(ID, scUMaxExpr, O, N, CurrentGeneration) { // Max never overflows. setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); } Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -338,7 +338,7 @@ ID.AddInteger(scConstant); ID.AddPointer(V); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (const SCEV *S = lookupUniqueSCEV(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V); UniqueSCEVs.InsertNode(S, IP); return S; @@ -354,29 +354,66 @@ return getConstant(ConstantInt::get(ITy, V, isSigned)); } -SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, - unsigned SCEVTy, const SCEV *op, Type *ty) - : SCEV(ID, SCEVTy), Op(op), Ty(ty) {} +#ifndef NDEBUG +void SCEVExpr::verifyGeneration(ArrayRef SCEVs, + uint64_t CurrentGeneration) { + for (const SCEV *S : SCEVs) { + switch (S->getSCEVType()) { + case scUnknown: + case scConstant: + // No generation. + continue; + + case scTruncate: + case scZeroExtend: + case scSignExtend: + case scAddExpr: + case scMulExpr: + case scSMaxExpr: + case scUMaxExpr: + case scAddRecExpr: + case scUDivExpr: + assert(cast(S)->getGeneration() == CurrentGeneration && + "Unexpected generation!"); + continue; + + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + } + llvm_unreachable("Unknown SCEV kind!"); + } +} +#endif + +SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, unsigned SCEVTy, + const SCEV *op, Type *ty, uint64_t CurrentGeneration) + : SCEVExpr(ID, SCEVTy, CurrentGeneration), Op(op), Ty(ty) { +#ifndef NDEBUG + verifyGeneration({Op}, CurrentGeneration); +#endif +} -SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, Type *ty) - : SCEVCastExpr(ID, scTruncate, op, ty) { +SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, + Type *ty, uint64_t CurrentGeneration) + : SCEVCastExpr(ID, scTruncate, op, ty, CurrentGeneration) { assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot truncate non-integer value!"); } SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, Type *ty) - : SCEVCastExpr(ID, scZeroExtend, op, ty) { + const SCEV *op, Type *ty, + uint64_t CurrentGeneration) + : SCEVCastExpr(ID, scZeroExtend, op, ty, CurrentGeneration) { assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot zero extend non-integer value!"); } SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, Type *ty) - : SCEVCastExpr(ID, scSignExtend, op, ty) { + const SCEV *op, Type *ty, + uint64_t CurrentGeneration) + : SCEVCastExpr(ID, scSignExtend, op, ty, CurrentGeneration) { assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot sign extend non-integer value!"); @@ -389,6 +426,9 @@ // Remove this SCEVUnknown from the uniquing map. SE->UniqueSCEVs.RemoveNode(this); + // Increment the generation so that users will be updated when next queried. + SE->incrementSCEVGeneration(); + // Release the value. setValPtr(nullptr); } @@ -1162,7 +1202,7 @@ ID.AddPointer(Op); ID.AddPointer(Ty); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (const SCEV *S = lookupUniqueSCEV(ID, IP)) return S; // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) @@ -1225,7 +1265,7 @@ // the existing insert position since if we get here, we won't have // made any changes which would invalidate it. SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), - Op, Ty); + Op, Ty, Generation); UniqueSCEVs.InsertNode(S, IP); return S; } @@ -1468,8 +1508,7 @@ ID.AddPointer(Step); ID.AddPointer(L); void *IP = nullptr; - const auto *PreAR = - static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + const auto *PreAR = static_cast(lookupUniqueSCEV(ID, IP)); // Give up if we don't already have the add recurrence we need because // actually constructing an add recurrence is relatively expensive. @@ -1510,7 +1549,7 @@ ID.AddPointer(Op); ID.AddPointer(Ty); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (const SCEV *S = lookupUniqueSCEV(ID, IP)) return S; // zext(trunc(x)) --> zext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { @@ -1675,9 +1714,9 @@ // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (const SCEV *S = lookupUniqueSCEV(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), - Op, Ty); + Op, Ty, Generation); UniqueSCEVs.InsertNode(S, IP); return S; } @@ -1710,7 +1749,7 @@ ID.AddPointer(Op); ID.AddPointer(Ty); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (const SCEV *S = lookupUniqueSCEV(ID, IP)) return S; // sext(trunc(x)) --> sext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { @@ -1902,9 +1941,9 @@ // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (const SCEV *S = lookupUniqueSCEV(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), - Op, Ty); + Op, Ty, Generation); UniqueSCEVs.InsertNode(S, IP); return S; } @@ -2469,13 +2508,12 @@ for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = nullptr; - SCEVAddExpr *S = - static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + auto *S = static_cast(lookupUniqueSCEV(ID, IP)); if (!S) { const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); S = new (SCEVAllocator) - SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); + SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size(), Generation); UniqueSCEVs.InsertNode(S, IP); } S->setNoWrapFlags(Flags); @@ -2761,13 +2799,12 @@ for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = nullptr; - SCEVMulExpr *S = - static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + auto *S = static_cast(lookupUniqueSCEV(ID, IP)); if (!S) { const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), - O, Ops.size()); + O, Ops.size(), Generation); UniqueSCEVs.InsertNode(S, IP); } S->setNoWrapFlags(Flags); @@ -2885,9 +2922,9 @@ ID.AddPointer(LHS); ID.AddPointer(RHS); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (const SCEV *S = lookupUniqueSCEV(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), - LHS, RHS); + LHS, RHS, Generation); UniqueSCEVs.InsertNode(S, IP); return S; } @@ -3061,13 +3098,12 @@ ID.AddPointer(Operands[i]); ID.AddPointer(L); void *IP = nullptr; - SCEVAddRecExpr *S = - static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + auto *S = static_cast(lookupUniqueSCEV(ID, IP)); if (!S) { const SCEV **O = SCEVAllocator.Allocate(Operands.size()); std::uninitialized_copy(Operands.begin(), Operands.end(), O); S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator), - O, Operands.size(), L); + O, Operands.size(), L, Generation); UniqueSCEVs.InsertNode(S, IP); } S->setNoWrapFlags(Flags); @@ -3218,11 +3254,11 @@ for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (const SCEV *S = lookupUniqueSCEV(ID, IP)) return S; const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator), - O, Ops.size()); + O, Ops.size(), Generation); UniqueSCEVs.InsertNode(S, IP); return S; } @@ -3319,11 +3355,11 @@ for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (const SCEV *S = lookupUniqueSCEV(ID, IP)) return S; const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); - SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator), - O, Ops.size()); + SCEV *S = new (SCEVAllocator) + SCEVUMaxExpr(ID.Intern(SCEVAllocator), O, Ops.size(), Generation); UniqueSCEVs.InsertNode(S, IP); return S; } @@ -3367,7 +3403,7 @@ ID.AddInteger(scUnknown); ID.AddPointer(V); void *IP = nullptr; - if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { + if (const SCEV *S = lookupUniqueSCEV(ID, IP)) { assert(cast(S)->getValue() == V && "Stale SCEVUnknown in uniquing map!"); return S; @@ -3417,13 +3453,95 @@ return CouldNotCompute.get(); } -bool ScalarEvolution::checkValidity(const SCEV *S) const { - bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) { - auto *SU = dyn_cast(S); - return SU && SU->getValue() == nullptr; - }); +bool ScalarEvolution::validateSCEVSubgraph(const SCEV *S) { + // We do a depth-first search of the SCEV expression DAG looking for + // potentially invalid expressions. Any expression with a generation older + // than the current one needs to be recursively validated. Once recursively + // validated, we can update the generation count to prune subsequent walks. + SmallVector, 8> Stack; + size_t Index = 0; + for (;;) { + switch (S->getSCEVType()) { + case scConstant: + break; + case scUnknown: { + const auto *SU = cast(S); + if (SU->getValue() == nullptr) { + while (!Stack.empty()) { + const SCEV *S = Stack.pop_back_val().first; + if (SetVector *SV = getSCEVValues(S)) + for (auto &VO : *SV) + ValueExprMap.erase(VO.first); + forgetMemoizedResults(S); + } + return false; + } + break; + } + case scTruncate: + case scZeroExtend: + case scSignExtend: { + const auto *Cast = cast(S); + if (Index != 0) { + // Finished visiting subgraph and it remained valid, update generation + // and break. + Cast->setGeneration(Generation); + break; + } + // First visit, test and if necessary recurse. + if (Cast->getGeneration() == Generation) + break; + + Stack.push_back({S, Index + 1}); + S = Cast->getOperand(); + Index = 0; + continue; + } + case scAddExpr: + case scMulExpr: + case scSMaxExpr: + case scUMaxExpr: + case scAddRecExpr: { + const auto *NAry = cast(S); + if (Index == NAry->getNumOperands()) { + NAry->setGeneration(Generation); + break; + } + if (NAry->getGeneration() == Generation) + break; - return !ContainsNulls; + Stack.push_back({S, Index + 1}); + S = NAry->getOperand(Index); + Index = 0; + continue; + } + case scUDivExpr: { + const auto *UDiv = cast(S); + if (Index == 2) { + UDiv->setGeneration(Generation); + break; + } + + if (Index == 0 && UDiv->getGeneration() == Generation) + break; + + Stack.push_back({S, Index + 1}); + S = Index == 0 ? UDiv->getLHS() : UDiv->getRHS(); + Index = 0; + continue; + } + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + default: + llvm_unreachable("Unknown SCEV kind!"); + } + + if (Stack.empty()) + break; + std::tie(S, Index) = Stack.pop_back_val(); + } + + return true; } bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { @@ -3471,6 +3589,17 @@ return &SI->second; } +void ScalarEvolution::incrementSCEVGeneration() { + // Note that we never wrap the generation count. This would require calling + // this function 2^64 times which would require doing nothing else for a Unix + // epoch. However, there is no *semantic* problem wrapping here as the + // generation count works in the modular space rather than requiring an + // ordering. + assert(Generation != UINT64_MAX && "Likely memory corruption: the generation " + "count has reached the max value!"); + ++Generation; +} + /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V) /// cannot be used separately. eraseValueFromMap should be used to remove /// V from ValueExprMap and ExprValueMap at the same time. @@ -3534,10 +3663,8 @@ ValueExprMapType::iterator I = ValueExprMap.find_as(V); if (I != ValueExprMap.end()) { const SCEV *S = I->second; - if (checkValidity(S)) + if (validateSCEVSubgraph(S)) return S; - eraseValueFromMap(V); - forgetMemoizedResults(S); } return nullptr; } @@ -9483,7 +9610,7 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI) - : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), + : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), Generation(0), CouldNotCompute(new SCEVCouldNotCompute()), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64), @@ -9506,7 +9633,8 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), - LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), + LI(Arg.LI), Generation(Arg.Generation), + CouldNotCompute(std::move(Arg.CouldNotCompute)), ValueExprMap(std::move(Arg.ValueExprMap)), PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), Index: lib/Transforms/Scalar/LoopRerollPass.cpp =================================================================== --- lib/Transforms/Scalar/LoopRerollPass.cpp +++ lib/Transforms/Scalar/LoopRerollPass.cpp @@ -365,8 +365,7 @@ bool validate(ReductionTracker &Reductions); /// Stage 3: Assuming validate() returned true, perform the /// replacement. - /// @param IterCount The maximum iteration count of L. - void replace(const SCEV *IterCount); + void replace(); protected: typedef MapVector UsesTy; @@ -396,7 +395,7 @@ bool instrDependsOn(Instruction *I, UsesTy::iterator Start, UsesTy::iterator End); - void replaceIV(Instruction *Inst, Instruction *IV, const SCEV *IterCount); + void replaceIV(Instruction *Inst, Instruction *IV); void updateNonLoopCtrlIncr(); LoopReroll *Parent; @@ -443,7 +442,7 @@ void collectPossibleIVs(Loop *L, SmallInstructionVector &PossibleIVs); void collectPossibleReductions(Loop *L, ReductionTracker &Reductions); - bool reroll(Instruction *IV, Loop *L, BasicBlock *Header, const SCEV *IterCount, + bool reroll(Instruction *IV, Loop *L, BasicBlock *Header, ReductionTracker &Reductions); }; } @@ -1417,7 +1416,7 @@ return true; } -void LoopReroll::DAGRootTracker::replace(const SCEV *IterCount) { +void LoopReroll::DAGRootTracker::replace() { BasicBlock *Header = L->getHeader(); // Remove instructions associated with non-base iterations. for (BasicBlock::reverse_iterator J = Header->rbegin(), JE = Header->rend(); @@ -1431,17 +1430,20 @@ ++J; } + // Flush the cached information about this loop in SCEV as we've mutated it + // heavily. + SE->forgetLoop(L); bool HasTwoIVs = LoopControlIV && LoopControlIV != IV; if (HasTwoIVs) { updateNonLoopCtrlIncr(); - replaceIV(LoopControlIV, LoopControlIV, IterCount); + replaceIV(LoopControlIV, LoopControlIV); } else // We need to create a new induction variable for each different BaseInst. for (auto &DRS : RootSets) // Insert the new induction variable. - replaceIV(DRS.BaseInst, IV, IterCount); + replaceIV(DRS.BaseInst, IV); SimplifyInstructionsInBlock(Header, TLI); DeleteDeadPHIs(Header, TLI); @@ -1478,8 +1480,7 @@ } void LoopReroll::DAGRootTracker::replaceIV(Instruction *Inst, - Instruction *InstIV, - const SCEV *IterCount) { + Instruction *InstIV) { BasicBlock *Header = L->getHeader(); int64_t Inc = IVToIncMap[InstIV]; bool NeedNewIV = InstIV == LoopControlIV; @@ -1516,6 +1517,9 @@ if (BranchInst *BI = dyn_cast(Header->getTerminator())) { // FIXME: Why do we need this check? if (Uses[BI].find_first() == IL_All) { + const SCEV *LIBETC = SE->getBackedgeTakenCount(L); + const SCEV *IterCount = + SE->getAddExpr(LIBETC, SE->getOne(LIBETC->getType())); const SCEV *ICSCEV = RealIVSCEV->evaluateAtIteration(IterCount, *SE); if (NeedNewIV) @@ -1676,7 +1680,6 @@ // f(%iv) or part of some f(%iv.i). If all of that is true (and all reductions // have been validated), then we reroll the loop. bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, - const SCEV *IterCount, ReductionTracker &Reductions) { DAGRootTracker DAGRoots(this, L, IV, SE, AA, TLI, DT, LI, PreserveLCSSA, IVToIncMap, LoopControlIV); @@ -1694,7 +1697,7 @@ // making changes! Reductions.replaceSelected(); - DAGRoots.replace(IterCount); + DAGRoots.replace(); ++NumRerolledLoops; return true; @@ -1723,10 +1726,12 @@ if (!SE->hasLoopInvariantBackedgeTakenCount(L)) return false; - const SCEV *LIBETC = SE->getBackedgeTakenCount(L); - const SCEV *IterCount = SE->getAddExpr(LIBETC, SE->getOne(LIBETC->getType())); DEBUG(dbgs() << "\n Before Reroll:\n" << *(L->getHeader()) << "\n"); - DEBUG(dbgs() << "LRR: iteration count = " << *IterCount << "\n"); + DEBUG(dbgs() << "LRR: iteration count = " + << *SE->getAddExpr( + SE->getBackedgeTakenCount(L), + SE->getOne(SE->getBackedgeTakenCount(L)->getType())) + << "\n"); // First, we need to find the induction variable with respect to which we can // reroll (there may be several possible options). @@ -1747,7 +1752,7 @@ // For each possible IV, collect the associated possible set of 'root' nodes // (i+1, i+2, etc.). for (Instruction *PossibleIV : PossibleIVs) - if (reroll(PossibleIV, L, Header, IterCount, Reductions)) { + if (reroll(PossibleIV, L, Header, Reductions)) { Changed = true; break; }