Index: llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -82,6 +82,11 @@ public: const SCEV *getOperand() const { return Op; } + const SCEV *getOperand(unsigned i) const { + assert(i == 0 && "Operand index out of range!"); + return Op; + } + size_t getNumOperands() const { return 1; } Type *getType() const { return Ty; } /// Methods for support type inquiry through isa, cast, and dyn_cast: @@ -273,6 +278,11 @@ public: const SCEV *getLHS() const { return LHS; } const SCEV *getRHS() const { return RHS; } + size_t getNumOperands() const { return 2; } + const SCEV *getOperand(unsigned i) const { + assert((i == 0 || i == 1) && "Operand index out of range!"); + return i == 0 ? LHS : RHS; + } Type *getType() const { // In most cases the types of LHS and RHS will be the same, but in some Index: llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h =================================================================== --- llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -39,6 +39,16 @@ bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint, ScalarEvolution &SE); +/// struct for holding enough information to help calculate the cost of the +/// given SCEV when expanded into IR. +struct SCEVOperand { + explicit SCEVOperand(unsigned Opc, int Idx, const SCEV *S) : + ParentOpcode(Opc), OperandIdx(Idx), S(S) { } + unsigned ParentOpcode; + int OperandIdx; + const SCEV* S = nullptr; +}; + /// This class uses information about analyze scalars to rewrite expressions /// in canonical form. /// @@ -193,14 +203,14 @@ assert(At && "This function requires At instruction to be provided."); if (!TTI) // In assert-less builds, avoid crashing return true; // by always claiming to be high-cost. - SmallVector Worklist; + SmallVector, 8> Worklist; SmallPtrSet Processed; int BudgetRemaining = Budget * TargetTransformInfo::TCC_Basic; - Worklist.emplace_back(Expr); + Worklist.push_back(std::make_unique(0, 0, Expr)); while (!Worklist.empty()) { - const SCEV *S = Worklist.pop_back_val(); - if (isHighCostExpansionHelper(S, L, *At, BudgetRemaining, *TTI, Processed, - Worklist)) + auto WorkItem = Worklist.pop_back_val(); + if (isHighCostExpansionHelper(&*WorkItem, L, *At, BudgetRemaining, + *TTI, Processed, Worklist)) return true; } assert(BudgetRemaining >= 0 && "Should have returned from inner loop."); @@ -366,11 +376,11 @@ Value *expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *I, bool Root); /// Recursive helper function for isHighCostExpansion. - bool isHighCostExpansionHelper(const SCEV *S, Loop *L, const Instruction &At, - int &BudgetRemaining, - const TargetTransformInfo &TTI, - SmallPtrSetImpl &Processed, - SmallVectorImpl &Worklist); + bool isHighCostExpansionHelper( + SCEVOperand *WorkItem, Loop *L, const Instruction &At, int &BudgetRemaining, + const TargetTransformInfo &TTI, + SmallPtrSetImpl &Processed, + SmallVectorImpl> &Worklist); /// Insert the specified binary operator, doing a small amount of work to /// avoid inserting an obviously redundant operation, and hoisting to an Index: llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp =================================================================== --- llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -2172,13 +2172,88 @@ return None; } +template static int costAndCollectOperands( + SCEVOperand *WorkItem, const TargetTransformInfo &TTI, + TargetTransformInfo::TargetCostKind CostKind, + SmallVectorImpl> &Worklist) { + + const T *S = cast(WorkItem->S); + int Cost = 0; + SmallVector Opcodes; + auto CastCost = [&](unsigned Opcode) { + Opcodes.push_back(Opcode); + return TTI.getCastInstrCost(Opcode, S->getType(), + S->getOperand(0)->getType(), + TTI::CastContextHint::None, CostKind); + }; + auto ArithCost = [&](unsigned Opcode) { + Opcodes.push_back(Opcode); + return TTI.getArithmeticInstrCost(Opcode, S->getType(), CostKind); + }; + + switch (S->getSCEVType()) { + default: + return 0; + case scTruncate: + Cost = CastCost(Instruction::Trunc); + break; + case scZeroExtend: + Cost = CastCost(Instruction::ZExt); + break; + case scSignExtend: + Cost = CastCost(Instruction::SExt); + break; + case scUDivExpr: { + unsigned Opcode = Instruction::UDiv; + if (auto *SC = dyn_cast(S->getOperand(1))) + if (SC->getAPInt().isPowerOf2()) + Opcode = Instruction::LShr; + Cost = ArithCost(Opcode); + break; + } + case scAddExpr: + Cost = ArithCost(Instruction::Add); + break; + case scMulExpr: + // TODO: this is a very pessimistic cost modelling for Mul, + // because of Bin Pow algorithm actually used by the expander, + // see SCEVExpander::visitMulExpr(), ExpandOpBinPowN(). + Cost = ArithCost(Instruction::Mul); + break; + case scSMaxExpr: + case scUMaxExpr: + case scSMinExpr: + case scUMinExpr: { + Type *OpType = S->getOperand(0)->getType(); + Opcodes.push_back(Instruction::ICmp); + Opcodes.push_back(Instruction::Select); + Cost = TTI.getCmpSelInstrCost(Instruction::ICmp, OpType, + CmpInst::makeCmpResultType(OpType), + CostKind) + + TTI.getCmpSelInstrCost(Instruction::Select, OpType, + CmpInst::makeCmpResultType(OpType), + CostKind); + break; + } + } + + for (unsigned Opc : Opcodes) { + for (unsigned OpIdx = 0; OpIdx < S->getNumOperands(); ++OpIdx) { + Worklist.push_back(std::make_unique( + Opc, OpIdx, S->getOperand(OpIdx))); + } + } + return Cost; +} + bool SCEVExpander::isHighCostExpansionHelper( - const SCEV *S, Loop *L, const Instruction &At, int &BudgetRemaining, + SCEVOperand *WorkItem, Loop *L, const Instruction &At, int &BudgetRemaining, const TargetTransformInfo &TTI, SmallPtrSetImpl &Processed, - SmallVectorImpl &Worklist) { + SmallVectorImpl> &Worklist) { if (BudgetRemaining < 0) return true; // Already run out of budget, give up. + const SCEV *S = WorkItem->S; // Was the cost of expansion of this expression already accounted for? if (!Processed.insert(S).second) return false; // We have already accounted for this expression. @@ -2197,43 +2272,14 @@ TargetTransformInfo::TargetCostKind CostKind = TargetTransformInfo::TCK_RecipThroughput; - if (auto *CastExpr = dyn_cast(S)) { - unsigned Opcode; - switch (S->getSCEVType()) { - case scTruncate: - Opcode = Instruction::Trunc; - break; - case scZeroExtend: - Opcode = Instruction::ZExt; - break; - case scSignExtend: - Opcode = Instruction::SExt; - break; - default: - llvm_unreachable("There are no other cast types."); - } - const SCEV *Op = CastExpr->getOperand(); - BudgetRemaining -= TTI.getCastInstrCost( - Opcode, /*Dst=*/S->getType(), - /*Src=*/Op->getType(), TTI::CastContextHint::None, CostKind); - Worklist.emplace_back(Op); + if (isa(S)) { + int Cost = + costAndCollectOperands(WorkItem, TTI, CostKind, Worklist); + BudgetRemaining -= Cost; return false; // Will answer upon next entry into this function. - } - - if (auto *UDivExpr = dyn_cast(S)) { - // If the divisor is a power of two count this as a logical right-shift. - if (auto *SC = dyn_cast(UDivExpr->getRHS())) { - if (SC->getAPInt().isPowerOf2()) { - BudgetRemaining -= - TTI.getArithmeticInstrCost(Instruction::LShr, S->getType(), - CostKind); - // Note that we don't count the cost of RHS, because it is a constant, - // and we consider those to be free. But if that changes, we would need - // to log2() it first before calling isHighCostExpansionHelper(). - Worklist.emplace_back(UDivExpr->getLHS()); - return false; // Will answer upon next entry into this function. - } - } + } else if (isa(S)) { + int Cost = + costAndCollectOperands(WorkItem, TTI, CostKind, Worklist); // UDivExpr is very likely a UDiv that ScalarEvolution's HowFarToZero or // HowManyLessThans produced to compute a precise expression, rather than a @@ -2248,14 +2294,18 @@ return false; // Consider it to be free. // Need to count the cost of this UDiv. - BudgetRemaining -= - TTI.getArithmeticInstrCost(Instruction::UDiv, S->getType(), - CostKind); - Worklist.insert(Worklist.end(), {UDivExpr->getLHS(), UDivExpr->getRHS()}); + BudgetRemaining -= Cost; return false; // Will answer upon next entry into this function. - } - - if (const auto *NAry = dyn_cast(S)) { + } else if (const SCEVNAryExpr *NAry = dyn_cast(S)) { + assert(NAry->getNumOperands() > 1 && + "Nary expr should have more than 1 operand."); + // The simple nary expr will require one less op (or pair of ops) + // than the number of it's terms. + int PairCost = + costAndCollectOperands(WorkItem, TTI, CostKind, Worklist); + BudgetRemaining -= PairCost * (NAry->getNumOperands() - 1); + return BudgetRemaining < 0; + } else if (const auto *NAry = dyn_cast(S)) { Type *OpType = NAry->getType(); assert(NAry->getNumOperands() >= 2 && @@ -2266,6 +2316,9 @@ int MulCost = TTI.getArithmeticInstrCost(Instruction::Mul, OpType, CostKind); + (void)costAndCollectOperands( + WorkItem, TTI, CostKind, Worklist); + // In this polynominal, we may have some zero operands, and we shouldn't // really charge for those. So how many non-zero coeffients are there? int NumTerms = llvm::count_if(NAry->operands(), @@ -2303,61 +2356,9 @@ // x ^ {PolyDegree} will give us x ^ {2} .. x ^ {PolyDegree-1} for free. // FIXME: this is conservatively correct, but might be overly pessimistic. BudgetRemaining -= MulCost * (PolyDegree - 1); - if (BudgetRemaining < 0) - return true; - - // And finally, the operands themselves should fit within the budget. - Worklist.insert(Worklist.end(), NAry->operands().begin(), - NAry->operands().end()); - return false; // So far so good, though ops may be too costly? - } - - if (const SCEVNAryExpr *NAry = dyn_cast(S)) { - Type *OpType = NAry->getType(); - - int PairCost; - switch (S->getSCEVType()) { - case scAddExpr: - PairCost = - TTI.getArithmeticInstrCost(Instruction::Add, OpType, CostKind); - break; - case scMulExpr: - // TODO: this is a very pessimistic cost modelling for Mul, - // because of Bin Pow algorithm actually used by the expander, - // see SCEVExpander::visitMulExpr(), ExpandOpBinPowN(). - PairCost = - TTI.getArithmeticInstrCost(Instruction::Mul, OpType, CostKind); - break; - case scSMaxExpr: - case scUMaxExpr: - case scSMinExpr: - case scUMinExpr: - PairCost = TTI.getCmpSelInstrCost(Instruction::ICmp, OpType, - CmpInst::makeCmpResultType(OpType), - CostKind) + - TTI.getCmpSelInstrCost(Instruction::Select, OpType, - CmpInst::makeCmpResultType(OpType), - CostKind); - break; - default: - llvm_unreachable("There are no other variants here."); - } - - assert(NAry->getNumOperands() > 1 && - "Nary expr should have more than 1 operand."); - // The simple nary expr will require one less op (or pair of ops) - // than the number of it's terms. - BudgetRemaining -= PairCost * (NAry->getNumOperands() - 1); - if (BudgetRemaining < 0) - return true; - - // And finally, the operands themselves should fit within the budget. - Worklist.insert(Worklist.end(), NAry->operands().begin(), - NAry->operands().end()); - return false; // So far so good, though ops may be too costly? - } - - llvm_unreachable("No other scev expressions possible."); + return BudgetRemaining < 0; + } else + llvm_unreachable("No other scev expressions possible."); } Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,