diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -838,14 +838,9 @@ SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const; - SDValue lowerFixedLengthVectorLogicOpToRVV(SDValue Op, SelectionDAG &DAG, - unsigned MaskOpc, - unsigned VecOpc) const; - SDValue lowerFixedLengthVectorShiftToRVV(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFixedLengthVectorSelectToRVV(SDValue Op, SelectionDAG &DAG) const; - SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG, unsigned NewOpc, - bool HasMergeOp = false, bool HasMask = true) const; + SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG) const; SDValue LowerIS_FPCLASS(SDValue Op, SelectionDAG &DAG) const; SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG, unsigned RISCVISDOpc, bool HasMergeOp = false) const; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -4564,6 +4564,105 @@ ISD::CondCode::SETNE); } +/// Get a RISCV target specified VL op for a given SDNode. +static unsigned getRISCVVLOp(SDValue Op) { +#define OP_CASE(NODE) \ + case ISD::NODE: \ + return RISCVISD::NODE##_VL; + switch (Op.getOpcode()) { + default: + llvm_unreachable("don't have RISC-V specified VL op for this SDNode"); + // clang-format off + OP_CASE(ADD) + OP_CASE(SUB) + OP_CASE(MUL) + OP_CASE(MULHS) + OP_CASE(MULHU) + OP_CASE(SDIV) + OP_CASE(SREM) + OP_CASE(UDIV) + OP_CASE(UREM) + OP_CASE(SHL) + OP_CASE(SRA) + OP_CASE(SRL) + OP_CASE(SADDSAT) + OP_CASE(UADDSAT) + OP_CASE(SSUBSAT) + OP_CASE(USUBSAT) + OP_CASE(FADD) + OP_CASE(FSUB) + OP_CASE(FMUL) + OP_CASE(FDIV) + OP_CASE(FNEG) + OP_CASE(FABS) + OP_CASE(FSQRT) + OP_CASE(SMIN) + OP_CASE(SMAX) + OP_CASE(UMIN) + OP_CASE(UMAX) + OP_CASE(FMINNUM) + OP_CASE(FMAXNUM) + OP_CASE(STRICT_FADD) + OP_CASE(STRICT_FSUB) + OP_CASE(STRICT_FMUL) + OP_CASE(STRICT_FDIV) + OP_CASE(STRICT_FSQRT) + // clang-format on +#undef OP_CASE + case ISD::FMA: + return RISCVISD::VFMADD_VL; + case ISD::STRICT_FMA: + return RISCVISD::STRICT_VFMADD_VL; + case ISD::AND: + if (Op.getSimpleValueType().getVectorElementType() == MVT::i1) + return RISCVISD::VMAND_VL; + return RISCVISD::AND_VL; + case ISD::OR: + if (Op.getSimpleValueType().getVectorElementType() == MVT::i1) + return RISCVISD::VMOR_VL; + return RISCVISD::OR_VL; + case ISD::XOR: + if (Op.getSimpleValueType().getVectorElementType() == MVT::i1) + return RISCVISD::VMXOR_VL; + return RISCVISD::XOR_VL; + } +} + +/// Return true if a RISC-V target specified op has a merge operand. +static bool hasMergeOp(unsigned Opcode) { + assert(Opcode > RISCVISD::FIRST_NUMBER && + Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL && + "not a RISC-V target specific op"); + assert(RISCVISD::STRICT_VFROUND_NOEXCEPT_VL - RISCVISD::FIRST_NUMBER == 421 && + "adding target specific op should update this function"); + if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::FMAXNUM_VL) + return true; + if (Opcode == RISCVISD::FCOPYSIGN_VL) + return true; + if (Opcode >= RISCVISD::VWMUL_VL && Opcode <= RISCVISD::VFWSUB_W_VL) + return true; + if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL) + return true; + return false; +} + +/// Return true if a RISC-V target specified op has a mask operand. +static bool hasMaskOp(unsigned Opcode) { + assert(Opcode > RISCVISD::FIRST_NUMBER && + Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL && + "not a RISC-V target specific op"); + assert(RISCVISD::STRICT_VFROUND_NOEXCEPT_VL - RISCVISD::FIRST_NUMBER == 421 && + "adding target specific op should update this function"); + if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL) + return true; + if (Opcode >= RISCVISD::VRGATHER_VX_VL && Opcode <= RISCVISD::VFIRST_VL) + return true; + if (Opcode >= RISCVISD::STRICT_FADD_VL && + Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL) + return true; + return false; +} + SDValue RISCVTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -5154,83 +5253,46 @@ return lowerFixedLengthVectorSetccToRVV(Op, DAG); } case ISD::ADD: - return lowerToScalableOp(Op, DAG, RISCVISD::ADD_VL, /*HasMergeOp*/ true); case ISD::SUB: - return lowerToScalableOp(Op, DAG, RISCVISD::SUB_VL, /*HasMergeOp*/ true); case ISD::MUL: - return lowerToScalableOp(Op, DAG, RISCVISD::MUL_VL, /*HasMergeOp*/ true); case ISD::MULHS: - return lowerToScalableOp(Op, DAG, RISCVISD::MULHS_VL, /*HasMergeOp*/ true); case ISD::MULHU: - return lowerToScalableOp(Op, DAG, RISCVISD::MULHU_VL, /*HasMergeOp*/ true); case ISD::AND: - return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMAND_VL, - RISCVISD::AND_VL); case ISD::OR: - return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMOR_VL, - RISCVISD::OR_VL); case ISD::XOR: - return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMXOR_VL, - RISCVISD::XOR_VL); case ISD::SDIV: - return lowerToScalableOp(Op, DAG, RISCVISD::SDIV_VL, /*HasMergeOp*/ true); case ISD::SREM: - return lowerToScalableOp(Op, DAG, RISCVISD::SREM_VL, /*HasMergeOp*/ true); case ISD::UDIV: - return lowerToScalableOp(Op, DAG, RISCVISD::UDIV_VL, /*HasMergeOp*/ true); case ISD::UREM: - return lowerToScalableOp(Op, DAG, RISCVISD::UREM_VL, /*HasMergeOp*/ true); + return lowerToScalableOp(Op, DAG); case ISD::SHL: case ISD::SRA: case ISD::SRL: if (Op.getSimpleValueType().isFixedLengthVector()) - return lowerFixedLengthVectorShiftToRVV(Op, DAG); + return lowerToScalableOp(Op, DAG); // This can be called for an i32 shift amount that needs to be promoted. assert(Op.getOperand(1).getValueType() == MVT::i32 && Subtarget.is64Bit() && "Unexpected custom legalisation"); return SDValue(); case ISD::SADDSAT: - return lowerToScalableOp(Op, DAG, RISCVISD::SADDSAT_VL, - /*HasMergeOp*/ true); case ISD::UADDSAT: - return lowerToScalableOp(Op, DAG, RISCVISD::UADDSAT_VL, - /*HasMergeOp*/ true); case ISD::SSUBSAT: - return lowerToScalableOp(Op, DAG, RISCVISD::SSUBSAT_VL, - /*HasMergeOp*/ true); case ISD::USUBSAT: - return lowerToScalableOp(Op, DAG, RISCVISD::USUBSAT_VL, - /*HasMergeOp*/ true); case ISD::FADD: - return lowerToScalableOp(Op, DAG, RISCVISD::FADD_VL, /*HasMergeOp*/ true); case ISD::FSUB: - return lowerToScalableOp(Op, DAG, RISCVISD::FSUB_VL, /*HasMergeOp*/ true); case ISD::FMUL: - return lowerToScalableOp(Op, DAG, RISCVISD::FMUL_VL, /*HasMergeOp*/ true); case ISD::FDIV: - return lowerToScalableOp(Op, DAG, RISCVISD::FDIV_VL, /*HasMergeOp*/ true); case ISD::FNEG: - return lowerToScalableOp(Op, DAG, RISCVISD::FNEG_VL); case ISD::FABS: - return lowerToScalableOp(Op, DAG, RISCVISD::FABS_VL); case ISD::FSQRT: - return lowerToScalableOp(Op, DAG, RISCVISD::FSQRT_VL); case ISD::FMA: - return lowerToScalableOp(Op, DAG, RISCVISD::VFMADD_VL); case ISD::SMIN: - return lowerToScalableOp(Op, DAG, RISCVISD::SMIN_VL, /*HasMergeOp*/ true); case ISD::SMAX: - return lowerToScalableOp(Op, DAG, RISCVISD::SMAX_VL, /*HasMergeOp*/ true); case ISD::UMIN: - return lowerToScalableOp(Op, DAG, RISCVISD::UMIN_VL, /*HasMergeOp*/ true); case ISD::UMAX: - return lowerToScalableOp(Op, DAG, RISCVISD::UMAX_VL, /*HasMergeOp*/ true); case ISD::FMINNUM: - return lowerToScalableOp(Op, DAG, RISCVISD::FMINNUM_VL, - /*HasMergeOp*/ true); case ISD::FMAXNUM: - return lowerToScalableOp(Op, DAG, RISCVISD::FMAXNUM_VL, - /*HasMergeOp*/ true); + return lowerToScalableOp(Op, DAG); case ISD::ABS: case ISD::VP_ABS: return lowerABS(Op, DAG); @@ -5243,21 +5305,12 @@ case ISD::FCOPYSIGN: return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG); case ISD::STRICT_FADD: - return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FADD_VL, - /*HasMergeOp*/ true); case ISD::STRICT_FSUB: - return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FSUB_VL, - /*HasMergeOp*/ true); case ISD::STRICT_FMUL: - return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FMUL_VL, - /*HasMergeOp*/ true); case ISD::STRICT_FDIV: - return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FDIV_VL, - /*HasMergeOp*/ true); case ISD::STRICT_FSQRT: - return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FSQRT_VL); case ISD::STRICT_FMA: - return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_VFMADD_VL); + return lowerToScalableOp(Op, DAG); case ISD::STRICT_FSETCC: case ISD::STRICT_FSETCCS: return lowerVectorStrictFSetcc(Op, DAG); @@ -8338,31 +8391,6 @@ return Res; } -SDValue RISCVTargetLowering::lowerFixedLengthVectorLogicOpToRVV( - SDValue Op, SelectionDAG &DAG, unsigned MaskOpc, unsigned VecOpc) const { - MVT VT = Op.getSimpleValueType(); - - if (VT.getVectorElementType() == MVT::i1) - return lowerToScalableOp(Op, DAG, MaskOpc, /*HasMergeOp*/ false, - /*HasMask*/ false); - - return lowerToScalableOp(Op, DAG, VecOpc, /*HasMergeOp*/ true); -} - -SDValue -RISCVTargetLowering::lowerFixedLengthVectorShiftToRVV(SDValue Op, - SelectionDAG &DAG) const { - unsigned Opc; - switch (Op.getOpcode()) { - default: llvm_unreachable("Unexpected opcode!"); - case ISD::SHL: Opc = RISCVISD::SHL_VL; break; - case ISD::SRA: Opc = RISCVISD::SRA_VL; break; - case ISD::SRL: Opc = RISCVISD::SRL_VL; break; - } - - return lowerToScalableOp(Op, DAG, Opc, /*HasMergeOp*/ true); -} - // Lower vector ABS to smax(X, sub(0, X)). SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); @@ -8446,9 +8474,12 @@ return convertFromScalableVector(VT, Select, DAG, Subtarget); } -SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG, - unsigned NewOpc, bool HasMergeOp, - bool HasMask) const { +SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, + SelectionDAG &DAG) const { + unsigned NewOpc = getRISCVVLOp(Op); + bool HasMergeOp = hasMergeOp(NewOpc); + bool HasMask = hasMaskOp(NewOpc); + MVT VT = Op.getSimpleValueType(); MVT ContainerVT = getContainerForFixedLengthVector(VT);