Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -1160,7 +1160,14 @@ const SCEV *getConstant(Type *Ty, uint64_t V, bool isSigned = false); const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty); const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty); + const SCEV *getZeroExtendExprImpl( + const SCEV *Op, Type *Ty, + DenseMap, const SCEV *> &Cache); + const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty); + const SCEV *getSignExtendExprImpl( + const SCEV *Op, Type *Ty, + DenseMap, const SCEV *> &Cache); const SCEV *getAnyExtendExpr(const SCEV *Op, Type *Ty); const SCEV *getAddExpr(SmallVectorImpl &Ops, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -1276,7 +1276,9 @@ namespace { struct ExtendOpTraitsBase { - typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *); + typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)( + const SCEV *, Type *, + DenseMap, const SCEV *> &Cache); }; // Used to make code generic over signed and unsigned overflow. @@ -1305,8 +1307,9 @@ } }; -const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< - SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr; +const ExtendOpTraitsBase::GetExtendExprTy + ExtendOpTraits::GetExtendExpr = + &ScalarEvolution::getSignExtendExprImpl; template <> struct ExtendOpTraits : public ExtendOpTraitsBase { @@ -1321,8 +1324,9 @@ } }; -const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< - SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; +const ExtendOpTraitsBase::GetExtendExprTy + ExtendOpTraits::GetExtendExpr = + &ScalarEvolution::getZeroExtendExprImpl; } // The recurrence AR has been shown to have no signed/unsigned wrap or something @@ -1333,8 +1337,9 @@ // expression "Step + sext/zext(PreIncAR)" is congruent with // "sext/zext(PostIncAR)" template -static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE) { +static const SCEV *getPreStartForExtend( + const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, + DenseMap, const SCEV *> &Cache) { auto WrapType = ExtendOpTraits::WrapType; auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; @@ -1381,9 +1386,9 @@ unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); const SCEV *OperandExtendedStart = - SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy), - (SE->*GetExtendExpr)(Step, WideTy)); - if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) { + SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Cache), + (SE->*GetExtendExpr)(Step, WideTy, Cache)); + if ((SE->*GetExtendExpr)(Start, WideTy, Cache) == OperandExtendedStart) { if (PreAR && AR->getNoWrapFlags(WrapType)) { // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then @@ -1407,16 +1412,18 @@ // Get the normalized zero or sign extended expression for this AddRec's Start. template -static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE) { +static const SCEV *getExtendAddRecStart( + const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, + DenseMap, const SCEV *> &Cache) { auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; - const SCEV *PreStart = getPreStartForExtend(AR, Ty, SE); + const SCEV *PreStart = getPreStartForExtend(AR, Ty, SE, Cache); if (!PreStart) - return (SE->*GetExtendExpr)(AR->getStart(), Ty); + return (SE->*GetExtendExpr)(AR->getStart(), Ty, Cache); - return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty), - (SE->*GetExtendExpr)(PreStart, Ty)); + return SE->getAddExpr( + (SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty, Cache), + (SE->*GetExtendExpr)(PreStart, Ty, Cache)); } // Try to prove away overflow by looking at "nearby" add recurrences. A @@ -1496,31 +1503,52 @@ return false; } -const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, - Type *Ty) { +const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty) { + // Use the local cache to prevent exponential behavior of + // getZeroExtendExprImpl. + DenseMap, const SCEV *> Cache; + return getZeroExtendExprImpl(Op, Ty, Cache); +} + +const SCEV *ScalarEvolution::getZeroExtendExprImpl( + const SCEV *Op, Type *Ty, + DenseMap, const SCEV *> &Cache) { + auto It = Cache.find({Op, Ty}); + if (It != Cache.end()) + return It->second; + assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); Ty = getEffectiveSCEVType(Ty); + const SCEV *ZExt = nullptr; + FoldingSetNodeID ID; + void *IP = nullptr; + // Fold if the operand is constant. - if (const SCEVConstant *SC = dyn_cast(Op)) - return getConstant( - cast(ConstantExpr::getZExt(SC->getValue(), Ty))); + if (const SCEVConstant *SC = dyn_cast(Op)) { + ZExt = getConstant( + cast(ConstantExpr::getZExt(SC->getValue(), Ty))); + goto RET; + } // zext(zext(x)) --> zext(x) - if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) - return getZeroExtendExpr(SZ->getOperand(), Ty); + if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) { + ZExt = getZeroExtendExprImpl(SZ->getOperand(), Ty, Cache); + goto RET; + } - // Before doing any expensive analysis, check to see if we've already - // computed a SCEV for this Op and Ty. - FoldingSetNodeID ID; ID.AddInteger(scZeroExtend); ID.AddPointer(Op); ID.AddPointer(Ty); - void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + // Before doing any expensive analysis, check to see if we've already + // computed a SCEV for this Op and Ty. + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { + ZExt = S; + goto RET; + } // zext(trunc(x)) --> zext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { @@ -1531,8 +1559,10 @@ unsigned TruncBits = getTypeSizeInBits(ST->getType()); unsigned NewBits = getTypeSizeInBits(Ty); if (CR.truncate(TruncBits).zeroExtend(NewBits).contains( - CR.zextOrTrunc(NewBits))) - return getTruncateOrZeroExtend(X, Ty); + CR.zextOrTrunc(NewBits))) { + ZExt = getTruncateOrZeroExtend(X, Ty); + goto RET; + } } // If the input value is a chrec scev, and we can prove that the value @@ -1553,10 +1583,12 @@ // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. - if (AR->hasNoUnsignedWrap()) - return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + if (AR->hasNoUnsignedWrap()) { + ZExt = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, Cache), + getZeroExtendExprImpl(Step, Ty, Cache), L, AR->getNoWrapFlags()); + goto RET; + } // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1581,21 +1613,23 @@ Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no unsigned overflow. const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step); - const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul), WideTy); - const SCEV *WideStart = getZeroExtendExpr(Start, WideTy); + const SCEV *ZAdd = + getZeroExtendExprImpl(getAddExpr(Start, ZMul), WideTy, Cache); + const SCEV *WideStart = getZeroExtendExprImpl(Start, WideTy, Cache); const SCEV *WideMaxBECount = - getZeroExtendExpr(CastedMaxBECount, WideTy); - const SCEV *OperandExtendedAdd = - getAddExpr(WideStart, - getMulExpr(WideMaxBECount, - getZeroExtendExpr(Step, WideTy))); + getZeroExtendExprImpl(CastedMaxBECount, WideTy, Cache); + const SCEV *OperandExtendedAdd = getAddExpr( + WideStart, getMulExpr(WideMaxBECount, getZeroExtendExprImpl( + Step, WideTy, Cache))); if (ZAdd == OperandExtendedAdd) { // Cache knowledge of AR NUW, which is propagated to this AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. - return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + ZExt = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, Cache), + getZeroExtendExprImpl(Step, Ty, Cache), L, + AR->getNoWrapFlags()); + goto RET; } // Similar to above, only this time treat the step value as signed. // This covers loops that count down. @@ -1608,9 +1642,10 @@ // Negative step causes unsigned wrap, but it still can't self-wrap. const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. - return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), + ZExt = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, Cache), getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + goto RET; } } } @@ -1640,9 +1675,11 @@ // AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. - return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + ZExt = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, Cache), + getZeroExtendExprImpl(Step, Ty, Cache), L, + AR->getNoWrapFlags()); + goto RET; } } else if (isKnownNegative(Step)) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - @@ -1656,18 +1693,20 @@ // still can't self-wrap. const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. - return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), + ZExt = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, Cache), getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + goto RET; } } } if (proveNoWrapByVaryingStart(Start, Step, L)) { const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); - return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + ZExt = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, Cache), + getZeroExtendExprImpl(Step, Ty, Cache), L, AR->getNoWrapFlags()); + goto RET; } } @@ -1678,49 +1717,78 @@ // commute the zero extension with the addition operation. SmallVector Ops; for (const auto *Op : SA->operands()) - Ops.push_back(getZeroExtendExpr(Op, Ty)); - return getAddExpr(Ops, SCEV::FlagNUW); + Ops.push_back(getZeroExtendExprImpl(Op, Ty, Cache)); + ZExt = getAddExpr(Ops, SCEV::FlagNUW); + goto RET; } } // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), - Op, Ty); - UniqueSCEVs.InsertNode(S, IP); - return S; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { + ZExt = S; + goto RET; + } + ZExt = + new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); + UniqueSCEVs.InsertNode(const_cast(ZExt), IP); +RET: + Cache.insert({{Op, Ty}, ZExt}); + return ZExt; } -const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, - Type *Ty) { +const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty) { + // Use the local cache to prevent exponential behavior of + // getSignExtendExprImpl. + DenseMap, const SCEV *> Cache; + return getSignExtendExprImpl(Op, Ty, Cache); +} + +const SCEV *ScalarEvolution::getSignExtendExprImpl( + const SCEV *Op, Type *Ty, + DenseMap, const SCEV *> &Cache) { + auto It = Cache.find({Op, Ty}); + if (It != Cache.end()) + return It->second; + assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); Ty = getEffectiveSCEVType(Ty); + const SCEV *SExt = nullptr; + FoldingSetNodeID ID; + void *IP = nullptr; + // Fold if the operand is constant. - if (const SCEVConstant *SC = dyn_cast(Op)) - return getConstant( - cast(ConstantExpr::getSExt(SC->getValue(), Ty))); + if (const SCEVConstant *SC = dyn_cast(Op)) { + SExt = getConstant( + cast(ConstantExpr::getSExt(SC->getValue(), Ty))); + goto RET; + } // sext(sext(x)) --> sext(x) - if (const SCEVSignExtendExpr *SS = dyn_cast(Op)) - return getSignExtendExpr(SS->getOperand(), Ty); + if (const SCEVSignExtendExpr *SS = dyn_cast(Op)) { + SExt = getSignExtendExprImpl(SS->getOperand(), Ty, Cache); + goto RET; + } // sext(zext(x)) --> zext(x) - if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) - return getZeroExtendExpr(SZ->getOperand(), Ty); + if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) { + SExt = getZeroExtendExpr(SZ->getOperand(), Ty); + goto RET; + } // Before doing any expensive analysis, check to see if we've already // computed a SCEV for this Op and Ty. - FoldingSetNodeID ID; ID.AddInteger(scSignExtend); ID.AddPointer(Op); ID.AddPointer(Ty); - void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { + SExt = S; + goto RET; + } // sext(trunc(x)) --> sext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { @@ -1731,8 +1799,10 @@ unsigned TruncBits = getTypeSizeInBits(ST->getType()); unsigned NewBits = getTypeSizeInBits(Ty); if (CR.truncate(TruncBits).signExtend(NewBits).contains( - CR.sextOrTrunc(NewBits))) - return getTruncateOrSignExtend(X, Ty); + CR.sextOrTrunc(NewBits))) { + SExt = getTruncateOrSignExtend(X, Ty); + goto RET; + } } // sext(C1 + (C2 * x)) --> C1 + sext(C2 * x) if C1 < C2 @@ -1745,9 +1815,11 @@ const APInt &C1 = SC1->getAPInt(); const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && - C2.ugt(C1) && C2.isPowerOf2()) - return getAddExpr(getSignExtendExpr(SC1, Ty), - getSignExtendExpr(SMul, Ty)); + C2.ugt(C1) && C2.isPowerOf2()) { + SExt = getAddExpr(getSignExtendExprImpl(SC1, Ty, Cache), + getSignExtendExprImpl(SMul, Ty, Cache)); + goto RET; + } } } } @@ -1758,8 +1830,9 @@ // commute the sign extension with the addition operation. SmallVector Ops; for (const auto *Op : SA->operands()) - Ops.push_back(getSignExtendExpr(Op, Ty)); - return getAddExpr(Ops, SCEV::FlagNSW); + Ops.push_back(getSignExtendExprImpl(Op, Ty, Cache)); + SExt = getAddExpr(Ops, SCEV::FlagNSW); + goto RET; } } // If the input value is a chrec scev, and we can prove that the value @@ -1780,10 +1853,12 @@ // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. - if (AR->hasNoSignedWrap()) - return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); + if (AR->hasNoSignedWrap()) { + SExt = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, Cache), + getSignExtendExprImpl(Step, Ty, Cache), L, SCEV::FlagNSW); + goto RET; + } // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1808,28 +1883,29 @@ Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no signed overflow. const SCEV *SMul = getMulExpr(CastedMaxBECount, Step); - const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul), WideTy); - const SCEV *WideStart = getSignExtendExpr(Start, WideTy); + const SCEV *SAdd = + getSignExtendExprImpl(getAddExpr(Start, SMul), WideTy, Cache); + const SCEV *WideStart = getSignExtendExprImpl(Start, WideTy, Cache); const SCEV *WideMaxBECount = - getZeroExtendExpr(CastedMaxBECount, WideTy); - const SCEV *OperandExtendedAdd = - getAddExpr(WideStart, - getMulExpr(WideMaxBECount, - getSignExtendExpr(Step, WideTy))); + getZeroExtendExpr(CastedMaxBECount, WideTy); + const SCEV *OperandExtendedAdd = getAddExpr( + WideStart, getMulExpr(WideMaxBECount, getSignExtendExprImpl( + Step, WideTy, Cache))); if (SAdd == OperandExtendedAdd) { // Cache knowledge of AR NSW, which is propagated to this AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); // Return the expression with the addrec on the outside. - return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + SExt = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, Cache), + getSignExtendExprImpl(Step, Ty, Cache), L, + AR->getNoWrapFlags()); + goto RET; } // Similar to above, only this time treat the step value as unsigned. // This covers loops that count up with an unsigned step. - OperandExtendedAdd = - getAddExpr(WideStart, - getMulExpr(WideMaxBECount, - getZeroExtendExpr(Step, WideTy))); + OperandExtendedAdd = getAddExpr( + WideStart, + getMulExpr(WideMaxBECount, getZeroExtendExpr(Step, WideTy))); if (SAdd == OperandExtendedAdd) { // If AR wraps around then // @@ -1842,9 +1918,10 @@ const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. - return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), + SExt = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, Cache), getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + goto RET; } } } @@ -1874,9 +1951,10 @@ OverflowLimit)))) { // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); - return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + SExt = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, Cache), + getSignExtendExprImpl(Step, Ty, Cache), L, AR->getNoWrapFlags()); + goto RET; } } @@ -1890,33 +1968,39 @@ const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) { - Start = getSignExtendExpr(Start, Ty); + Start = getSignExtendExprImpl(Start, Ty, Cache); const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L, AR->getNoWrapFlags()); - return getAddExpr(Start, getSignExtendExpr(NewAR, Ty)); + SExt = getAddExpr(Start, getSignExtendExprImpl(NewAR, Ty, Cache)); + goto RET; } } if (proveNoWrapByVaryingStart(Start, Step, L)) { const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); - return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + SExt = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, Cache), + getSignExtendExprImpl(Step, Ty, Cache), L, AR->getNoWrapFlags()); + goto RET; } } // If the input value is provably positive and we could not simplify // away the sext build a zext instead. - if (isKnownNonNegative(Op)) - return getZeroExtendExpr(Op, Ty); + if (isKnownNonNegative(Op)) { + SExt = getZeroExtendExpr(Op, Ty); + goto RET; + } // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), - Op, Ty); - UniqueSCEVs.InsertNode(S, IP); - return S; + SExt = + new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); + UniqueSCEVs.InsertNode(const_cast(SExt), IP); +RET: + Cache.insert({{Op, Ty}, SExt}); + return SExt; } /// getAnyExtendExpr - Return a SCEV for the given operand extended with Index: unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- unittests/Analysis/ScalarEvolutionTest.cpp +++ unittests/Analysis/ScalarEvolutionTest.cpp @@ -600,5 +600,93 @@ EXPECT_NE(nullptr, SE.getSCEV(Mul1)); } +// Expect the call of getZeroExtendExpr will not cost exponential time. +TEST_F(ScalarEvolutionsTest, SCEVZeroExtendExpr) { + LLVMContext C; + SMDiagnostic Err; + + // Generate a function like below: + // define void @foo() { + // entry: + // br label %for.cond + // + // for.cond: + // %0 = phi i64 [ 100, %entry ], [ %dec, %for.inc ] + // %cmp = icmp sgt i64 %0, 90 + // br i1 %cmp, label %for.inc, label %for.cond1 + // + // for.inc: + // %dec = add nsw i64 %0, -1 + // br label %for.cond + // + // for.cond1: + // %1 = phi i64 [ 100, %for.cond ], [ %dec5, %for.inc2 ] + // %cmp3 = icmp sgt i64 %1, 90 + // br i1 %cmp3, label %for.inc2, label %for.cond4 + // + // for.inc2: + // %dec5 = add nsw i64 %1, -1 + // br label %for.cond1 + // + // ...... + // + // for.cond89: + // %19 = phi i64 [ 100, %for.cond84 ], [ %dec94, %for.inc92 ] + // %cmp93 = icmp sgt i64 %19, 90 + // br i1 %cmp93, label %for.inc92, label %for.end + // + // for.inc92: + // %dec94 = add nsw i64 %19, -1 + // br label %for.cond89 + // + // for.end: + // %gep = getelementptr i8, i8* null, i64 %dec + // %gep6 = getelementptr i8, i8* %gep, i64 %dec5 + // ...... + // %gep95 = getelementptr i8, i8* %gep91, i64 %dec94 + // ret void + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), {}, false); + Function *F = cast(M.getOrInsertFunction("foo", FTy)); + + BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F); + BasicBlock *CondBB = BasicBlock::Create(Context, "for.cond", F); + BasicBlock *EndBB = BasicBlock::Create(Context, "for.end", F); + BranchInst::Create(CondBB, EntryBB); + BasicBlock *PrevBB = EntryBB; + + Type *I64Ty = Type::getInt64Ty(Context); + Type *I8Ty = Type::getInt8Ty(Context); + Type *I8PtrTy = Type::getInt8PtrTy(Context); + Value *Accum = Constant::getNullValue(I8PtrTy); + int Iters = 20; + for (int i = 0; i < Iters; i++) { + BasicBlock *IncBB = BasicBlock::Create(Context, "for.inc", F, EndBB); + auto *PN = PHINode::Create(I64Ty, 2, "", CondBB); + PN->addIncoming(ConstantInt::get(Context, APInt(64, 100)), PrevBB); + auto *Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, PN, + ConstantInt::get(Context, APInt(64, 90)), "cmp", + CondBB); + BasicBlock *NextBB; + if (i != Iters - 1) + NextBB = BasicBlock::Create(Context, "for.cond", F, EndBB); + else + NextBB = EndBB; + BranchInst::Create(IncBB, NextBB, Cmp, CondBB); + auto *Dec = BinaryOperator::CreateNSWAdd( + PN, ConstantInt::get(Context, APInt(64, -1)), "dec", IncBB); + PN->addIncoming(Dec, IncBB); + BranchInst::Create(CondBB, IncBB); + + Accum = GetElementPtrInst::Create(I8Ty, Accum, Dec, "gep", EndBB); + + PrevBB = CondBB; + CondBB = NextBB; + } + ReturnInst::Create(Context, nullptr, EndBB); + ScalarEvolution SE = buildSE(*F); + const SCEV *S = SE.getSCEV(Accum); + Type *I128Ty = Type::getInt128Ty(Context); + SE.getZeroExtendExpr(S, I128Ty); +} } // end anonymous namespace } // end namespace llvm