Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -330,6 +330,8 @@ SDValue visitFP_TO_FP16(SDNode *N); SDValue visitFP16_TO_FP(SDNode *N); + bool isContractable(SDNode *N); + bool isContractableFMUL(SDValue N); SDValue visitFADDForFMACombine(SDNode *N); SDValue visitFSUBForFMACombine(SDNode *N); SDValue visitFMULForFMADistributiveCombine(SDNode *N); @@ -8692,6 +8694,22 @@ return DAG.getBuildVector(VT, DL, Ops); } +static bool allowFusionGlobally(SelectionDAG &DAG) { + const TargetOptions &Options = DAG.getTarget().Options; + return Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath; +} + +bool DAGCombiner::isContractable(SDNode *N) { + if (allowFusionGlobally(DAG)) + return true; + SDNodeFlags F = cast(N)->Flags; + return F.hasAllowContract() || F.hasUnsafeAlgebra(); +} + +bool DAGCombiner::isContractableFMUL(SDValue N) { + return N.getOpcode() == ISD::FMUL && isContractable(N.getNode()); +} + /// Try to perform FMA combining on a given FADD node. SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); @@ -8700,15 +8718,15 @@ SDLoc SL(N); const TargetOptions &Options = DAG.getTarget().Options; - bool AllowFusion = - (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); + if (!isContractable(N)) + return SDValue(); // Floating-point multiply-add with intermediate rounding. bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); // Floating-point multiply-add without intermediate rounding. bool HasFMA = - AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) && + TLI.isFMAFasterThanFMulAndFAdd(VT) && (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. @@ -8717,7 +8735,7 @@ const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo(); ; - if (AllowFusion && STI && STI->generateFMAsInMachineCombiner(OptLevel)) + if (STI && STI->generateFMAsInMachineCombiner(OptLevel)) return SDValue(); // Always prefer FMAD to FMA for precision. @@ -8727,33 +8745,30 @@ // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), // prefer to fold the multiply with fewer uses. - if (Aggressive && N0.getOpcode() == ISD::FMUL && - N1.getOpcode() == ISD::FMUL) { + if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) { if (N0.getNode()->use_size() > N1.getNode()->use_size()) std::swap(N0, N1); } // fold (fadd (fmul x, y), z) -> (fma x, y, z) - if (N0.getOpcode() == ISD::FMUL && - (Aggressive || N0->hasOneUse())) { + if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), N1); } // fold (fadd x, (fmul y, z)) -> (fma y, z, x) // Note: Commutes FADD operands. - if (N1.getOpcode() == ISD::FMUL && - (Aggressive || N1->hasOneUse())) { + if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), N0); } // Look through FP_EXTEND nodes to do more combining. - if (AllowFusion && LookThroughFPExt) { + if (LookThroughFPExt) { // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) if (N0.getOpcode() == ISD::FP_EXTEND) { SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::FMUL) + if (isContractableFMUL(N00)) return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), @@ -8765,7 +8780,7 @@ // Note: Commutes FADD operands. if (N1.getOpcode() == ISD::FP_EXTEND) { SDValue N10 = N1.getOperand(0); - if (N10.getOpcode() == ISD::FMUL) + if (isContractableFMUL(N10)) return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), @@ -8806,7 +8821,7 @@ N0)); } - if (AllowFusion && LookThroughFPExt) { + if (/*AllowFusion &&*/ LookThroughFPExt) { // fold (fadd (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y, (fma (fpext u), (fpext v), z)) auto FoldFAddFMAFPExtFMul = [&] ( @@ -8821,7 +8836,7 @@ SDValue N02 = N0.getOperand(2); if (N02.getOpcode() == ISD::FP_EXTEND) { SDValue N020 = N02.getOperand(0); - if (N020.getOpcode() == ISD::FMUL) + if (isContractableFMUL(N020)) return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1), N020.getOperand(0), N020.getOperand(1), N1); @@ -8847,7 +8862,7 @@ SDValue N00 = N0.getOperand(0); if (N00.getOpcode() == PreferredFusedOpcode) { SDValue N002 = N00.getOperand(2); - if (N002.getOpcode() == ISD::FMUL) + if (isContractableFMUL(N002)) return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1), N002.getOperand(0), N002.getOperand(1), N1); @@ -8860,7 +8875,7 @@ SDValue N12 = N1.getOperand(2); if (N12.getOpcode() == ISD::FP_EXTEND) { SDValue N120 = N12.getOperand(0); - if (N120.getOpcode() == ISD::FMUL) + if (isContractableFMUL(N120)) return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1), N120.getOperand(0), N120.getOperand(1), N0); @@ -8876,7 +8891,7 @@ SDValue N10 = N1.getOperand(0); if (N10.getOpcode() == PreferredFusedOpcode) { SDValue N102 = N10.getOperand(2); - if (N102.getOpcode() == ISD::FMUL) + if (isContractableFMUL(N102)) return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1), N102.getOperand(0), N102.getOperand(1), N0); Index: test/CodeGen/AArch64/neon-fma-FMF.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/neon-fma-FMF.ll @@ -0,0 +1,25 @@ +; RUN: llc < %s -verify-machineinstrs -mtriple=aarch64-none-linux-gnu -mattr=+neon | FileCheck %s + +define <2 x float> @fma(<2 x float> %A, <2 x float> %B, <2 x float> %C) { +; CHECK-LABEL: fma: +; CHECK: fmla {{v[0-9]+}}.2s, {{v[0-9]+}}.2s, {{v[0-9]+}}.2s + %tmp1 = fmul contract <2 x float> %A, %B; + %tmp2 = fadd contract <2 x float> %C, %tmp1; + ret <2 x float> %tmp2 +} + +define <2 x float> @no_fma_1(<2 x float> %A, <2 x float> %B, <2 x float> %C) { +; CHECK-LABEL: no_fma_1: +; CHECK-NOT: fmla + %tmp1 = fmul contract <2 x float> %A, %B; + %tmp2 = fadd <2 x float> %C, %tmp1; + ret <2 x float> %tmp2 +} + +define <2 x float> @no_fma_2(<2 x float> %A, <2 x float> %B, <2 x float> %C) { +; CHECK-LABEL: fma_2: +; CHECK-NOT: fmla + %tmp1 = fmul <2 x float> %A, %B; + %tmp2 = fadd contract <2 x float> %C, %tmp1; + ret <2 x float> %tmp2 +}