diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -878,14 +878,11 @@ SelectionDAG &DAG) const; SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG) const; SDValue LowerIS_FPCLASS(SDValue Op, SelectionDAG &DAG) const; - SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG, unsigned RISCVISDOpc, - bool HasMergeOp = false) const; - SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG, unsigned MaskOpc, - unsigned VecOpc) const; + SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG) const; SDValue lowerVPExtMaskOp(SDValue Op, SelectionDAG &DAG) const; SDValue lowerVPSetCCMaskOp(SDValue Op, SelectionDAG &DAG) const; - SDValue lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG, - unsigned RISCVISDOpc) const; + SDValue lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG) const; SDValue lowerVPStridedLoad(SDValue Op, SelectionDAG &DAG) const; SDValue lowerVPStridedStore(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -4664,10 +4664,13 @@ #define OP_CASE(NODE) \ case ISD::NODE: \ return RISCVISD::NODE##_VL; +#define VP_CASE(NODE) \ + case ISD::VP_##NODE: \ + return RISCVISD::NODE##_VL; + // clang-format off switch (Op.getOpcode()) { default: llvm_unreachable("don't have RISC-V specified VL op for this SDNode"); - // clang-format off OP_CASE(ADD) OP_CASE(SUB) OP_CASE(MUL) @@ -4702,25 +4705,81 @@ OP_CASE(STRICT_FMUL) OP_CASE(STRICT_FDIV) OP_CASE(STRICT_FSQRT) - // clang-format on -#undef OP_CASE + VP_CASE(ADD) // VP_ADD + VP_CASE(SUB) // VP_SUB + VP_CASE(MUL) // VP_MUL + VP_CASE(SDIV) // VP_SDIV + VP_CASE(SREM) // VP_SREM + VP_CASE(UDIV) // VP_UDIV + VP_CASE(UREM) // VP_UREM + VP_CASE(SHL) // VP_SHL + VP_CASE(FADD) // VP_FADD + VP_CASE(FSUB) // VP_FSUB + VP_CASE(FMUL) // VP_FMUL + VP_CASE(FDIV) // VP_FDIV + VP_CASE(FNEG) // VP_FNEG + VP_CASE(FABS) // VP_FABS + VP_CASE(SMIN) // VP_SMIN + VP_CASE(SMAX) // VP_SMAX + VP_CASE(UMIN) // VP_UMIN + VP_CASE(UMAX) // VP_UMAX + VP_CASE(FMINNUM) // VP_FMINNUM + VP_CASE(FMAXNUM) // VP_FMAXNUM + VP_CASE(FCOPYSIGN) // VP_FCOPYSIGN + VP_CASE(SETCC) // VP_SETCC + VP_CASE(SINT_TO_FP) // VP_SINT_TO_FP + VP_CASE(UINT_TO_FP) // VP_UINT_TO_FP + VP_CASE(BITREVERSE) // VP_BITREVERSE + VP_CASE(BSWAP) // VP_BSWAP + VP_CASE(CTLZ) // VP_CTLZ + VP_CASE(CTTZ) // VP_CTTZ + VP_CASE(CTPOP) // VP_CTPOP + case ISD::VP_CTLZ_ZERO_UNDEF: + return RISCVISD::CTLZ_VL; + case ISD::VP_CTTZ_ZERO_UNDEF: + return RISCVISD::CTTZ_VL; case ISD::FMA: + case ISD::VP_FMA: return RISCVISD::VFMADD_VL; case ISD::STRICT_FMA: return RISCVISD::STRICT_VFMADD_VL; case ISD::AND: + case ISD::VP_AND: if (Op.getSimpleValueType().getVectorElementType() == MVT::i1) return RISCVISD::VMAND_VL; return RISCVISD::AND_VL; case ISD::OR: + case ISD::VP_OR: if (Op.getSimpleValueType().getVectorElementType() == MVT::i1) return RISCVISD::VMOR_VL; return RISCVISD::OR_VL; case ISD::XOR: + case ISD::VP_XOR: if (Op.getSimpleValueType().getVectorElementType() == MVT::i1) return RISCVISD::VMXOR_VL; return RISCVISD::XOR_VL; + case ISD::VP_SELECT: + return RISCVISD::VSELECT_VL; + case ISD::VP_MERGE: + return RISCVISD::VP_MERGE_VL; + case ISD::VP_ASHR: + return RISCVISD::SRA_VL; + case ISD::VP_LSHR: + return RISCVISD::SRL_VL; + case ISD::VP_SQRT: + return RISCVISD::FSQRT_VL; + case ISD::VP_SIGN_EXTEND: + return RISCVISD::VSEXT_VL; + case ISD::VP_ZERO_EXTEND: + return RISCVISD::VZEXT_VL; + case ISD::VP_FP_TO_SINT: + return RISCVISD::VFCVT_RTZ_X_F_VL; + case ISD::VP_FP_TO_UINT: + return RISCVISD::VFCVT_RTZ_XU_F_VL; } + // clang-format on +#undef OP_CASE +#undef VP_CASE } /// Return true if a RISC-V target specified op has a merge operand. @@ -4739,6 +4798,8 @@ return true; if (Opcode >= RISCVISD::VWMUL_VL && Opcode <= RISCVISD::VFWSUB_W_VL) return true; + if (Opcode == RISCVISD::SETCC_VL) + return true; if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL) return true; return false; @@ -5476,106 +5537,72 @@ case ISD::EH_DWARF_CFA: return lowerEH_DWARF_CFA(Op, DAG); case ISD::VP_SELECT: - return lowerVPOp(Op, DAG, RISCVISD::VSELECT_VL); case ISD::VP_MERGE: - return lowerVPOp(Op, DAG, RISCVISD::VP_MERGE_VL); case ISD::VP_ADD: - return lowerVPOp(Op, DAG, RISCVISD::ADD_VL, /*HasMergeOp*/ true); case ISD::VP_SUB: - return lowerVPOp(Op, DAG, RISCVISD::SUB_VL, /*HasMergeOp*/ true); case ISD::VP_MUL: - return lowerVPOp(Op, DAG, RISCVISD::MUL_VL, /*HasMergeOp*/ true); case ISD::VP_SDIV: - return lowerVPOp(Op, DAG, RISCVISD::SDIV_VL, /*HasMergeOp*/ true); case ISD::VP_UDIV: - return lowerVPOp(Op, DAG, RISCVISD::UDIV_VL, /*HasMergeOp*/ true); case ISD::VP_SREM: - return lowerVPOp(Op, DAG, RISCVISD::SREM_VL, /*HasMergeOp*/ true); case ISD::VP_UREM: - return lowerVPOp(Op, DAG, RISCVISD::UREM_VL, /*HasMergeOp*/ true); + return lowerVPOp(Op, DAG); case ISD::VP_AND: - return lowerLogicVPOp(Op, DAG, RISCVISD::VMAND_VL, RISCVISD::AND_VL); case ISD::VP_OR: - return lowerLogicVPOp(Op, DAG, RISCVISD::VMOR_VL, RISCVISD::OR_VL); case ISD::VP_XOR: - return lowerLogicVPOp(Op, DAG, RISCVISD::VMXOR_VL, RISCVISD::XOR_VL); + return lowerLogicVPOp(Op, DAG); case ISD::VP_ASHR: - return lowerVPOp(Op, DAG, RISCVISD::SRA_VL, /*HasMergeOp*/ true); case ISD::VP_LSHR: - return lowerVPOp(Op, DAG, RISCVISD::SRL_VL, /*HasMergeOp*/ true); case ISD::VP_SHL: - return lowerVPOp(Op, DAG, RISCVISD::SHL_VL, /*HasMergeOp*/ true); case ISD::VP_FADD: - return lowerVPOp(Op, DAG, RISCVISD::FADD_VL, /*HasMergeOp*/ true); case ISD::VP_FSUB: - return lowerVPOp(Op, DAG, RISCVISD::FSUB_VL, /*HasMergeOp*/ true); case ISD::VP_FMUL: - return lowerVPOp(Op, DAG, RISCVISD::FMUL_VL, /*HasMergeOp*/ true); case ISD::VP_FDIV: - return lowerVPOp(Op, DAG, RISCVISD::FDIV_VL, /*HasMergeOp*/ true); case ISD::VP_FNEG: - return lowerVPOp(Op, DAG, RISCVISD::FNEG_VL); case ISD::VP_FABS: - return lowerVPOp(Op, DAG, RISCVISD::FABS_VL); case ISD::VP_SQRT: - return lowerVPOp(Op, DAG, RISCVISD::FSQRT_VL); case ISD::VP_FMA: - return lowerVPOp(Op, DAG, RISCVISD::VFMADD_VL); case ISD::VP_FMINNUM: - return lowerVPOp(Op, DAG, RISCVISD::FMINNUM_VL, /*HasMergeOp*/ true); case ISD::VP_FMAXNUM: - return lowerVPOp(Op, DAG, RISCVISD::FMAXNUM_VL, /*HasMergeOp*/ true); case ISD::VP_FCOPYSIGN: - return lowerVPOp(Op, DAG, RISCVISD::FCOPYSIGN_VL, /*HasMergeOp*/ true); + return lowerVPOp(Op, DAG); case ISD::VP_SIGN_EXTEND: case ISD::VP_ZERO_EXTEND: if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1) return lowerVPExtMaskOp(Op, DAG); - return lowerVPOp(Op, DAG, - Op.getOpcode() == ISD::VP_SIGN_EXTEND - ? RISCVISD::VSEXT_VL - : RISCVISD::VZEXT_VL); + return lowerVPOp(Op, DAG); case ISD::VP_TRUNCATE: return lowerVectorTruncLike(Op, DAG); case ISD::VP_FP_EXTEND: case ISD::VP_FP_ROUND: return lowerVectorFPExtendOrRoundLike(Op, DAG); case ISD::VP_FP_TO_SINT: - return lowerVPFPIntConvOp(Op, DAG, RISCVISD::VFCVT_RTZ_X_F_VL); case ISD::VP_FP_TO_UINT: - return lowerVPFPIntConvOp(Op, DAG, RISCVISD::VFCVT_RTZ_XU_F_VL); case ISD::VP_SINT_TO_FP: - return lowerVPFPIntConvOp(Op, DAG, RISCVISD::SINT_TO_FP_VL); case ISD::VP_UINT_TO_FP: - return lowerVPFPIntConvOp(Op, DAG, RISCVISD::UINT_TO_FP_VL); + return lowerVPFPIntConvOp(Op, DAG); case ISD::VP_SETCC: if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1) return lowerVPSetCCMaskOp(Op, DAG); - return lowerVPOp(Op, DAG, RISCVISD::SETCC_VL, /*HasMergeOp*/ true); + [[fallthrough]]; case ISD::VP_SMIN: - return lowerVPOp(Op, DAG, RISCVISD::SMIN_VL, /*HasMergeOp*/ true); case ISD::VP_SMAX: - return lowerVPOp(Op, DAG, RISCVISD::SMAX_VL, /*HasMergeOp*/ true); case ISD::VP_UMIN: - return lowerVPOp(Op, DAG, RISCVISD::UMIN_VL, /*HasMergeOp*/ true); case ISD::VP_UMAX: - return lowerVPOp(Op, DAG, RISCVISD::UMAX_VL, /*HasMergeOp*/ true); case ISD::VP_BITREVERSE: - return lowerVPOp(Op, DAG, RISCVISD::BITREVERSE_VL, /*HasMergeOp*/ true); case ISD::VP_BSWAP: - return lowerVPOp(Op, DAG, RISCVISD::BSWAP_VL, /*HasMergeOp*/ true); + return lowerVPOp(Op, DAG); case ISD::VP_CTLZ: case ISD::VP_CTLZ_ZERO_UNDEF: if (Subtarget.hasStdExtZvbb()) - return lowerVPOp(Op, DAG, RISCVISD::CTLZ_VL, /*HasMergeOp*/ true); + return lowerVPOp(Op, DAG); return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG); case ISD::VP_CTTZ: case ISD::VP_CTTZ_ZERO_UNDEF: if (Subtarget.hasStdExtZvbb()) - return lowerVPOp(Op, DAG, RISCVISD::CTTZ_VL, /*HasMergeOp*/ true); + return lowerVPOp(Op, DAG); return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG); case ISD::VP_CTPOP: - return lowerVPOp(Op, DAG, RISCVISD::CTPOP_VL, /*HasMergeOp*/ true); + return lowerVPOp(Op, DAG); case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: return lowerVPStridedLoad(Op, DAG); case ISD::EXPERIMENTAL_VP_STRIDED_STORE: @@ -8827,9 +8854,10 @@ // * The EVL operand is promoted from i32 to i64 on RV64. // * Fixed-length vectors are converted to their scalable-vector container // types. -SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG, - unsigned RISCVISDOpc, - bool HasMergeOp) const { +SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG) const { + unsigned RISCVISDOpc = getRISCVVLOp(Op); + bool HasMergeOp = hasMergeOp(RISCVISDOpc); + SDLoc DL(Op); MVT VT = Op.getSimpleValueType(); SmallVector Ops; @@ -8978,13 +9006,14 @@ } // Lower Floating-Point/Integer Type-Convert VP SDNodes -SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG, - unsigned RISCVISDOpc) const { +SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, + SelectionDAG &DAG) const { SDLoc DL(Op); SDValue Src = Op.getOperand(0); SDValue Mask = Op.getOperand(1); SDValue VL = Op.getOperand(2); + unsigned RISCVISDOpc = getRISCVVLOp(Op); MVT DstVT = Op.getSimpleValueType(); MVT SrcVT = Src.getSimpleValueType(); @@ -9110,12 +9139,11 @@ return convertFromScalableVector(VT, Result, DAG, Subtarget); } -SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op, SelectionDAG &DAG, - unsigned MaskOpc, - unsigned VecOpc) const { +SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op, + SelectionDAG &DAG) const { MVT VT = Op.getSimpleValueType(); if (VT.getVectorElementType() != MVT::i1) - return lowerVPOp(Op, DAG, VecOpc, true); + return lowerVPOp(Op, DAG); // It is safe to drop mask parameter as masked-off elements are undef. SDValue Op1 = Op->getOperand(0); @@ -9131,7 +9159,7 @@ } SDLoc DL(Op); - SDValue Val = DAG.getNode(MaskOpc, DL, ContainerVT, Op1, Op2, VL); + SDValue Val = DAG.getNode(getRISCVVLOp(Op), DL, ContainerVT, Op1, Op2, VL); if (!IsFixed) return Val; return convertFromScalableVector(VT, Val, DAG, Subtarget);