Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -103,6 +103,9 @@ CCMN, FCCMP, + //Uniary Floating Point Operation + FRINTP_PRED, + // Floating point comparison FCMP, @@ -866,7 +869,7 @@ SDValue LowerSPLAT_VECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerDUPQLane(SDValue Op, SelectionDAG &DAG) const; SDValue LowerToPredicatedOp(SDValue Op, SelectionDAG &DAG, - unsigned NewOp) const; + unsigned NewOp, bool Merging = false) const; SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerINSERT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const; Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -954,6 +954,7 @@ setOperationAction(ISD::FMA, VT, Custom); setOperationAction(ISD::FMUL, VT, Custom); setOperationAction(ISD::FSUB, VT, Custom); + setOperationAction(ISD::FCEIL, VT, Custom); } } @@ -1616,6 +1617,7 @@ MAKE_CASE(AArch64ISD::STNP) MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::INDEX_VECTOR) + MAKE_CASE(AArch64ISD::FRINTP_PRED) } #undef MAKE_CASE return nullptr; @@ -3436,12 +3438,17 @@ return SDValue(); } +static auto CreateNodeWithImplicitDef(SDValue Op, SelectionDAG &DAG){ + EVT VT = Op.getValueType(); + SDLoc DL(Op); + auto NewOperand = SDValue(DAG.getMachineNode(TargetOpcode::IMPLICIT_DEF, DL, VT), 0); + return DAG.getNode(Op.getOpcode(), DL, VT, NewOperand, Op->getOperand(0)); +} SDValue AArch64TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { LLVM_DEBUG(dbgs() << "Custom lowering: "); LLVM_DEBUG(Op.dump()); - switch (Op.getOpcode()) { default: llvm_unreachable("unimplemented operand"); @@ -3506,6 +3513,8 @@ if (Op.getValueType() == MVT::f128) return LowerF128Call(Op, DAG, RTLIB::DIV_F128); return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED); + case ISD::FCEIL: + return LowerToPredicatedOp(CreateNodeWithImplicitDef(Op, DAG), DAG, AArch64ISD::FRINTP_PRED, true); case ISD::FP_ROUND: case ISD::STRICT_FP_ROUND: return LowerFP_ROUND(Op, DAG); @@ -7934,6 +7943,15 @@ return DAG.getNode(ISD::BITCAST, DL, VT, TBL); } +static SDValue combineSVEPredIntrinsic(unsigned Opc, SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue Vector1 = N->getOperand(1); + SDValue Pred = N->getOperand(2); + SDValue Vector2 = N->getOperand(3); + + return DAG.getNode(Opc, DL, VT, Vector1, Pred, Vector2); +} static bool resolveBuildVector(BuildVectorSDNode *BVN, APInt &CnstBits, APInt &UndefBits) { @@ -12116,6 +12134,8 @@ case Intrinsic::aarch64_sve_ptest_last: return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2), AArch64CC::LAST_ACTIVE); + case Intrinsic::aarch64_sve_frintp: + return combineSVEPredIntrinsic(AArch64ISD::FRINTP_PRED, N, DAG); } return SDValue(); } @@ -15182,7 +15202,8 @@ SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op, SelectionDAG &DAG, - unsigned NewOp) const { + unsigned NewOp, + bool Merging) const { EVT VT = Op.getValueType(); SDLoc DL(Op); auto Pg = getPredicateForVector(DAG, DL, VT); @@ -15209,12 +15230,16 @@ assert(VT.isScalableVector() && "Only expect to lower scalable vector op!"); - SmallVector Operands = {Pg}; + SmallVector Operands; for (const SDValue &V : Op->op_values()) { assert((isa(V) || V.getValueType().isScalableVector()) && "Only scalable vectors are supported!"); Operands.push_back(V); } + if (Merging) + Operands.insert(Operands.begin() + 1, Pg); + else + Operands.insert(Operands.begin(), Pg); return DAG.getNode(NewOp, DL, VT, Operands); -} +} \ No newline at end of file Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -210,6 +210,9 @@ def reinterpret_cast : SDNode<"AArch64ISD::REINTERPRET_CAST", SDTUnaryOp>; +def SDT_AArch64PredUnFp : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>]>; +def AArch64frintp : SDNode<"AArch64ISD::FRINTP_PRED", SDT_AArch64PredUnFp>; + let Predicates = [HasSVE] in { defm RDFFR_PPz : sve_int_rdffr_pred<0b0, "rdffr", int_aarch64_sve_rdffr_z>; def RDFFRS_PPz : sve_int_rdffr_pred<0b1, "rdffrs">; @@ -1368,7 +1371,7 @@ defm FCVTZU_ZPmZ_DtoD : sve_fp_2op_p_zd<0b1111111, "fcvtzu", ZPR64, ZPR64, int_aarch64_sve_fcvtzu, nxv2i64, nxv2i1, nxv2f64, ElementSizeD>; defm FRINTN_ZPmZ : sve_fp_2op_p_zd_HSD<0b00000, "frintn", int_aarch64_sve_frintn>; - defm FRINTP_ZPmZ : sve_fp_2op_p_zd_HSD<0b00001, "frintp", int_aarch64_sve_frintp>; + defm FRINTP_ZPmZ : sve_fp_2op_p_zd_HSD<0b00001, "frintp", AArch64frintp>; defm FRINTM_ZPmZ : sve_fp_2op_p_zd_HSD<0b00010, "frintm", int_aarch64_sve_frintm>; defm FRINTZ_ZPmZ : sve_fp_2op_p_zd_HSD<0b00011, "frintz", int_aarch64_sve_frintz>; defm FRINTA_ZPmZ : sve_fp_2op_p_zd_HSD<0b00100, "frinta", int_aarch64_sve_frinta>; Index: llvm/test/CodeGen/AArch64/sve-fp.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-fp.ll +++ llvm/test/CodeGen/AArch64/sve-fp.ll @@ -408,6 +408,38 @@ ret void } +; FCEIL + +define @frintp_nxv8f16( %a) { +; CHECK-LABEL: frintp_nxv8f16: +; CHECK: ptrue p0.h +; CHECK-NEXT: frintp z0.h, p0/m, z0.h +; CHECK-NEXT: ret + %res = call @llvm.ceil.nxv8f16( %a) + ret %res + +} + +define @frintp_nxv4f32( %a) { +; CHECK-LABEL: frintp_nxv4f32: +; CHECK: ptrue p0.s +; CHECK-NEXT: frintp z0.s, p0/m, z0.s +; CHECK-NEXT: ret + %res = call @llvm.ceil.nxv4f32( %a) + ret %res + +} + +define @frintp_nxv2f64( %a) { +; CHECK-LABEL: frintp_nxv2f64: +; CHECK: ptrue p0.d +; CHECK-NEXT: frintp z0.d, p0/m, z0.d +; CHECK-NEXT: ret + %res = call @llvm.ceil.nxv2f64( %a) + ret %res + +} + declare @llvm.aarch64.sve.frecps.x.nxv8f16(, ) declare @llvm.aarch64.sve.frecps.x.nxv4f32( , ) declare @llvm.aarch64.sve.frecps.x.nxv2f64(, ) @@ -423,5 +455,9 @@ declare @llvm.fma.nxv4f16(, , ) declare @llvm.fma.nxv2f16(, , ) +declare @llvm.ceil.nxv8f16() +declare @llvm.ceil.nxv4f32() +declare @llvm.ceil.nxv2f64() + ; Function Attrs: nounwind readnone declare double @llvm.aarch64.sve.faddv.nxv2f64(, ) #2