Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -8700,6 +8700,11 @@ return DAG.getBuildVector(VT, DL, Ops); } +static bool isContractable(SDNode *N) { + SDNodeFlags F = cast(N)->Flags; + return F.hasAllowContract() || F.hasUnsafeAlgebra(); +} + /// Try to perform FMA combining on a given FADD node. SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); @@ -8708,24 +8713,27 @@ SDLoc SL(N); const TargetOptions &Options = DAG.getTarget().Options; - bool AllowFusion = - (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); // Floating-point multiply-add with intermediate rounding. bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); // Floating-point multiply-add without intermediate rounding. bool HasFMA = - AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) && + TLI.isFMAFasterThanFMulAndFAdd(VT) && (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) return SDValue(); + bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast || + Options.UnsafeFPMath || HasFMAD); + // If the addition is not contractable, do not combine. + if (!AllowFusionGlobally && !isContractable(N)) + return SDValue(); + const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo(); - ; - if (AllowFusion && STI && STI->generateFMAsInMachineCombiner(OptLevel)) + if (STI && STI->generateFMAsInMachineCombiner(OptLevel)) return SDValue(); // Always prefer FMAD to FMA for precision. @@ -8733,35 +8741,39 @@ bool Aggressive = TLI.enableAggressiveFMAFusion(VT); bool LookThroughFPExt = TLI.isFPExtFree(VT); + // Is the node an FMUL and contractable either due to global flags or + // SDNodeFlags. + auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { + if (N.getOpcode() != ISD::FMUL) + return false; + return AllowFusionGlobally || isContractable(N.getNode()); + }; // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), // prefer to fold the multiply with fewer uses. - if (Aggressive && N0.getOpcode() == ISD::FMUL && - N1.getOpcode() == ISD::FMUL) { + if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) { if (N0.getNode()->use_size() > N1.getNode()->use_size()) std::swap(N0, N1); } // fold (fadd (fmul x, y), z) -> (fma x, y, z) - if (N0.getOpcode() == ISD::FMUL && - (Aggressive || N0->hasOneUse())) { + if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), N1); } // fold (fadd x, (fmul y, z)) -> (fma y, z, x) // Note: Commutes FADD operands. - if (N1.getOpcode() == ISD::FMUL && - (Aggressive || N1->hasOneUse())) { + if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), N0); } // Look through FP_EXTEND nodes to do more combining. - if (AllowFusion && LookThroughFPExt) { + if (LookThroughFPExt) { // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) if (N0.getOpcode() == ISD::FP_EXTEND) { SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::FMUL) + if (isContractableFMUL(N00)) return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), @@ -8773,7 +8785,7 @@ // Note: Commutes FADD operands. if (N1.getOpcode() == ISD::FP_EXTEND) { SDValue N10 = N1.getOperand(0); - if (N10.getOpcode() == ISD::FMUL) + if (isContractableFMUL(N10)) return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), @@ -8814,7 +8826,7 @@ N0)); } - if (AllowFusion && LookThroughFPExt) { + if (/*AllowFusion &&*/ LookThroughFPExt) { // fold (fadd (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y, (fma (fpext u), (fpext v), z)) auto FoldFAddFMAFPExtFMul = [&] ( @@ -8829,7 +8841,7 @@ SDValue N02 = N0.getOperand(2); if (N02.getOpcode() == ISD::FP_EXTEND) { SDValue N020 = N02.getOperand(0); - if (N020.getOpcode() == ISD::FMUL) + if (isContractableFMUL(N020)) return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1), N020.getOperand(0), N020.getOperand(1), N1); @@ -8855,7 +8867,7 @@ SDValue N00 = N0.getOperand(0); if (N00.getOpcode() == PreferredFusedOpcode) { SDValue N002 = N00.getOperand(2); - if (N002.getOpcode() == ISD::FMUL) + if (isContractableFMUL(N002)) return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1), N002.getOperand(0), N002.getOperand(1), N1); @@ -8868,7 +8880,7 @@ SDValue N12 = N1.getOperand(2); if (N12.getOpcode() == ISD::FP_EXTEND) { SDValue N120 = N12.getOperand(0); - if (N120.getOpcode() == ISD::FMUL) + if (isContractableFMUL(N120)) return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1), N120.getOperand(0), N120.getOperand(1), N0); @@ -8884,7 +8896,7 @@ SDValue N10 = N1.getOperand(0); if (N10.getOpcode() == PreferredFusedOpcode) { SDValue N102 = N10.getOperand(2); - if (N102.getOpcode() == ISD::FMUL) + if (isContractableFMUL(N102)) return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1), N102.getOperand(0), N102.getOperand(1), N0); Index: test/CodeGen/AArch64/neon-fma-FMF.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/neon-fma-FMF.ll @@ -0,0 +1,27 @@ +; RUN: llc < %s -verify-machineinstrs -mtriple=aarch64-none-linux-gnu -mattr=+neon | FileCheck %s + +define <2 x float> @fma(<2 x float> %A, <2 x float> %B, <2 x float> %C) { +; CHECK-LABEL: fma: +; CHECK: fmla {{v[0-9]+}}.2s, {{v[0-9]+}}.2s, {{v[0-9]+}}.2s + %tmp1 = fmul contract <2 x float> %A, %B; + %tmp2 = fadd contract <2 x float> %C, %tmp1; + ret <2 x float> %tmp2 +} + +define <2 x float> @no_fma_1(<2 x float> %A, <2 x float> %B, <2 x float> %C) { +; CHECK-LABEL: no_fma_1: +; CHECK: fmul +; CHECK: fadd + %tmp1 = fmul contract <2 x float> %A, %B; + %tmp2 = fadd <2 x float> %C, %tmp1; + ret <2 x float> %tmp2 +} + +define <2 x float> @no_fma_2(<2 x float> %A, <2 x float> %B, <2 x float> %C) { +; CHECK-LABEL: no_fma_2: +; CHECK: fmul +; CHECK: fadd + %tmp1 = fmul <2 x float> %A, %B; + %tmp2 = fadd contract <2 x float> %C, %tmp1; + ret <2 x float> %tmp2 +} Index: test/CodeGen/PowerPC/fma-aggr-FMF.ll =================================================================== --- /dev/null +++ test/CodeGen/PowerPC/fma-aggr-FMF.ll @@ -0,0 +1,28 @@ +; RUN: llc < %s -verify-machineinstrs -mtriple=powerpc-darwin64 | FileCheck %s + +define float @can_fma_with_fewer_uses(float %f1, float %f2, float %f3, float %f4) { +; CHECK-LABEL: can_fma_with_fewer_uses: + %mul1 = fmul contract float %f1, %f2 +; CHECK: fmuls + %mul2 = fmul contract float %f3, %f4 +; CHECK-NOT: fmuls + %add = fadd contract float %mul1, %mul2 +; CHECK: fmadds + %second_use_of_mul1 = fdiv float %mul1, %add + ret float %second_use_of_mul1 +} + +; There is no contract on the mul with no use so we can't fuse that. Since we +; are fusing with the mul with a use, the fmul needs to stick around beside +; the fma. +define float @no_fma_with_fewer_uses(float %f1, float %f2, float %f3, float %f4) { +; CHECK-LABEL: no_fma_with_fewer_uses: + %mul1 = fmul contract float %f1, %f2 +; CHECK: fmuls + %mul2 = fmul float %f3, %f4 +; CHECK: fmuls + %add = fadd contract float %mul1, %mul2 +; CHECK: fmadds + %second_use_of_mul1 = fdiv float %mul1, %add + ret float %second_use_of_mul1 +}