diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h --- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -39,6 +39,19 @@ bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint, ScalarEvolution &SE); +/// struct for holding enough information to help calculate the cost of the +/// given SCEV when expanded into IR. +struct SCEVOperand { + explicit SCEVOperand(unsigned Opc, int Idx, const SCEV *S) : + ParentOpcode(Opc), OperandIdx(Idx), S(S) { } + /// LLVM instruction opcode that uses the operand. + unsigned ParentOpcode; + /// The use index of an expanded instruction. + int OperandIdx; + /// The SCEV operand to be costed. + const SCEV* S; +}; + /// This class uses information about analyze scalars to rewrite expressions /// in canonical form. /// @@ -220,14 +233,14 @@ assert(At && "This function requires At instruction to be provided."); if (!TTI) // In assert-less builds, avoid crashing return true; // by always claiming to be high-cost. - SmallVector Worklist; + SmallVector Worklist; SmallPtrSet Processed; int BudgetRemaining = Budget * TargetTransformInfo::TCC_Basic; - Worklist.emplace_back(Expr); + Worklist.emplace_back(-1, -1, Expr); while (!Worklist.empty()) { - const SCEV *S = Worklist.pop_back_val(); - if (isHighCostExpansionHelper(S, L, *At, BudgetRemaining, *TTI, Processed, - Worklist)) + const SCEVOperand WorkItem = Worklist.pop_back_val(); + if (isHighCostExpansionHelper(WorkItem, L, *At, BudgetRemaining, + *TTI, Processed, Worklist)) return true; } assert(BudgetRemaining >= 0 && "Should have returned from inner loop."); @@ -394,11 +407,11 @@ Value *expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *I, bool Root); /// Recursive helper function for isHighCostExpansion. - bool isHighCostExpansionHelper(const SCEV *S, Loop *L, const Instruction &At, - int &BudgetRemaining, - const TargetTransformInfo &TTI, - SmallPtrSetImpl &Processed, - SmallVectorImpl &Worklist); + bool isHighCostExpansionHelper( + const SCEVOperand &WorkItem, Loop *L, const Instruction &At, + int &BudgetRemaining, const TargetTransformInfo &TTI, + SmallPtrSetImpl &Processed, + SmallVectorImpl &Worklist); /// Insert the specified binary operator, doing a small amount of work to /// avoid inserting an obviously redundant operation, and hoisting to an diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -2177,13 +2177,133 @@ return None; } +template static int costAndCollectOperands( + const SCEVOperand &WorkItem, const TargetTransformInfo &TTI, + TargetTransformInfo::TargetCostKind CostKind, + SmallVectorImpl &Worklist) { + + const T *S = cast(WorkItem.S); + int Cost = 0; + // Collect the opcodes of all the instructions that will be needed to expand + // the SCEVExpr. This is so that when we come to cost the operands, we know + // what the generated user(s) will be. + SmallVector Opcodes; + + auto CastCost = [&](unsigned Opcode) { + Opcodes.push_back(Opcode); + return TTI.getCastInstrCost(Opcode, S->getType(), + S->getOperand(0)->getType(), + TTI::CastContextHint::None, CostKind); + }; + + auto ArithCost = [&](unsigned Opcode, unsigned NumRequired) { + Opcodes.push_back(Opcode); + return NumRequired * + TTI.getArithmeticInstrCost(Opcode, S->getType(), CostKind); + }; + + auto CmpSelCost = [&](unsigned Opcode, unsigned NumRequired) { + Opcodes.push_back(Opcode); + Type *OpType = S->getOperand(0)->getType(); + return NumRequired * + TTI.getCmpSelInstrCost(Opcode, OpType, + CmpInst::makeCmpResultType(OpType), CostKind); + }; + + switch (S->getSCEVType()) { + default: + llvm_unreachable("No other scev expressions possible."); + case scUnknown: + case scConstant: + return 0; + case scTruncate: + Cost = CastCost(Instruction::Trunc); + break; + case scZeroExtend: + Cost = CastCost(Instruction::ZExt); + break; + case scSignExtend: + Cost = CastCost(Instruction::SExt); + break; + case scUDivExpr: { + unsigned Opcode = Instruction::UDiv; + if (auto *SC = dyn_cast(S->getOperand(1))) + if (SC->getAPInt().isPowerOf2()) + Opcode = Instruction::LShr; + Cost = ArithCost(Opcode, 1); + break; + } + case scAddExpr: + Cost = ArithCost(Instruction::Add, S->getNumOperands() - 1); + break; + case scMulExpr: + // TODO: this is a very pessimistic cost modelling for Mul, + // because of Bin Pow algorithm actually used by the expander, + // see SCEVExpander::visitMulExpr(), ExpandOpBinPowN(). + Cost = ArithCost(Instruction::Mul, S->getNumOperands() - 1); + break; + case scSMaxExpr: + case scUMaxExpr: + case scSMinExpr: + case scUMinExpr: { + Cost += CmpSelCost(Instruction::ICmp, S->getNumOperands() - 1); + Cost += CmpSelCost(Instruction::Select, S->getNumOperands() - 1); + break; + } + case scAddRecExpr: { + // In this polynominal, we may have some zero operands, and we shouldn't + // really charge for those. So how many non-zero coeffients are there? + int NumTerms = llvm::count_if(S->operands(), [](const SCEV *Op) { + return !Op->isZero(); + }); + + assert(NumTerms >= 1 && "Polynominal should have at least one term."); + assert(!(*std::prev(S->operands().end()))->isZero() && + "Last operand should not be zero"); + + // Ignoring constant term (operand 0), how many of the coeffients are u> 1? + int NumNonZeroDegreeNonOneTerms = + llvm::count_if(S->operands(), [](const SCEV *Op) { + auto *SConst = dyn_cast(Op); + return !SConst || SConst->getAPInt().ugt(1); + }); + + // Much like with normal add expr, the polynominal will require + // one less addition than the number of it's terms. + int AddCost = ArithCost(Instruction::Add, NumTerms - 1); + // Here, *each* one of those will require a multiplication. + int MulCost = ArithCost(Instruction::Mul, NumNonZeroDegreeNonOneTerms); + Cost = AddCost + MulCost; + + // What is the degree of this polynominal? + int PolyDegree = S->getNumOperands() - 1; + assert(PolyDegree >= 1 && "Should be at least affine."); + + // The final term will be: + // Op_{PolyDegree} * x ^ {PolyDegree} + // Where x ^ {PolyDegree} will again require PolyDegree-1 mul operations. + // Note that x ^ {PolyDegree} = x * x ^ {PolyDegree-1} so charging for + // x ^ {PolyDegree} will give us x ^ {2} .. x ^ {PolyDegree-1} for free. + // FIXME: this is conservatively correct, but might be overly pessimistic. + Cost += MulCost * (PolyDegree - 1); + } + } + + for (unsigned Opc : Opcodes) + for (auto I : enumerate(S->operands())) + Worklist.emplace_back(Opc, I.index(), I.value()); + return Cost; +} + bool SCEVExpander::isHighCostExpansionHelper( - const SCEV *S, Loop *L, const Instruction &At, int &BudgetRemaining, - const TargetTransformInfo &TTI, SmallPtrSetImpl &Processed, - SmallVectorImpl &Worklist) { + const SCEVOperand &WorkItem, Loop *L, const Instruction &At, + int &BudgetRemaining, const TargetTransformInfo &TTI, + SmallPtrSetImpl &Processed, + SmallVectorImpl &Worklist) { if (BudgetRemaining < 0) return true; // Already run out of budget, give up. + const SCEV *S = WorkItem.S; // Was the cost of expansion of this expression already accounted for? if (!Processed.insert(S).second) return false; // We have already accounted for this expression. @@ -2202,44 +2322,12 @@ TargetTransformInfo::TargetCostKind CostKind = TargetTransformInfo::TCK_RecipThroughput; - if (auto *CastExpr = dyn_cast(S)) { - unsigned Opcode; - switch (S->getSCEVType()) { - case scTruncate: - Opcode = Instruction::Trunc; - break; - case scZeroExtend: - Opcode = Instruction::ZExt; - break; - case scSignExtend: - Opcode = Instruction::SExt; - break; - default: - llvm_unreachable("There are no other cast types."); - } - const SCEV *Op = CastExpr->getOperand(); - BudgetRemaining -= TTI.getCastInstrCost( - Opcode, /*Dst=*/S->getType(), - /*Src=*/Op->getType(), TTI::CastContextHint::None, CostKind); - Worklist.emplace_back(Op); + if (isa(S)) { + int Cost = + costAndCollectOperands(WorkItem, TTI, CostKind, Worklist); + BudgetRemaining -= Cost; return false; // Will answer upon next entry into this function. - } - - if (auto *UDivExpr = dyn_cast(S)) { - // If the divisor is a power of two count this as a logical right-shift. - if (auto *SC = dyn_cast(UDivExpr->getRHS())) { - if (SC->getAPInt().isPowerOf2()) { - BudgetRemaining -= - TTI.getArithmeticInstrCost(Instruction::LShr, S->getType(), - CostKind); - // Note that we don't count the cost of RHS, because it is a constant, - // and we consider those to be free. But if that changes, we would need - // to log2() it first before calling isHighCostExpansionHelper(). - Worklist.emplace_back(UDivExpr->getLHS()); - return false; // Will answer upon next entry into this function. - } - } - + } else if (isa(S)) { // UDivExpr is very likely a UDiv that ScalarEvolution's HowFarToZero or // HowManyLessThans produced to compute a precise expression, rather than a // UDiv from the user's code. If we can't find a UDiv in the code with some @@ -2252,117 +2340,28 @@ SE.getAddExpr(S, SE.getConstant(S->getType(), 1)), &At, L)) return false; // Consider it to be free. + int Cost = + costAndCollectOperands(WorkItem, TTI, CostKind, Worklist); // Need to count the cost of this UDiv. - BudgetRemaining -= - TTI.getArithmeticInstrCost(Instruction::UDiv, S->getType(), - CostKind); - Worklist.insert(Worklist.end(), {UDivExpr->getLHS(), UDivExpr->getRHS()}); + BudgetRemaining -= Cost; return false; // Will answer upon next entry into this function. - } - - if (const auto *NAry = dyn_cast(S)) { - Type *OpType = NAry->getType(); - - assert(NAry->getNumOperands() >= 2 && - "Polynomial should be at least linear"); - - int AddCost = - TTI.getArithmeticInstrCost(Instruction::Add, OpType, CostKind); - int MulCost = - TTI.getArithmeticInstrCost(Instruction::Mul, OpType, CostKind); - - // In this polynominal, we may have some zero operands, and we shouldn't - // really charge for those. So how many non-zero coeffients are there? - int NumTerms = llvm::count_if(NAry->operands(), - [](const SCEV *S) { return !S->isZero(); }); - assert(NumTerms >= 1 && "Polynominal should have at least one term."); - assert(!(*std::prev(NAry->operands().end()))->isZero() && - "Last operand should not be zero"); - - // Much like with normal add expr, the polynominal will require - // one less addition than the number of it's terms. - BudgetRemaining -= AddCost * (NumTerms - 1); - if (BudgetRemaining < 0) - return true; - - // Ignoring constant term (operand 0), how many of the coeffients are u> 1? - int NumNonZeroDegreeNonOneTerms = - llvm::count_if(make_range(std::next(NAry->op_begin()), NAry->op_end()), - [](const SCEV *S) { - auto *SConst = dyn_cast(S); - return !SConst || SConst->getAPInt().ugt(1); - }); - // Here, *each* one of those will require a multiplication. - BudgetRemaining -= MulCost * NumNonZeroDegreeNonOneTerms; - if (BudgetRemaining < 0) - return true; - - // What is the degree of this polynominal? - int PolyDegree = NAry->getNumOperands() - 1; - assert(PolyDegree >= 1 && "Should be at least affine."); - - // The final term will be: - // Op_{PolyDegree} * x ^ {PolyDegree} - // Where x ^ {PolyDegree} will again require PolyDegree-1 mul operations. - // Note that x ^ {PolyDegree} = x * x ^ {PolyDegree-1} so charging for - // x ^ {PolyDegree} will give us x ^ {2} .. x ^ {PolyDegree-1} for free. - // FIXME: this is conservatively correct, but might be overly pessimistic. - BudgetRemaining -= MulCost * (PolyDegree - 1); - if (BudgetRemaining < 0) - return true; - - // And finally, the operands themselves should fit within the budget. - Worklist.insert(Worklist.end(), NAry->operands().begin(), - NAry->operands().end()); - return false; // So far so good, though ops may be too costly? - } - - if (const SCEVNAryExpr *NAry = dyn_cast(S)) { - Type *OpType = NAry->getType(); - - int PairCost; - switch (S->getSCEVType()) { - case scAddExpr: - PairCost = - TTI.getArithmeticInstrCost(Instruction::Add, OpType, CostKind); - break; - case scMulExpr: - // TODO: this is a very pessimistic cost modelling for Mul, - // because of Bin Pow algorithm actually used by the expander, - // see SCEVExpander::visitMulExpr(), ExpandOpBinPowN(). - PairCost = - TTI.getArithmeticInstrCost(Instruction::Mul, OpType, CostKind); - break; - case scSMaxExpr: - case scUMaxExpr: - case scSMinExpr: - case scUMinExpr: - PairCost = TTI.getCmpSelInstrCost(Instruction::ICmp, OpType, - CmpInst::makeCmpResultType(OpType), - CostKind) + - TTI.getCmpSelInstrCost(Instruction::Select, OpType, - CmpInst::makeCmpResultType(OpType), - CostKind); - break; - default: - llvm_unreachable("There are no other variants here."); - } - + } else if (const SCEVNAryExpr *NAry = dyn_cast(S)) { assert(NAry->getNumOperands() > 1 && "Nary expr should have more than 1 operand."); // The simple nary expr will require one less op (or pair of ops) // than the number of it's terms. - BudgetRemaining -= PairCost * (NAry->getNumOperands() - 1); - if (BudgetRemaining < 0) - return true; - - // And finally, the operands themselves should fit within the budget. - Worklist.insert(Worklist.end(), NAry->operands().begin(), - NAry->operands().end()); - return false; // So far so good, though ops may be too costly? - } - - llvm_unreachable("No other scev expressions possible."); + int Cost = + costAndCollectOperands(WorkItem, TTI, CostKind, Worklist); + BudgetRemaining -= Cost; + return BudgetRemaining < 0; + } else if (const auto *NAry = dyn_cast(S)) { + assert(NAry->getNumOperands() >= 2 && + "Polynomial should be at least linear"); + BudgetRemaining -= costAndCollectOperands( + WorkItem, TTI, CostKind, Worklist); + return BudgetRemaining < 0; + } else + llvm_unreachable("No other scev expressions possible."); } Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,