diff --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp --- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp +++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp @@ -275,33 +275,11 @@ // e.g.: read/write or user/supervisor/machine privileges. }; - enum class VSEW { - SEW_8 = 0, - SEW_16, - SEW_32, - SEW_64, - SEW_128, - SEW_256, - SEW_512, - SEW_1024, - }; - - enum class VLMUL { - LMUL_1 = 0, - LMUL_2, - LMUL_4, - LMUL_8, - LMUL_F8 = 5, - LMUL_F4, - LMUL_F2 - }; - struct VTypeOp { - VSEW Sew; - VLMUL Lmul; + RISCVVSEW Sew; + RISCVVLMUL Lmul; bool TailAgnostic; bool MaskedoffAgnostic; - unsigned Encoding; }; SMLoc StartLoc, EndLoc; @@ -752,43 +730,43 @@ return Tok; } - static StringRef getSEWStr(VSEW Sew) { + static StringRef getSEWStr(RISCVVSEW Sew) { switch (Sew) { - case VSEW::SEW_8: + case RISCVVSEW::SEW_8: return "e8"; - case VSEW::SEW_16: + case RISCVVSEW::SEW_16: return "e16"; - case VSEW::SEW_32: + case RISCVVSEW::SEW_32: return "e32"; - case VSEW::SEW_64: + case RISCVVSEW::SEW_64: return "e64"; - case VSEW::SEW_128: + case RISCVVSEW::SEW_128: return "e128"; - case VSEW::SEW_256: + case RISCVVSEW::SEW_256: return "e256"; - case VSEW::SEW_512: + case RISCVVSEW::SEW_512: return "e512"; - case VSEW::SEW_1024: + case RISCVVSEW::SEW_1024: return "e1024"; } llvm_unreachable("Unknown SEW."); } - static StringRef getLMULStr(VLMUL Lmul) { + static StringRef getLMULStr(RISCVVLMUL Lmul) { switch (Lmul) { - case VLMUL::LMUL_1: + case RISCVVLMUL::LMUL_1: return "m1"; - case VLMUL::LMUL_2: + case RISCVVLMUL::LMUL_2: return "m2"; - case VLMUL::LMUL_4: + case RISCVVLMUL::LMUL_4: return "m4"; - case VLMUL::LMUL_8: + case RISCVVLMUL::LMUL_8: return "m8"; - case VLMUL::LMUL_F2: + case RISCVVLMUL::LMUL_F2: return "mf2"; - case VLMUL::LMUL_F4: + case RISCVVLMUL::LMUL_F4: return "mf4"; - case VLMUL::LMUL_F8: + case RISCVVLMUL::LMUL_F8: return "mf8"; } llvm_unreachable("Unknown LMUL."); @@ -872,21 +850,12 @@ auto Op = std::make_unique(KindTy::VType); unsigned SewLog2 = Log2_32(Sew / 8); unsigned LmulLog2 = Log2_32(Lmul); - Op->VType.Sew = static_cast(SewLog2); + Op->VType.Sew = static_cast(SewLog2); if (Fractional) { unsigned Flmul = 8 - LmulLog2; - Op->VType.Lmul = static_cast(Flmul); - Op->VType.Encoding = - ((Flmul & 0x4) << 3) | ((SewLog2 & 0x7) << 2) | (Flmul & 0x3); + Op->VType.Lmul = static_cast(Flmul); } else { - Op->VType.Lmul = static_cast(LmulLog2); - Op->VType.Encoding = (SewLog2 << 2) | LmulLog2; - } - if (TailAgnostic) { - Op->VType.Encoding |= 0x40; - } - if (MaskedoffAgnostic) { - Op->VType.Encoding |= 0x80; + Op->VType.Lmul = static_cast(LmulLog2); } Op->VType.TailAgnostic = TailAgnostic; Op->VType.MaskedoffAgnostic = MaskedoffAgnostic; @@ -954,7 +923,9 @@ void addVTypeIOperands(MCInst &Inst, unsigned N) const { assert(N == 1 && "Invalid number of operands!"); - Inst.addOperand(MCOperand::createImm(VType.Encoding)); + unsigned VTypeI = RISCVVType::encodeVTYPE( + VType.Lmul, VType.Sew, VType.TailAgnostic, VType.MaskedoffAgnostic); + Inst.addOperand(MCOperand::createImm(VTypeI)); } // Returns the rounding mode represented by this RISCVOperand. Should only @@ -1600,8 +1571,7 @@ unsigned Sew; if (Name.getAsInteger(10, Sew)) return MatchOperand_NoMatch; - if (Sew != 8 && Sew != 16 && Sew != 32 && Sew != 64 && Sew != 128 && - Sew != 256 && Sew != 512 && Sew != 1024) + if (!RISCVVType::isValidSEW(Sew)) return MatchOperand_NoMatch; getLexer().Lex(); @@ -1613,16 +1583,11 @@ if (!Name.consume_front("m")) return MatchOperand_NoMatch; // "m" or "mf" - bool Fractional = false; - if (Name.consume_front("f")) { - Fractional = true; - } + bool Fractional = Name.consume_front("f"); unsigned Lmul; if (Name.getAsInteger(10, Lmul)) return MatchOperand_NoMatch; - if (Lmul != 1 && Lmul != 2 && Lmul != 4 && Lmul != 8) - return MatchOperand_NoMatch; - if (Fractional && Lmul == 1) + if (!RISCVVType::isValidLMUL(Lmul, Fractional)) return MatchOperand_NoMatch; getLexer().Lex(); diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1929,39 +1929,11 @@ const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo(); unsigned SEW = MI.getOperand(SEWIndex).getImm(); - RISCVVLengthMultiplier::LengthMultiplier Multiplier; + assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW"); + RISCVVSEW ElementWidth = static_cast(Log2_32(SEW / 8)); - switch (VLMul) { - default: - llvm_unreachable("Unexpected LMUL for instruction"); - case 0: - case 1: - case 2: - case 3: - case 5: - case 6: - case 7: - Multiplier = static_cast(VLMul); - break; - } - - RISCVVStandardElementWidth::StandardElementWidth ElementWidth; - switch (SEW) { - default: - llvm_unreachable("Unexpected SEW for instruction"); - case 8: - ElementWidth = RISCVVStandardElementWidth::ElementWidth8; - break; - case 16: - ElementWidth = RISCVVStandardElementWidth::ElementWidth16; - break; - case 32: - ElementWidth = RISCVVStandardElementWidth::ElementWidth32; - break; - case 64: - ElementWidth = RISCVVStandardElementWidth::ElementWidth64; - break; - } + // LMUL should already be encoded correctly. + RISCVVLMUL Multiplier = static_cast(VLMul); MachineRegisterInfo &MRI = MF.getRegInfo(); @@ -1979,13 +1951,9 @@ .addReg(RISCV::X0, RegState::Kill); // For simplicity we reuse the vtype representation here. - // Bits | Name | Description - // -----+------------+------------------------------------------------ - // 5 | vlmul[2] | Fractional lmul? - // 4:2 | vsew[2:0] | Standard element width (SEW) setting - // 1:0 | vlmul[1:0] | Vector register group multiplier (LMUL) setting - MIB.addImm(((Multiplier & 0x4) << 3) | ((ElementWidth & 0x3) << 2) | - (Multiplier & 0x3)); + MIB.addImm(RISCVVType::encodeVTYPE(Multiplier, ElementWidth, + /*TailAgnostic*/ false, + /*MaskedOffAgnostic*/ false)); // Remove (now) redundant operands from pseudo MI.getOperand(SEWIndex).setImm(-1); diff --git a/llvm/lib/Target/RISCV/Utils/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/Utils/RISCVBaseInfo.h --- a/llvm/lib/Target/RISCV/Utils/RISCVBaseInfo.h +++ b/llvm/lib/Target/RISCV/Utils/RISCVBaseInfo.h @@ -330,30 +330,54 @@ } // namespace RISCVVMVTs -namespace RISCVVLengthMultiplier { - -enum LengthMultiplier { - LMul1 = 0, - LMul2 = 1, - LMul4 = 2, - LMul8 = 3, - LMulF8 = 5, - LMulF4 = 6, - LMulF2 = 7 +enum class RISCVVSEW { + SEW_8 = 0, + SEW_16, + SEW_32, + SEW_64, + SEW_128, + SEW_256, + SEW_512, + SEW_1024, }; -} +enum class RISCVVLMUL { + LMUL_1 = 0, + LMUL_2, + LMUL_4, + LMUL_8, + LMUL_F8 = 5, + LMUL_F4, + LMUL_F2 +}; -namespace RISCVVStandardElementWidth { +namespace RISCVVType { +// Is this a SEW value that can be encoded into the VTYPE format. +inline static bool isValidSEW(unsigned SEW) { + return isPowerOf2_32(SEW) && SEW >= 8 && SEW <= 1024; +} -enum StandardElementWidth { - ElementWidth8 = 0, - ElementWidth16 = 1, - ElementWidth32 = 2, - ElementWidth64 = 3 -}; +// Is this a LMUL value that can be encoded into the VTYPE format. +inline static bool isValidLMUL(unsigned LMUL, bool Fractional) { + return isPowerOf2_32(LMUL) && LMUL <= 8 && (!Fractional || LMUL != 1); +} +// Encode VTYPE into the binary format used by the the VSETVLI instruction which +// is used by our MC layer representation. +inline static unsigned encodeVTYPE(RISCVVLMUL VLMUL, RISCVVSEW VSEW, + bool TailAgnostic, bool MaskedoffAgnostic) { + unsigned VLMULBits = static_cast(VLMUL); + unsigned VSEWBits = static_cast(VSEW); + unsigned VTypeI = + ((VLMULBits & 0x4) << 3) | (VSEWBits << 2) | (VLMULBits & 0x3); + if (TailAgnostic) + VTypeI |= 0x40; + if (MaskedoffAgnostic) + VTypeI |= 0x80; + + return VTypeI; } +} // namespace RISCVVType namespace RISCVVPseudosTable {