Index: llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h =================================================================== --- llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h +++ llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h @@ -117,9 +117,13 @@ /// \brief Return true for expressions that may incur non-trivial cost to /// evaluate at runtime. - bool isHighCostExpansion(const SCEV *Expr, Loop *L) { + /// 'At' is an optional parameter which specifies point in code where user + /// is going to expand this expression. Sometimes this knowledge can lead to + /// a more accurate cost estimation. + bool isHighCostExpansion(const SCEV *Expr, Loop *L, + const Instruction *At = nullptr) { SmallPtrSet Processed; - return isHighCostExpansionHelper(Expr, L, Processed); + return isHighCostExpansionHelper(Expr, L, At, Processed); } /// \brief This method returns the canonical induction variable of the @@ -193,11 +197,20 @@ void setChainedPhi(PHINode *PN) { ChainedPhis.insert(PN); } + /// \brief Try to find LLVM IR value for 'S' available at the point 'At'. + // 'L' is a hint which tells in which loop to look for the suitable value. + // On success return value which is equivalent to the expanded 'S' at point + // 'At'. Return nullptr if value was not found. + // Note that this function does not perform exhaustive search. I.e if it + // didn't find any value it does not mean that there is no such value. + Value *findExistingExpansion(const SCEV *S, const Instruction *At, Loop *L); + private: LLVMContext &getContext() const { return SE.getContext(); } /// \brief Recursive helper function for isHighCostExpansion. bool isHighCostExpansionHelper(const SCEV *S, Loop *L, + const Instruction *At, SmallPtrSetImpl &Processed); /// \brief Insert the specified binary operator, doing a small amount Index: llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp =================================================================== --- llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp +++ llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp @@ -1812,8 +1812,46 @@ return NumElim; } +Value *SCEVExpander::findExistingExpansion(const SCEV *S, + const Instruction *At, Loop *L) { + using namespace llvm::PatternMatch; + + SmallVector Latches; + L->getLoopLatches(Latches); + + // Look for suitable value in simple conditions at the loop latches. + for (BasicBlock *BB : Latches) { + ICmpInst::Predicate Pred; + Instruction *LHS, *RHS; + BasicBlock *TrueBB, *FalseBB; + + if (!match(BB->getTerminator(), + m_Br(m_ICmp(Pred, m_Instruction(LHS), m_Instruction(RHS)), + TrueBB, FalseBB))) + continue; + + if (SE.getSCEV(LHS) == S && SE.DT->dominates(LHS, At)) + return LHS; + + if (SE.getSCEV(RHS) == S && SE.DT->dominates(RHS, At)) + return RHS; + } + + // There is potential to make this significantly smarter, but this simple + // heuristic already gets some interesting cases. + + // Can not find suitable value. + return nullptr; +} + bool SCEVExpander::isHighCostExpansionHelper( - const SCEV *S, Loop *L, SmallPtrSetImpl &Processed) { + const SCEV *S, Loop *L, const Instruction *At, + SmallPtrSetImpl &Processed) { + + // If we can find an existing value for this scev avaliable at the point "At" + // then consider the expression cheap. + if (At && findExistingExpansion(S, At, L) != nullptr) + return false; // Zero/One operand expressions switch (S->getSCEVType()) { @@ -1821,14 +1859,14 @@ case scConstant: return false; case scTruncate: - return isHighCostExpansionHelper(cast(S)->getOperand(), L, - Processed); + return isHighCostExpansionHelper(cast(S)->getOperand(), + L, At, Processed); case scZeroExtend: return isHighCostExpansionHelper(cast(S)->getOperand(), - L, Processed); + L, At, Processed); case scSignExtend: return isHighCostExpansionHelper(cast(S)->getOperand(), - L, Processed); + L, At, Processed); } if (!Processed.insert(S).second) @@ -1884,7 +1922,7 @@ if (const SCEVNAryExpr *NAry = dyn_cast(S)) { for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); I != E; ++I) { - if (isHighCostExpansionHelper(*I, L, Processed)) + if (isHighCostExpansionHelper(*I, L, At, Processed)) return true; } } Index: llvm/trunk/lib/Transforms/Scalar/IndVarSimplify.cpp =================================================================== --- llvm/trunk/lib/Transforms/Scalar/IndVarSimplify.cpp +++ llvm/trunk/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -138,8 +138,7 @@ void SinkUnusedInvariants(Loop *L); Value *ExpandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, Loop *L, - Instruction *InsertPt, Type *Ty, - bool &IsHighCostExpansion); + Instruction *InsertPt, Type *Ty); }; } @@ -503,47 +502,13 @@ Value *IndVarSimplify::ExpandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, Loop *L, Instruction *InsertPt, - Type *ResultTy, - bool &IsHighCostExpansion) { - using namespace llvm::PatternMatch; - - if (!Rewriter.isHighCostExpansion(S, L)) { - IsHighCostExpansion = false; - return Rewriter.expandCodeFor(S, ResultTy, InsertPt); - } - + Type *ResultTy) { // Before expanding S into an expensive LLVM expression, see if we can use an - // already existing value as the expansion for S. There is potential to make - // this significantly smarter, but this simple heuristic already gets some - // interesting cases. - - SmallVector Latches; - L->getLoopLatches(Latches); - - for (BasicBlock *BB : Latches) { - ICmpInst::Predicate Pred; - Instruction *LHS, *RHS; - BasicBlock *TrueBB, *FalseBB; - - if (!match(BB->getTerminator(), - m_Br(m_ICmp(Pred, m_Instruction(LHS), m_Instruction(RHS)), - TrueBB, FalseBB))) - continue; - - if (SE->getSCEV(LHS) == S && DT->dominates(LHS, InsertPt)) { - IsHighCostExpansion = false; - return LHS; - } - - if (SE->getSCEV(RHS) == S && DT->dominates(RHS, InsertPt)) { - IsHighCostExpansion = false; - return RHS; - } - } + // already existing value as the expansion for S. + if (Value *RetValue = Rewriter.findExistingExpansion(S, InsertPt, L)) + return RetValue; // We didn't find anything, fall back to using SCEVExpander. - assert(Rewriter.isHighCostExpansion(S, L) && "this should not have changed!"); - IsHighCostExpansion = true; return Rewriter.expandCodeFor(S, ResultTy, InsertPt); } @@ -679,9 +644,9 @@ continue; } - bool HighCost = false; - Value *ExitVal = ExpandSCEVIfNeeded(Rewriter, ExitValue, L, Inst, - PN->getType(), HighCost); + bool HighCost = Rewriter.isHighCostExpansion(ExitValue, L, Inst); + Value *ExitVal = + ExpandSCEVIfNeeded(Rewriter, ExitValue, L, Inst, PN->getType()); DEBUG(dbgs() << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal << '\n' << " LoopVal = " << *Inst << "\n");