Index: llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolution.h +++ llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1888,8 +1888,10 @@ bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R, SCEV::NoWrapFlags &Flags); - /// Drop memoized information for all \p SCEVs. - void forgetMemoizedResults(ArrayRef SCEVs); + /// Drop memoized information for all \p SCEVs. If \ForgetValues is enabled, + /// also drop the corresponding Value <-> SCEV mappings. + void forgetMemoizedResults(ArrayRef SCEVs, + bool ForgetValues = false); /// Helper for forgetMemoizedResults. void forgetMemoizedResultsImpl(const SCEV *S); @@ -1897,10 +1899,6 @@ /// Return an existing SCEV for V if there is one, otherwise return nullptr. const SCEV *getExistingSCEV(Value *V); - /// Return false iff given SCEV contains a SCEVUnknown with NULL value- - /// pointer. - bool checkValidity(const SCEV *S) const; - /// Return true if `ExtendOpTy`({`Start`,+,`Step`}) can be proved to be /// equal to {`ExtendOpTy`(`Start`),+,`ExtendOpTy`(`Step`)}. This is /// equivalent to proving no signed (resp. unsigned) wrap in Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -503,7 +503,7 @@ void SCEVUnknown::deleted() { // Clear this SCEVUnknown from various maps. - SE->forgetMemoizedResults(this); + SE->forgetMemoizedResults(this, /* ForgetValues */ true); // Remove this SCEVUnknown from the uniquing map. SE->UniqueSCEVs.RemoveNode(this); @@ -3999,15 +3999,6 @@ return CouldNotCompute.get(); } -bool ScalarEvolution::checkValidity(const SCEV *S) const { - bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) { - auto *SU = dyn_cast(S); - return SU && SU->getValue() == nullptr; - }); - - return !ContainsNulls; -} - bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { HasRecMapType::iterator I = HasRecMap.find(S); if (I != HasRecMap.end()) @@ -4115,14 +4106,7 @@ assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); ValueExprMapType::iterator I = ValueExprMap.find_as(V); - if (I != ValueExprMap.end()) { - const SCEV *S = I->second; - if (checkValidity(S)) - return S; - eraseValueFromMap(V); - forgetMemoizedResults(S); - } - return nullptr; + return I != ValueExprMap.end() ? I->second : nullptr; } /// Return a SCEV corresponding to -V = -1*V @@ -12736,7 +12720,8 @@ return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; }); } -void ScalarEvolution::forgetMemoizedResults(ArrayRef SCEVs) { +void ScalarEvolution::forgetMemoizedResults(ArrayRef SCEVs, + bool ForgetValues) { SmallPtrSet ToForget(SCEVs.begin(), SCEVs.end()); SmallVector Worklist(ToForget.begin(), ToForget.end()); @@ -12775,6 +12760,21 @@ RemoveSCEVFromBackedgeMap(BackedgeTakenCounts); RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); + + if (ForgetValues) { + for (const SCEV *S : ToForget) { + auto ExprIt = ExprValueMap.find(S); + if (ExprIt == ExprValueMap.end()) + continue; + + for (auto &ValueAndOffset : ExprIt->second) { + auto ValueIt = ValueExprMap.find_as(ValueAndOffset.first); + if (ValueIt != ValueExprMap.end()) + ValueExprMap.erase(ValueIt); + } + ExprValueMap.erase(ExprIt); + } + } } void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {