diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -1737,7 +1737,16 @@ def int_aarch64_sve_fdivr : AdvSIMD_Pred2VectorArg_Intrinsic; def int_aarch64_sve_fexpa_x : AdvSIMD_SVE_EXPA_Intrinsic; def int_aarch64_sve_fmad : AdvSIMD_Pred3VectorArg_Intrinsic; + +def int_aarch64_sve_fmax_unpred : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_fmin_unpred : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_fmaxnm_unpred : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_fminnm_unpred : AdvSIMD_2VectorArg_Intrinsic; + def int_aarch64_sve_fmax : AdvSIMD_Pred2VectorArg_Intrinsic; + + + def int_aarch64_sve_fmaxnm : AdvSIMD_Pred2VectorArg_Intrinsic; def int_aarch64_sve_fmin : AdvSIMD_Pred2VectorArg_Intrinsic; def int_aarch64_sve_fminnm : AdvSIMD_Pred2VectorArg_Intrinsic; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -202,6 +202,11 @@ EORV_PRED, ANDV_PRED, + FMAX, + FMIN, + FMAXNM, + FMINNM, + // Vector bitwise negation NOT, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1571,8 +1571,11 @@ MAKE_CASE(AArch64ISD::STNP) MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::INDEX_VECTOR) + MAKE_CASE(AArch64ISD::FMAX) + MAKE_CASE(AArch64ISD::FMIN) + MAKE_CASE(AArch64ISD::FMAXNM) + MAKE_CASE(AArch64ISD::FMINNM) } -#undef MAKE_CASE return nullptr; } @@ -7856,6 +7859,31 @@ return DAG.getNode(ISD::BITCAST, DL, VT, TBL); } +static SDValue combineSVEPredIntrinsic(unsigned Opc, SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue Vector1 = N->getOperand(1); + SDValue Pred = N->getOperand(2); + SDValue Vector2 = N->getOperand(3); + + return DAG.getNode(Opc, DL, VT, Vector1, Pred, Vector2); +} + +static SDValue combineToSVEPred(SDNode *N, SelectionDAG &DAG, unsigned NewOp) { + EVT VT = N->getValueType(0); + SDLoc DL(N); + + assert(N->getOperand(1).getValueType().isScalableVector() && + N->getOperand(2).getValueType().isScalableVector() && + "Only scalable vectors are supported"); + + auto PredTy = + VT.getVectorVT(*DAG.getContext(), MVT::i1, VT.getVectorElementCount()); + SDValue Mask = getPTrue(DAG, DL, PredTy, AArch64SVEPredPattern::all); + + SmallVector Operands = {Mask, N->getOperand(1), N->getOperand(2)}; + return DAG.getNode(NewOp, DL, VT, Operands); +} static bool resolveBuildVector(BuildVectorSDNode *BVN, APInt &CnstBits, APInt &UndefBits) { @@ -11862,7 +11890,24 @@ case Intrinsic::aarch64_sve_ptest_last: return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2), AArch64CC::LAST_ACTIVE); + case Intrinsic::aarch64_sve_fmax: + return combineSVEPredIntrinsic(AArch64ISD::FMAX, N, DAG); + case Intrinsic::aarch64_sve_fmin: + return combineSVEPredIntrinsic(AArch64ISD::FMIN, N, DAG); + case Intrinsic::aarch64_sve_fmaxnm: + return combineSVEPredIntrinsic(AArch64ISD::FMAXNM, N, DAG); + case Intrinsic::aarch64_sve_fminnm: + return combineSVEPredIntrinsic(AArch64ISD::FMINNM, N, DAG); + case Intrinsic::aarch64_sve_fmax_unpred: + return combineToSVEPred(N, DAG, AArch64ISD::FMAX); + case Intrinsic::aarch64_sve_fmin_unpred: + return combineToSVEPred(N, DAG, AArch64ISD::FMIN); + case Intrinsic::aarch64_sve_fmaxnm_unpred: + return combineToSVEPred(N, DAG, AArch64ISD::FMAXNM); + case Intrinsic::aarch64_sve_fminnm_unpred: + return combineToSVEPred(N, DAG, AArch64ISD::FMINNM); } + return SDValue(); } diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -552,6 +552,7 @@ def AArch64smaxv : SDNode<"AArch64ISD::SMAXV", SDT_AArch64UnaryVec>; def AArch64umaxv : SDNode<"AArch64ISD::UMAXV", SDT_AArch64UnaryVec>; + def SDT_AArch64SETTAG : SDTypeProfile<0, 2, [SDTCisPtrTy<0>, SDTCisPtrTy<1>]>; def AArch64stg : SDNode<"AArch64ISD::STG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; def AArch64stzg : SDNode<"AArch64ISD::STZG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -201,6 +201,12 @@ def reinterpret_cast : SDNode<"AArch64ISD::REINTERPRET_CAST", SDTUnaryOp>; +def SDT_AArch64PredBinFp : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>]>; +def AArch64fmax : SDNode<"AArch64ISD::FMAX", SDT_AArch64PredBinFp>; +def AArch64fmin : SDNode<"AArch64ISD::FMIN", SDT_AArch64PredBinFp>; +def AArch64fmaxnm : SDNode<"AArch64ISD::FMAXNM", SDT_AArch64PredBinFp>; +def AArch64fminnm : SDNode<"AArch64ISD::FMINNM", SDT_AArch64PredBinFp>; + let Predicates = [HasSVE] in { defm RDFFR_PPz : sve_int_rdffr_pred<0b0, "rdffr", int_aarch64_sve_rdffr_z>; def RDFFRS_PPz : sve_int_rdffr_pred<0b1, "rdffrs">; @@ -342,10 +348,10 @@ defm FSUB_ZPmZ : sve_fp_2op_p_zds<0b0001, "fsub", "FSUB_ZPZZ", int_aarch64_sve_fsub, DestructiveBinaryCommWithRev, "FSUBR_ZPmZ", 1>; defm FMUL_ZPmZ : sve_fp_2op_p_zds<0b0010, "fmul", "FMUL_ZPZZ", int_aarch64_sve_fmul, DestructiveBinaryComm>; defm FSUBR_ZPmZ : sve_fp_2op_p_zds<0b0011, "fsubr", "FSUBR_ZPZZ", int_aarch64_sve_fsubr, DestructiveBinaryCommWithRev, "FSUB_ZPmZ", 0>; - defm FMAXNM_ZPmZ : sve_fp_2op_p_zds<0b0100, "fmaxnm", "FMAXNM_ZPZZ", int_aarch64_sve_fmaxnm, DestructiveBinaryComm>; - defm FMINNM_ZPmZ : sve_fp_2op_p_zds<0b0101, "fminnm", "FMINNM_ZPZZ", int_aarch64_sve_fminnm, DestructiveBinaryComm>; - defm FMAX_ZPmZ : sve_fp_2op_p_zds<0b0110, "fmax", "FMAX_ZPZZ", int_aarch64_sve_fmax, DestructiveBinaryComm>; - defm FMIN_ZPmZ : sve_fp_2op_p_zds<0b0111, "fmin", "FMIN_ZPZZ", int_aarch64_sve_fmin, DestructiveBinaryComm>; + defm FMAXNM_ZPmZ : sve_fp_2op_p_zds<0b0100, "fmaxnm", "FMAXNM_ZPZZ", AArch64fmaxnm, DestructiveBinaryComm>; + defm FMINNM_ZPmZ : sve_fp_2op_p_zds<0b0101, "fminnm", "FMINNM_ZPZZ", AArch64fminnm, DestructiveBinaryComm>; + defm FMAX_ZPmZ : sve_fp_2op_p_zds<0b0110, "fmax", "FMAX_ZPZZ", AArch64fmax, DestructiveBinaryComm>; + defm FMIN_ZPmZ : sve_fp_2op_p_zds<0b0111, "fmin", "FMIN_ZPZZ", AArch64fmin, DestructiveBinaryComm>; defm FABD_ZPmZ : sve_fp_2op_p_zds<0b1000, "fabd", "FABD_ZPZZ", int_aarch64_sve_fabd, DestructiveBinaryComm>; defm FSCALE_ZPmZ : sve_fp_2op_p_zds_fscale<0b1001, "fscale", int_aarch64_sve_fscale>; defm FMULX_ZPmZ : sve_fp_2op_p_zds<0b1010, "fmulx", "FMULX_ZPZZ", int_aarch64_sve_fmulx, DestructiveBinaryComm>; @@ -358,10 +364,10 @@ defm FSUB_ZPZZ : sve_fp_2op_p_zds_zx; defm FMUL_ZPZZ : sve_fp_2op_p_zds_zx; defm FSUBR_ZPZZ : sve_fp_2op_p_zds_zx; - defm FMAXNM_ZPZZ : sve_fp_2op_p_zds_zx; - defm FMINNM_ZPZZ : sve_fp_2op_p_zds_zx; - defm FMAX_ZPZZ : sve_fp_2op_p_zds_zx; - defm FMIN_ZPZZ : sve_fp_2op_p_zds_zx; + defm FMAXNM_ZPZZ : sve_fp_2op_p_zds_zx; + defm FMINNM_ZPZZ : sve_fp_2op_p_zds_zx; + defm FMAX_ZPZZ : sve_fp_2op_p_zds_zx; + defm FMIN_ZPZZ : sve_fp_2op_p_zds_zx; defm FABD_ZPZZ : sve_fp_2op_p_zds_zx; defm FMULX_ZPZZ : sve_fp_2op_p_zds_zx; defm FDIVR_ZPZZ : sve_fp_2op_p_zds_zx; diff --git a/llvm/test/CodeGen/AArch64/sve-fp-unpred-to-pred.ll b/llvm/test/CodeGen/AArch64/sve-fp-unpred-to-pred.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-fp-unpred-to-pred.ll @@ -0,0 +1,144 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; FMAX + +define @fmax_half( %a, %b) { +; CHECK-LABEL: fmax_half: +; CHECK: fmax z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = call @llvm.aarch64.sve.fmax.unpred.nxv8f16( + %a, + %b) + ret %res +} + +define @fmax_float( %a, %b) { +; CHECK-LABEL: fmax_float: +; CHECK: fmax z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %res = call @llvm.aarch64.sve.fmax.unpred.nxv4f32( + %a, + %b) + ret %res +} + +define @fmax_double( %a, %b) { +; CHECK-LABEL: fmax_double: +; CHECK: fmax z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %res = call @llvm.aarch64.sve.fmax.unpred.nxv2f64( + %a, + %b) + ret %res +} + +; FMINNM + +define @fmin_half( %a, %b) { +; CHECK-LABEL: fmin_half: +; CHECK: fmin z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = call @llvm.aarch64.sve.fmin.unpred.nxv8f16( + %a, + %b) + ret %res +} + +define @fmin_float( %a, %b) { +; CHECK-LABEL: fmin_float: +; CHECK: fmin z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %res = call @llvm.aarch64.sve.fmin.unpred.nxv4f32( + %a, + %b) + ret %res +} + +define @fmin_double( %a, %b) { +; CHECK-LABEL: fmin_double: +; CHECK: fmin z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %res = call @llvm.aarch64.sve.fmin.unpred.nxv2f64( + %a, + %b) + ret %res +} + +; FMAXNM + +define @fmaxnm_half( %a, %b) { +; CHECK-LABEL: fmaxnm_half: +; CHECK: fmaxnm z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = call @llvm.aarch64.sve.fmaxnm.unpred.nxv8f16( + %a, + %b) + ret %res +} + +define @fmaxnm_float( %a, %b) { +; CHECK-LABEL: fmaxnm_float: +; CHECK: fmaxnm z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %res = call @llvm.aarch64.sve.fmaxnm.unpred.nxv4f32( + %a, + %b) + ret %res +} + +define @fmaxnm_double( %a, %b) { +; CHECK-LABEL: fmaxnm_double: +; CHECK: fmaxnm z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %res = call @llvm.aarch64.sve.fmaxnm.unpred.nxv2f64( + %a, + %b) + ret %res +} + +; FMINNM + +define @fminnm_half( %a, %b) { +; CHECK-LABEL: fminnm_half: +; CHECK: fminnm z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %res = call @llvm.aarch64.sve.fminnm.unpred.nxv8f16( + %a, + %b) + ret %res +} + +define @fminnm_float( %a, %b) { +; CHECK-LABEL: fminnm_float: +; CHECK: fminnm z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %res = call @llvm.aarch64.sve.fminnm.unpred.nxv4f32( + %a, + %b) + ret %res +} + +define @fminnm_double( %a, %b) { +; CHECK-LABEL: fminnm_double: +; CHECK: fminnm z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %res = call @llvm.aarch64.sve.fminnm.unpred.nxv2f64( + %a, + %b) + ret %res +} +declare @llvm.aarch64.sve.fmax.unpred.nxv8f16(, ) +declare @llvm.aarch64.sve.fmax.unpred.nxv4f32(, ) +declare @llvm.aarch64.sve.fmax.unpred.nxv2f64(, ) + +declare @llvm.aarch64.sve.fmin.unpred.nxv8f16(, ) +declare @llvm.aarch64.sve.fmin.unpred.nxv4f32(, ) +declare @llvm.aarch64.sve.fmin.unpred.nxv2f64(, ) + +declare @llvm.aarch64.sve.fmaxnm.unpred.nxv8f16(, ) +declare @llvm.aarch64.sve.fmaxnm.unpred.nxv4f32(, ) +declare @llvm.aarch64.sve.fmaxnm.unpred.nxv2f64(, ) + +declare @llvm.aarch64.sve.fminnm.unpred.nxv8f16(, ) +declare @llvm.aarch64.sve.fminnm.unpred.nxv4f32(, ) +declare @llvm.aarch64.sve.fminnm.unpred.nxv2f64(, )