diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -112,6 +112,11 @@ Optional FMF, TTI::TargetCostKind CostKind); + InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned, + Type *ResTy, VectorType *ValTy, + Optional FMF, + TTI::TargetCostKind CostKind); + bool isElementTypeLegalForScalableVector(Type *Ty) const { return TLI->isLegalElementTypeForRVV(Ty); } diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -377,6 +377,32 @@ return (LT.first - 1) + BaseCost + Log2_32_Ceil(VL); } +InstructionCost RISCVTTIImpl::getExtendedReductionCost( + unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *ValTy, + Optional FMF, TTI::TargetCostKind CostKind) { + if (isa(ValTy) && !ST->useRVVForFixedLengthVectors()) + return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy, + FMF, CostKind); + + // Skip if scalar size of ResTy is bigger than ELEN. + if (ResTy->getScalarSizeInBits() > ST->getELEN()) + return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy, + FMF, CostKind); + + if (Opcode != Instruction::Add && Opcode != Instruction::FAdd) + return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy, + FMF, CostKind); + + std::pair LT = TLI->getTypeLegalizationCost(DL, ValTy); + + if (ResTy->getScalarSizeInBits() != 2 * LT.second.getScalarSizeInBits()) + return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy, + FMF, CostKind); + + return (LT.first - 1) + + getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind); +} + void RISCVTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE, TTI::UnrollingPreferences &UP, OptimizationRemarkEmitter *ORE) {