diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -16884,6 +16884,46 @@ return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0, FaddFlags); } +static SDValue PerformFADDVCMLACombine(SDNode *N, SelectionDAG &DAG) { + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + if (!N->getFlags().hasAllowReassociation()) + return SDValue(); + + // Combine fadd(a, vcmla(b, c, d)) -> vcmla(fadd(a, b), b, c) + auto ReassocComplex = [&](SDValue A, SDValue B) { + if (A.getOpcode() != ISD::INTRINSIC_WO_CHAIN) + return SDValue(); + unsigned Opc = A.getConstantOperandVal(0); + if (Opc != Intrinsic::arm_mve_vcmlaq) + return SDValue(); + SDValue VCMLA = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, VT, A.getOperand(0), A.getOperand(1), + DAG.getNode(ISD::FADD, DL, VT, A.getOperand(2), B, N->getFlags()), + A.getOperand(3), A.getOperand(4)); + VCMLA->setFlags(A->getFlags()); + return VCMLA; + }; + if (SDValue R = ReassocComplex(LHS, RHS)) + return R; + if (SDValue R = ReassocComplex(RHS, LHS)) + return R; + + return SDValue(); +} + +static SDValue PerformFADDCombine(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *Subtarget) { + if (SDValue S = PerformFAddVSelectCombine(N, DAG, Subtarget)) + return S; + if (SDValue S = PerformFADDVCMLACombine(N, DAG)) + return S; + return SDValue(); +} + /// PerformVDIVCombine - VCVT (fixed-point to floating-point, Advanced SIMD) /// can replace combinations of VCVT (integer to floating-point) and VDIV /// when the VDIV has a constant operand that is a power of 2. @@ -18771,7 +18811,7 @@ case ISD::FP_TO_UINT: return PerformVCVTCombine(N, DCI.DAG, Subtarget); case ISD::FADD: - return PerformFAddVSelectCombine(N, DCI.DAG, Subtarget); + return PerformFADDCombine(N, DCI.DAG, Subtarget); case ISD::FDIV: return PerformVDIVCombine(N, DCI.DAG, Subtarget); case ISD::INTRINSIC_WO_CHAIN: diff --git a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll --- a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll +++ b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll @@ -391,16 +391,16 @@ ; CHECK-LABEL: mul_addequal: ; CHECK: @ %bb.0: @ %entry ; CHECK-NEXT: vmov d0, r0, r1 -; CHECK-NEXT: mov r1, sp -; CHECK-NEXT: vldrw.u32 q2, [r1] -; CHECK-NEXT: vmov d1, r2, r3 -; CHECK-NEXT: add r0, sp, #16 -; CHECK-NEXT: vcmul.f32 q3, q0, q2, #0 +; CHECK-NEXT: mov r0, sp +; CHECK-NEXT: add r1, sp, #16 ; CHECK-NEXT: vldrw.u32 q1, [r0] -; CHECK-NEXT: vcmla.f32 q3, q0, q2, #90 -; CHECK-NEXT: vadd.f32 q0, q3, q1 -; CHECK-NEXT: vmov r0, r1, d0 -; CHECK-NEXT: vmov r2, r3, d1 +; CHECK-NEXT: vmov d1, r2, r3 +; CHECK-NEXT: vldrw.u32 q2, [r1] +; CHECK-NEXT: vcmul.f32 q3, q0, q1, #0 +; CHECK-NEXT: vadd.f32 q2, q3, q2 +; CHECK-NEXT: vcmla.f32 q2, q0, q1, #90 +; CHECK-NEXT: vmov r0, r1, d4 +; CHECK-NEXT: vmov r2, r3, d5 ; CHECK-NEXT: bx lr entry: %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> diff --git a/llvm/test/CodeGen/Thumb2/mve-vcmla.ll b/llvm/test/CodeGen/Thumb2/mve-vcmla.ll --- a/llvm/test/CodeGen/Thumb2/mve-vcmla.ll +++ b/llvm/test/CodeGen/Thumb2/mve-vcmla.ll @@ -10,9 +10,7 @@ define arm_aapcs_vfpcc <4 x float> @reassoc_f32x4(<4 x float> %a, <4 x float> %b, <4 x float> %c) { ; CHECK-LABEL: reassoc_f32x4: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmov.i32 q3, #0x0 -; CHECK-NEXT: vcmla.f32 q3, q1, q2, #0 -; CHECK-NEXT: vadd.f32 q0, q3, q0 +; CHECK-NEXT: vcmla.f32 q0, q1, q2, #0 ; CHECK-NEXT: bx lr entry: %d = tail call <4 x float> @llvm.arm.mve.vcmlaq.v4f32(i32 0, <4 x float> zeroinitializer, <4 x float> %b, <4 x float> %c) @@ -23,9 +21,7 @@ define arm_aapcs_vfpcc <4 x float> @reassoc_c_f32x4(<4 x float> %a, <4 x float> %b, <4 x float> %c) { ; CHECK-LABEL: reassoc_c_f32x4: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmov.i32 q3, #0x0 -; CHECK-NEXT: vcmla.f32 q3, q1, q2, #90 -; CHECK-NEXT: vadd.f32 q0, q0, q3 +; CHECK-NEXT: vcmla.f32 q0, q1, q2, #90 ; CHECK-NEXT: bx lr entry: %d = tail call <4 x float> @llvm.arm.mve.vcmlaq.v4f32(i32 1, <4 x float> zeroinitializer, <4 x float> %b, <4 x float> %c) @@ -36,9 +32,7 @@ define arm_aapcs_vfpcc <8 x half> @reassoc_f16x4(<8 x half> %a, <8 x half> %b, <8 x half> %c) { ; CHECK-LABEL: reassoc_f16x4: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmov.i32 q3, #0x0 -; CHECK-NEXT: vcmla.f16 q3, q1, q2, #180 -; CHECK-NEXT: vadd.f16 q0, q3, q0 +; CHECK-NEXT: vcmla.f16 q0, q1, q2, #180 ; CHECK-NEXT: bx lr entry: %d = tail call <8 x half> @llvm.arm.mve.vcmlaq.v8f16(i32 2, <8 x half> zeroinitializer, <8 x half> %b, <8 x half> %c) @@ -49,9 +43,7 @@ define arm_aapcs_vfpcc <8 x half> @reassoc_c_f16x4(<8 x half> %a, <8 x half> %b, <8 x half> %c) { ; CHECK-LABEL: reassoc_c_f16x4: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmov.i32 q3, #0x0 -; CHECK-NEXT: vcmla.f16 q3, q1, q2, #270 -; CHECK-NEXT: vadd.f16 q0, q0, q3 +; CHECK-NEXT: vcmla.f16 q0, q1, q2, #270 ; CHECK-NEXT: bx lr entry: %d = tail call <8 x half> @llvm.arm.mve.vcmlaq.v8f16(i32 3, <8 x half> zeroinitializer, <8 x half> %b, <8 x half> %c)