diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -111,6 +111,10 @@ Optional FMF, TTI::TargetCostKind CostKind); + InstructionCost getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned, + Type *ResTy, VectorType *ValTy, + TTI::TargetCostKind CostKind); + bool isElementTypeLegalForScalableVector(Type *Ty) const { return TLI->isLegalElementTypeForRVV(Ty); } diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -371,6 +371,20 @@ return (LT.first - 1) + BaseCost + Log2_32_Ceil(VL); } +InstructionCost +RISCVTTIImpl::getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned, + Type *ResTy, VectorType *ValTy, + TTI::TargetCostKind CostKind) { + // Now we only use vwredsum and vwredsumu if the result value width is equal + // to twice of input SEW-width. + if (IsMLA || ResTy->getScalarSizeInBits() != 2 * ValTy->getScalarSizeInBits()) + return BaseT::getExtendedAddReductionCost(IsMLA, IsUnsigned, ResTy, ValTy, + CostKind); + + // vwredsum and vwredsumu is same cost with vredsumu. + return getArithmeticReductionCost(Instruction::Add, ValTy, {}, CostKind); +} + void RISCVTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE, TTI::UnrollingPreferences &UP, OptimizationRemarkEmitter *ORE) {