diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpander.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpander.h --- a/llvm/include/llvm/Analysis/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpander.h @@ -19,11 +19,13 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ScalarEvolutionNormalization.h" #include "llvm/Analysis/TargetFolder.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/Support/CommandLine.h" namespace llvm { - class TargetTransformInfo; + extern cl::opt SCEVCheapExpansionBudget; /// Return true if the given expression is safe to expand in the sense that /// all materialized values are safe to speculate anywhere their operands are @@ -177,11 +179,13 @@ /// At is an optional parameter which specifies point in code where user is /// going to expand this expression. Sometimes this knowledge can lead to a /// more accurate cost estimation. - bool isHighCostExpansion(const SCEV *Expr, Loop *L, + bool isHighCostExpansion(const SCEV *Expr, Loop *L, unsigned Budget, const TargetTransformInfo *TTI, const Instruction *At = nullptr) { SmallPtrSet Processed; - return isHighCostExpansionHelper(Expr, L, At, TTI, Processed); + int BudgetRemaining = Budget * TargetTransformInfo::TCC_Basic; + return isHighCostExpansionHelper(Expr, L, At, BudgetRemaining, TTI, + Processed); } /// This method returns the canonical induction variable of the specified @@ -324,7 +328,7 @@ /// Recursive helper function for isHighCostExpansion. bool isHighCostExpansionHelper(const SCEV *S, Loop *L, - const Instruction *At, + const Instruction *At, int &BudgetRemaining, const TargetTransformInfo *TTI, SmallPtrSetImpl &Processed); diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp --- a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -24,10 +24,17 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; + +cl::opt llvm::SCEVCheapExpansionBudget( + "scev-cheap-expansion-budget", cl::Hidden, cl::init(4), + cl::desc("When performing SCEV expansion only if it is cheap to do, this " + "controls the budget that is considered cheap (default = 4)")); + using namespace PatternMatch; /// ReuseOrCreateCast - Arrange for there to be a cast of V to Ty at IP, @@ -2129,7 +2136,7 @@ } bool SCEVExpander::isHighCostExpansionHelper( - const SCEV *S, Loop *L, const Instruction *At, + const SCEV *S, Loop *L, const Instruction *At, int &BudgetRemaining, const TargetTransformInfo *TTI, SmallPtrSetImpl &Processed) { // If we can find an existing value for this scev available at the point "At" // then consider the expression cheap. @@ -2143,13 +2150,13 @@ return false; case scTruncate: return isHighCostExpansionHelper(cast(S)->getOperand(), L, - At, TTI, Processed); + At, BudgetRemaining, TTI, Processed); case scZeroExtend: return isHighCostExpansionHelper(cast(S)->getOperand(), - L, At, TTI, Processed); + L, At, BudgetRemaining, TTI, Processed); case scSignExtend: return isHighCostExpansionHelper(cast(S)->getOperand(), - L, At, TTI, Processed); + L, At, BudgetRemaining, TTI, Processed); } if (!Processed.insert(S).second) @@ -2162,8 +2169,8 @@ // lowered into a right shift. if (auto *SC = dyn_cast(UDivExpr->getRHS())) if (SC->getAPInt().isPowerOf2()) { - if (isHighCostExpansionHelper(UDivExpr->getLHS(), L, At, TTI, - Processed)) + if (isHighCostExpansionHelper(UDivExpr->getLHS(), L, At, + BudgetRemaining, TTI, Processed)) return true; const DataLayout &DL = L->getHeader()->getParent()->getParent()->getDataLayout(); @@ -2200,7 +2207,7 @@ // they are not too expensive rematerialize. if (const SCEVNAryExpr *NAry = dyn_cast(S)) { for (auto *Op : NAry->operands()) - if (isHighCostExpansionHelper(Op, L, At, TTI, Processed)) + if (isHighCostExpansionHelper(Op, L, At, BudgetRemaining, TTI, Processed)) return true; } diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -2759,7 +2759,8 @@ // Avoid high cost expansions. Note: This heuristic is questionable in // that our definition of "high cost" is not exactly principled. - if (Rewriter.isHighCostExpansion(ExitCount, L, TTI)) + if (Rewriter.isHighCostExpansion(ExitCount, L, SCEVCheapExpansionBudget, + TTI)) continue; // Check preconditions for proper SCEVExpander operation. SCEV does not diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -635,7 +635,8 @@ const DataLayout &DL = Header->getModule()->getDataLayout(); SCEVExpander Expander(*SE, DL, "loop-unroll"); if (!AllowExpensiveTripCount && - Expander.isHighCostExpansion(TripCountSC, L, TTI, PreHeaderBR)) { + Expander.isHighCostExpansion(TripCountSC, L, SCEVCheapExpansionBudget, + TTI, PreHeaderBR)) { LLVM_DEBUG(dbgs() << "High cost for expanding trip count scev!\n"); return false; } diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1361,7 +1361,8 @@ hasHardUserWithinLoop(L, Inst)) continue; - bool HighCost = Rewriter.isHighCostExpansion(ExitValue, L, TTI, Inst); + bool HighCost = Rewriter.isHighCostExpansion( + ExitValue, L, SCEVCheapExpansionBudget, TTI, Inst); Value *ExitVal = Rewriter.expandCodeFor(ExitValue, PN->getType(), Inst); LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp --- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -669,7 +669,7 @@ return false; // Do not generate something ridiculous even if S is loop invariant. - if (Rewriter.isHighCostExpansion(S, L, TTI, I)) + if (Rewriter.isHighCostExpansion(S, L, SCEVCheapExpansionBudget, TTI, I)) return false; auto *IP = GetLoopInvariantInsertPosition(L, I);