diff --git a/clang/include/clang/Support/RISCVVIntrinsicUtils.h b/clang/include/clang/Support/RISCVVIntrinsicUtils.h --- a/clang/include/clang/Support/RISCVVIntrinsicUtils.h +++ b/clang/include/clang/Support/RISCVVIntrinsicUtils.h @@ -9,6 +9,7 @@ #ifndef CLANG_SUPPORT_RISCVVINTRINSICUTILS_H #define CLANG_SUPPORT_RISCVVINTRINSICUTILS_H +#include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/StringRef.h" #include @@ -18,9 +19,126 @@ namespace clang { namespace RISCV { -using BasicType = char; using VScaleVal = llvm::Optional; +// Modifier for vector type. +enum class VectorTypeModifier : uint8_t { + NoModifier, + Log2EEW3, + Log2EEW4, + Log2EEW5, + Log2EEW6, + FixedSEW8, + FixedSEW16, + FixedSEW32, + FixedSEW64, + LFixedLog2LMULN3, + LFixedLog2LMULN2, + LFixedLog2LMULN1, + LFixedLog2LMUL0, + LFixedLog2LMUL1, + LFixedLog2LMUL2, + LFixedLog2LMUL3, + SFixedLog2LMULN3, + SFixedLog2LMULN2, + SFixedLog2LMULN1, + SFixedLog2LMUL0, + SFixedLog2LMUL1, + SFixedLog2LMUL2, + SFixedLog2LMUL3, +}; + +// Similar to basic type but used to describe what's kind of type related to +// basic vector type, used to compute type info of arguments. +enum class PrimitiveType : uint8_t { + Invalid, + Scalar, + Vector, + Widening2XVector, + Widening4XVector, + Widening8XVector, + MaskVector, + Void, + SizeT, + Ptrdiff, + UnsignedLong, + SignedLong, +}; + +// Modifier for type, used for both scalar and vector types. +enum class TypeModifier : uint8_t { + NoModifier = 0, + Pointer = 1 << 0, + Const = 1 << 1, + Immediate = 1 << 2, + UnsignedInteger = 1 << 3, + SignedInteger = 1 << 4, + Float = 1 << 5, + LMUL1 = 1 << 6, + MaxOffset = 6, + LLVM_MARK_AS_BITMASK_ENUM(LMUL1), +}; + +// TypeProfile is used to compute type info of arguments or return value. +struct TypeProfile { + constexpr TypeProfile() = default; + constexpr TypeProfile(PrimitiveType PT) : PT(static_cast(PT)) {} + constexpr TypeProfile(PrimitiveType PT, TypeModifier TM) + : PT(static_cast(PT)), TM(static_cast(TM)) {} + constexpr TypeProfile(uint8_t PT, uint8_t VTM, uint8_t TM) + : PT(PT), VTM(VTM), TM(TM) {} + + uint8_t PT = static_cast(PrimitiveType::Invalid); + uint8_t VTM = static_cast(VectorTypeModifier::NoModifier); + uint8_t TM = static_cast(TypeModifier::NoModifier); + + std::string IndexStr() const { + return std::to_string(PT) + "_" + std::to_string(VTM) + "_" + + std::to_string(TM); + }; + + bool operator!=(const TypeProfile &TP) const { + return TP.PT != PT || TP.VTM != VTM || TP.TM != TM; + } + bool operator>(const TypeProfile &TP) const { + return !(TP.PT <= PT && TP.VTM <= VTM && TP.TM <= TM); + } + + static const TypeProfile Mask; + static const TypeProfile Vector; + static const TypeProfile VL; + static llvm::Optional + parseTypeProfile(llvm::StringRef PrototypeStr); +}; + +// Basic type of vector type. +enum class BasicType : uint8_t { + Unknown = 0, + Int8 = 1 << 0, + Int16 = 1 << 1, + Int32 = 1 << 2, + Int64 = 1 << 3, + Float16 = 1 << 4, + Float32 = 1 << 5, + Float64 = 1 << 6, + MaxOffset = 6, + LLVM_MARK_AS_BITMASK_ENUM(Float64), +}; + +// Type of vector type. +enum ScalarTypeKind : uint8_t { + Void, + Size_t, + Ptrdiff_t, + UnsignedLong, + SignedLong, + Boolean, + SignedInteger, + UnsignedInteger, + Float, + Invalid, +}; + // Exponential LMUL struct LMULType { int Log2LMUL; @@ -34,18 +152,6 @@ // This class is compact representation of a valid and invalid RVVType. class RVVType { - enum ScalarTypeKind : uint32_t { - Void, - Size_t, - Ptrdiff_t, - UnsignedLong, - SignedLong, - Boolean, - SignedInteger, - UnsignedInteger, - Float, - Invalid, - }; BasicType BT; ScalarTypeKind ScalarType = Invalid; LMULType LMUL; @@ -64,8 +170,8 @@ std::string ShortStr; public: - RVVType() : RVVType(BasicType(), 0, llvm::StringRef()) {} - RVVType(BasicType BT, int Log2LMUL, llvm::StringRef prototype); + RVVType() : BT(BasicType::Unknown), LMUL(0), Valid(false) {} + RVVType(BasicType BT, int Log2LMUL, const TypeProfile &Profile); // Return the string representation of a type, which is an encoded string for // passing to the BUILTIN() macro in Builtins.def. @@ -114,7 +220,11 @@ // Applies a prototype modifier to the current type. The result maybe an // invalid type. - void applyModifier(llvm::StringRef prototype); + void applyModifier(const TypeProfile &prototype); + + void applyLog2EEW(unsigned Log2EEW); + void applyFixedSEW(unsigned NewSEW); + void applyFixedLog2LMUL(int Log2LMUL, bool LargerThan); // Compute and record a string for legal type. void initBuiltinStr(); diff --git a/clang/lib/Support/RISCVVIntrinsicUtils.cpp b/clang/lib/Support/RISCVVIntrinsicUtils.cpp --- a/clang/lib/Support/RISCVVIntrinsicUtils.cpp +++ b/clang/lib/Support/RISCVVIntrinsicUtils.cpp @@ -22,6 +22,9 @@ namespace clang { namespace RISCV { +const TypeProfile TypeProfile::Mask = TypeProfile(PrimitiveType::MaskVector); +const TypeProfile TypeProfile::VL = TypeProfile(PrimitiveType::SizeT); +const TypeProfile TypeProfile::Vector = TypeProfile(PrimitiveType::Vector); //===----------------------------------------------------------------------===// // Type implementation //===----------------------------------------------------------------------===// @@ -70,7 +73,7 @@ return *this; } -RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype) +RVVType::RVVType(BasicType BT, int Log2LMUL, const TypeProfile &prototype) : BT(BT), LMUL(LMULType(Log2LMUL)) { applyBasicType(); applyModifier(prototype); @@ -326,31 +329,31 @@ void RVVType::applyBasicType() { switch (BT) { - case 'c': + case BasicType::Int8: ElementBitwidth = 8; ScalarType = ScalarTypeKind::SignedInteger; break; - case 's': + case BasicType::Int16: ElementBitwidth = 16; ScalarType = ScalarTypeKind::SignedInteger; break; - case 'i': + case BasicType::Int32: ElementBitwidth = 32; ScalarType = ScalarTypeKind::SignedInteger; break; - case 'l': + case BasicType::Int64: ElementBitwidth = 64; ScalarType = ScalarTypeKind::SignedInteger; break; - case 'x': + case BasicType::Float16: ElementBitwidth = 16; ScalarType = ScalarTypeKind::Float; break; - case 'f': + case BasicType::Float32: ElementBitwidth = 32; ScalarType = ScalarTypeKind::Float; break; - case 'd': + case BasicType::Float64: ElementBitwidth = 64; ScalarType = ScalarTypeKind::Float; break; @@ -360,160 +363,417 @@ assert(ElementBitwidth != 0 && "Bad element bitwidth!"); } -void RVVType::applyModifier(StringRef Transformer) { - if (Transformer.empty()) - return; +Optional +TypeProfile::parseTypeProfile(llvm::StringRef TypeProfileStr) { + TypeProfile TP; + PrimitiveType PT = PrimitiveType::Invalid; + if (TypeProfileStr.empty()) + return TP; // Handle primitive type transformer - auto PType = Transformer.back(); + auto PType = TypeProfileStr.back(); switch (PType) { case 'e': - Scale = 0; + PT = PrimitiveType::Scalar; break; case 'v': - Scale = LMUL.getScale(ElementBitwidth); + PT = PrimitiveType::Vector; break; case 'w': - ElementBitwidth *= 2; - LMUL *= 2; - Scale = LMUL.getScale(ElementBitwidth); + PT = PrimitiveType::Widening2XVector; break; case 'q': - ElementBitwidth *= 4; - LMUL *= 4; - Scale = LMUL.getScale(ElementBitwidth); + PT = PrimitiveType::Widening4XVector; break; case 'o': - ElementBitwidth *= 8; - LMUL *= 8; - Scale = LMUL.getScale(ElementBitwidth); + PT = PrimitiveType::Widening8XVector; break; case 'm': - ScalarType = ScalarTypeKind::Boolean; - Scale = LMUL.getScale(ElementBitwidth); - ElementBitwidth = 1; + PT = PrimitiveType::MaskVector; break; case '0': - ScalarType = ScalarTypeKind::Void; + PT = PrimitiveType::Void; break; case 'z': - ScalarType = ScalarTypeKind::Size_t; + PT = PrimitiveType::SizeT; break; case 't': - ScalarType = ScalarTypeKind::Ptrdiff_t; + PT = PrimitiveType::Ptrdiff; break; case 'u': - ScalarType = ScalarTypeKind::UnsignedLong; + PT = PrimitiveType::UnsignedLong; break; case 'l': - ScalarType = ScalarTypeKind::SignedLong; + PT = PrimitiveType::SignedLong; break; default: llvm_unreachable("Illegal primitive type transformers!"); } - Transformer = Transformer.drop_back(); + TP.PT = static_cast(PT); + TypeProfileStr = TypeProfileStr.drop_back(); // Extract and compute complex type transformer. It can only appear one time. - if (Transformer.startswith("(")) { - size_t Idx = Transformer.find(')'); + if (TypeProfileStr.startswith("(")) { + size_t Idx = TypeProfileStr.find(')'); assert(Idx != StringRef::npos); - StringRef ComplexType = Transformer.slice(1, Idx); - Transformer = Transformer.drop_front(Idx + 1); - assert(!Transformer.contains('(') && + StringRef ComplexType = TypeProfileStr.slice(1, Idx); + TypeProfileStr = TypeProfileStr.drop_front(Idx + 1); + assert(!TypeProfileStr.contains('(') && "Only allow one complex type transformer"); - auto UpdateAndCheckComplexProto = [&]() { - Scale = LMUL.getScale(ElementBitwidth); - const StringRef VectorPrototypes("vwqom"); - if (!VectorPrototypes.contains(PType)) - llvm_unreachable("Complex type transformer only supports vector type!"); - if (Transformer.find_first_of("PCKWS") != StringRef::npos) - llvm_unreachable( - "Illegal type transformer for Complex type transformer"); - }; - auto ComputeFixedLog2LMUL = - [&](StringRef Value, - std::function Compare) { - int32_t Log2LMUL; - Value.getAsInteger(10, Log2LMUL); - if (!Compare(Log2LMUL, LMUL.Log2LMUL)) { - ScalarType = Invalid; - return false; - } - // Update new LMUL - LMUL = LMULType(Log2LMUL); - UpdateAndCheckComplexProto(); - return true; - }; auto ComplexTT = ComplexType.split(":"); + VectorTypeModifier VTM = VectorTypeModifier::NoModifier; if (ComplexTT.first == "Log2EEW") { uint32_t Log2EEW; - ComplexTT.second.getAsInteger(10, Log2EEW); - // update new elmul = (eew/sew) * lmul - LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); - // update new eew - ElementBitwidth = 1 << Log2EEW; - ScalarType = ScalarTypeKind::SignedInteger; - UpdateAndCheckComplexProto(); + if (ComplexTT.second.getAsInteger(10, Log2EEW)) { + llvm_unreachable("Invalid Log2EEW value!"); + return None; + } + switch (Log2EEW) { + case 3: + VTM = VectorTypeModifier::Log2EEW3; + break; + case 4: + VTM = VectorTypeModifier::Log2EEW4; + break; + case 5: + VTM = VectorTypeModifier::Log2EEW5; + break; + case 6: + VTM = VectorTypeModifier::Log2EEW6; + break; + default: + llvm_unreachable("Invalid Log2EEW value, should be [3-6]"); + return None; + } } else if (ComplexTT.first == "FixedSEW") { uint32_t NewSEW; - ComplexTT.second.getAsInteger(10, NewSEW); - // Set invalid type if src and dst SEW are same. - if (ElementBitwidth == NewSEW) { - ScalarType = Invalid; - return; + if (ComplexTT.second.getAsInteger(10, NewSEW)) { + llvm_unreachable("Invalid FixedSEW value!"); + return None; + } + switch (NewSEW) { + case 8: + VTM = VectorTypeModifier::FixedSEW8; + break; + case 16: + VTM = VectorTypeModifier::FixedSEW16; + break; + case 32: + VTM = VectorTypeModifier::FixedSEW32; + break; + case 64: + VTM = VectorTypeModifier::FixedSEW64; + break; + default: + llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64"); + return None; } - // Update new SEW - ElementBitwidth = NewSEW; - UpdateAndCheckComplexProto(); } else if (ComplexTT.first == "LFixedLog2LMUL") { - // New LMUL should be larger than old - if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater())) - return; + int32_t Log2LMUL; + if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { + llvm_unreachable("Invalid LFixedLog2LMUL value!"); + return None; + } + switch (Log2LMUL) { + case -3: + VTM = VectorTypeModifier::LFixedLog2LMULN3; + break; + case -2: + VTM = VectorTypeModifier::LFixedLog2LMULN2; + break; + case -1: + VTM = VectorTypeModifier::LFixedLog2LMULN1; + break; + case 0: + VTM = VectorTypeModifier::LFixedLog2LMUL0; + break; + case 1: + VTM = VectorTypeModifier::LFixedLog2LMUL1; + break; + case 2: + VTM = VectorTypeModifier::LFixedLog2LMUL2; + break; + case 3: + VTM = VectorTypeModifier::LFixedLog2LMUL3; + break; + default: + llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); + return None; + } } else if (ComplexTT.first == "SFixedLog2LMUL") { - // New LMUL should be smaller than old - if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less())) - return; + int32_t Log2LMUL; + if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { + llvm_unreachable("Invalid SFixedLog2LMUL value!"); + return None; + } + switch (Log2LMUL) { + case -3: + VTM = VectorTypeModifier::SFixedLog2LMULN3; + break; + case -2: + VTM = VectorTypeModifier::SFixedLog2LMULN2; + break; + case -1: + VTM = VectorTypeModifier::SFixedLog2LMULN1; + break; + case 0: + VTM = VectorTypeModifier::SFixedLog2LMUL0; + break; + case 1: + VTM = VectorTypeModifier::SFixedLog2LMUL1; + break; + case 2: + VTM = VectorTypeModifier::SFixedLog2LMUL2; + break; + case 3: + VTM = VectorTypeModifier::SFixedLog2LMUL3; + break; + default: + llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); + return None; + } + } else { llvm_unreachable("Illegal complex type transformers!"); } + TP.VTM = static_cast(VTM); } // Compute the remain type transformers - for (char I : Transformer) { + TypeModifier TM = TypeModifier::NoModifier; + for (char I : TypeProfileStr) { switch (I) { case 'P': - if (IsConstant) + if ((TM & TypeModifier::Const) == TypeModifier::Const) llvm_unreachable("'P' transformer cannot be used after 'C'"); - if (IsPointer) + if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer) llvm_unreachable("'P' transformer cannot be used twice"); - IsPointer = true; + TM |= TypeModifier::Pointer; break; case 'C': - if (IsConstant) - llvm_unreachable("'C' transformer cannot be used twice"); - IsConstant = true; + TM |= TypeModifier::Const; break; case 'K': - IsImmediate = true; + TM |= TypeModifier::Immediate; break; case 'U': - ScalarType = ScalarTypeKind::UnsignedInteger; + TM |= TypeModifier::UnsignedInteger; break; case 'I': - ScalarType = ScalarTypeKind::SignedInteger; + TM |= TypeModifier::SignedInteger; break; case 'F': - ScalarType = ScalarTypeKind::Float; + TM |= TypeModifier::Float; break; case 'S': + TM |= TypeModifier::LMUL1; + break; + default: + llvm_unreachable("Illegal non-primitive type transformer!"); + } + } + TP.TM = static_cast(TM); + + return TP; +} + +void RVVType::applyModifier(const TypeProfile &Transformer) { + // Handle primitive type transformer + switch (static_cast(Transformer.PT)) { + case PrimitiveType::Scalar: + Scale = 0; + break; + case PrimitiveType::Vector: + Scale = LMUL.getScale(ElementBitwidth); + break; + case PrimitiveType::Widening2XVector: + ElementBitwidth *= 2; + LMUL *= 2; + Scale = LMUL.getScale(ElementBitwidth); + break; + case PrimitiveType::Widening4XVector: + ElementBitwidth *= 4; + LMUL *= 4; + Scale = LMUL.getScale(ElementBitwidth); + break; + case PrimitiveType::Widening8XVector: + ElementBitwidth *= 8; + LMUL *= 8; + Scale = LMUL.getScale(ElementBitwidth); + break; + case PrimitiveType::MaskVector: + ScalarType = ScalarTypeKind::Boolean; + Scale = LMUL.getScale(ElementBitwidth); + ElementBitwidth = 1; + break; + case PrimitiveType::Void: + ScalarType = ScalarTypeKind::Void; + break; + case PrimitiveType::SizeT: + ScalarType = ScalarTypeKind::Size_t; + break; + case PrimitiveType::Ptrdiff: + ScalarType = ScalarTypeKind::Ptrdiff_t; + break; + case PrimitiveType::UnsignedLong: + ScalarType = ScalarTypeKind::UnsignedLong; + break; + case PrimitiveType::SignedLong: + ScalarType = ScalarTypeKind::SignedLong; + break; + case PrimitiveType::Invalid: + ScalarType = ScalarTypeKind::Invalid; + return; + default: + llvm_unreachable("Illegal primitive type transformers!"); + } + + switch (static_cast(Transformer.VTM)) { + case VectorTypeModifier::Log2EEW3: + applyLog2EEW(3); + break; + case VectorTypeModifier::Log2EEW4: + applyLog2EEW(4); + break; + case VectorTypeModifier::Log2EEW5: + applyLog2EEW(5); + break; + case VectorTypeModifier::Log2EEW6: + applyLog2EEW(6); + break; + case VectorTypeModifier::FixedSEW8: + applyFixedSEW(8); + break; + case VectorTypeModifier::FixedSEW16: + applyFixedSEW(16); + break; + case VectorTypeModifier::FixedSEW32: + applyFixedSEW(32); + break; + case VectorTypeModifier::FixedSEW64: + applyFixedSEW(64); + break; + case VectorTypeModifier::LFixedLog2LMULN3: + applyFixedLog2LMUL(-3, /* LargerThan= */ true); + break; + case VectorTypeModifier::LFixedLog2LMULN2: + applyFixedLog2LMUL(-2, /* LargerThan= */ true); + break; + case VectorTypeModifier::LFixedLog2LMULN1: + applyFixedLog2LMUL(-1, /* LargerThan= */ true); + break; + case VectorTypeModifier::LFixedLog2LMUL0: + applyFixedLog2LMUL(0, /* LargerThan= */ true); + break; + case VectorTypeModifier::LFixedLog2LMUL1: + applyFixedLog2LMUL(1, /* LargerThan= */ true); + break; + case VectorTypeModifier::LFixedLog2LMUL2: + applyFixedLog2LMUL(2, /* LargerThan= */ true); + break; + case VectorTypeModifier::LFixedLog2LMUL3: + applyFixedLog2LMUL(3, /* LargerThan= */ true); + break; + case VectorTypeModifier::SFixedLog2LMULN3: + applyFixedLog2LMUL(-3, /* LargerThan= */ false); + break; + case VectorTypeModifier::SFixedLog2LMULN2: + applyFixedLog2LMUL(-2, /* LargerThan= */ false); + break; + case VectorTypeModifier::SFixedLog2LMULN1: + applyFixedLog2LMUL(-1, /* LargerThan= */ false); + break; + case VectorTypeModifier::SFixedLog2LMUL0: + applyFixedLog2LMUL(0, /* LargerThan= */ false); + break; + case VectorTypeModifier::SFixedLog2LMUL1: + applyFixedLog2LMUL(1, /* LargerThan= */ false); + break; + case VectorTypeModifier::SFixedLog2LMUL2: + applyFixedLog2LMUL(2, /* LargerThan= */ false); + break; + case VectorTypeModifier::SFixedLog2LMUL3: + applyFixedLog2LMUL(3, /* LargerThan= */ false); + break; + case VectorTypeModifier::NoModifier: + break; + default: + llvm_unreachable("Illegal vector type modifier!"); + } + + for (unsigned TypeModifierMaskShift = 0; + TypeModifierMaskShift <= static_cast(TypeModifier::MaxOffset); + ++TypeModifierMaskShift) { + unsigned TypeModifierMask = 1 << TypeModifierMaskShift; + if ((static_cast(Transformer.TM) & TypeModifierMask) != + TypeModifierMask) + continue; + switch (static_cast(TypeModifierMask)) { + case TypeModifier::Pointer: + IsPointer = true; + break; + case TypeModifier::Const: + IsConstant = true; + break; + case TypeModifier::Immediate: + IsImmediate = true; + IsConstant = true; + break; + case TypeModifier::UnsignedInteger: + ScalarType = ScalarTypeKind::UnsignedInteger; + break; + case TypeModifier::SignedInteger: + ScalarType = ScalarTypeKind::SignedInteger; + break; + case TypeModifier::Float: + ScalarType = ScalarTypeKind::Float; + break; + case TypeModifier::LMUL1: LMUL = LMULType(0); // Update ElementBitwidth need to update Scale too. Scale = LMUL.getScale(ElementBitwidth); break; default: - llvm_unreachable("Illegal non-primitive type transformer!"); + llvm_unreachable("Unknown type modifier mask!"); + } + } +} + +void RVVType::applyLog2EEW(unsigned Log2EEW) { + // update new elmul = (eew/sew) * lmul + LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); + // update new eew + ElementBitwidth = 1 << Log2EEW; + ScalarType = ScalarTypeKind::SignedInteger; + Scale = LMUL.getScale(ElementBitwidth); +} + +void RVVType::applyFixedSEW(unsigned NewSEW) { + // Set invalid type if src and dst SEW are same. + if (ElementBitwidth == NewSEW) { + ScalarType = ScalarTypeKind::Invalid; + return; + } + // Update new SEW + ElementBitwidth = NewSEW; + Scale = LMUL.getScale(ElementBitwidth); +} + +void RVVType::applyFixedLog2LMUL(int Log2LMUL, bool LargerThan) { + if (LargerThan) { + if (Log2LMUL < LMUL.Log2LMUL) { + ScalarType = ScalarTypeKind::Invalid; + return; + } + } else { + if (Log2LMUL > LMUL.Log2LMUL) { + ScalarType = ScalarTypeKind::Invalid; + return; } } + // Update new LMUL + LMUL = LMULType(Log2LMUL); + Scale = LMUL.getScale(ElementBitwidth); } //===----------------------------------------------------------------------===// diff --git a/clang/utils/TableGen/RISCVVEmitter.cpp b/clang/utils/TableGen/RISCVVEmitter.cpp --- a/clang/utils/TableGen/RISCVVEmitter.cpp +++ b/clang/utils/TableGen/RISCVVEmitter.cpp @@ -48,7 +48,7 @@ /// Emit all the information needed to map builtin -> LLVM IR intrinsic. void createCodeGen(raw_ostream &o); - std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes); + std::string getSuffixStr(BasicType Type, int Log2LMUL, StringRef Prototypes); private: /// Create all intrinsics and add them to \p Out @@ -60,8 +60,9 @@ /// or illegal set to avoid compute the same config again. The result maybe /// have illegal RVVType. Optional computeTypes(BasicType BT, int Log2LMUL, unsigned NF, - ArrayRef PrototypeSeq); - Optional computeType(BasicType BT, int Log2LMUL, StringRef Proto); + ArrayRef PrototypeSeq); + Optional computeType(BasicType BT, int Log2LMUL, + TypeProfile Proto); /// Emit Acrh predecessor definitions and body, assume the element of Defs are /// sorted by extension. @@ -76,11 +77,40 @@ // Slice Prototypes string into sub prototype string and process each sub // prototype string individually in the Handler. void parsePrototypes(StringRef Prototypes, - std::function Handler); + std::function Handler); }; } // namespace +static BasicType ParseBasicType(char c) { + switch (c) { + case 'c': + return BasicType::Int8; + break; + case 's': + return BasicType::Int16; + break; + case 'i': + return BasicType::Int32; + break; + case 'l': + return BasicType::Int64; + break; + case 'x': + return BasicType::Float16; + break; + case 'f': + return BasicType::Float32; + break; + case 'd': + return BasicType::Float64; + break; + + default: + return BasicType::Unknown; + } +} + void emitCodeGenSwitchBody(const RVVIntrinsic *RVVI, raw_ostream &OS) { if (!RVVI->getIRName().empty()) OS << " ID = Intrinsic::riscv_" + RVVI->getIRName() + ";\n"; @@ -202,24 +232,27 @@ constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3}; // Print RVV boolean types. for (int Log2LMUL : Log2LMULs) { - auto T = computeType('c', Log2LMUL, "m"); + auto T = computeType(BasicType::Int8, Log2LMUL, TypeProfile::Mask); if (T.hasValue()) printType(T.getValue()); } // Print RVV int/float types. for (char I : StringRef("csil")) { + BasicType BT = ParseBasicType(I); for (int Log2LMUL : Log2LMULs) { - auto T = computeType(I, Log2LMUL, "v"); + auto T = computeType(BT, Log2LMUL, TypeProfile::Vector); if (T.hasValue()) { printType(T.getValue()); - auto UT = computeType(I, Log2LMUL, "Uv"); + auto UT = computeType( + BT, Log2LMUL, + TypeProfile(PrimitiveType::Vector, TypeModifier::UnsignedInteger)); printType(UT.getValue()); } } } OS << "#if defined(__riscv_zvfh)\n"; for (int Log2LMUL : Log2LMULs) { - auto T = computeType('x', Log2LMUL, "v"); + auto T = computeType(BasicType::Float16, Log2LMUL, TypeProfile::Vector); if (T.hasValue()) printType(T.getValue()); } @@ -227,7 +260,7 @@ OS << "#if defined(__riscv_f)\n"; for (int Log2LMUL : Log2LMULs) { - auto T = computeType('f', Log2LMUL, "v"); + auto T = computeType(BasicType::Float32, Log2LMUL, TypeProfile::Vector); if (T.hasValue()) printType(T.getValue()); } @@ -235,7 +268,7 @@ OS << "#if defined(__riscv_d)\n"; for (int Log2LMUL : Log2LMULs) { - auto T = computeType('d', Log2LMUL, "v"); + auto T = computeType(BasicType::Float64, Log2LMUL, TypeProfile::Vector); if (T.hasValue()) printType(T.getValue()); } @@ -360,7 +393,7 @@ } void RVVEmitter::parsePrototypes(StringRef Prototypes, - std::function Handler) { + std::function Handler) { const StringRef Primaries("evwqom0ztul"); while (!Prototypes.empty()) { size_t Idx = 0; @@ -370,15 +403,18 @@ Idx = Prototypes.find_first_of(')'); Idx = Prototypes.find_first_of(Primaries, Idx); assert(Idx != StringRef::npos); - Handler(Prototypes.slice(0, Idx + 1)); + auto TP = TypeProfile::parseTypeProfile(Prototypes.slice(0, Idx + 1)); + if (!TP) + PrintFatalError("Error during parsing prototype."); + Handler(*TP); Prototypes = Prototypes.drop_front(Idx + 1); } } -std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL, +std::string RVVEmitter::getSuffixStr(BasicType Type, int Log2LMUL, StringRef Prototypes) { SmallVector SuffixStrs; - parsePrototypes(Prototypes, [&](StringRef Proto) { + parsePrototypes(Prototypes, [&](TypeProfile Proto) { auto T = computeType(Type, Log2LMUL, Proto); SuffixStrs.push_back(T.getValue()->getShortStr()); }); @@ -419,13 +455,13 @@ // Parse prototype and create a list of primitive type with transformers // (operand) in ProtoSeq. ProtoSeq[0] is output operand. - SmallVector ProtoSeq; - parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) { - ProtoSeq.push_back(Proto.str()); + SmallVector ProtoSeq; + parsePrototypes(Prototypes, [&ProtoSeq](TypeProfile Proto) { + ProtoSeq.push_back(Proto); }); // Compute Builtin types - SmallVector ProtoMaskSeq = ProtoSeq; + SmallVector ProtoMaskSeq = ProtoSeq; if (HasMasked) { // If HasMaskedOffOperand, insert result type as first input operand. if (HasMaskedOffOperand) { @@ -436,10 +472,10 @@ // (void, op0 address, op1 address, ...) // to // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) + TypeProfile MaskoffType = ProtoSeq[1]; + MaskoffType.TM &= ~static_cast(TypeModifier::Pointer); for (unsigned I = 0; I < NF; ++I) - ProtoMaskSeq.insert( - ProtoMaskSeq.begin() + NF + 1, - ProtoSeq[1].substr(1)); // Use substr(1) to skip '*' + ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, MaskoffType); } } if (HasMaskedOffOperand && NF > 1) { @@ -448,28 +484,29 @@ // to // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1, // ...) - ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m"); + ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, TypeProfile::Mask); } else { - // If HasMasked, insert 'm' as first input operand. - ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m"); + // If HasMasked, insert TypeProfile:Mask as first input operand. + ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, TypeProfile::Mask); } } - // If HasVL, append 'z' to last operand + // If HasVL, append TypeProfile:VL to last operand if (HasVL) { - ProtoSeq.push_back("z"); - ProtoMaskSeq.push_back("z"); + ProtoSeq.push_back(TypeProfile::VL); + ProtoMaskSeq.push_back(TypeProfile::VL); } // Create Intrinsics for each type and LMUL. for (char I : TypeRange) { for (int Log2LMUL : Log2LMULList) { - Optional Types = computeTypes(I, Log2LMUL, NF, ProtoSeq); + BasicType BT = ParseBasicType(I); + Optional Types = computeTypes(BT, Log2LMUL, NF, ProtoSeq); // Ignored to create new intrinsic if there are any illegal types. if (!Types.hasValue()) continue; - auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto); - auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto); + auto SuffixStr = getSuffixStr(BT, Log2LMUL, SuffixProto); + auto MangledSuffixStr = getSuffixStr(BT, Log2LMUL, MangledSuffixProto); // Create a unmasked intrinsic Out.push_back(std::make_unique( Name, SuffixStr, MangledName, MangledSuffixStr, IRName, @@ -480,7 +517,7 @@ if (HasMasked) { // Create a masked intrinsic Optional MaskTypes = - computeTypes(I, Log2LMUL, NF, ProtoMaskSeq); + computeTypes(BT, Log2LMUL, NF, ProtoMaskSeq); Out.push_back(std::make_unique( Name, SuffixStr, MangledName, MangledSuffixStr, MaskedIRName, /*IsMasked=*/true, HasMaskedOffOperand, HasVL, MaskedPolicy, @@ -503,13 +540,13 @@ Optional RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, - ArrayRef PrototypeSeq) { + ArrayRef PrototypeSeq) { // LMUL x NF must be less than or equal to 8. if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) return llvm::None; RVVTypes Types; - for (const std::string &Proto : PrototypeSeq) { + for (const TypeProfile &Proto : PrototypeSeq) { auto T = computeType(BT, Log2LMUL, Proto); if (!T.hasValue()) return llvm::None; @@ -520,8 +557,10 @@ } Optional RVVEmitter::computeType(BasicType BT, int Log2LMUL, - StringRef Proto) { - std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str(); + TypeProfile Proto) { + std::string Idx = + Twine(Twine(static_cast(BT)) + Twine(Log2LMUL) + Proto.IndexStr()) + .str(); // Search first auto It = LegalTypes.find(Idx); if (It != LegalTypes.end())