diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -112,6 +112,11 @@ Optional FMF, TTI::TargetCostKind CostKind); + InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned, + Type *ResTy, VectorType *ValTy, + Optional FMF, + TTI::TargetCostKind CostKind); + bool isElementTypeLegalForScalableVector(Type *Ty) const { return TLI->isLegalElementTypeForRVV(Ty); } diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -377,6 +377,43 @@ return (LT.first - 1) + BaseCost + Log2_32_Ceil(VL); } +InstructionCost RISCVTTIImpl::getExtendedReductionCost( + unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *ValTy, + Optional FMF, TTI::TargetCostKind CostKind) { + if (isa(ValTy) && !ST->useRVVForFixedLengthVectors()) + return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy, + FMF, CostKind); + + // Skip if scalar size of ResTy is bigger than ELEN. + if (ResTy->getScalarSizeInBits() > ST->getELEN()) + return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy, + FMF, CostKind); + + int ISD = TLI->InstructionOpcodeToISD(Opcode); + assert(ISD && "Invalid opcode"); + + std::pair LT = TLI->getTypeLegalizationCost(DL, ValTy); + + switch (ISD) { + case ISD::ADD: + if (ResTy->getScalarSizeInBits() == 2 * LT.second.getScalarSizeInBits()) + // vwredsum and vwredsumu. + return (LT.first - 1) + + getArithmeticReductionCost(Instruction::Add, ValTy, FMF, CostKind); + break; + case ISD::FADD: + if (ResTy->getScalarSizeInBits() == 2 * LT.second.getScalarSizeInBits()) + // vfwredosum and vfwredusum. + return (LT.first - 1) + getArithmeticReductionCost(Instruction::FAdd, + ValTy, FMF, CostKind); + break; + default: + break; + } + return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy, FMF, + CostKind); +} + void RISCVTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE, TTI::UnrollingPreferences &UP, OptimizationRemarkEmitter *ORE) {