diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -75,9 +75,12 @@ // Arithmetic instructions ADD_PRED, FADD_PRED, + FDIV_PRED, + FMA_PRED, + FMUL_PRED, + FSUB_PRED, SDIV_PRED, UDIV_PRED, - FMA_PRED, SMIN_MERGE_OP1, UMIN_MERGE_OP1, SMAX_MERGE_OP1, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -948,7 +948,11 @@ setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); + setOperationAction(ISD::FADD, VT, Custom); + setOperationAction(ISD::FDIV, VT, Custom); setOperationAction(ISD::FMA, VT, Custom); + setOperationAction(ISD::FMUL, VT, Custom); + setOperationAction(ISD::FSUB, VT, Custom); } } @@ -1483,11 +1487,14 @@ MAKE_CASE(AArch64ISD::FADD_PRED) MAKE_CASE(AArch64ISD::FADDA_PRED) MAKE_CASE(AArch64ISD::FADDV_PRED) + MAKE_CASE(AArch64ISD::FDIV_PRED) MAKE_CASE(AArch64ISD::FMA_PRED) MAKE_CASE(AArch64ISD::FMAXV_PRED) MAKE_CASE(AArch64ISD::FMAXNMV_PRED) MAKE_CASE(AArch64ISD::FMINV_PRED) MAKE_CASE(AArch64ISD::FMINNMV_PRED) + MAKE_CASE(AArch64ISD::FMUL_PRED) + MAKE_CASE(AArch64ISD::FSUB_PRED) MAKE_CASE(AArch64ISD::NOT) MAKE_CASE(AArch64ISD::BIT) MAKE_CASE(AArch64ISD::CBZ) @@ -3468,16 +3475,23 @@ case ISD::UMULO: return LowerXALUO(Op, DAG); case ISD::FADD: - if (useSVEForFixedLengthVectorVT(Op.getValueType())) + if (Op.getValueType().isScalableVector() || + useSVEForFixedLengthVectorVT(Op.getValueType())) return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED); return LowerF128Call(Op, DAG, RTLIB::ADD_F128); case ISD::FSUB: + if (Op.getValueType().isScalableVector()) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED); return LowerF128Call(Op, DAG, RTLIB::SUB_F128); case ISD::FMUL: + if (Op.getValueType().isScalableVector()) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); return LowerF128Call(Op, DAG, RTLIB::MUL_F128); case ISD::FMA: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED); case ISD::FDIV: + if (Op.getValueType().isScalableVector()) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED); return LowerF128Call(Op, DAG, RTLIB::DIV_F128); case ISD::FP_ROUND: case ISD::STRICT_FP_ROUND: diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -175,7 +175,10 @@ // Predicated operations with the result of inactive lanes being unspecified. def AArch64add_p : SDNode<"AArch64ISD::ADD_PRED", SDT_AArch64Arith>; def AArch64fadd_p : SDNode<"AArch64ISD::FADD_PRED", SDT_AArch64Arith>; +def AArch64fdiv_p : SDNode<"AArch64ISD::FDIV_PRED", SDT_AArch64Arith>; def AArch64fma_p : SDNode<"AArch64ISD::FMA_PRED", SDT_AArch64FMA>; +def AArch64fmul_p : SDNode<"AArch64ISD::FMUL_PRED", SDT_AArch64Arith>; +def AArch64fsub_p : SDNode<"AArch64ISD::FSUB_PRED", SDT_AArch64Arith>; def AArch64sdiv_p : SDNode<"AArch64ISD::SDIV_PRED", SDT_AArch64Arith>; def AArch64udiv_p : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>; @@ -361,6 +364,9 @@ defm FDIV_ZPmZ : sve_fp_2op_p_zds<0b1101, "fdiv", "FDIV_ZPZZ", int_aarch64_sve_fdiv, DestructiveBinaryCommWithRev, "FDIVR_ZPmZ">; defm FADD_ZPZZ : sve_fp_bin_pred_hfd; + defm FSUB_ZPZZ : sve_fp_bin_pred_hfd; + defm FMUL_ZPZZ : sve_fp_bin_pred_hfd; + defm FDIV_ZPZZ : sve_fp_bin_pred_hfd; let Predicates = [HasSVE, UseExperimentalZeroingPseudos] in { defm FADD_ZPZZ : sve_fp_2op_p_zds_zeroing_hsd; @@ -377,10 +383,10 @@ defm FDIV_ZPZZ : sve_fp_2op_p_zds_zeroing_hsd; } - defm FADD_ZZZ : sve_fp_3op_u_zd<0b000, "fadd", fadd>; - defm FSUB_ZZZ : sve_fp_3op_u_zd<0b001, "fsub", fsub>; - defm FMUL_ZZZ : sve_fp_3op_u_zd<0b010, "fmul", fmul>; - defm FTSMUL_ZZZ : sve_fp_3op_u_zd_ftsmul<0b011, "ftsmul", int_aarch64_sve_ftsmul_x>; + defm FADD_ZZZ : sve_fp_3op_u_zd<0b000, "fadd", fadd, AArch64fadd_p>; + defm FSUB_ZZZ : sve_fp_3op_u_zd<0b001, "fsub", fsub, AArch64fsub_p>; + defm FMUL_ZZZ : sve_fp_3op_u_zd<0b010, "fmul", fmul, AArch64fmul_p>; + defm FTSMUL_ZZZ : sve_fp_3op_u_zd_ftsmul<0b011, "ftsmul", int_aarch64_sve_ftsmul_x>; defm FRECPS_ZZZ : sve_fp_3op_u_zd<0b110, "frecps", int_aarch64_sve_frecps_x>; defm FRSQRTS_ZZZ : sve_fp_3op_u_zd<0b111, "frsqrts", int_aarch64_sve_frsqrts_x>; @@ -404,8 +410,14 @@ // regalloc. def : Pat<(nxv8f16 (AArch64fma_p nxv8i1:$P, nxv8f16:$Op1, nxv8f16:$Op2, nxv8f16:$Op3)), (FMLA_ZPmZZ_H $P, $Op3, $Op1, $Op2)>; + def : Pat<(nxv4f16 (AArch64fma_p nxv4i1:$P, nxv4f16:$Op1, nxv4f16:$Op2, nxv4f16:$Op3)), + (FMLA_ZPmZZ_H $P, $Op3, $Op1, $Op2)>; + def : Pat<(nxv2f16 (AArch64fma_p nxv2i1:$P, nxv2f16:$Op1, nxv2f16:$Op2, nxv2f16:$Op3)), + (FMLA_ZPmZZ_H $P, $Op3, $Op1, $Op2)>; def : Pat<(nxv4f32 (AArch64fma_p nxv4i1:$P, nxv4f32:$Op1, nxv4f32:$Op2, nxv4f32:$Op3)), (FMLA_ZPmZZ_S $P, $Op3, $Op1, $Op2)>; + def : Pat<(nxv2f32 (AArch64fma_p nxv2i1:$P, nxv2f32:$Op1, nxv2f32:$Op2, nxv2f32:$Op3)), + (FMLA_ZPmZZ_S $P, $Op3, $Op1, $Op2)>; def : Pat<(nxv2f64 (AArch64fma_p nxv2i1:$P, nxv2f64:$Op1, nxv2f64:$Op2, nxv2f64:$Op3)), (FMLA_ZPmZZ_D $P, $Op3, $Op1, $Op2)>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -340,6 +340,12 @@ : Pat<(vtd (op vt1:$Op1, vt2:$Op2)), (inst $Op1, $Op2)>; +class SVE_2_Op_Pred_All_Active +: Pat<(vtd (op (pt (AArch64ptrue 31)), vt1:$Op1, vt2:$Op2)), + (inst $Op1, $Op2)>; + class SVE_2_Op_Pat_Reduce_To_Neon : Pat<(vtd (op vt1:$Op1, vt2:$Op2)), @@ -1665,7 +1671,8 @@ let Inst{4-0} = Zd; } -multiclass sve_fp_3op_u_zd opc, string asm, SDPatternOperator op> { +multiclass sve_fp_3op_u_zd opc, string asm, SDPatternOperator op, + SDPatternOperator predicated_op = null_frag> { def _H : sve_fp_3op_u_zd<0b01, opc, asm, ZPR16>; def _S : sve_fp_3op_u_zd<0b10, opc, asm, ZPR32>; def _D : sve_fp_3op_u_zd<0b11, opc, asm, ZPR64>; @@ -1674,6 +1681,9 @@ def : SVE_2_Op_Pat(NAME # _S)>; def : SVE_2_Op_Pat(NAME # _D)>; + def : SVE_2_Op_Pred_All_Active(NAME # _H)>; + def : SVE_2_Op_Pred_All_Active(NAME # _S)>; + def : SVE_2_Op_Pred_All_Active(NAME # _D)>; } multiclass sve_fp_3op_u_zd_ftsmul opc, string asm, SDPatternOperator op> { @@ -7804,7 +7814,10 @@ def _UNDEF_D : PredTwoOpPseudo; def : SVE_3_Op_Pat(NAME # _UNDEF_H)>; + def : SVE_3_Op_Pat(NAME # _UNDEF_H)>; + def : SVE_3_Op_Pat(NAME # _UNDEF_H)>; def : SVE_3_Op_Pat(NAME # _UNDEF_S)>; + def : SVE_3_Op_Pat(NAME # _UNDEF_S)>; def : SVE_3_Op_Pat(NAME # _UNDEF_D)>; } diff --git a/llvm/test/CodeGen/AArch64/sve-fp.ll b/llvm/test/CodeGen/AArch64/sve-fp.ll --- a/llvm/test/CodeGen/AArch64/sve-fp.ll +++ b/llvm/test/CodeGen/AArch64/sve-fp.ll @@ -5,8 +5,8 @@ ; If this check fails please read test/CodeGen/AArch64/README for instructions on how to resolve it. ; WARN-NOT: warning -define @fadd_h( %a, %b) { -; CHECK-LABEL: fadd_h: +define @fadd_nxv8f16( %a, %b) { +; CHECK-LABEL: fadd_nxv8f16: ; CHECK: // %bb.0: ; CHECK-NEXT: fadd z0.h, z0.h, z1.h ; CHECK-NEXT: ret @@ -14,8 +14,28 @@ ret %res } -define @fadd_s( %a, %b) { -; CHECK-LABEL: fadd_s: +define @fadd_nxv4f16( %a, %b) { +; CHECK-LABEL: fadd_nxv4f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fadd z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = fadd %a, %b + ret %res +} + +define @fadd_nxv2f16( %a, %b) { +; CHECK-LABEL: fadd_nxv2f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fadd z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = fadd %a, %b + ret %res +} + +define @fadd_nxv4f32( %a, %b) { +; CHECK-LABEL: fadd_nxv4f32: ; CHECK: // %bb.0: ; CHECK-NEXT: fadd z0.s, z0.s, z1.s ; CHECK-NEXT: ret @@ -23,8 +43,18 @@ ret %res } -define @fadd_d( %a, %b) { -; CHECK-LABEL: fadd_d: +define @fadd_nxv2f32( %a, %b) { +; CHECK-LABEL: fadd_nxv2f32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fadd z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %res = fadd %a, %b + ret %res +} + +define @fadd_nxv2f64( %a, %b) { +; CHECK-LABEL: fadd_nxv2f64: ; CHECK: // %bb.0: ; CHECK-NEXT: fadd z0.d, z0.d, z1.d ; CHECK-NEXT: ret @@ -32,8 +62,68 @@ ret %res } -define @fsub_h( %a, %b) { -; CHECK-LABEL: fsub_h: +define @fdiv_nxv8f16( %a, %b) { +; CHECK-LABEL: fdiv_nxv8f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: fdiv z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = fdiv %a, %b + ret %res +} + +define @fdiv_nxv4f16( %a, %b) { +; CHECK-LABEL: fdiv_nxv4f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fdiv z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = fdiv %a, %b + ret %res +} + +define @fdiv_nxv2f16( %a, %b) { +; CHECK-LABEL: fdiv_nxv2f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fdiv z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = fdiv %a, %b + ret %res +} + +define @fdiv_nxv4f32( %a, %b) { +; CHECK-LABEL: fdiv_nxv4f32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fdiv z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %res = fdiv %a, %b + ret %res +} + +define @fdiv_nxv2f32( %a, %b) { +; CHECK-LABEL: fdiv_nxv2f32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fdiv z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %res = fdiv %a, %b + ret %res +} + +define @fdiv_nxv2f64( %a, %b) { +; CHECK-LABEL: fdiv_nxv2f64: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fdiv z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %res = fdiv %a, %b + ret %res +} + +define @fsub_nxv8f16( %a, %b) { +; CHECK-LABEL: fsub_nxv8f16: ; CHECK: // %bb.0: ; CHECK-NEXT: fsub z0.h, z0.h, z1.h ; CHECK-NEXT: ret @@ -41,8 +131,28 @@ ret %res } -define @fsub_s( %a, %b) { -; CHECK-LABEL: fsub_s: +define @fsub_nxv4f16( %a, %b) { +; CHECK-LABEL: fsub_nxv4f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fsub z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = fsub %a, %b + ret %res +} + +define @fsub_nxv2f16( %a, %b) { +; CHECK-LABEL: fsub_nxv2f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fsub z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = fsub %a, %b + ret %res +} + +define @fsub_nxv4f32( %a, %b) { +; CHECK-LABEL: fsub_nxv4f32: ; CHECK: // %bb.0: ; CHECK-NEXT: fsub z0.s, z0.s, z1.s ; CHECK-NEXT: ret @@ -50,8 +160,18 @@ ret %res } -define @fsub_d( %a, %b) { -; CHECK-LABEL: fsub_d: +define @fsub_nxv2f32( %a, %b) { +; CHECK-LABEL: fsub_nxv2f32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fsub z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %res = fsub %a, %b + ret %res +} + +define @fsub_nxv2f64( %a, %b) { +; CHECK-LABEL: fsub_nxv2f64: ; CHECK: // %bb.0: ; CHECK-NEXT: fsub z0.d, z0.d, z1.d ; CHECK-NEXT: ret @@ -59,8 +179,8 @@ ret %res } -define @fmul_h( %a, %b) { -; CHECK-LABEL: fmul_h: +define @fmul_nxv8f16( %a, %b) { +; CHECK-LABEL: fmul_nxv8f16: ; CHECK: // %bb.0: ; CHECK-NEXT: fmul z0.h, z0.h, z1.h ; CHECK-NEXT: ret @@ -68,8 +188,28 @@ ret %res } -define @fmul_s( %a, %b) { -; CHECK-LABEL: fmul_s: +define @fmul_nxv4f16( %a, %b) { +; CHECK-LABEL: fmul_nxv4f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fmul z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = fmul %a, %b + ret %res +} + +define @fmul_nxv2f16( %a, %b) { +; CHECK-LABEL: fmul_nxv2f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fmul z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = fmul %a, %b + ret %res +} + +define @fmul_nxv4f32( %a, %b) { +; CHECK-LABEL: fmul_nxv4f32: ; CHECK: // %bb.0: ; CHECK-NEXT: fmul z0.s, z0.s, z1.s ; CHECK-NEXT: ret @@ -77,8 +217,18 @@ ret %res } -define @fmul_d( %a, %b) { -; CHECK-LABEL: fmul_d: +define @fmul_nxv2f32( %a, %b) { +; CHECK-LABEL: fmul_nxv2f32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fmul z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %res = fmul %a, %b + ret %res +} + +define @fmul_nxv2f64( %a, %b) { +; CHECK-LABEL: fmul_nxv2f64: ; CHECK: // %bb.0: ; CHECK-NEXT: fmul z0.d, z0.d, z1.d ; CHECK-NEXT: ret @@ -86,8 +236,8 @@ ret %res } -define @fma_half( %a, %b, %c) { -; CHECK-LABEL: fma_half: +define @fma_nxv8f16( %a, %b, %c) { +; CHECK-LABEL: fma_nxv8f16: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: fmla z2.h, p0/m, z0.h, z1.h @@ -96,8 +246,31 @@ %r = call @llvm.fma.nxv8f16( %a, %b, %c) ret %r } -define @fma_float( %a, %b, %c) { -; CHECK-LABEL: fma_float: + +define @fma_nxv4f16( %a, %b, %c) { +; CHECK-LABEL: fma_nxv4f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fmla z2.h, p0/m, z0.h, z1.h +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %r = call @llvm.fma.nxv4f16( %a, %b, %c) + ret %r +} + +define @fma_nxv2f16( %a, %b, %c) { +; CHECK-LABEL: fma_nxv2f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fmla z2.h, p0/m, z0.h, z1.h +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %r = call @llvm.fma.nxv2f16( %a, %b, %c) + ret %r +} + +define @fma_nxv4f32( %a, %b, %c) { +; CHECK-LABEL: fma_nxv4f32: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: fmla z2.s, p0/m, z0.s, z1.s @@ -106,8 +279,20 @@ %r = call @llvm.fma.nxv4f32( %a, %b, %c) ret %r } -define @fma_double_1( %a, %b, %c) { -; CHECK-LABEL: fma_double_1: + +define @fma_nxv2f32( %a, %b, %c) { +; CHECK-LABEL: fma_nxv2f32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fmla z2.s, p0/m, z0.s, z1.s +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %r = call @llvm.fma.nxv2f32( %a, %b, %c) + ret %r +} + +define @fma_nxv2f64_1( %a, %b, %c) { +; CHECK-LABEL: fma_nxv2f64_1: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: fmla z2.d, p0/m, z0.d, z1.d @@ -116,8 +301,9 @@ %r = call @llvm.fma.nxv2f64( %a, %b, %c) ret %r } -define @fma_double_2( %a, %b, %c) { -; CHECK-LABEL: fma_double_2: + +define @fma_nxv2f64_2( %a, %b, %c) { +; CHECK-LABEL: fma_nxv2f64_2: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: fmla z2.d, p0/m, z1.d, z0.d @@ -126,8 +312,9 @@ %r = call @llvm.fma.nxv2f64( %b, %a, %c) ret %r } -define @fma_double_3( %a, %b, %c) { -; CHECK-LABEL: fma_double_3: + +define @fma_nxv2f64_3( %a, %b, %c) { +; CHECK-LABEL: fma_nxv2f64_3: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: fmla z0.d, p0/m, z2.d, z1.d @@ -231,7 +418,10 @@ declare @llvm.fma.nxv2f64(, , ) declare @llvm.fma.nxv4f32(, , ) +declare @llvm.fma.nxv2f32(, , ) declare @llvm.fma.nxv8f16(, , ) +declare @llvm.fma.nxv4f16(, , ) +declare @llvm.fma.nxv2f16(, , ) ; Function Attrs: nounwind readnone declare double @llvm.aarch64.sve.faddv.nxv2f64(, ) #2