diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h @@ -18,7 +18,6 @@ #include "llvm/ADT/StringSwitch.h" #include "llvm/MC/MCInstrDesc.h" #include "llvm/MC/SubtargetFeature.h" -#include "llvm/Support/MachineValueType.h" namespace llvm { @@ -257,62 +256,6 @@ } // namespace RISCVFeatures -namespace RISCVVMVTs { - -constexpr MVT vint8mf8_t = MVT::nxv1i8; -constexpr MVT vint8mf4_t = MVT::nxv2i8; -constexpr MVT vint8mf2_t = MVT::nxv4i8; -constexpr MVT vint8m1_t = MVT::nxv8i8; -constexpr MVT vint8m2_t = MVT::nxv16i8; -constexpr MVT vint8m4_t = MVT::nxv32i8; -constexpr MVT vint8m8_t = MVT::nxv64i8; - -constexpr MVT vint16mf4_t = MVT::nxv1i16; -constexpr MVT vint16mf2_t = MVT::nxv2i16; -constexpr MVT vint16m1_t = MVT::nxv4i16; -constexpr MVT vint16m2_t = MVT::nxv8i16; -constexpr MVT vint16m4_t = MVT::nxv16i16; -constexpr MVT vint16m8_t = MVT::nxv32i16; - -constexpr MVT vint32mf2_t = MVT::nxv1i32; -constexpr MVT vint32m1_t = MVT::nxv2i32; -constexpr MVT vint32m2_t = MVT::nxv4i32; -constexpr MVT vint32m4_t = MVT::nxv8i32; -constexpr MVT vint32m8_t = MVT::nxv16i32; - -constexpr MVT vint64m1_t = MVT::nxv1i64; -constexpr MVT vint64m2_t = MVT::nxv2i64; -constexpr MVT vint64m4_t = MVT::nxv4i64; -constexpr MVT vint64m8_t = MVT::nxv8i64; - -constexpr MVT vfloat16mf4_t = MVT::nxv1f16; -constexpr MVT vfloat16mf2_t = MVT::nxv2f16; -constexpr MVT vfloat16m1_t = MVT::nxv4f16; -constexpr MVT vfloat16m2_t = MVT::nxv8f16; -constexpr MVT vfloat16m4_t = MVT::nxv16f16; -constexpr MVT vfloat16m8_t = MVT::nxv32f16; - -constexpr MVT vfloat32mf2_t = MVT::nxv1f32; -constexpr MVT vfloat32m1_t = MVT::nxv2f32; -constexpr MVT vfloat32m2_t = MVT::nxv4f32; -constexpr MVT vfloat32m4_t = MVT::nxv8f32; -constexpr MVT vfloat32m8_t = MVT::nxv16f32; - -constexpr MVT vfloat64m1_t = MVT::nxv1f64; -constexpr MVT vfloat64m2_t = MVT::nxv2f64; -constexpr MVT vfloat64m4_t = MVT::nxv4f64; -constexpr MVT vfloat64m8_t = MVT::nxv8f64; - -constexpr MVT vbool1_t = MVT::nxv64i1; -constexpr MVT vbool2_t = MVT::nxv32i1; -constexpr MVT vbool4_t = MVT::nxv16i1; -constexpr MVT vbool8_t = MVT::nxv8i1; -constexpr MVT vbool16_t = MVT::nxv4i1; -constexpr MVT vbool32_t = MVT::nxv2i1; -constexpr MVT vbool64_t = MVT::nxv1i1; - -} // namespace RISCVVMVTs - enum class RISCVVSEW { SEW_8 = 0, SEW_16, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -90,64 +90,56 @@ if (Subtarget.hasStdExtD()) addRegisterClass(MVT::f64, &RISCV::FPR64RegClass); + static const MVT::SimpleValueType BoolVecVTs[] = { + MVT::nxv1i1, MVT::nxv2i1, MVT::nxv4i1, MVT::nxv8i1, + MVT::nxv16i1, MVT::nxv32i1, MVT::nxv64i1}; + static const MVT::SimpleValueType IntVecVTs[] = { + MVT::nxv1i8, MVT::nxv2i8, MVT::nxv4i8, MVT::nxv8i8, MVT::nxv16i8, + MVT::nxv32i8, MVT::nxv64i8, MVT::nxv1i16, MVT::nxv2i16, MVT::nxv4i16, + MVT::nxv8i16, MVT::nxv16i16, MVT::nxv32i16, MVT::nxv1i32, MVT::nxv2i32, + MVT::nxv4i32, MVT::nxv8i32, MVT::nxv16i32, MVT::nxv1i64, MVT::nxv2i64, + MVT::nxv4i64, MVT::nxv8i64}; + static const MVT::SimpleValueType F16VecVTs[] = { + MVT::nxv1f16, MVT::nxv2f16, MVT::nxv4f16, + MVT::nxv8f16, MVT::nxv16f16, MVT::nxv32f16}; + static const MVT::SimpleValueType F32VecVTs[] = { + MVT::nxv1f32, MVT::nxv2f32, MVT::nxv4f32, MVT::nxv8f32, MVT::nxv16f32}; + static const MVT::SimpleValueType F64VecVTs[] = { + MVT::nxv1f64, MVT::nxv2f64, MVT::nxv4f64, MVT::nxv8f64}; + if (Subtarget.hasStdExtV()) { - addRegisterClass(RISCVVMVTs::vbool64_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vbool32_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vbool16_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vbool8_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vbool4_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vbool2_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vbool1_t, &RISCV::VRRegClass); - - addRegisterClass(RISCVVMVTs::vint8mf8_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vint8mf4_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vint8mf2_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vint8m1_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vint8m2_t, &RISCV::VRM2RegClass); - addRegisterClass(RISCVVMVTs::vint8m4_t, &RISCV::VRM4RegClass); - addRegisterClass(RISCVVMVTs::vint8m8_t, &RISCV::VRM8RegClass); - - addRegisterClass(RISCVVMVTs::vint16mf4_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vint16mf2_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vint16m1_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vint16m2_t, &RISCV::VRM2RegClass); - addRegisterClass(RISCVVMVTs::vint16m4_t, &RISCV::VRM4RegClass); - addRegisterClass(RISCVVMVTs::vint16m8_t, &RISCV::VRM8RegClass); - - addRegisterClass(RISCVVMVTs::vint32mf2_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vint32m1_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vint32m2_t, &RISCV::VRM2RegClass); - addRegisterClass(RISCVVMVTs::vint32m4_t, &RISCV::VRM4RegClass); - addRegisterClass(RISCVVMVTs::vint32m8_t, &RISCV::VRM8RegClass); - - addRegisterClass(RISCVVMVTs::vint64m1_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vint64m2_t, &RISCV::VRM2RegClass); - addRegisterClass(RISCVVMVTs::vint64m4_t, &RISCV::VRM4RegClass); - addRegisterClass(RISCVVMVTs::vint64m8_t, &RISCV::VRM8RegClass); - - if (Subtarget.hasStdExtZfh()) { - addRegisterClass(RISCVVMVTs::vfloat16mf4_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vfloat16mf2_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vfloat16m1_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vfloat16m2_t, &RISCV::VRM2RegClass); - addRegisterClass(RISCVVMVTs::vfloat16m4_t, &RISCV::VRM4RegClass); - addRegisterClass(RISCVVMVTs::vfloat16m8_t, &RISCV::VRM8RegClass); - } + auto addRegClassForRVV = [this](MVT VT) { + unsigned Size = VT.getSizeInBits().getKnownMinValue(); + assert(Size <= 512 && isPowerOf2_32(Size)); + const TargetRegisterClass *RC; + if (Size <= 64) + RC = &RISCV::VRRegClass; + else if (Size == 128) + RC = &RISCV::VRM2RegClass; + else if (Size == 256) + RC = &RISCV::VRM4RegClass; + else + RC = &RISCV::VRM8RegClass; - if (Subtarget.hasStdExtF()) { - addRegisterClass(RISCVVMVTs::vfloat32mf2_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vfloat32m1_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vfloat32m2_t, &RISCV::VRM2RegClass); - addRegisterClass(RISCVVMVTs::vfloat32m4_t, &RISCV::VRM4RegClass); - addRegisterClass(RISCVVMVTs::vfloat32m8_t, &RISCV::VRM8RegClass); - } + addRegisterClass(VT, RC); + }; - if (Subtarget.hasStdExtD()) { - addRegisterClass(RISCVVMVTs::vfloat64m1_t, &RISCV::VRRegClass); - addRegisterClass(RISCVVMVTs::vfloat64m2_t, &RISCV::VRM2RegClass); - addRegisterClass(RISCVVMVTs::vfloat64m4_t, &RISCV::VRM4RegClass); - addRegisterClass(RISCVVMVTs::vfloat64m8_t, &RISCV::VRM8RegClass); - } + for (MVT VT : BoolVecVTs) + addRegClassForRVV(VT); + for (MVT VT : IntVecVTs) + addRegClassForRVV(VT); + + if (Subtarget.hasStdExtZfh()) + for (MVT VT : F16VecVTs) + addRegClassForRVV(VT); + + if (Subtarget.hasStdExtF()) + for (MVT VT : F32VecVTs) + addRegClassForRVV(VT); + + if (Subtarget.hasStdExtD()) + for (MVT VT : F64VecVTs) + addRegClassForRVV(VT); } // Compute derived properties from the register classes. @@ -379,9 +371,22 @@ if (Subtarget.is64Bit()) { setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i64, Custom); setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i64, Custom); + } else { + // We must custom-lower certain vXi64 operations on RV32 due to the vector + // element type being illegal. + setOperationAction(ISD::SPLAT_VECTOR, MVT::i64, Custom); + setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom); } - for (auto VT : MVT::integer_scalable_vector_valuetypes()) { + for (MVT VT : BoolVecVTs) { + setOperationAction(ISD::SPLAT_VECTOR, VT, Legal); + + // Mask VTs are custom-expanded into a series of standard nodes + setOperationAction(ISD::TRUNCATE, VT, Custom); + } + + for (MVT VT : IntVecVTs) { setOperationAction(ISD::SPLAT_VECTOR, VT, Legal); setOperationAction(ISD::SMIN, VT, Legal); @@ -392,30 +397,18 @@ setOperationAction(ISD::ROTL, VT, Expand); setOperationAction(ISD::ROTR, VT, Expand); - if (isTypeLegal(VT)) { - // Custom-lower extensions and truncations from/to mask types. - setOperationAction(ISD::ANY_EXTEND, VT, Custom); - setOperationAction(ISD::SIGN_EXTEND, VT, Custom); - setOperationAction(ISD::ZERO_EXTEND, VT, Custom); - - // We custom-lower all legally-typed vector truncates: - // 1. Mask VTs are custom-expanded into a series of standard nodes - // 2. Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR" - // nodes which truncate by one power of two at a time. - setOperationAction(ISD::TRUNCATE, VT, Custom); - - // Custom-lower insert/extract operations to simplify patterns. - setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); - } - } + // Custom-lower extensions and truncations from/to mask types. + setOperationAction(ISD::ANY_EXTEND, VT, Custom); + setOperationAction(ISD::SIGN_EXTEND, VT, Custom); + setOperationAction(ISD::ZERO_EXTEND, VT, Custom); - // We must custom-lower certain vXi64 operations on RV32 due to the vector - // element type being illegal. - if (!Subtarget.is64Bit()) { - setOperationAction(ISD::SPLAT_VECTOR, MVT::i64, Custom); - setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom); + // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR" + // nodes which truncate by one power of two at a time. + setOperationAction(ISD::TRUNCATE, VT, Custom); + + // Custom-lower insert/extract operations to simplify patterns. + setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); } // Expand various CCs to best match the RVV ISA, which natively supports UNE @@ -441,25 +434,17 @@ setCondCodeAction(CC, VT, Expand); }; - if (Subtarget.hasStdExtZfh()) { - for (auto VT : {RISCVVMVTs::vfloat16mf4_t, RISCVVMVTs::vfloat16mf2_t, - RISCVVMVTs::vfloat16m1_t, RISCVVMVTs::vfloat16m2_t, - RISCVVMVTs::vfloat16m4_t, RISCVVMVTs::vfloat16m8_t}) + if (Subtarget.hasStdExtZfh()) + for (MVT VT : F16VecVTs) SetCommonVFPActions(VT); - } - if (Subtarget.hasStdExtF()) { - for (auto VT : {RISCVVMVTs::vfloat32mf2_t, RISCVVMVTs::vfloat32m1_t, - RISCVVMVTs::vfloat32m2_t, RISCVVMVTs::vfloat32m4_t, - RISCVVMVTs::vfloat32m8_t}) + if (Subtarget.hasStdExtF()) + for (MVT VT : F32VecVTs) SetCommonVFPActions(VT); - } - if (Subtarget.hasStdExtD()) { - for (auto VT : {RISCVVMVTs::vfloat64m1_t, RISCVVMVTs::vfloat64m2_t, - RISCVVMVTs::vfloat64m4_t, RISCVVMVTs::vfloat64m8_t}) + if (Subtarget.hasStdExtD()) + for (MVT VT : F64VecVTs) SetCommonVFPActions(VT); - } } // Function alignments.