Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -85,6 +85,9 @@ const unsigned short SCEVType; protected: + // Estimated complexity of this node's expression tree size. + const unsigned short ExpressionSize; + /// This field is initialized to zero and may be used in subclasses to store /// miscellaneous information. unsigned short SubclassData = 0; @@ -116,8 +119,9 @@ NoWrapMask = (1 << 3) - 1 }; - explicit SCEV(const FoldingSetNodeIDRef ID, unsigned SCEVTy) - : FastID(ID), SCEVType(SCEVTy) {} + explicit SCEV(const FoldingSetNodeIDRef ID, unsigned SCEVTy, + unsigned short ExpressionSize) + : FastID(ID), SCEVType(SCEVTy), ExpressionSize(ExpressionSize) {} SCEV(const SCEV &) = delete; SCEV &operator=(const SCEV &) = delete; @@ -138,6 +142,19 @@ /// Return true if the specified scev is negated, but not a constant. bool isNonConstantNegative() const; + // Returns estimated size of the mathematical expression represented by this + // SCEV. The rules of its calculation are following: + // 1) Size of a SCEV without operands (like constants and SCEVUnknown) is 1; + // 2) Size SCEV with operands Op1, Op2, ..., OpN is calculated by formula: + // (1 + Size(Op1) + ... + Size(OpN)). + // This value gives us an estimation of time we need to traverse through this + // SCEV and all its operands recursively. We may use it to avoid performing + // heavy transformations on SCEVs of excessive size for sake of saving the + // compilation time. + unsigned getExpressionSize() const { + return ExpressionSize; + } + /// Print out the internal representation of this scalar to the specified /// stream. This should really only be used for debugging purposes. void print(raw_ostream &OS) const; Index: include/llvm/Analysis/ScalarEvolutionExpressions.h =================================================================== --- include/llvm/Analysis/ScalarEvolutionExpressions.h +++ include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -51,7 +51,7 @@ ConstantInt *V; SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) : - SCEV(ID, scConstant), V(v) {} + SCEV(ID, scConstant, 1), V(v) {} public: ConstantInt *getValue() const { return V; } @@ -67,6 +67,12 @@ /// This is the base class for unary cast operator classes. class SCEVCastExpr : public SCEV { + static unsigned short computeExpressionSize(const SCEV *Operand) { + auto Size = APInt(16, 1); + Size = Size.uadd_sat(APInt(16, Operand->getExpressionSize())); + return (unsigned short)Size.getZExtValue(); + } + protected: const SCEV *Op; Type *Ty; @@ -134,6 +140,14 @@ /// This node is a base class providing common functionality for /// n'ary operators. class SCEVNAryExpr : public SCEV { + static unsigned short computeExpressionSize(size_t N, + const SCEV *const *O) { + APInt Size(16, 1); + for (size_t I = 0; I < N; ++I) + Size = Size.uadd_sat(APInt(16, O[I]->getExpressionSize())); + return (unsigned short)Size.getZExtValue(); + } + protected: // Since SCEVs are immutable, ScalarEvolution allocates operand // arrays with its SCEVAllocator, so this class just needs a simple @@ -142,9 +156,10 @@ const SCEV *const *Operands; size_t NumOperands; - SCEVNAryExpr(const FoldingSetNodeIDRef ID, - enum SCEVTypes T, const SCEV *const *O, size_t N) - : SCEV(ID, T), Operands(O), NumOperands(N) {} + SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, + const SCEV *const *O, size_t N) + : SCEV(ID, T, computeExpressionSize(N, O)), Operands(O), + NumOperands(N) {} public: size_t getNumOperands() const { return NumOperands; } @@ -257,8 +272,17 @@ const SCEV *LHS; const SCEV *RHS; + static unsigned short computeExpressionSize(const SCEV *Op1, + const SCEV *Op2) { + APInt Size(16, 1); + Size = Size.uadd_sat(APInt(16, Op1->getExpressionSize())); + Size = Size.uadd_sat(APInt(16, Op2->getExpressionSize())); + return (unsigned short)Size.getZExtValue(); + } + SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) - : SCEV(ID, scUDivExpr), LHS(lhs), RHS(rhs) {} + : SCEV(ID, scUDivExpr, computeExpressionSize(lhs, rhs)), LHS(lhs), + RHS(rhs) {} public: const SCEV *getLHS() const { return LHS; } @@ -411,7 +435,7 @@ SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se, SCEVUnknown *next) : - SCEV(ID, scUnknown), CallbackVH(V), SE(se), Next(next) {} + SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {} // Implement CallbackVH. void deleted() override; Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -393,7 +393,7 @@ } SCEVCouldNotCompute::SCEVCouldNotCompute() : - SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {} + SCEV(FoldingSetNodeIDRef(), scCouldNotCompute, 0) {} bool SCEVCouldNotCompute::classof(const SCEV *S) { return S->getSCEVType() == scCouldNotCompute; @@ -422,7 +422,7 @@ SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, unsigned SCEVTy, const SCEV *op, Type *ty) - : SCEV(ID, SCEVTy), Op(op), Ty(ty) {} + : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {} SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty) Index: unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- unittests/Analysis/ScalarEvolutionTest.cpp +++ unittests/Analysis/ScalarEvolutionTest.cpp @@ -1390,5 +1390,55 @@ EXPECT_FALSE(I->hasNoSignedWrap()); } +// Check logic of SCEV expression size computation. +TEST_F(ScalarEvolutionsTest, SCEVComputeExpressionSize) { + /* + * Create the following code: + * void func(i64 %a, i64 %b) + * entry: + * %s1 = add i64 %a, 1 + * %s2 = udiv i64 %s1, %b + * br label %exit + * exit: + * ret + */ + + // Create a module. + Module M("SCEVComputeExpressionSize", Context); + + Type *T_int64 = Type::getInt64Ty(Context); + + FunctionType *FTy = + FunctionType::get(Type::getVoidTy(Context), { T_int64, T_int64 }, false); + Function *F = cast(M.getOrInsertFunction("func", FTy)); + Argument *A = &*F->arg_begin(); + Argument *B = &*std::next(F->arg_begin()); + ConstantInt *C = ConstantInt::get(Context, APInt(64, 1)); + + BasicBlock *Entry = BasicBlock::Create(Context, "entry", F); + BasicBlock *Exit = BasicBlock::Create(Context, "exit", F); + + IRBuilder<> Builder(Entry); + auto *S1 = cast(Builder.CreateAdd(A, C, "s1")); + auto *S2 = cast(Builder.CreateUDiv(S1, B, "s2")); + Builder.CreateBr(Exit); + + Builder.SetInsertPoint(Exit); + auto *R = cast(Builder.CreateRetVoid()); + + ScalarEvolution SE = buildSE(*F); + // Get S2 first to move it to cache. + const SCEV *AS = SE.getSCEV(A); + const SCEV *BS = SE.getSCEV(B); + const SCEV *CS = SE.getSCEV(C); + const SCEV *S1S = SE.getSCEV(S1); + const SCEV *S2S = SE.getSCEV(S2); + EXPECT_EQ(AS->getExpressionSize(), 1); + EXPECT_EQ(BS->getExpressionSize(), 1); + EXPECT_EQ(CS->getExpressionSize(), 1); + EXPECT_EQ(S1S->getExpressionSize(), 3); + EXPECT_EQ(S2S->getExpressionSize(), 5); +} + } // end anonymous namespace } // end namespace llvm