Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -927,16 +927,71 @@ unsigned getReductionCost(unsigned Opcode, Type *Ty, bool IsPairwise) { assert(Ty->isVectorTy() && "Expect a vector type"); + Type *ScalarTy = Ty->getVectorElementType(); unsigned NumVecElts = Ty->getVectorNumElements(); unsigned NumReduxLevels = Log2_32(NumVecElts); - unsigned ArithCost = - NumReduxLevels * - static_cast(this)->getArithmeticInstrCost(Opcode, Ty); - // Assume the pairwise shuffles add a cost. - unsigned ShuffleCost = - NumReduxLevels * (IsPairwise + 1) * - static_cast(this) - ->getShuffleCost(TTI::SK_ExtractSubvector, Ty, NumVecElts / 2, Ty); + // Try to calculate arithmetic and shuffle op costs for reduction operations. + // We're assuming that reduction operation are performing the following way: + // 1. Non-pairwise reduction + // %val1 = shufflevector %val, %undef, + // + // \----------------v-------------/ \----------v------------/ + // n/2 elements n/2 elements + // %red1 = op %val, val1 + // After this operation we have a vector %red1 with only maningfull the + // first n/2 elements, the second n/2 elements are undefined and can be + // dropped. All other operations are actually working with the vector of + // length n/2, not n. though the real vector length is still n. + // %val2 = shufflevector %red1, %undef, + // + // \----------------v-------------/ \----------v------------/ + // n/4 elements 3*n/4 elements + // %red2 = op %red1, val2 - working with the vector of + // length n/2, the resulting vector has length n/4 etc. + // 2. Pairwise reduction: + // Everything is the same except for an additional shuffle operation which + // is used to produce operands for pairwise kind of reductions. + // %val1 = shufflevector %val, %undef, + // + // \-------------v----------/ \----------v------------/ + // n/2 elements n/2 elements + // %val2 = shufflevector %val, %undef, + // + // \-------------v----------/ \----------v------------/ + // n/2 elements n/2 elements + // %red1 = op %val1, val2 + // Again, the operation is performed on vector, but the resulting + // vector %red1 is vector. + // + // The cost model should take into account that the actual length of the + // vector is reduced on each iteration. + unsigned ArithCost = 0; + unsigned ShuffleCost = 0; + auto *ConcreteTTI = static_cast(this); + std::pair LT = + ConcreteTTI->getTLI()->getTypeLegalizationCost(DL, Ty); + unsigned LongVectorCount = 0; + unsigned MVTLen = + LT.second.isVector() ? LT.second.getVectorNumElements() : 1; + while (NumVecElts > MVTLen) { + NumVecElts /= 2; + // Assume the pairwise shuffles add a cost. + ShuffleCost += (IsPairwise + 1) * + ConcreteTTI->getShuffleCost(TTI::SK_ExtractSubvector, Ty, + NumVecElts, Ty); + ArithCost += ConcreteTTI->getArithmeticInstrCost(Opcode, Ty); + Ty = VectorType::get(ScalarTy, NumVecElts); + ++LongVectorCount; + } + // The minimal length of the vector is limited by the real length of vector + // operations performed on the current platform. That's why several final + // reduction opertions are perfomed on the vectors with the same + // architecture-dependent length. + ShuffleCost += (NumReduxLevels - LongVectorCount) * (IsPairwise + 1) * + ConcreteTTI->getShuffleCost(TTI::SK_ExtractSubvector, Ty, + NumVecElts, Ty); + ArithCost += (NumReduxLevels - LongVectorCount) * + ConcreteTTI->getArithmeticInstrCost(Opcode, Ty); return ShuffleCost + ArithCost + getScalarizationOverhead(Ty, false, true); } Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -4283,7 +4283,7 @@ int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost; int ScalarReduxCost = - ReduxWidth * TTI->getArithmeticInstrCost(ReductionOpcode, VecTy); + ReduxWidth * TTI->getArithmeticInstrCost(ReductionOpcode, ScalarTy); DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost << " for reduction that starts with " << *FirstReducedVal Index: test/Analysis/CostModel/X86/reduction.ll =================================================================== --- test/Analysis/CostModel/X86/reduction.ll +++ test/Analysis/CostModel/X86/reduction.ll @@ -33,7 +33,7 @@ %bin.rdx.3 = add <8 x i32> %bin.rdx.2, %rdx.shuf.3 ; CHECK-LABEL: reduction_cost_int -; CHECK: cost of 17 {{.*}} extractelement +; CHECK: cost of 11 {{.*}} extractelement %r = extractelement <8 x i32> %bin.rdx.3, i32 0 ret i32 %r