Index: llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolution.h +++ llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1851,18 +1851,6 @@ bool doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap); - /// Get add expr already created or create a new one. - const SCEV *getOrCreateAddExpr(ArrayRef Ops, - SCEV::NoWrapFlags Flags); - - /// Get mul expr already created or create a new one. - const SCEV *getOrCreateMulExpr(ArrayRef Ops, - SCEV::NoWrapFlags Flags); - - // Get addrec expr already created or create a new one. - const SCEV *getOrCreateAddRecExpr(ArrayRef Ops, - const Loop *L, SCEV::NoWrapFlags Flags); - /// Return x if \p Val is f(x) where f is a 1-1 function. const SCEV *stripInjectiveFunctions(const SCEV *Val) const; @@ -1879,15 +1867,28 @@ /// Assign A and B to LHS and RHS, respectively. bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS); - /// Look for a SCEV expression with type `SCEVType` and operands `Ops` in - /// `UniqueSCEVs`. - /// - /// The first component of the returned tuple is the SCEV if found and null - /// otherwise. The second component is the `FoldingSetNodeID` that was - /// constructed to look up the SCEV and the third component is the insertion - /// point. - std::tuple - findExistingSCEVInCache(int SCEVType, ArrayRef Ops); + /// Key for uniquely allocating SCEVs. + struct SCEVKey { + FoldingSetNodeID NodeID; + void *InsertPos = nullptr; + + SCEVKey() = default; + template SCEVKey(int SCEVType, TOps const &... Ops); + template void reset(int SCEVType, TOps const &... Ops); + }; + + template struct SCEVCtor; + + /// Return an existing SCEV for Key if there is one, otherwise return nullptr. + template TExpr *getExistingSCEV(SCEVKey &Key); + + /// Allocate a new SCEV for Key, and given arguments. + template + TExpr *allocateSCEV(SCEVKey &Key, TOps const &... Ops); + + /// Return an existing SCEV for Key, or allocate a new one, if doesn't exist. + template + TExpr *getOrAllocateSCEV(SCEVKey &Key, TOps const &... Ops); FoldingSet UniqueSCEVs; FoldingSet UniquePreds; Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -418,14 +418,8 @@ } const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { - FoldingSetNodeID ID; - ID.AddInteger(scConstant); - ID.AddPointer(V); - void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V); - UniqueSCEVs.InsertNode(S, IP); - return S; + SCEVKey Key(scConstant, V); + return getOrAllocateSCEV(Key, V); } const SCEV *ScalarEvolution::getConstant(const APInt &Val) { @@ -1259,13 +1253,6 @@ "This is not a conversion to a SCEVable type!"); Ty = getEffectiveSCEVType(Ty); - FoldingSetNodeID ID; - ID.AddInteger(scTruncate); - ID.AddPointer(Op); - ID.AddPointer(Ty); - void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) return getConstant( @@ -1283,13 +1270,12 @@ if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1); - if (Depth > MaxCastDepth) { - SCEV *S = - new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); + SCEVKey Key(scTruncate, Op, Ty); + if (const SCEV *S = getExistingSCEV(Key)) return S; - } + + if (Depth > MaxCastDepth) + return allocateSCEV(Key, Op, Ty); // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN), @@ -1317,7 +1303,7 @@ // Although we checked in the beginning that ID is not in the cache, it is // possible that during recursion and different modification ID was inserted // into the cache. So if we find it, just return it. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + if (const SCEV *S = getExistingSCEV(Key)) return S; } @@ -1332,11 +1318,7 @@ // The cast wasn't folded; create an explicit cast node. We can reuse // the existing insert position since if we get here, we won't have // made any changes which would invalidate it. - SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), - Op, Ty); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); - return S; + return allocateSCEV(Key, Op, Ty); } // Get the limit of a recurrence such that incrementing by Step cannot cause @@ -1573,14 +1555,8 @@ for (unsigned Delta : {-2, -1, 1, 2}) { const SCEV *PreStart = getConstant(StartAI - Delta); - FoldingSetNodeID ID; - ID.AddInteger(scAddRecExpr); - ID.AddPointer(PreStart); - ID.AddPointer(Step); - ID.AddPointer(L); - void *IP = nullptr; - const auto *PreAR = - static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + SCEVKey Key(scAddRecExpr, PreStart, Step, L); + const auto *PreAR = getExistingSCEV(Key); // Give up if we don't already have the add recurrence we need because // actually constructing an add recurrence is relatively expensive. @@ -1653,19 +1629,11 @@ // Before doing any expensive analysis, check to see if we've already // computed a SCEV for this Op and Ty. - FoldingSetNodeID ID; - ID.AddInteger(scZeroExtend); - ID.AddPointer(Op); - ID.AddPointer(Ty); - void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - if (Depth > MaxCastDepth) { - SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), - Op, Ty); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); + SCEVKey Key(scZeroExtend, Op, Ty); + if (const SCEV *S = getExistingSCEV(Key)) return S; - } + if (Depth > MaxCastDepth) + return allocateSCEV(Key, Op, Ty); // zext(trunc(x)) --> zext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { @@ -1932,12 +1900,7 @@ // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), - Op, Ty); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); - return S; + return getOrAllocateSCEV(Key, Op, Ty); } const SCEV * @@ -1963,20 +1926,12 @@ // Before doing any expensive analysis, check to see if we've already // computed a SCEV for this Op and Ty. - FoldingSetNodeID ID; - ID.AddInteger(scSignExtend); - ID.AddPointer(Op); - ID.AddPointer(Ty); - void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - // Limit recursion depth. - if (Depth > MaxCastDepth) { - SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), - Op, Ty); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); + SCEVKey Key(scSignExtend, Op, Ty); + if (const SCEV *S = getExistingSCEV(Key)) return S; - } + // Limit recursion depth. + if (Depth > MaxCastDepth) + return allocateSCEV(Key, Op, Ty); // sext(trunc(x)) --> sext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { @@ -2184,12 +2139,7 @@ // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), - Op, Ty); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); - return S; + return getOrAllocateSCEV(Key, Op, Ty); } /// getAnyExtendExpr - Return a SCEV for the given operand extended with @@ -2444,9 +2394,18 @@ if (Ops.size() == 1) return Ops[0]; } + SCEVKey Key(scAddExpr, Ops); + if (auto *S = getExistingSCEV(Key)) { + S->setNoWrapFlags(Flags); + return S; + } + // Limit recursion calls depth. - if (Depth > MaxArithDepth || hasHugeExpression(Ops)) - return getOrCreateAddExpr(Ops, Flags); + if (Depth > MaxArithDepth || hasHugeExpression(Ops)) { + auto *S = allocateSCEV(Key, Ops); + S->setNoWrapFlags(Flags); + return S; + } // Okay, check to see if the same value occurs in the operand list more than // once. If so, merge them together into an multiply expression. Since we @@ -2779,73 +2738,104 @@ // Okay, it looks like we really DO need an add expr. Check to see if we // already have one, otherwise create a new one. - return getOrCreateAddExpr(Ops, Flags); -} - -const SCEV * -ScalarEvolution::getOrCreateAddExpr(ArrayRef Ops, - SCEV::NoWrapFlags Flags) { - FoldingSetNodeID ID; - ID.AddInteger(scAddExpr); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); - void *IP = nullptr; - SCEVAddExpr *S = - static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); - if (!S) { - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); - std::uninitialized_copy(Ops.begin(), Ops.end(), O); - S = new (SCEVAllocator) - SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); - } + Key.reset(scAddExpr, Ops); + auto *S = getOrAllocateSCEV(Key, Ops); S->setNoWrapFlags(Flags); return S; } -const SCEV * -ScalarEvolution::getOrCreateAddRecExpr(ArrayRef Ops, - const Loop *L, SCEV::NoWrapFlags Flags) { - FoldingSetNodeID ID; - ID.AddInteger(scAddRecExpr); - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - ID.AddPointer(Ops[i]); - ID.AddPointer(L); - void *IP = nullptr; - SCEVAddRecExpr *S = - static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); - if (!S) { - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); - std::uninitialized_copy(Ops.begin(), Ops.end(), O); - S = new (SCEVAllocator) - SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); +template +inline typename std::enable_if::value, void>::type +addDataToNodeID(FoldingSetNodeID &ID, T I) { + ID.AddInteger(I); +} + +template <> inline void addDataToNodeID(FoldingSetNodeID &ID, bool B) { + ID.AddBoolean(B); +} + +template inline void addDataToNodeID(FoldingSetNodeID &ID, T *P) { + ID.AddPointer(P); +} + +template +inline void addDataToNodeID(FoldingSetNodeID &ID, const ArrayRef &Arr) { + for (auto &E : Arr) + addDataToNodeID(ID, E); +} + +template +inline void addDataToNodeID(FoldingSetNodeID &ID, + const SmallVectorImpl &Arr) { + addDataToNodeID(ID, makeArrayRef(Arr)); +} + +template +inline void addDataToNodeID(FoldingSetNodeID &ID, const T &V, TArgs &... Args) { + addDataToNodeID(ID, V); + addDataToNodeID(ID, Args...); +} + +template +ScalarEvolution::SCEVKey::SCEVKey(int SCEVType, TOps const &... Ops) { + addDataToNodeID(NodeID, SCEVType, Ops...); +} + +template +void ScalarEvolution::SCEVKey::reset(int SCEVType, TOps const &... Ops) { + NodeID.clear(); + InsertPos = nullptr; + addDataToNodeID(NodeID, SCEVType, Ops...); +} + +struct SCEVCtorCtx { + BumpPtrAllocator &Allocator; + FoldingSetNodeID &NodeID; +}; + +template struct ScalarEvolution::SCEVCtor { + template + static auto construct(const SCEVCtorCtx &Ctx, TOps &... Ops, TArgs &... Args); + + template + static auto construct(const SCEVCtorCtx &Ctx, TOps const &... Ops, T const &V, + TArgs const &... Args) { + return SCEVCtor::construct(Ctx, Ops..., V, Args...); } - S->setNoWrapFlags(Flags); + + template + static auto construct(const SCEVCtorCtx &Ctx, TOps const &... Ops, + const SmallVectorImpl &V, TArgs const &... Args) { + T *O = Ctx.Allocator.Allocate(V.size()); + std::uninitialized_copy(V.begin(), V.end(), O); + return SCEVCtor::construct(Ctx, Ops..., O, + V.size(), Args...); + } + + static auto construct(const SCEVCtorCtx &Ctx, TOps const &... Ops) { + return new (Ctx.Allocator) TExpr(Ctx.NodeID.Intern(Ctx.Allocator), Ops...); + } +}; + +template TExpr *ScalarEvolution::getExistingSCEV(SCEVKey &Key) { + return cast_or_null( + UniqueSCEVs.FindNodeOrInsertPos(Key.NodeID, Key.InsertPos)); +} + +template +TExpr *ScalarEvolution::allocateSCEV(SCEVKey &Key, TOps const &... Ops) { + SCEVCtorCtx Ctx = {SCEVAllocator, Key.NodeID}; + auto *S = SCEVCtor::construct(Ctx, Ops...); + UniqueSCEVs.InsertNode(S, Key.InsertPos); + addToLoopUseLists(S); return S; } -const SCEV * -ScalarEvolution::getOrCreateMulExpr(ArrayRef Ops, - SCEV::NoWrapFlags Flags) { - FoldingSetNodeID ID; - ID.AddInteger(scMulExpr); - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - ID.AddPointer(Ops[i]); - void *IP = nullptr; - SCEVMulExpr *S = - static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); - if (!S) { - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); - std::uninitialized_copy(Ops.begin(), Ops.end(), O); - S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), - O, Ops.size()); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); - } - S->setNoWrapFlags(Flags); +template +TExpr *ScalarEvolution::getOrAllocateSCEV(SCEVKey &Key, TOps const &... Ops) { + auto *S = getExistingSCEV(Key); + if (!S) + S = allocateSCEV(Key, Ops...); return S; } @@ -2923,9 +2913,19 @@ Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); + // Check if we have created the same expression before. + SCEVKey Key(scMulExpr, Ops); + if (auto *S = getExistingSCEV(Key)) { + S->setNoWrapFlags(Flags); + return S; + } + // Limit recursion calls depth. - if (Depth > MaxArithDepth || hasHugeExpression(Ops)) - return getOrCreateMulExpr(Ops, Flags); + if (Depth > MaxArithDepth || hasHugeExpression(Ops)) { + auto *S = allocateSCEV(Key, Ops); + S->setNoWrapFlags(Flags); + return S; + } // If there are any constants, fold them together. unsigned Idx = 0; @@ -3151,7 +3151,10 @@ // Okay, it looks like we really DO need an mul expr. Check to see if we // already have one, otherwise create a new one. - return getOrCreateMulExpr(Ops, Flags); + Key.reset(scMulExpr, Ops); + auto *S = getOrAllocateSCEV(Key, Ops); + S->setNoWrapFlags(Flags); + return S; } /// Represents an unsigned remainder expression based on unsigned division. @@ -3190,6 +3193,10 @@ getEffectiveSCEVType(RHS->getType()) && "SCEVUDivExpr operand types don't match!"); + SCEVKey Key(scUDivExpr, LHS, RHS); + if (const SCEV *S = getExistingSCEV(Key)) + return S; + if (const SCEVConstant *RHSC = dyn_cast(RHS)) { if (RHSC->getValue()->isOne()) return LHS; // X udiv 1 --> x @@ -3236,9 +3243,21 @@ AR->getLoop(), SCEV::FlagAnyWrap)) { const APInt &StartInt = StartC->getAPInt(); const APInt &StartRem = StartInt.urem(StepInt); - if (StartRem != 0) - LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step, - AR->getLoop(), SCEV::FlagNW); + if (StartRem != 0) { + const SCEV *NewLHS = + getAddRecExpr(getConstant(StartInt - StartRem), Step, + AR->getLoop(), SCEV::FlagNW); + + if (LHS != NewLHS) { + LHS = NewLHS; + + // Create the new ID with the new LHS, and check if it is + // already cached. + Key.reset(scUDivExpr, LHS, RHS); + if (const SCEV *S = getExistingSCEV(Key)) + return S; + } + } } } // (A*B)/C --> A*(B/C) if safe and B/C can be folded. @@ -3303,17 +3322,10 @@ } } - FoldingSetNodeID ID; - ID.AddInteger(scUDivExpr); - ID.AddPointer(LHS); - ID.AddPointer(RHS); - void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), - LHS, RHS); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); - return S; + // Although we checked in the beginning that ID is not in the cache, it is + // possible that during recursion and different modification, ID was inserted + // into the cache. So if we find it, just return it. + return getOrAllocateSCEV(Key, LHS, RHS); } static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { @@ -3479,7 +3491,10 @@ // Okay, it looks like we really DO need an addrec expr. Check to see if we // already have one, otherwise create a new one. - return getOrCreateAddRecExpr(Operands, L, Flags); + SCEVKey Key(scAddRecExpr, Operands, L); + auto *S = getOrAllocateSCEV(Key, Operands, L); + S->setNoWrapFlags(Flags); + return S; } const SCEV * @@ -3534,18 +3549,6 @@ return getAddExpr(BaseExpr, TotalOffset, Wrap); } -std::tuple -ScalarEvolution::findExistingSCEVInCache(int SCEVType, - ArrayRef Ops) { - FoldingSetNodeID ID; - void *IP = nullptr; - ID.AddInteger(SCEVType); - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - ID.AddPointer(Ops[i]); - return std::tuple( - UniqueSCEVs.FindNodeOrInsertPos(ID, IP), std::move(ID), IP); -} - const SCEV *ScalarEvolution::getMinMaxExpr(unsigned Kind, SmallVectorImpl &Ops) { assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!"); @@ -3564,9 +3567,9 @@ GroupByComplexity(Ops, &LI, DT); // Check if we have created the same expression before. - if (const SCEV *S = std::get<0>(findExistingSCEVInCache(Kind, Ops))) { + SCEVKey Key(Kind, Ops); + if (const SCEV *S = getExistingSCEV(Key)) return S; - } // If there are any constants, fold them together. unsigned Idx = 0; @@ -3662,20 +3665,9 @@ // Okay, it looks like we really DO need an expr. Check to see if we // already have one, otherwise create a new one. - const SCEV *ExistingSCEV; - FoldingSetNodeID ID; - void *IP; - std::tie(ExistingSCEV, ID, IP) = findExistingSCEVInCache(Kind, Ops); - if (ExistingSCEV) - return ExistingSCEV; - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); - std::uninitialized_copy(Ops.begin(), Ops.end(), O); - SCEV *S = new (SCEVAllocator) SCEVMinMaxExpr( - ID.Intern(SCEVAllocator), static_cast(Kind), O, Ops.size()); - - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); - return S; + Key.reset(Kind, Ops); + return getOrAllocateSCEV(Key, static_cast(Kind), + Ops); } const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { @@ -3739,20 +3731,12 @@ // interesting possibilities, and any other code that calls getUnknown // is doing so in order to hide a value from SCEV canonicalization. - FoldingSetNodeID ID; - ID.AddInteger(scUnknown); - ID.AddPointer(V); - void *IP = nullptr; - if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { - assert(cast(S)->getValue() == V && - "Stale SCEVUnknown in uniquing map!"); + SCEVKey Key(scUnknown, V); + if (auto *S = getExistingSCEV(Key)) { + assert(S->getValue() == V && "Stale SCEVUnknown in uniquing map!"); return S; } - SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this, - FirstUnknown); - FirstUnknown = cast(S); - UniqueSCEVs.InsertNode(S, IP); - return S; + return FirstUnknown = allocateSCEV(Key, V, this, FirstUnknown); } //===----------------------------------------------------------------------===//