diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7403,6 +7403,12 @@ InstructionCost BaseCost = TTI.getArithmeticReductionCost( RdxDesc.getOpcode(), VectorTy, RdxDesc.getFastMathFlags(), CostKind); + // For a call to the llvm.fmuladd intrinsic we need to add the cost of a + // normal fmul instruction to the cost of the fadd reduction. + if (RdxDesc.getRecurrenceKind() == RecurKind::FMulAdd) + BaseCost += + TTI.getArithmeticInstrCost(Instruction::FMul, VectorTy, CostKind); + // If we're using ordered reductions then we can just return the base cost // here, since getArithmeticReductionCost calculates the full ordered // reduction cost when FP reassociation is not allowed. @@ -8079,6 +8085,9 @@ return TTI.getCastInstrCost(Opcode, VectorTy, SrcVecTy, CCH, CostKind, I); } case Instruction::Call: { + if (RecurrenceDescriptor::isFMulAddIntrinsic(I)) + if (auto RedCost = getReductionPatternCost(I, VF, VectorTy, CostKind)) + return *RedCost; bool NeedToScalarize; CallInst *CI = cast(I); InstructionCost CallCost = getVectorCallCost(CI, VF, NeedToScalarize); diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/strict-fadd-cost.ll b/llvm/test/Transforms/LoopVectorize/AArch64/strict-fadd-cost.ll --- a/llvm/test/Transforms/LoopVectorize/AArch64/strict-fadd-cost.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/strict-fadd-cost.ll @@ -48,3 +48,53 @@ for.end: ret double %add } + +; CHECK-VF4: Found an estimated cost of 23 for VF 4 For instruction: %muladd = tail call float @llvm.fmuladd.f32(float %0, float %1, float %sum.07) +; CHECK-VF8: Found an estimated cost of 46 for VF 8 For instruction: %muladd = tail call float @llvm.fmuladd.f32(float %0, float %1, float %sum.07) + +define float @fmuladd_strict32(float* %a, float* %b, i64 %n) { +entry: + br label %for.body + +for.body: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] + %sum.07 = phi float [ 0.000000e+00, %entry ], [ %muladd, %for.body ] + %arrayidx = getelementptr inbounds float, float* %a, i64 %iv + %0 = load float, float* %arrayidx, align 4 + %arrayidx2 = getelementptr inbounds float, float* %b, i64 %iv + %1 = load float, float* %arrayidx2, align 4 + %muladd = tail call float @llvm.fmuladd.f32(float %0, float %1, float %sum.07) + %iv.next = add nuw nsw i64 %iv, 1 + %exitcond.not = icmp eq i64 %iv.next, %n + br i1 %exitcond.not, label %for.end, label %for.body + +for.end: + ret float %muladd +} + +declare float @llvm.fmuladd.f32(float, float, float) + +; CHECK-VF4: Found an estimated cost of 22 for VF 4 For instruction: %muladd = tail call double @llvm.fmuladd.f64(double %0, double %1, double %sum.07) +; CHECK-VF8: Found an estimated cost of 44 for VF 8 For instruction: %muladd = tail call double @llvm.fmuladd.f64(double %0, double %1, double %sum.07) + +define double @fmuladd_strict64(double* %a, double* %b, i64 %n) { +entry: + br label %for.body + +for.body: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] + %sum.07 = phi double [ 0.000000e+00, %entry ], [ %muladd, %for.body ] + %arrayidx = getelementptr inbounds double, double* %a, i64 %iv + %0 = load double, double* %arrayidx, align 4 + %arrayidx2 = getelementptr inbounds double, double* %b, i64 %iv + %1 = load double, double* %arrayidx2, align 4 + %muladd = tail call double @llvm.fmuladd.f64(double %0, double %1, double %sum.07) + %iv.next = add nuw nsw i64 %iv, 1 + %exitcond.not = icmp eq i64 %iv.next, %n + br i1 %exitcond.not, label %for.end, label %for.body + +for.end: + ret double %muladd +} + +declare double @llvm.fmuladd.f64(double, double, double)