Index: llvm/trunk/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/trunk/include/llvm/Analysis/ScalarEvolution.h +++ llvm/trunk/include/llvm/Analysis/ScalarEvolution.h @@ -1140,7 +1140,8 @@ const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty); const SCEV *getAnyExtendExpr(const SCEV *Op, Type *Ty); const SCEV *getAddExpr(SmallVectorImpl &Ops, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap); + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0); const SCEV *getAddExpr(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) { SmallVector Ops = {LHS, RHS}; @@ -1613,6 +1614,10 @@ bool doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap); + /// Get add expr already created or create a new one + const SCEV *getOrCreateAddExpr(SmallVectorImpl &Ops, + SCEV::NoWrapFlags Flags); + private: FoldingSet UniqueSCEVs; FoldingSet UniquePreds; Index: llvm/trunk/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/trunk/lib/Analysis/ScalarEvolution.cpp +++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp @@ -137,6 +137,11 @@ cl::desc("Maximum depth of recursive compare complexity"), cl::init(32)); +static cl::opt + MaxAddExprDepth("scalar-evolution-max-addexpr-depth", cl::Hidden, + cl::desc("Maximum depth of recursive AddExpr"), + cl::init(32)); + static cl::opt MaxConstantEvolvingDepth( "scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32)); @@ -2100,7 +2105,8 @@ /// Get a canonical add expression, or something simpler if possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, - SCEV::NoWrapFlags Flags) { + SCEV::NoWrapFlags Flags, + unsigned Depth) { assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty add!"); @@ -2139,6 +2145,10 @@ if (Ops.size() == 1) return Ops[0]; } + // Limit recursion calls depth + if (Depth > MaxAddExprDepth) + return getOrCreateAddExpr(Ops, Flags); + // Okay, check to see if the same value occurs in the operand list more than // once. If so, merge them together into an multiply expression. Since we // sorted the list, these values are required to be adjacent. @@ -2210,7 +2220,7 @@ } if (Ok) { // Evaluate the expression in the larger type. - const SCEV *Fold = getAddExpr(LargeOps, Flags); + const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1); // If it folds to something simple, use it. Otherwise, don't. if (isa(Fold) || isa(Fold)) return getTruncateExpr(Fold, DstType); @@ -2239,7 +2249,7 @@ // and they are not necessarily sorted. Recurse to resort and resimplify // any operands we just acquired. if (DeletedAdd) - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Skip over the add expression until we get to a multiply. @@ -2274,13 +2284,14 @@ Ops.push_back(getConstant(AccumulatedConstant)); for (auto &MulOp : MulOpLists) if (MulOp.first != 0) - Ops.push_back(getMulExpr(getConstant(MulOp.first), - getAddExpr(MulOp.second))); + Ops.push_back(getMulExpr( + getConstant(MulOp.first), + getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1))); if (Ops.empty()) return getZero(Ty); if (Ops.size() == 1) return Ops[0]; - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } } @@ -2305,8 +2316,8 @@ MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); InnerMul = getMulExpr(MulOps); } - const SCEV *One = getOne(Ty); - const SCEV *AddOne = getAddExpr(One, InnerMul); + SmallVector TwoOps = {getOne(Ty), InnerMul}; + const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV); if (Ops.size() == 2) return OuterMul; if (AddOp < Idx) { @@ -2317,7 +2328,7 @@ Ops.erase(Ops.begin()+AddOp-1); } Ops.push_back(OuterMul); - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Check this multiply against other multiplies being added together. @@ -2345,13 +2356,15 @@ MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end()); InnerMul2 = getMulExpr(MulOps); } - const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2); + SmallVector TwoOps = {InnerMul1, InnerMul2}; + const SCEV *InnerMulSum = + getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); if (Ops.size() == 2) return OuterMul; Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+OtherMulIdx-1); Ops.push_back(OuterMul); - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } } } @@ -2387,7 +2400,7 @@ // This follows from the fact that the no-wrap flags on the outer add // expression are applicable on the 0th iteration, when the add recurrence // will be equal to its start value. - AddRecOps[0] = getAddExpr(LIOps, Flags); + AddRecOps[0] = getAddExpr(LIOps, Flags, Depth + 1); // Build the new addrec. Propagate the NUW and NSW flags if both the // outer add and the inner addrec are guaranteed to have no overflow. @@ -2404,7 +2417,7 @@ Ops[i] = NewRec; break; } - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Okay, if there weren't any loop invariants to be folded, check to see if @@ -2428,14 +2441,15 @@ OtherAddRec->op_end()); break; } - AddRecOps[i] = getAddExpr(AddRecOps[i], - OtherAddRec->getOperand(i)); + SmallVector TwoOps = { + AddRecOps[i], OtherAddRec->getOperand(i)}; + AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); } Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; } // Step size has changed, so we cannot guarantee no self-wraparound. Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap); - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Otherwise couldn't fold anything into this recurrence. Move onto the @@ -2444,18 +2458,24 @@ // Okay, it looks like we really DO need an add expr. Check to see if we // already have one, otherwise create a new one. + return getOrCreateAddExpr(Ops, Flags); +} + +const SCEV * +ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl &Ops, + SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; ID.AddInteger(scAddExpr); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = nullptr; SCEVAddExpr *S = - static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); - S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator), - O, Ops.size()); + S = new (SCEVAllocator) + SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); } S->setNoWrapFlags(Flags); Index: llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp +++ llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp @@ -532,5 +532,33 @@ EXPECT_NE(nullptr, SE.getSCEV(Acc[0])); } +TEST_F(ScalarEvolutionsTest, SCEVAddExpr) { + Type *Ty32 = Type::getInt32Ty(Context); + Type *ArgTys[] = {Type::getInt64Ty(Context), Ty32}; + + FunctionType *FTy = + FunctionType::get(Type::getVoidTy(Context), ArgTys, false); + Function *F = cast(M.getOrInsertFunction("f", FTy)); + + Argument *A1 = &*F->arg_begin(); + Argument *A2 = &*(std::next(F->arg_begin())); + BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F); + + Instruction *Trunc = CastInst::CreateTruncOrBitCast(A1, Ty32, "", EntryBB); + Instruction *Mul1 = BinaryOperator::CreateMul(Trunc, A2, "", EntryBB); + Instruction *Add1 = BinaryOperator::CreateAdd(Mul1, Trunc, "", EntryBB); + Mul1 = BinaryOperator::CreateMul(Add1, Trunc, "", EntryBB); + Instruction *Add2 = BinaryOperator::CreateAdd(Mul1, Add1, "", EntryBB); + for (int i = 0; i < 1000; i++) { + Mul1 = BinaryOperator::CreateMul(Add2, Add1, "", EntryBB); + Add1 = Add2; + Add2 = BinaryOperator::CreateAdd(Mul1, Add1, "", EntryBB); + } + + ReturnInst::Create(Context, nullptr, EntryBB); + ScalarEvolution SE = buildSE(*F); + EXPECT_NE(nullptr, SE.getSCEV(Mul1)); +} + } // end anonymous namespace } // end namespace llvm