Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7261,8 +7261,21 @@ // If we're using ordered reductions then we can just return the base cost // here, since getArithmeticReductionCost calculates the full ordered // reduction cost when FP reassociation is not allowed. - if (useOrderedReductions(RdxDesc)) - return BaseCost; + if (useOrderedReductions(RdxDesc)) { + if (RdxDesc.getRecurrenceKind() != RecurKind::FMulAdd) + return BaseCost; + // For a call to the llvm.fmuladd intrinsic we need to add the cost of a + // normal fmul instruction to the cost of the fadd reduction. + Value *Op2 = I->getOperand(1); + TargetTransformInfo::OperandValueProperties Op2VP; + TargetTransformInfo::OperandValueKind Op2VK = + TTI.getOperandInfo(Op2, Op2VP); + SmallVector Operands(I->operand_values()); + InstructionCost FMulCost = TTI.getArithmeticInstrCost( + Instruction::FMul, VectorTy, CostKind, TargetTransformInfo::OK_AnyValue, + Op2VK, TargetTransformInfo::OP_None, Op2VP, Operands, I); + return BaseCost + FMulCost; + } // Get the operand that was not the reduction chain and match it to one of the // patterns, returning the better cost if it is found. @@ -7929,6 +7942,12 @@ return TTI.getCastInstrCost(Opcode, VectorTy, SrcVecTy, CCH, CostKind, I); } case Instruction::Call: { + // Recognize a call to the llvm.fmuladd intrinsic. + if (RecurrenceDescriptor::isFMulAddIntrinsic(I)) { + // Detect reduction patterns. + if (auto RedCost = getReductionPatternCost(I, VF, VectorTy, CostKind)) + return *RedCost; + } bool NeedToScalarize; CallInst *CI = cast(I); InstructionCost CallCost = getVectorCallCost(CI, VF, NeedToScalarize); Index: llvm/test/Transforms/LoopVectorize/AArch64/strict-fadd-cost.ll =================================================================== --- llvm/test/Transforms/LoopVectorize/AArch64/strict-fadd-cost.ll +++ llvm/test/Transforms/LoopVectorize/AArch64/strict-fadd-cost.ll @@ -48,3 +48,53 @@ for.end: ret double %add } + +; CHECK-VF4: Found an estimated cost of 23 for VF 4 For instruction: %muladd = tail call float @llvm.fmuladd.f32(float %0, float %1, float %sum.07) +; CHECK-VF8: Found an estimated cost of 46 for VF 8 For instruction: %muladd = tail call float @llvm.fmuladd.f32(float %0, float %1, float %sum.07) + +define float @fmuladd_strict32(float* %a, float* %b, i64 %n) { +entry: + br label %for.body + +for.body: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] + %sum.07 = phi float [ 0.000000e+00, %entry ], [ %muladd, %for.body ] + %arrayidx = getelementptr inbounds float, float* %a, i64 %iv + %0 = load float, float* %arrayidx, align 4 + %arrayidx2 = getelementptr inbounds float, float* %b, i64 %iv + %1 = load float, float* %arrayidx2, align 4 + %muladd = tail call float @llvm.fmuladd.f32(float %0, float %1, float %sum.07) + %iv.next = add nuw nsw i64 %iv, 1 + %exitcond.not = icmp eq i64 %iv.next, %n + br i1 %exitcond.not, label %for.end, label %for.body + +for.end: + ret float %muladd +} + +declare float @llvm.fmuladd.f32(float, float, float) + +; CHECK-VF4: Found an estimated cost of 22 for VF 4 For instruction: %muladd = tail call double @llvm.fmuladd.f64(double %0, double %1, double %sum.07) +; CHECK-VF8: Found an estimated cost of 44 for VF 8 For instruction: %muladd = tail call double @llvm.fmuladd.f64(double %0, double %1, double %sum.07) + +define double @fmuladd_strict64(double* %a, double* %b, i64 %n) { +entry: + br label %for.body + +for.body: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] + %sum.07 = phi double [ 0.000000e+00, %entry ], [ %muladd, %for.body ] + %arrayidx = getelementptr inbounds double, double* %a, i64 %iv + %0 = load double, double* %arrayidx, align 4 + %arrayidx2 = getelementptr inbounds double, double* %b, i64 %iv + %1 = load double, double* %arrayidx2, align 4 + %muladd = tail call double @llvm.fmuladd.f64(double %0, double %1, double %sum.07) + %iv.next = add nuw nsw i64 %iv, 1 + %exitcond.not = icmp eq i64 %iv.next, %n + br i1 %exitcond.not, label %for.end, label %for.body + +for.end: + ret double %muladd +} + +declare double @llvm.fmuladd.f64(double, double, double)