diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -9665,6 +9665,59 @@ return tryToVectorizeList(VL, R); } +/// Calculate cost of losing FMA contraction opportunities when an fadd +/// reduction or bundle is fed by an SLP tree rooted at an fmul bundle. + +// When a bundle of multiplies feeds a bundle of adds, and floating-point +// contracts are enabled, it may be more profitable to allow those to +// combine into FMAs than to perform a horizontal reduction or to vectorize +// the tree of multiplies. +static InstructionCost getFMALossCost(const TargetTransformInfo *TTI, + Value *TreeRootOp, FastMathFlags FMF, + unsigned Width) { + + if (!FMF.allowContract()) + return 0; + + // FIXME: Return 0 if target does not support FMA instructions at all. + + auto *MainInstr = dyn_cast(TreeRootOp); + if (!MainInstr + || MainInstr->getOpcode() != Instruction::FMul + || !MainInstr->hasAllowContract()) + return 0; + + // Cost = Cost(FMul) + Cost(FAdd) - Cost(FMA). + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + Type *ScalarTy = MainInstr->getType(); + + InstructionCost MulCost = + TTI->getArithmeticInstrCost(Instruction::FMul, ScalarTy, CostKind); + InstructionCost AddCost = + TTI->getArithmeticInstrCost(Instruction::FMul, ScalarTy, CostKind); + + Intrinsic::ID ID = Intrinsic::fmuladd; + SmallVector Tys; + Tys.push_back(ScalarTy); + Tys.push_back(ScalarTy); + ArrayRef ArgTys(Tys); + IntrinsicCostAttributes CostAttrs(ID, ScalarTy, ArgTys); + InstructionCost FMACost = TTI->getIntrinsicInstrCost(CostAttrs, CostKind); + + InstructionCost ScalarCost = MulCost + AddCost - FMACost; + if (ScalarCost < 0) + ScalarCost = 0; + + if (ScalarCost > 0) { + LLVM_DEBUG(dbgs() << "SLP: Adding cost " << Width * ScalarCost + << " for vectorization that breaks " << Width << " FMAs\n"); + LLVM_DEBUG(dbgs() << "SLP: MulCost = " << MulCost << ", AddCost = " + << AddCost << ", FMACost = " << FMACost << "\n"); + } + + return Width * ScalarCost; +} + bool SLPVectorizerPass::tryToVectorizeList(ArrayRef VL, BoUpSLP &R, bool LimitForRegisterSize) { if (VL.size() < 2) @@ -10795,6 +10848,8 @@ VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF, CostKind); ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind); + if (RdxKind == RecurKind::FAdd) + VectorCost += getFMALossCost(TTI, FirstReducedVal, FMF, ReduxWidth); break; } case RecurKind::FMax: diff --git a/llvm/test/Transforms/SLPVectorizer/X86/slp-fma-loss.ll b/llvm/test/Transforms/SLPVectorizer/X86/slp-fma-loss.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/slp-fma-loss.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/slp-fma-loss.ll @@ -5,18 +5,20 @@ ; adds may look profitable, but is not because it eliminates generation of ; floating-point FMAs that would be more profitable. -; FIXME: We generate a horizontal reduction today. - define void @hr() { ; CHECK-LABEL: @hr( ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[PHI0:%.*]] = phi double [ 0.000000e+00, [[TMP0:%.*]] ], [ [[OP_RDX:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI0:%.*]] = phi double [ 0.000000e+00, [[TMP0:%.*]] ], [ [[ADD3:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[CVT0:%.*]] = uitofp i16 0 to double -; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x double> , double [[CVT0]], i32 0 -; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <4 x double> zeroinitializer, [[TMP1]] -; CHECK-NEXT: [[TMP3:%.*]] = call fast double @llvm.vector.reduce.fadd.v4f64(double -0.000000e+00, <4 x double> [[TMP2]]) -; CHECK-NEXT: [[OP_RDX]] = fadd fast double [[TMP3]], [[PHI0]] +; CHECK-NEXT: [[MUL0:%.*]] = fmul fast double 0.000000e+00, [[CVT0]] +; CHECK-NEXT: [[ADD0:%.*]] = fadd fast double [[MUL0]], [[PHI0]] +; CHECK-NEXT: [[MUL1:%.*]] = fmul fast double 0.000000e+00, 0.000000e+00 +; CHECK-NEXT: [[ADD1:%.*]] = fadd fast double [[MUL1]], [[ADD0]] +; CHECK-NEXT: [[MUL2:%.*]] = fmul fast double 0.000000e+00, 0.000000e+00 +; CHECK-NEXT: [[ADD2:%.*]] = fadd fast double [[MUL2]], [[ADD1]] +; CHECK-NEXT: [[MUL3:%.*]] = fmul fast double 0.000000e+00, 0.000000e+00 +; CHECK-NEXT: [[ADD3]] = fadd fast double [[MUL3]], [[ADD2]] ; CHECK-NEXT: br i1 true, label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: ; CHECK-NEXT: ret void @@ -45,18 +47,25 @@ ; may look profitable; but both are not because this eliminates generation ; of floating-point FMAs that would be more profitable. -; FIXME: We generate a horizontal reduction today, and if that's disabled, we -; still vectorize some of the multiplies. +; FIXME: We no longer generate a horizontal reduction today, but we +; still vectorize the multiplies. define double @hr_or_mul() { ; CHECK-LABEL: @hr_or_mul( ; CHECK-NEXT: [[CVT0:%.*]] = uitofp i16 3 to double -; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x double> poison, double [[CVT0]], i32 0 -; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> poison, <4 x i32> zeroinitializer -; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <4 x double> , [[SHUFFLE]] -; CHECK-NEXT: [[TMP3:%.*]] = call fast double @llvm.vector.reduce.fadd.v4f64(double -0.000000e+00, <4 x double> [[TMP2]]) -; CHECK-NEXT: [[OP_RDX:%.*]] = fadd fast double [[TMP3]], [[CVT0]] -; CHECK-NEXT: ret double [[OP_RDX]] +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x double> poison, double [[CVT0]], i32 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <2 x double> [[TMP1]], double [[CVT0]], i32 1 +; CHECK-NEXT: [[TMP3:%.*]] = fmul fast <2 x double> , [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[TMP3]], i32 1 +; CHECK-NEXT: [[ADD0:%.*]] = fadd fast double [[TMP4]], [[CVT0]] +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x double> [[TMP3]], i32 0 +; CHECK-NEXT: [[ADD1:%.*]] = fadd fast double [[TMP5]], [[ADD0]] +; CHECK-NEXT: [[TMP6:%.*]] = fmul fast <2 x double> , [[TMP2]] +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x double> [[TMP6]], i32 1 +; CHECK-NEXT: [[ADD2:%.*]] = fadd fast double [[TMP7]], [[ADD1]] +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[TMP6]], i32 0 +; CHECK-NEXT: [[ADD3:%.*]] = fadd fast double [[TMP8]], [[ADD2]] +; CHECK-NEXT: ret double [[ADD3]] ; %cvt0 = uitofp i16 3 to double %mul0 = fmul fast double 7.000000e+00, %cvt0