Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1193,6 +1193,13 @@ VectorType *Ty, VectorType *CondTy, bool IsUnsigned, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const; + /// Calculate the cost of a call to the llvm.fmuladd intrinsic. This is + /// modeled as the cost of a normal fmul instruction plus the cost of an fadd + /// reduction. + InstructionCost getFMulAddReductionCost( + VectorType *Ty, Optional FMF, + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const; + /// Calculate the cost of an extended reduction pattern, similar to /// getArithmeticReductionCost of an Add reduction with an extension and /// optional multiply. This is the cost of as: @@ -1662,6 +1669,9 @@ virtual InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy, bool IsUnsigned, TTI::TargetCostKind CostKind) = 0; + virtual InstructionCost + getFMulAddReductionCost(VectorType *Ty, Optional FMF, + TTI::TargetCostKind CostKind) = 0; virtual InstructionCost getExtendedAddReductionCost( bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0; @@ -2177,6 +2187,11 @@ TTI::TargetCostKind CostKind) override { return Impl.getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind); } + InstructionCost + getFMulAddReductionCost(VectorType *Ty, Optional FMF, + TTI::TargetCostKind CostKind) override { + return Impl.getFMulAddReductionCost(Ty, FMF, CostKind); + } InstructionCost getExtendedAddReductionCost( bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) override { Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -637,6 +637,11 @@ return 1; } + InstructionCost getFMulAddReductionCost(VectorType *, Optional, + TTI::TargetCostKind) const { + return 1; + } + InstructionCost getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty, Index: llvm/include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -2174,6 +2174,16 @@ thisT()->getVectorInstrCost(Instruction::ExtractElement, Ty, 0); } + InstructionCost getFMulAddReductionCost(VectorType *Ty, + Optional FMF, + TTI::TargetCostKind CostKind) { + InstructionCost FAddReductionCost = thisT()->getArithmeticReductionCost( + Instruction::FAdd, Ty, FMF, CostKind); + InstructionCost FMulCost = + thisT()->getArithmeticInstrCost(Instruction::FMul, Ty, CostKind); + return FMulCost + FAddReductionCost; + } + InstructionCost getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty, TTI::TargetCostKind CostKind) { Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -917,6 +917,14 @@ return Cost; } +InstructionCost TargetTransformInfo::getFMulAddReductionCost( + VectorType *Ty, Optional FMF, + TTI::TargetCostKind CostKind) const { + InstructionCost Cost = TTIImpl->getFMulAddReductionCost(Ty, FMF, CostKind); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + InstructionCost TargetTransformInfo::getExtendedAddReductionCost( bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty, TTI::TargetCostKind CostKind) const { Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7255,8 +7255,14 @@ const RecurrenceDescriptor &RdxDesc = Legal->getReductionVars()[cast(ReductionPhi)]; - InstructionCost BaseCost = TTI.getArithmeticReductionCost( - RdxDesc.getOpcode(), VectorTy, RdxDesc.getFastMathFlags(), CostKind); + InstructionCost BaseCost; + if (RdxDesc.getRecurrenceKind() == RecurKind::FMulAdd) + // Recognize a call to the llvm.fmuladd intrinsic. + BaseCost = TTI.getFMulAddReductionCost(VectorTy, RdxDesc.getFastMathFlags(), + CostKind); + else + BaseCost = TTI.getArithmeticReductionCost( + RdxDesc.getOpcode(), VectorTy, RdxDesc.getFastMathFlags(), CostKind); // If we're using ordered reductions then we can just return the base cost // here, since getArithmeticReductionCost calculates the full ordered @@ -7929,6 +7935,12 @@ return TTI.getCastInstrCost(Opcode, VectorTy, SrcVecTy, CCH, CostKind, I); } case Instruction::Call: { + // Recognize a call to the llvm.fmuladd intrinsic. + if (RecurrenceDescriptor::isFMulAddIntrinsic(I)) { + // Detect reduction patterns. + if (auto RedCost = getReductionPatternCost(I, VF, VectorTy, CostKind)) + return *RedCost; + } bool NeedToScalarize; CallInst *CI = cast(I); InstructionCost CallCost = getVectorCallCost(CI, VF, NeedToScalarize); Index: llvm/test/Transforms/LoopVectorize/AArch64/strict-fadd-cost.ll =================================================================== --- llvm/test/Transforms/LoopVectorize/AArch64/strict-fadd-cost.ll +++ llvm/test/Transforms/LoopVectorize/AArch64/strict-fadd-cost.ll @@ -48,3 +48,53 @@ for.end: ret double %add } + +; CHECK-VF4: Found an estimated cost of 23 for VF 4 For instruction: %muladd = tail call float @llvm.fmuladd.f32(float %0, float %1, float %sum.07) +; CHECK-VF8: Found an estimated cost of 46 for VF 8 For instruction: %muladd = tail call float @llvm.fmuladd.f32(float %0, float %1, float %sum.07) + +define float @fmuladd_strict32(float* %a, float* %b, i64 %n) { +entry: + br label %for.body + +for.body: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] + %sum.07 = phi float [ 0.000000e+00, %entry ], [ %muladd, %for.body ] + %arrayidx = getelementptr inbounds float, float* %a, i64 %iv + %0 = load float, float* %arrayidx, align 4 + %arrayidx2 = getelementptr inbounds float, float* %b, i64 %iv + %1 = load float, float* %arrayidx2, align 4 + %muladd = tail call float @llvm.fmuladd.f32(float %0, float %1, float %sum.07) + %iv.next = add nuw nsw i64 %iv, 1 + %exitcond.not = icmp eq i64 %iv.next, %n + br i1 %exitcond.not, label %for.end, label %for.body + +for.end: + ret float %muladd +} + +declare float @llvm.fmuladd.f32(float, float, float) + +; CHECK-VF4: Found an estimated cost of 22 for VF 4 For instruction: %muladd = tail call double @llvm.fmuladd.f64(double %0, double %1, double %sum.07) +; CHECK-VF8: Found an estimated cost of 44 for VF 8 For instruction: %muladd = tail call double @llvm.fmuladd.f64(double %0, double %1, double %sum.07) + +define double @fmuladd_strict64(double* %a, double* %b, i64 %n) { +entry: + br label %for.body + +for.body: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] + %sum.07 = phi double [ 0.000000e+00, %entry ], [ %muladd, %for.body ] + %arrayidx = getelementptr inbounds double, double* %a, i64 %iv + %0 = load double, double* %arrayidx, align 4 + %arrayidx2 = getelementptr inbounds double, double* %b, i64 %iv + %1 = load double, double* %arrayidx2, align 4 + %muladd = tail call double @llvm.fmuladd.f64(double %0, double %1, double %sum.07) + %iv.next = add nuw nsw i64 %iv, 1 + %exitcond.not = icmp eq i64 %iv.next, %n + br i1 %exitcond.not, label %for.end, label %for.body + +for.end: + ret double %muladd +} + +declare double @llvm.fmuladd.f64(double, double, double)