Index: include/llvm/Analysis/ScalarEvolutionExpressions.h =================================================================== --- include/llvm/Analysis/ScalarEvolutionExpressions.h +++ include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -553,64 +553,56 @@ T.visitAll(Root); } - typedef DenseMap ValueToValueMap; - - /// The SCEVParameterRewriter takes a scalar evolution expression and updates - /// the SCEVUnknown components following the Map (Value -> Value). - struct SCEVParameterRewriter - : public SCEVVisitor { + /// Recursively visits a SCEV expression and re-writes it. + template + class SCEVRewriteVisitor : public SCEVVisitor { + protected: + ScalarEvolution &SE; public: - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, - ValueToValueMap &Map, - bool InterpretConsts = false) { - SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts); - return Rewriter.visit(Scev); - } - - SCEVParameterRewriter(ScalarEvolution &S, ValueToValueMap &M, bool C) - : SE(S), Map(M), InterpretConsts(C) {} + SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); + const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); return SE.getTruncateExpr(Operand, Expr->getType()); } const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); + const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); return SE.getZeroExtendExpr(Operand, Expr->getType()); } const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); + const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); return SE.getSignExtendExpr(Operand, Expr->getType()); } const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { SmallVector Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); + Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); return SE.getAddExpr(Operands); } const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { SmallVector Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); + Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); return SE.getMulExpr(Operands); } const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { - return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS())); + return SE.getUDivExpr(((SC*)this)->visit(Expr->getLHS()), + ((SC*)this)->visit(Expr->getRHS())); } const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { SmallVector Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); + Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); return SE.getAddRecExpr(Operands, Expr->getLoop(), Expr->getNoWrapFlags()); } @@ -618,18 +610,43 @@ const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { SmallVector Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); + Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); return SE.getSMaxExpr(Operands); } const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { SmallVector Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); + Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); return SE.getUMaxExpr(Operands); } const SCEV *visitUnknown(const SCEVUnknown *Expr) { + return Expr; + } + + const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return Expr; + } + }; + + typedef DenseMap ValueToValueMap; + + /// The SCEVParameterRewriter takes a scalar evolution expression and updates + /// the SCEVUnknown components following the Map (Value -> Value). + class SCEVParameterRewriter : public SCEVRewriteVisitor { + public: + static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, + ValueToValueMap &Map, + bool InterpretConsts = false) { + SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts); + return Rewriter.visit(Scev); + } + + SCEVParameterRewriter(ScalarEvolution &SE, ValueToValueMap &M, bool C) + : SCEVRewriteVisitor(SE), Map(M), InterpretConsts(C) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { Value *V = Expr->getValue(); if (Map.count(V)) { Value *NV = Map[V]; @@ -640,12 +657,7 @@ return Expr; } - const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { - return Expr; - } - private: - ScalarEvolution &SE; ValueToValueMap ⤅ bool InterpretConsts; }; @@ -654,8 +666,7 @@ /// The SCEVApplyRewriter takes a scalar evolution expression and applies /// the Map (Loop -> SCEV) to all AddRecExprs. - struct SCEVApplyRewriter - : public SCEVVisitor { + class SCEVApplyRewriter : public SCEVRewriteVisitor { public: static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, ScalarEvolution &SE) { @@ -663,45 +674,8 @@ return Rewriter.visit(Scev); } - SCEVApplyRewriter(ScalarEvolution &S, LoopToScevMapT &M) - : SE(S), Map(M) {} - - const SCEV *visitConstant(const SCEVConstant *Constant) { - return Constant; - } - - const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); - return SE.getTruncateExpr(Operand, Expr->getType()); - } - - const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); - return SE.getZeroExtendExpr(Operand, Expr->getType()); - } - - const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); - return SE.getSignExtendExpr(Operand, Expr->getType()); - } - - const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { - SmallVector Operands; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); - return SE.getAddExpr(Operands); - } - - const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { - SmallVector Operands; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); - return SE.getMulExpr(Operands); - } - - const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { - return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS())); - } + SCEVApplyRewriter(ScalarEvolution &SE, LoopToScevMapT &M) + : SCEVRewriteVisitor(SE), Map(M) {} const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { SmallVector Operands; @@ -718,30 +692,7 @@ return Rec->evaluateAtIteration(Map[L], SE); } - const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { - SmallVector Operands; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); - return SE.getSMaxExpr(Operands); - } - - const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { - SmallVector Operands; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); - return SE.getUMaxExpr(Operands); - } - - const SCEV *visitUnknown(const SCEVUnknown *Expr) { - return Expr; - } - - const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { - return Expr; - } - private: - ScalarEvolution &SE; LoopToScevMapT ⤅ };