diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1899,7 +1899,7 @@ /// otherwise. The second component is the `FoldingSetNodeID` that was /// constructed to look up the SCEV and the third component is the insertion /// point. - std::tuple + std::tuple findExistingSCEVInCache(int SCEVType, ArrayRef Ops); FoldingSet UniqueSCEVs; diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -2452,6 +2452,11 @@ if (Depth > MaxArithDepth || hasHugeExpression(Ops)) return getOrCreateAddExpr(Ops, Flags); + if (SCEV *S = std::get<0>(findExistingSCEVInCache(scAddExpr, Ops))) { + static_cast(S)->setNoWrapFlags(Flags); + return S; + } + // Okay, check to see if the same value occurs in the operand list more than // once. If so, merge them together into an multiply expression. Since we // sorted the list, these values are required to be adjacent. @@ -2931,6 +2936,11 @@ if (Depth > MaxArithDepth || hasHugeExpression(Ops)) return getOrCreateMulExpr(Ops, Flags); + if (SCEV *S = std::get<0>(findExistingSCEVInCache(scMulExpr, Ops))) { + static_cast(S)->setNoWrapFlags(Flags); + return S; + } + // If there are any constants, fold them together. unsigned Idx = 0; if (const SCEVConstant *LHSC = dyn_cast(Ops[0])) { @@ -3193,6 +3203,14 @@ getEffectiveSCEVType(RHS->getType()) && "SCEVUDivExpr operand types don't match!"); + FoldingSetNodeID ID; + ID.AddInteger(scUDivExpr); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + void *IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; + if (const SCEVConstant *RHSC = dyn_cast(RHS)) { if (RHSC->getValue()->isOne()) return LHS; // X udiv 1 --> x @@ -3239,9 +3257,24 @@ AR->getLoop(), SCEV::FlagAnyWrap)) { const APInt &StartInt = StartC->getAPInt(); const APInt &StartRem = StartInt.urem(StepInt); - if (StartRem != 0) - LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step, - AR->getLoop(), SCEV::FlagNW); + if (StartRem != 0) { + const SCEV *NewLHS = + getAddRecExpr(getConstant(StartInt - StartRem), Step, + AR->getLoop(), SCEV::FlagNW); + if (LHS != NewLHS) { + LHS = NewLHS; + + // Reset the ID to include the new LHS, and check if it is + // already cached. + ID.clear(); + ID.AddInteger(scUDivExpr); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; + } + } } } // (A*B)/C --> A*(B/C) if safe and B/C can be folded. @@ -3306,12 +3339,6 @@ } } - FoldingSetNodeID ID; - ID.AddInteger(scUDivExpr); - ID.AddPointer(LHS); - ID.AddPointer(RHS); - void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), LHS, RHS); UniqueSCEVs.InsertNode(S, IP); @@ -3537,7 +3564,7 @@ return getAddExpr(BaseExpr, TotalOffset, Wrap); } -std::tuple +std::tuple ScalarEvolution::findExistingSCEVInCache(int SCEVType, ArrayRef Ops) { FoldingSetNodeID ID; @@ -3545,7 +3572,7 @@ ID.AddInteger(SCEVType); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); - return std::tuple( + return std::tuple( UniqueSCEVs.FindNodeOrInsertPos(ID, IP), std::move(ID), IP); }