diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -77,6 +77,7 @@ FADD_PRED, SDIV_PRED, UDIV_PRED, + FMA_PRED, SMIN_MERGE_OP1, UMIN_MERGE_OP1, SMAX_MERGE_OP1, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -948,6 +948,7 @@ setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); + setOperationAction(ISD::FMA, VT, Custom); } } @@ -1470,6 +1471,7 @@ MAKE_CASE(AArch64ISD::FADD_PRED) MAKE_CASE(AArch64ISD::FADDA_PRED) MAKE_CASE(AArch64ISD::FADDV_PRED) + MAKE_CASE(AArch64ISD::FMA_PRED) MAKE_CASE(AArch64ISD::FMAXV_PRED) MAKE_CASE(AArch64ISD::FMAXNMV_PRED) MAKE_CASE(AArch64ISD::FMINV_PRED) @@ -3455,6 +3457,8 @@ return LowerF128Call(Op, DAG, RTLIB::SUB_F128); case ISD::FMUL: return LowerF128Call(Op, DAG, RTLIB::MUL_F128); + case ISD::FMA: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED); case ISD::FDIV: return LowerF128Call(Op, DAG, RTLIB::DIV_F128); case ISD::FP_ROUND: diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -167,9 +167,15 @@ SDTCVecEltisVT<1,i1>, SDTCisSameAs<2,3> ]>; +def SDT_AArch64FMA : SDTypeProfile<1, 4, [ + SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>, SDTCisVec<4>, + SDTCVecEltisVT<1,i1>, SDTCisSameAs<2,3>, SDTCisSameAs<3,4> +]>; + // Predicated operations with the result of inactive lanes being unspecified. def AArch64add_p : SDNode<"AArch64ISD::ADD_PRED", SDT_AArch64Arith>; def AArch64fadd_p : SDNode<"AArch64ISD::FADD_PRED", SDT_AArch64Arith>; +def AArch64fma_p : SDNode<"AArch64ISD::FMA_PRED", SDT_AArch64FMA>; def AArch64sdiv_p : SDNode<"AArch64ISD::SDIV_PRED", SDT_AArch64Arith>; def AArch64udiv_p : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>; @@ -393,6 +399,16 @@ defm FNMAD_ZPmZZ : sve_fp_3op_p_zds_b<0b10, "fnmad", int_aarch64_sve_fnmad>; defm FNMSB_ZPmZZ : sve_fp_3op_p_zds_b<0b11, "fnmsb", int_aarch64_sve_fnmsb>; + // Add patterns for FMA where disabled lanes are undef. + // FIXME: Implement a pseudo so we can choose a better instruction after + // regalloc. + def : Pat<(nxv8f16 (AArch64fma_p nxv8i1:$P, nxv8f16:$Op1, nxv8f16:$Op2, nxv8f16:$Op3)), + (FMLA_ZPmZZ_H $P, $Op3, $Op1, $Op2)>; + def : Pat<(nxv4f32 (AArch64fma_p nxv4i1:$P, nxv4f32:$Op1, nxv4f32:$Op2, nxv4f32:$Op3)), + (FMLA_ZPmZZ_S $P, $Op3, $Op1, $Op2)>; + def : Pat<(nxv2f64 (AArch64fma_p nxv2i1:$P, nxv2f64:$Op1, nxv2f64:$Op2, nxv2f64:$Op3)), + (FMLA_ZPmZZ_D $P, $Op3, $Op1, $Op2)>; + defm FTMAD_ZZI : sve_fp_ftmad<"ftmad", int_aarch64_sve_ftmad_x>; defm FMLA_ZZZI : sve_fp_fma_by_indexed_elem<0b0, "fmla", int_aarch64_sve_fmla_lane>; diff --git a/llvm/test/CodeGen/AArch64/sve-fp.ll b/llvm/test/CodeGen/AArch64/sve-fp.ll --- a/llvm/test/CodeGen/AArch64/sve-fp.ll +++ b/llvm/test/CodeGen/AArch64/sve-fp.ll @@ -85,6 +85,56 @@ ret %res } +define @fma_half( %a, %b, %c) { +; CHECK-LABEL: fma_half: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: fmla z2.h, p0/m, z0.h, z1.h +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %r = call @llvm.fma.nxv8f16( %a, %b, %c) + ret %r +} +define @fma_float( %a, %b, %c) { +; CHECK-LABEL: fma_float: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fmla z2.s, p0/m, z0.s, z1.s +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %r = call @llvm.fma.nxv4f32( %a, %b, %c) + ret %r +} +define @fma_double_1( %a, %b, %c) { +; CHECK-LABEL: fma_double_1: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fmla z2.d, p0/m, z0.d, z1.d +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %r = call @llvm.fma.nxv2f64( %a, %b, %c) + ret %r +} +define @fma_double_2( %a, %b, %c) { +; CHECK-LABEL: fma_double_2: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fmla z2.d, p0/m, z1.d, z0.d +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %r = call @llvm.fma.nxv2f64( %b, %a, %c) + ret %r +} +define @fma_double_3( %a, %b, %c) { +; CHECK-LABEL: fma_double_3: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fmla z0.d, p0/m, z2.d, z1.d +; CHECK-NEXT: ret + %r = call @llvm.fma.nxv2f64( %c, %b, %a) + ret %r +} + define @frecps_h( %a, %b) { ; CHECK-LABEL: frecps_h: ; CHECK: // %bb.0: @@ -166,5 +216,9 @@ declare @llvm.aarch64.sve.frsqrts.x.nxv4f32(, ) declare @llvm.aarch64.sve.frsqrts.x.nxv2f64(, ) +declare @llvm.fma.nxv2f64(, , ) +declare @llvm.fma.nxv4f32(, , ) +declare @llvm.fma.nxv8f16(, , ) + ; Function Attrs: nounwind readnone declare double @llvm.aarch64.sve.faddv.nxv2f64(, ) #2