Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -1159,8 +1159,20 @@ const SCEV *getConstant(const APInt &Val); const SCEV *getConstant(Type *Ty, uint64_t V, bool isSigned = false); const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty); + + typedef SmallDenseMap, const SCEV *, 8> + CacheTypeForExtend; const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty); + const SCEV *getZeroExtendExprCached(const SCEV *Op, Type *Ty, + CacheTypeForExtend &Cache); + const SCEV *getZeroExtendExprImpl(const SCEV *Op, Type *Ty, + CacheTypeForExtend &Cache); + const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty); + const SCEV *getSignExtendExprCached(const SCEV *Op, Type *Ty, + CacheTypeForExtend &Cache); + const SCEV *getSignExtendExprImpl(const SCEV *Op, Type *Ty, + CacheTypeForExtend &Cache); const SCEV *getAnyExtendExpr(const SCEV *Op, Type *Ty); const SCEV *getAddExpr(SmallVectorImpl &Ops, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -1276,7 +1276,8 @@ namespace { struct ExtendOpTraitsBase { - typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *); + typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)( + const SCEV *, Type *, ScalarEvolution::CacheTypeForExtend &Cache); }; // Used to make code generic over signed and unsigned overflow. @@ -1305,8 +1306,9 @@ } }; -const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< - SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr; +const ExtendOpTraitsBase::GetExtendExprTy + ExtendOpTraits::GetExtendExpr = + &ScalarEvolution::getSignExtendExprCached; template <> struct ExtendOpTraits : public ExtendOpTraitsBase { @@ -1321,8 +1323,9 @@ } }; -const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< - SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; +const ExtendOpTraitsBase::GetExtendExprTy + ExtendOpTraits::GetExtendExpr = + &ScalarEvolution::getZeroExtendExprCached; } // The recurrence AR has been shown to have no signed/unsigned wrap or something @@ -1333,8 +1336,9 @@ // expression "Step + sext/zext(PreIncAR)" is congruent with // "sext/zext(PostIncAR)" template -static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE) { +static const SCEV * +getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, + ScalarEvolution::CacheTypeForExtend &Cache) { auto WrapType = ExtendOpTraits::WrapType; auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; @@ -1381,9 +1385,9 @@ unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); const SCEV *OperandExtendedStart = - SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy), - (SE->*GetExtendExpr)(Step, WideTy)); - if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) { + SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Cache), + (SE->*GetExtendExpr)(Step, WideTy, Cache)); + if ((SE->*GetExtendExpr)(Start, WideTy, Cache) == OperandExtendedStart) { if (PreAR && AR->getNoWrapFlags(WrapType)) { // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then @@ -1407,16 +1411,18 @@ // Get the normalized zero or sign extended expression for this AddRec's Start. template -static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE) { +static const SCEV * +getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, + ScalarEvolution::CacheTypeForExtend &Cache) { auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; - const SCEV *PreStart = getPreStartForExtend(AR, Ty, SE); + const SCEV *PreStart = getPreStartForExtend(AR, Ty, SE, Cache); if (!PreStart) - return (SE->*GetExtendExpr)(AR->getStart(), Ty); + return (SE->*GetExtendExpr)(AR->getStart(), Ty, Cache); - return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty), - (SE->*GetExtendExpr)(PreStart, Ty)); + return SE->getAddExpr( + (SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty, Cache), + (SE->*GetExtendExpr)(PreStart, Ty, Cache)); } // Try to prove away overflow by looking at "nearby" add recurrences. A @@ -1496,8 +1502,30 @@ return false; } -const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, - Type *Ty) { +const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty) { + // Use the local cache to prevent exponential behavior of + // getZeroExtendExprImpl. + CacheTypeForExtend Cache; + return getZeroExtendExprCached(Op, Ty, Cache); +} + +/// Query \p Cache before calling getZeroExtendExprImpl. If there is no +/// related entry in the \p Cache, call getZeroExtendExprImpl and save +/// the result in the \p Cache. +const SCEV * +ScalarEvolution::getZeroExtendExprCached(const SCEV *Op, Type *Ty, + CacheTypeForExtend &Cache) { + auto It = Cache.find({Op, Ty}); + if (It != Cache.end()) + return It->second; + const SCEV *ZExt = getZeroExtendExprImpl(Op, Ty, Cache); + Cache.insert({{Op, Ty}, ZExt}); + return ZExt; +} + +/// The real implementation of getZeroExtendExpr. +const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, + CacheTypeForExtend &Cache) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -1507,11 +1535,11 @@ // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) return getConstant( - cast(ConstantExpr::getZExt(SC->getValue(), Ty))); + cast(ConstantExpr::getZExt(SC->getValue(), Ty))); // zext(zext(x)) --> zext(x) if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) - return getZeroExtendExpr(SZ->getOperand(), Ty); + return getZeroExtendExprCached(SZ->getOperand(), Ty, Cache); // Before doing any expensive analysis, check to see if we've already // computed a SCEV for this Op and Ty. @@ -1555,8 +1583,8 @@ // we don't need to do any further analysis. if (AR->hasNoUnsignedWrap()) return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart(AR, Ty, this, Cache), + getZeroExtendExprCached(Step, Ty, Cache), L, AR->getNoWrapFlags()); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1581,21 +1609,22 @@ Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no unsigned overflow. const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step); - const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul), WideTy); - const SCEV *WideStart = getZeroExtendExpr(Start, WideTy); + const SCEV *ZAdd = + getZeroExtendExprCached(getAddExpr(Start, ZMul), WideTy, Cache); + const SCEV *WideStart = getZeroExtendExprCached(Start, WideTy, Cache); const SCEV *WideMaxBECount = - getZeroExtendExpr(CastedMaxBECount, WideTy); - const SCEV *OperandExtendedAdd = - getAddExpr(WideStart, - getMulExpr(WideMaxBECount, - getZeroExtendExpr(Step, WideTy))); + getZeroExtendExprCached(CastedMaxBECount, WideTy, Cache); + const SCEV *OperandExtendedAdd = getAddExpr( + WideStart, getMulExpr(WideMaxBECount, getZeroExtendExprCached( + Step, WideTy, Cache))); if (ZAdd == OperandExtendedAdd) { // Cache knowledge of AR NUW, which is propagated to this AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart(AR, Ty, this, Cache), + getZeroExtendExprCached(Step, Ty, Cache), L, + AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as signed. // This covers loops that count down. @@ -1609,7 +1638,7 @@ const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), + getExtendAddRecStart(AR, Ty, this, Cache), getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } @@ -1641,8 +1670,9 @@ const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart(AR, Ty, this, Cache), + getZeroExtendExprCached(Step, Ty, Cache), L, + AR->getNoWrapFlags()); } } else if (isKnownNegative(Step)) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - @@ -1657,7 +1687,7 @@ const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), + getExtendAddRecStart(AR, Ty, this, Cache), getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } @@ -1666,8 +1696,8 @@ if (proveNoWrapByVaryingStart(Start, Step, L)) { const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart(AR, Ty, this, Cache), + getZeroExtendExprCached(Step, Ty, Cache), L, AR->getNoWrapFlags()); } } @@ -1678,7 +1708,7 @@ // commute the zero extension with the addition operation. SmallVector Ops; for (const auto *Op : SA->operands()) - Ops.push_back(getZeroExtendExpr(Op, Ty)); + Ops.push_back(getZeroExtendExprCached(Op, Ty, Cache)); return getAddExpr(Ops, SCEV::FlagNUW); } } @@ -1692,8 +1722,30 @@ return S; } -const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, - Type *Ty) { +const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty) { + // Use the local cache to prevent exponential behavior of + // getSignExtendExprImpl. + CacheTypeForExtend Cache; + return getSignExtendExprCached(Op, Ty, Cache); +} + +/// Query \p Cache before calling getSignExtendExprImpl. If there is no +/// related entry in the \p Cache, call getSignExtendExprImpl and save +/// the result in the \p Cache. +const SCEV * +ScalarEvolution::getSignExtendExprCached(const SCEV *Op, Type *Ty, + CacheTypeForExtend &Cache) { + auto It = Cache.find({Op, Ty}); + if (It != Cache.end()) + return It->second; + const SCEV *SExt = getSignExtendExprImpl(Op, Ty, Cache); + Cache.insert({{Op, Ty}, SExt}); + return SExt; +} + +/// The real implementation of getSignExtendExpr. +const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, + CacheTypeForExtend &Cache) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -1703,11 +1755,11 @@ // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) return getConstant( - cast(ConstantExpr::getSExt(SC->getValue(), Ty))); + cast(ConstantExpr::getSExt(SC->getValue(), Ty))); // sext(sext(x)) --> sext(x) if (const SCEVSignExtendExpr *SS = dyn_cast(Op)) - return getSignExtendExpr(SS->getOperand(), Ty); + return getSignExtendExprCached(SS->getOperand(), Ty, Cache); // sext(zext(x)) --> zext(x) if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) @@ -1746,8 +1798,8 @@ const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) - return getAddExpr(getSignExtendExpr(SC1, Ty), - getSignExtendExpr(SMul, Ty)); + return getAddExpr(getSignExtendExprCached(SC1, Ty, Cache), + getSignExtendExprCached(SMul, Ty, Cache)); } } } @@ -1758,7 +1810,7 @@ // commute the sign extension with the addition operation. SmallVector Ops; for (const auto *Op : SA->operands()) - Ops.push_back(getSignExtendExpr(Op, Ty)); + Ops.push_back(getSignExtendExprCached(Op, Ty, Cache)); return getAddExpr(Ops, SCEV::FlagNSW); } } @@ -1782,8 +1834,8 @@ // we don't need to do any further analysis. if (AR->hasNoSignedWrap()) return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); + getExtendAddRecStart(AR, Ty, this, Cache), + getSignExtendExprCached(Step, Ty, Cache), L, SCEV::FlagNSW); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1808,21 +1860,22 @@ Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no signed overflow. const SCEV *SMul = getMulExpr(CastedMaxBECount, Step); - const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul), WideTy); - const SCEV *WideStart = getSignExtendExpr(Start, WideTy); + const SCEV *SAdd = + getSignExtendExprCached(getAddExpr(Start, SMul), WideTy, Cache); + const SCEV *WideStart = getSignExtendExprCached(Start, WideTy, Cache); const SCEV *WideMaxBECount = - getZeroExtendExpr(CastedMaxBECount, WideTy); - const SCEV *OperandExtendedAdd = - getAddExpr(WideStart, - getMulExpr(WideMaxBECount, - getSignExtendExpr(Step, WideTy))); + getZeroExtendExpr(CastedMaxBECount, WideTy); + const SCEV *OperandExtendedAdd = getAddExpr( + WideStart, getMulExpr(WideMaxBECount, getSignExtendExprCached( + Step, WideTy, Cache))); if (SAdd == OperandExtendedAdd) { // Cache knowledge of AR NSW, which is propagated to this AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart(AR, Ty, this, Cache), + getSignExtendExprCached(Step, Ty, Cache), L, + AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as unsigned. // This covers loops that count up with an unsigned step. @@ -1843,7 +1896,7 @@ // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), + getExtendAddRecStart(AR, Ty, this, Cache), getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } @@ -1875,8 +1928,9 @@ // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart(AR, Ty, this, Cache), + getSignExtendExprCached(Step, Ty, Cache), L, + AR->getNoWrapFlags()); } } @@ -1890,18 +1944,18 @@ const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) { - Start = getSignExtendExpr(Start, Ty); + Start = getSignExtendExprCached(Start, Ty, Cache); const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L, AR->getNoWrapFlags()); - return getAddExpr(Start, getSignExtendExpr(NewAR, Ty)); + return getAddExpr(Start, getSignExtendExprCached(NewAR, Ty, Cache)); } } if (proveNoWrapByVaryingStart(Start, Step, L)) { const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); return getAddRecExpr( - getExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart(AR, Ty, this, Cache), + getSignExtendExprCached(Step, Ty, Cache), L, AR->getNoWrapFlags()); } } Index: unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- unittests/Analysis/ScalarEvolutionTest.cpp +++ unittests/Analysis/ScalarEvolutionTest.cpp @@ -600,5 +600,93 @@ EXPECT_NE(nullptr, SE.getSCEV(Mul1)); } +// Expect the call of getZeroExtendExpr will not cost exponential time. +TEST_F(ScalarEvolutionsTest, SCEVZeroExtendExpr) { + LLVMContext C; + SMDiagnostic Err; + + // Generate a function like below: + // define void @foo() { + // entry: + // br label %for.cond + // + // for.cond: + // %0 = phi i64 [ 100, %entry ], [ %dec, %for.inc ] + // %cmp = icmp sgt i64 %0, 90 + // br i1 %cmp, label %for.inc, label %for.cond1 + // + // for.inc: + // %dec = add nsw i64 %0, -1 + // br label %for.cond + // + // for.cond1: + // %1 = phi i64 [ 100, %for.cond ], [ %dec5, %for.inc2 ] + // %cmp3 = icmp sgt i64 %1, 90 + // br i1 %cmp3, label %for.inc2, label %for.cond4 + // + // for.inc2: + // %dec5 = add nsw i64 %1, -1 + // br label %for.cond1 + // + // ...... + // + // for.cond89: + // %19 = phi i64 [ 100, %for.cond84 ], [ %dec94, %for.inc92 ] + // %cmp93 = icmp sgt i64 %19, 90 + // br i1 %cmp93, label %for.inc92, label %for.end + // + // for.inc92: + // %dec94 = add nsw i64 %19, -1 + // br label %for.cond89 + // + // for.end: + // %gep = getelementptr i8, i8* null, i64 %dec + // %gep6 = getelementptr i8, i8* %gep, i64 %dec5 + // ...... + // %gep95 = getelementptr i8, i8* %gep91, i64 %dec94 + // ret void + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), {}, false); + Function *F = cast(M.getOrInsertFunction("foo", FTy)); + + BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F); + BasicBlock *CondBB = BasicBlock::Create(Context, "for.cond", F); + BasicBlock *EndBB = BasicBlock::Create(Context, "for.end", F); + BranchInst::Create(CondBB, EntryBB); + BasicBlock *PrevBB = EntryBB; + + Type *I64Ty = Type::getInt64Ty(Context); + Type *I8Ty = Type::getInt8Ty(Context); + Type *I8PtrTy = Type::getInt8PtrTy(Context); + Value *Accum = Constant::getNullValue(I8PtrTy); + int Iters = 20; + for (int i = 0; i < Iters; i++) { + BasicBlock *IncBB = BasicBlock::Create(Context, "for.inc", F, EndBB); + auto *PN = PHINode::Create(I64Ty, 2, "", CondBB); + PN->addIncoming(ConstantInt::get(Context, APInt(64, 100)), PrevBB); + auto *Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, PN, + ConstantInt::get(Context, APInt(64, 90)), "cmp", + CondBB); + BasicBlock *NextBB; + if (i != Iters - 1) + NextBB = BasicBlock::Create(Context, "for.cond", F, EndBB); + else + NextBB = EndBB; + BranchInst::Create(IncBB, NextBB, Cmp, CondBB); + auto *Dec = BinaryOperator::CreateNSWAdd( + PN, ConstantInt::get(Context, APInt(64, -1)), "dec", IncBB); + PN->addIncoming(Dec, IncBB); + BranchInst::Create(CondBB, IncBB); + + Accum = GetElementPtrInst::Create(I8Ty, Accum, Dec, "gep", EndBB); + + PrevBB = CondBB; + CondBB = NextBB; + } + ReturnInst::Create(Context, nullptr, EndBB); + ScalarEvolution SE = buildSE(*F); + const SCEV *S = SE.getSCEV(Accum); + Type *I128Ty = Type::getInt128Ty(Context); + SE.getZeroExtendExpr(S, I128Ty); +} } // end anonymous namespace } // end namespace llvm