diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -1017,6 +1017,9 @@ setTargetDAGCombine(ISD::SELECT); setTargetDAGCombine(ISD::SELECT_CC); } + if (Subtarget->hasMVEFloatOps()) { + setTargetDAGCombine(ISD::FADD); + } if (!Subtarget->hasFP64()) { // When targeting a floating-point unit with only single-precision @@ -16407,6 +16410,42 @@ return FixConv; } +static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *Subtarget) { + if (!Subtarget->hasMVEFloatOps()) + return SDValue(); + + // Turn (fadd x, (vselect c, y, -0.0)) into (vselect c, (fadd x, y), x) + // The second form can be more easily turned into a predicated vadd, and + // possibly combined into a fma to become a predicated vfma. + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + // The identity element for a fadd is -0.0, which these VMOV's represent. + auto isNegativeZeroSplat = [&](SDValue Op) { + if (Op.getOpcode() != ISD::BITCAST || + Op.getOperand(0).getOpcode() != ARMISD::VMOVIMM) + return false; + if (VT == MVT::v4f32 && Op.getOperand(0).getConstantOperandVal(0) == 1664) + return true; + if (VT == MVT::v8f16 && Op.getOperand(0).getConstantOperandVal(0) == 2688) + return true; + return false; + }; + + if (Op0.getOpcode() == ISD::VSELECT && Op1.getOpcode() != ISD::VSELECT) + std::swap(Op0, Op1); + + if (Op1.getOpcode() != ISD::VSELECT || + !isNegativeZeroSplat(Op1.getOperand(2))) + return SDValue(); + SDValue FAdd = + DAG.getNode(ISD::FADD, DL, VT, Op0, Op1.getOperand(1), N->getFlags()); + return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0); +} + /// PerformVDIVCombine - VCVT (fixed-point to floating-point, Advanced SIMD) /// can replace combinations of VCVT (integer to floating-point) and VDIV /// when the VDIV has a constant operand that is a power of 2. @@ -18201,6 +18240,8 @@ case ISD::FP_TO_SINT: case ISD::FP_TO_UINT: return PerformVCVTCombine(N, DCI.DAG, Subtarget); + case ISD::FADD: + return PerformFAddVSelectCombine(N, DCI.DAG, Subtarget); case ISD::FDIV: return PerformVDIVCombine(N, DCI.DAG, Subtarget); case ISD::INTRINSIC_WO_CHAIN: diff --git a/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll b/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll --- a/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll +++ b/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll @@ -470,10 +470,9 @@ define arm_aapcs_vfpcc <4 x float> @fma_v4f32_x(<4 x float> %x, <4 x float> %y, <4 x float> %z, i32 %n) { ; CHECK-LABEL: fma_v4f32_x: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmul.f32 q1, q1, q2 ; CHECK-NEXT: vctp.32 r0 ; CHECK-NEXT: vpst -; CHECK-NEXT: vaddt.f32 q0, q0, q1 +; CHECK-NEXT: vfmat.f32 q0, q1, q2 ; CHECK-NEXT: bx lr entry: %c = call <4 x i1> @llvm.arm.mve.vctp32(i32 %n) @@ -486,10 +485,9 @@ define arm_aapcs_vfpcc <8 x half> @fma_v8f16_x(<8 x half> %x, <8 x half> %y, <8 x half> %z, i32 %n) { ; CHECK-LABEL: fma_v8f16_x: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmul.f16 q1, q1, q2 ; CHECK-NEXT: vctp.16 r0 ; CHECK-NEXT: vpst -; CHECK-NEXT: vaddt.f16 q0, q0, q1 +; CHECK-NEXT: vfmat.f16 q0, q1, q2 ; CHECK-NEXT: bx lr entry: %c = call <8 x i1> @llvm.arm.mve.vctp16(i32 %n) @@ -2422,7 +2420,7 @@ ; CHECK-NEXT: vctp.32 r0 ; CHECK-NEXT: vdup.32 q1, r1 ; CHECK-NEXT: vpst -; CHECK-NEXT: vaddt.f32 q1, q1, q0 +; CHECK-NEXT: vaddt.f32 q1, q0, r1 ; CHECK-NEXT: vmov q0, q1 ; CHECK-NEXT: bx lr entry: @@ -2441,7 +2439,7 @@ ; CHECK-NEXT: vctp.16 r0 ; CHECK-NEXT: vdup.16 q1, r1 ; CHECK-NEXT: vpst -; CHECK-NEXT: vaddt.f16 q1, q1, q0 +; CHECK-NEXT: vaddt.f16 q1, q0, r1 ; CHECK-NEXT: vmov q0, q1 ; CHECK-NEXT: bx lr entry: