diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -67,16 +67,8 @@ return RISCV::X0 == MI.getOperand(0).getReg(); } -static uint16_t getRVVMCOpcode(uint16_t RVVPseudoOpcode) { - const RISCVVPseudosTable::PseudoInfo *RVV = - RISCVVPseudosTable::getPseudoInfo(RVVPseudoOpcode); - if (!RVV) - return 0; - return RVV->BaseInstr; -} - static bool isScalarMoveInstr(const MachineInstr &MI) { - switch (getRVVMCOpcode(MI.getOpcode())) { + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { default: return false; case RISCV::VMV_S_X: @@ -86,7 +78,7 @@ } static bool isVSlideInstr(const MachineInstr &MI) { - switch (getRVVMCOpcode(MI.getOpcode())) { + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { default: return false; case RISCV::VSLIDEDOWN_VX: @@ -100,7 +92,7 @@ /// Get the EEW for a load or store instruction. Return std::nullopt if MI is /// not a load or store which ignores SEW. static std::optional getEEWForLoadStore(const MachineInstr &MI) { - switch (getRVVMCOpcode(MI.getOpcode())) { + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { default: return std::nullopt; case RISCV::VLE8_V: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -262,6 +262,9 @@ // one of the instructions does not have rounding mode, false will be returned. bool hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2); +// Returns the MC opcode of RVV pseudo instruction. +unsigned getRVVMCOpcode(unsigned RVVPseudoOpcode); + // Special immediate for AVL operand of V pseudo instructions to indicate VLMax. static constexpr int64_t VLMaxSentinel = -1LL; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -2140,34 +2140,6 @@ return Comment; } -// clang-format off -#define CASE_VFMA_OPCODE_COMMON(OP, TYPE, LMUL) \ - RISCV::PseudoV##OP##_##TYPE##_##LMUL - -#define CASE_VFMA_OPCODE_LMULS_M1(OP, TYPE) \ - CASE_VFMA_OPCODE_COMMON(OP, TYPE, M1): \ - case CASE_VFMA_OPCODE_COMMON(OP, TYPE, M2): \ - case CASE_VFMA_OPCODE_COMMON(OP, TYPE, M4): \ - case CASE_VFMA_OPCODE_COMMON(OP, TYPE, M8) - -#define CASE_VFMA_OPCODE_LMULS_MF2(OP, TYPE) \ - CASE_VFMA_OPCODE_COMMON(OP, TYPE, MF2): \ - case CASE_VFMA_OPCODE_LMULS_M1(OP, TYPE) - -#define CASE_VFMA_OPCODE_LMULS_MF4(OP, TYPE) \ - CASE_VFMA_OPCODE_COMMON(OP, TYPE, MF4): \ - case CASE_VFMA_OPCODE_LMULS_MF2(OP, TYPE) - -#define CASE_VFMA_OPCODE_LMULS(OP, TYPE) \ - CASE_VFMA_OPCODE_COMMON(OP, TYPE, MF8): \ - case CASE_VFMA_OPCODE_LMULS_MF4(OP, TYPE) - -#define CASE_VFMA_SPLATS(OP) \ - CASE_VFMA_OPCODE_LMULS_MF4(OP, VF16): \ - case CASE_VFMA_OPCODE_LMULS_MF2(OP, VF32): \ - case CASE_VFMA_OPCODE_LMULS_M1(OP, VF64) -// clang-format on - bool RISCVInstrInfo::findCommutedOpIndices(const MachineInstr &MI, unsigned &SrcOpIdx1, unsigned &SrcOpIdx2) const { @@ -2196,24 +2168,27 @@ case RISCV::PseudoCCMOVGPR: // Operands 4 and 5 are commutable. return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 4, 5); - case CASE_VFMA_SPLATS(FMADD): - case CASE_VFMA_SPLATS(FMSUB): - case CASE_VFMA_SPLATS(FMACC): - case CASE_VFMA_SPLATS(FMSAC): - case CASE_VFMA_SPLATS(FNMADD): - case CASE_VFMA_SPLATS(FNMSUB): - case CASE_VFMA_SPLATS(FNMACC): - case CASE_VFMA_SPLATS(FNMSAC): - case CASE_VFMA_OPCODE_LMULS_MF4(FMACC, VV): - case CASE_VFMA_OPCODE_LMULS_MF4(FMSAC, VV): - case CASE_VFMA_OPCODE_LMULS_MF4(FNMACC, VV): - case CASE_VFMA_OPCODE_LMULS_MF4(FNMSAC, VV): - case CASE_VFMA_OPCODE_LMULS(MADD, VX): - case CASE_VFMA_OPCODE_LMULS(NMSUB, VX): - case CASE_VFMA_OPCODE_LMULS(MACC, VX): - case CASE_VFMA_OPCODE_LMULS(NMSAC, VX): - case CASE_VFMA_OPCODE_LMULS(MACC, VV): - case CASE_VFMA_OPCODE_LMULS(NMSAC, VV): { + } + + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { + case RISCV::VFMADD_VF: + case RISCV::VFMSUB_VF: + case RISCV::VFMACC_VF: + case RISCV::VFMSAC_VF: + case RISCV::VFNMADD_VF: + case RISCV::VFNMSUB_VF: + case RISCV::VFNMACC_VF: + case RISCV::VFNMSAC_VF: + case RISCV::VFMACC_VV: + case RISCV::VFMSAC_VV: + case RISCV::VFNMACC_VV: + case RISCV::VFNMSAC_VV: + case RISCV::VMACC_VV: + case RISCV::VNMSAC_VV: + case RISCV::VMADD_VX: + case RISCV::VNMSUB_VX: + case RISCV::VMACC_VX: + case RISCV::VNMSAC_VX: { // If the tail policy is undisturbed we can't commute. assert(RISCVII::hasVecPolicyOp(MI.getDesc().TSFlags)); if ((MI.getOperand(MI.getNumExplicitOperands() - 1).getImm() & 1) == 0) @@ -2228,12 +2203,12 @@ return false; return true; } - case CASE_VFMA_OPCODE_LMULS_MF4(FMADD, VV): - case CASE_VFMA_OPCODE_LMULS_MF4(FMSUB, VV): - case CASE_VFMA_OPCODE_LMULS_MF4(FNMADD, VV): - case CASE_VFMA_OPCODE_LMULS_MF4(FNMSUB, VV): - case CASE_VFMA_OPCODE_LMULS(MADD, VV): - case CASE_VFMA_OPCODE_LMULS(NMSUB, VV): { + case RISCV::VFMADD_VV: + case RISCV::VFMSUB_VV: + case RISCV::VFNMADD_VV: + case RISCV::VFNMSUB_VV: + case RISCV::VMADD_VV: + case RISCV::VNMSUB_VV: { // If the tail policy is undisturbed we can't commute. assert(RISCVII::hasVecPolicyOp(MI.getDesc().TSFlags)); if ((MI.getOperand(MI.getNumExplicitOperands() - 1).getImm() & 1) == 0) @@ -2358,24 +2333,27 @@ return TargetInstrInfo::commuteInstructionImpl(WorkingMI, /*NewMI*/ false, OpIdx1, OpIdx2); } - case CASE_VFMA_SPLATS(FMACC): - case CASE_VFMA_SPLATS(FMADD): - case CASE_VFMA_SPLATS(FMSAC): - case CASE_VFMA_SPLATS(FMSUB): - case CASE_VFMA_SPLATS(FNMACC): - case CASE_VFMA_SPLATS(FNMADD): - case CASE_VFMA_SPLATS(FNMSAC): - case CASE_VFMA_SPLATS(FNMSUB): - case CASE_VFMA_OPCODE_LMULS_MF4(FMACC, VV): - case CASE_VFMA_OPCODE_LMULS_MF4(FMSAC, VV): - case CASE_VFMA_OPCODE_LMULS_MF4(FNMACC, VV): - case CASE_VFMA_OPCODE_LMULS_MF4(FNMSAC, VV): - case CASE_VFMA_OPCODE_LMULS(MADD, VX): - case CASE_VFMA_OPCODE_LMULS(NMSUB, VX): - case CASE_VFMA_OPCODE_LMULS(MACC, VX): - case CASE_VFMA_OPCODE_LMULS(NMSAC, VX): - case CASE_VFMA_OPCODE_LMULS(MACC, VV): - case CASE_VFMA_OPCODE_LMULS(NMSAC, VV): { + } + + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { + case RISCV::VFMADD_VF: + case RISCV::VFMSUB_VF: + case RISCV::VFMACC_VF: + case RISCV::VFMSAC_VF: + case RISCV::VFNMADD_VF: + case RISCV::VFNMSUB_VF: + case RISCV::VFNMACC_VF: + case RISCV::VFNMSAC_VF: + case RISCV::VFMACC_VV: + case RISCV::VFMSAC_VV: + case RISCV::VFNMACC_VV: + case RISCV::VFNMSAC_VV: + case RISCV::VMACC_VV: + case RISCV::VNMSAC_VV: + case RISCV::VMADD_VX: + case RISCV::VNMSUB_VX: + case RISCV::VMACC_VX: + case RISCV::VNMSAC_VX: { // It only make sense to toggle these between clobbering the // addend/subtrahend/minuend one of the multiplicands. assert((OpIdx1 == 1 || OpIdx2 == 1) && "Unexpected opcode index"); @@ -2409,12 +2387,12 @@ return TargetInstrInfo::commuteInstructionImpl(WorkingMI, /*NewMI=*/false, OpIdx1, OpIdx2); } - case CASE_VFMA_OPCODE_LMULS_MF4(FMADD, VV): - case CASE_VFMA_OPCODE_LMULS_MF4(FMSUB, VV): - case CASE_VFMA_OPCODE_LMULS_MF4(FNMADD, VV): - case CASE_VFMA_OPCODE_LMULS_MF4(FNMSUB, VV): - case CASE_VFMA_OPCODE_LMULS(MADD, VV): - case CASE_VFMA_OPCODE_LMULS(NMSUB, VV): { + case RISCV::VFMADD_VV: + case RISCV::VFMSUB_VV: + case RISCV::VFNMADD_VV: + case RISCV::VFNMSUB_VV: + case RISCV::VMADD_VV: + case RISCV::VNMSUB_VV: { assert((OpIdx1 == 1 || OpIdx2 == 1) && "Unexpected opcode index"); // If one of the operands, is the addend we need to change opcode. // Otherwise we're just swapping 2 of the multiplicands. @@ -2447,25 +2425,6 @@ #undef CASE_VFMA_CHANGE_OPCODE_SPLATS #undef CASE_VFMA_CHANGE_OPCODE_LMULS #undef CASE_VFMA_CHANGE_OPCODE_COMMON -#undef CASE_VFMA_SPLATS -#undef CASE_VFMA_OPCODE_LMULS -#undef CASE_VFMA_OPCODE_COMMON - -// clang-format off -#define CASE_WIDEOP_OPCODE_COMMON(OP, LMUL) \ - RISCV::PseudoV##OP##_##LMUL##_TIED - -#define CASE_WIDEOP_OPCODE_LMULS_MF4(OP) \ - CASE_WIDEOP_OPCODE_COMMON(OP, MF4): \ - case CASE_WIDEOP_OPCODE_COMMON(OP, MF2): \ - case CASE_WIDEOP_OPCODE_COMMON(OP, M1): \ - case CASE_WIDEOP_OPCODE_COMMON(OP, M2): \ - case CASE_WIDEOP_OPCODE_COMMON(OP, M4) - -#define CASE_WIDEOP_OPCODE_LMULS(OP) \ - CASE_WIDEOP_OPCODE_COMMON(OP, MF8): \ - case CASE_WIDEOP_OPCODE_LMULS_MF4(OP) -// clang-format on #define CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, LMUL) \ case RISCV::PseudoV##OP##_##LMUL##_TIED: \ @@ -2486,15 +2445,15 @@ MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI, LiveVariables *LV, LiveIntervals *LIS) const { - switch (MI.getOpcode()) { + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { default: break; - case CASE_WIDEOP_OPCODE_LMULS_MF4(FWADD_WV): - case CASE_WIDEOP_OPCODE_LMULS_MF4(FWSUB_WV): - case CASE_WIDEOP_OPCODE_LMULS(WADD_WV): - case CASE_WIDEOP_OPCODE_LMULS(WADDU_WV): - case CASE_WIDEOP_OPCODE_LMULS(WSUB_WV): - case CASE_WIDEOP_OPCODE_LMULS(WSUBU_WV): { + case RISCV::VFWADD_WV: + case RISCV::VFWSUB_WV: + case RISCV::VWADD_WV: + case RISCV::VWADDU_WV: + case RISCV::VWSUB_WV: + case RISCV::VWSUBU_WV: { // If the tail policy is undisturbed we can't convert. assert(RISCVII::hasVecPolicyOp(MI.getDesc().TSFlags) && MI.getNumExplicitOperands() == 6); @@ -2556,8 +2515,6 @@ #undef CASE_WIDEOP_CHANGE_OPCODE_LMULS #undef CASE_WIDEOP_CHANGE_OPCODE_COMMON -#undef CASE_WIDEOP_OPCODE_LMULS -#undef CASE_WIDEOP_OPCODE_COMMON void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF, MachineBasicBlock &MBB, @@ -2769,3 +2726,11 @@ MachineOperand FrmOp2 = MI2.getOperand(MI2FrmOpIdx); return FrmOp1.getImm() == FrmOp2.getImm(); } + +unsigned RISCV::getRVVMCOpcode(unsigned RVVPseudoOpcode) { + const RISCVVPseudosTable::PseudoInfo *RVV = + RISCVVPseudosTable::getPseudoInfo(RVVPseudoOpcode); + if (!RVV) + return 0; + return RVV->BaseInstr; +}