diff --git a/clang/include/clang/Support/RISCVVIntrinsicUtils.h b/clang/include/clang/Support/RISCVVIntrinsicUtils.h --- a/clang/include/clang/Support/RISCVVIntrinsicUtils.h +++ b/clang/include/clang/Support/RISCVVIntrinsicUtils.h @@ -9,7 +9,10 @@ #ifndef CLANG_SUPPORT_RISCVVINTRINSICUTILS_H #define CLANG_SUPPORT_RISCVVINTRINSICUTILS_H +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include #include @@ -18,9 +21,128 @@ namespace clang { namespace RISCV { -using BasicType = char; using VScaleVal = llvm::Optional; +// Modifier for vector type. +enum class VectorTypeModifier : uint8_t { + NoModifier, + Log2EEW3, + Log2EEW4, + Log2EEW5, + Log2EEW6, + FixedSEW8, + FixedSEW16, + FixedSEW32, + FixedSEW64, + LFixedLog2LMULN3, + LFixedLog2LMULN2, + LFixedLog2LMULN1, + LFixedLog2LMUL0, + LFixedLog2LMUL1, + LFixedLog2LMUL2, + LFixedLog2LMUL3, + SFixedLog2LMULN3, + SFixedLog2LMULN2, + SFixedLog2LMULN1, + SFixedLog2LMUL0, + SFixedLog2LMUL1, + SFixedLog2LMUL2, + SFixedLog2LMUL3, +}; + +// Similar to basic type but used to describe what's kind of type related to +// basic vector type, used to compute type info of arguments. +enum class PrimitiveType : uint8_t { + Invalid, + Scalar, + Vector, + Widening2XVector, + Widening4XVector, + Widening8XVector, + MaskVector, + Void, + SizeT, + Ptrdiff, + UnsignedLong, + SignedLong, +}; + +// Modifier for type, used for both scalar and vector types. +enum class TypeModifier : uint8_t { + NoModifier = 0, + Pointer = 1 << 0, + Const = 1 << 1, + Immediate = 1 << 2, + UnsignedInteger = 1 << 3, + SignedInteger = 1 << 4, + Float = 1 << 5, + LMUL1 = 1 << 6, + MaxOffset = 6, + LLVM_MARK_AS_BITMASK_ENUM(LMUL1), +}; + +// TypeProfile is used to compute type info of arguments or return value. +struct TypeProfile { + constexpr TypeProfile() = default; + constexpr TypeProfile(PrimitiveType PT) : PT(static_cast(PT)) {} + constexpr TypeProfile(PrimitiveType PT, TypeModifier TM) + : PT(static_cast(PT)), TM(static_cast(TM)) {} + constexpr TypeProfile(uint8_t PT, uint8_t VTM, uint8_t TM) + : PT(PT), VTM(VTM), TM(TM) {} + + uint8_t PT = static_cast(PrimitiveType::Invalid); + uint8_t VTM = static_cast(VectorTypeModifier::NoModifier); + uint8_t TM = static_cast(TypeModifier::NoModifier); + + std::string IndexStr() const { + return std::to_string(PT) + "_" + std::to_string(VTM) + "_" + + std::to_string(TM); + }; + + bool operator!=(const TypeProfile &TP) const { + return TP.PT != PT || TP.VTM != VTM || TP.TM != TM; + } + bool operator>(const TypeProfile &TP) const { + return !(TP.PT <= PT && TP.VTM <= VTM && TP.TM <= TM); + } + + static const TypeProfile Mask; + static const TypeProfile Vector; + static const TypeProfile VL; + static llvm::Optional + parseTypeProfile(llvm::StringRef PrototypeStr); +}; + +llvm::SmallVector parsePrototypes(llvm::StringRef Prototypes); + +// Basic type of vector type. +enum class BasicType : uint8_t { + Unknown = 0, + Int8 = 1 << 0, + Int16 = 1 << 1, + Int32 = 1 << 2, + Int64 = 1 << 3, + Float16 = 1 << 4, + Float32 = 1 << 5, + Float64 = 1 << 6, + MaxOffset = 6, + LLVM_MARK_AS_BITMASK_ENUM(Float64), +}; + +// Type of vector type. +enum ScalarTypeKind : uint8_t { + Void, + Size_t, + Ptrdiff_t, + UnsignedLong, + SignedLong, + Boolean, + SignedInteger, + UnsignedInteger, + Float, + Invalid, +}; + // Exponential LMUL struct LMULType { int Log2LMUL; @@ -32,20 +154,12 @@ LMULType &operator*=(uint32_t RHS); }; +class RVVType; +using RVVTypePtr = RVVType *; +using RVVTypes = std::vector; + // This class is compact representation of a valid and invalid RVVType. class RVVType { - enum ScalarTypeKind : uint32_t { - Void, - Size_t, - Ptrdiff_t, - UnsignedLong, - SignedLong, - Boolean, - SignedInteger, - UnsignedInteger, - Float, - Invalid, - }; BasicType BT; ScalarTypeKind ScalarType = Invalid; LMULType LMUL; @@ -64,8 +178,8 @@ std::string ShortStr; public: - RVVType() : RVVType(BasicType(), 0, llvm::StringRef()) {} - RVVType(BasicType BT, int Log2LMUL, llvm::StringRef prototype); + RVVType() : BT(BasicType::Unknown), LMUL(0), Valid(false) {} + RVVType(BasicType BT, int Log2LMUL, const TypeProfile &Profile); // Return the string representation of a type, which is an encoded string for // passing to the BUILTIN() macro in Builtins.def. @@ -114,7 +228,11 @@ // Applies a prototype modifier to the current type. The result maybe an // invalid type. - void applyModifier(llvm::StringRef prototype); + void applyModifier(const TypeProfile &prototype); + + void applyLog2EEW(unsigned Log2EEW); + void applyFixedSEW(unsigned NewSEW); + void applyFixedLog2LMUL(int Log2LMUL, bool LargerThan); // Compute and record a string for legal type. void initBuiltinStr(); @@ -124,10 +242,19 @@ void initTypeStr(); // Compute and record a short name of a type for C/C++ name suffix. void initShortStr(); + +public: + /// Compute output and input types by applying different config (basic type + /// and LMUL with type transformers). It also record result of type in legal + /// or illegal set to avoid compute the same config again. The result maybe + /// have illegal RVVType. + static llvm::Optional + computeTypes(BasicType BT, int Log2LMUL, unsigned NF, + llvm::ArrayRef PrototypeSeq); + static llvm::Optional computeType(BasicType BT, int Log2LMUL, + TypeProfile Proto); }; -using RVVTypePtr = RVVType *; -using RVVTypes = std::vector; using RISCVPredefinedMacroT = uint8_t; enum RISCVPredefinedMacro : RISCVPredefinedMacroT { @@ -206,6 +333,10 @@ // Return the type string for a BUILTIN() macro in Builtins.def. std::string getBuiltinTypeStr() const; + + static std::string + getSuffixStr(BasicType Type, int Log2LMUL, + const llvm::SmallVector &TypeProfiles); }; } // end namespace RISCV diff --git a/clang/lib/Support/RISCVVIntrinsicUtils.cpp b/clang/lib/Support/RISCVVIntrinsicUtils.cpp --- a/clang/lib/Support/RISCVVIntrinsicUtils.cpp +++ b/clang/lib/Support/RISCVVIntrinsicUtils.cpp @@ -22,6 +22,14 @@ namespace clang { namespace RISCV { +const TypeProfile TypeProfile::Mask = TypeProfile(PrimitiveType::MaskVector); +const TypeProfile TypeProfile::VL = TypeProfile(PrimitiveType::SizeT); +const TypeProfile TypeProfile::Vector = TypeProfile(PrimitiveType::Vector); + +// Concat BasicType, LMUL and Proto as key +static StringMap LegalTypes; +static StringSet<> IllegalTypes; + //===----------------------------------------------------------------------===// // Type implementation //===----------------------------------------------------------------------===// @@ -70,7 +78,7 @@ return *this; } -RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype) +RVVType::RVVType(BasicType BT, int Log2LMUL, const TypeProfile &prototype) : BT(BT), LMUL(LMULType(Log2LMUL)) { applyBasicType(); applyModifier(prototype); @@ -326,31 +334,31 @@ void RVVType::applyBasicType() { switch (BT) { - case 'c': + case BasicType::Int8: ElementBitwidth = 8; ScalarType = ScalarTypeKind::SignedInteger; break; - case 's': + case BasicType::Int16: ElementBitwidth = 16; ScalarType = ScalarTypeKind::SignedInteger; break; - case 'i': + case BasicType::Int32: ElementBitwidth = 32; ScalarType = ScalarTypeKind::SignedInteger; break; - case 'l': + case BasicType::Int64: ElementBitwidth = 64; ScalarType = ScalarTypeKind::SignedInteger; break; - case 'x': + case BasicType::Float16: ElementBitwidth = 16; ScalarType = ScalarTypeKind::Float; break; - case 'f': + case BasicType::Float32: ElementBitwidth = 32; ScalarType = ScalarTypeKind::Float; break; - case 'd': + case BasicType::Float64: ElementBitwidth = 64; ScalarType = ScalarTypeKind::Float; break; @@ -360,162 +368,460 @@ assert(ElementBitwidth != 0 && "Bad element bitwidth!"); } -void RVVType::applyModifier(StringRef Transformer) { - if (Transformer.empty()) - return; +Optional +TypeProfile::parseTypeProfile(llvm::StringRef TypeProfileStr) { + TypeProfile TP; + PrimitiveType PT = PrimitiveType::Invalid; + if (TypeProfileStr.empty()) + return TP; // Handle primitive type transformer - auto PType = Transformer.back(); + auto PType = TypeProfileStr.back(); switch (PType) { case 'e': - Scale = 0; + PT = PrimitiveType::Scalar; break; case 'v': - Scale = LMUL.getScale(ElementBitwidth); + PT = PrimitiveType::Vector; break; case 'w': - ElementBitwidth *= 2; - LMUL *= 2; - Scale = LMUL.getScale(ElementBitwidth); + PT = PrimitiveType::Widening2XVector; break; case 'q': - ElementBitwidth *= 4; - LMUL *= 4; - Scale = LMUL.getScale(ElementBitwidth); + PT = PrimitiveType::Widening4XVector; break; case 'o': - ElementBitwidth *= 8; - LMUL *= 8; - Scale = LMUL.getScale(ElementBitwidth); + PT = PrimitiveType::Widening8XVector; break; case 'm': - ScalarType = ScalarTypeKind::Boolean; - Scale = LMUL.getScale(ElementBitwidth); - ElementBitwidth = 1; + PT = PrimitiveType::MaskVector; break; case '0': - ScalarType = ScalarTypeKind::Void; + PT = PrimitiveType::Void; break; case 'z': - ScalarType = ScalarTypeKind::Size_t; + PT = PrimitiveType::SizeT; break; case 't': - ScalarType = ScalarTypeKind::Ptrdiff_t; + PT = PrimitiveType::Ptrdiff; break; case 'u': - ScalarType = ScalarTypeKind::UnsignedLong; + PT = PrimitiveType::UnsignedLong; break; case 'l': - ScalarType = ScalarTypeKind::SignedLong; + PT = PrimitiveType::SignedLong; break; default: llvm_unreachable("Illegal primitive type transformers!"); } - Transformer = Transformer.drop_back(); + TP.PT = static_cast(PT); + TypeProfileStr = TypeProfileStr.drop_back(); // Extract and compute complex type transformer. It can only appear one time. - if (Transformer.startswith("(")) { - size_t Idx = Transformer.find(')'); + if (TypeProfileStr.startswith("(")) { + size_t Idx = TypeProfileStr.find(')'); assert(Idx != StringRef::npos); - StringRef ComplexType = Transformer.slice(1, Idx); - Transformer = Transformer.drop_front(Idx + 1); - assert(!Transformer.contains('(') && + StringRef ComplexType = TypeProfileStr.slice(1, Idx); + TypeProfileStr = TypeProfileStr.drop_front(Idx + 1); + assert(!TypeProfileStr.contains('(') && "Only allow one complex type transformer"); - auto UpdateAndCheckComplexProto = [&]() { - Scale = LMUL.getScale(ElementBitwidth); - const StringRef VectorPrototypes("vwqom"); - if (!VectorPrototypes.contains(PType)) - llvm_unreachable("Complex type transformer only supports vector type!"); - if (Transformer.find_first_of("PCKWS") != StringRef::npos) - llvm_unreachable( - "Illegal type transformer for Complex type transformer"); - }; - auto ComputeFixedLog2LMUL = - [&](StringRef Value, - std::function Compare) { - int32_t Log2LMUL; - Value.getAsInteger(10, Log2LMUL); - if (!Compare(Log2LMUL, LMUL.Log2LMUL)) { - ScalarType = Invalid; - return false; - } - // Update new LMUL - LMUL = LMULType(Log2LMUL); - UpdateAndCheckComplexProto(); - return true; - }; auto ComplexTT = ComplexType.split(":"); + VectorTypeModifier VTM = VectorTypeModifier::NoModifier; if (ComplexTT.first == "Log2EEW") { uint32_t Log2EEW; - ComplexTT.second.getAsInteger(10, Log2EEW); - // update new elmul = (eew/sew) * lmul - LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); - // update new eew - ElementBitwidth = 1 << Log2EEW; - ScalarType = ScalarTypeKind::SignedInteger; - UpdateAndCheckComplexProto(); + if (ComplexTT.second.getAsInteger(10, Log2EEW)) { + llvm_unreachable("Invalid Log2EEW value!"); + return None; + } + switch (Log2EEW) { + case 3: + VTM = VectorTypeModifier::Log2EEW3; + break; + case 4: + VTM = VectorTypeModifier::Log2EEW4; + break; + case 5: + VTM = VectorTypeModifier::Log2EEW5; + break; + case 6: + VTM = VectorTypeModifier::Log2EEW6; + break; + default: + llvm_unreachable("Invalid Log2EEW value, should be [3-6]"); + return None; + } } else if (ComplexTT.first == "FixedSEW") { uint32_t NewSEW; - ComplexTT.second.getAsInteger(10, NewSEW); - // Set invalid type if src and dst SEW are same. - if (ElementBitwidth == NewSEW) { - ScalarType = Invalid; - return; + if (ComplexTT.second.getAsInteger(10, NewSEW)) { + llvm_unreachable("Invalid FixedSEW value!"); + return None; + } + switch (NewSEW) { + case 8: + VTM = VectorTypeModifier::FixedSEW8; + break; + case 16: + VTM = VectorTypeModifier::FixedSEW16; + break; + case 32: + VTM = VectorTypeModifier::FixedSEW32; + break; + case 64: + VTM = VectorTypeModifier::FixedSEW64; + break; + default: + llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64"); + return None; } - // Update new SEW - ElementBitwidth = NewSEW; - UpdateAndCheckComplexProto(); } else if (ComplexTT.first == "LFixedLog2LMUL") { - // New LMUL should be larger than old - if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater())) - return; + int32_t Log2LMUL; + if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { + llvm_unreachable("Invalid LFixedLog2LMUL value!"); + return None; + } + switch (Log2LMUL) { + case -3: + VTM = VectorTypeModifier::LFixedLog2LMULN3; + break; + case -2: + VTM = VectorTypeModifier::LFixedLog2LMULN2; + break; + case -1: + VTM = VectorTypeModifier::LFixedLog2LMULN1; + break; + case 0: + VTM = VectorTypeModifier::LFixedLog2LMUL0; + break; + case 1: + VTM = VectorTypeModifier::LFixedLog2LMUL1; + break; + case 2: + VTM = VectorTypeModifier::LFixedLog2LMUL2; + break; + case 3: + VTM = VectorTypeModifier::LFixedLog2LMUL3; + break; + default: + llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); + return None; + } } else if (ComplexTT.first == "SFixedLog2LMUL") { - // New LMUL should be smaller than old - if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less())) - return; + int32_t Log2LMUL; + if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { + llvm_unreachable("Invalid SFixedLog2LMUL value!"); + return None; + } + switch (Log2LMUL) { + case -3: + VTM = VectorTypeModifier::SFixedLog2LMULN3; + break; + case -2: + VTM = VectorTypeModifier::SFixedLog2LMULN2; + break; + case -1: + VTM = VectorTypeModifier::SFixedLog2LMULN1; + break; + case 0: + VTM = VectorTypeModifier::SFixedLog2LMUL0; + break; + case 1: + VTM = VectorTypeModifier::SFixedLog2LMUL1; + break; + case 2: + VTM = VectorTypeModifier::SFixedLog2LMUL2; + break; + case 3: + VTM = VectorTypeModifier::SFixedLog2LMUL3; + break; + default: + llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); + return None; + } + } else { llvm_unreachable("Illegal complex type transformers!"); } + TP.VTM = static_cast(VTM); } // Compute the remain type transformers - for (char I : Transformer) { + TypeModifier TM = TypeModifier::NoModifier; + for (char I : TypeProfileStr) { switch (I) { case 'P': - if (IsConstant) + if ((TM & TypeModifier::Const) == TypeModifier::Const) llvm_unreachable("'P' transformer cannot be used after 'C'"); - if (IsPointer) + if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer) llvm_unreachable("'P' transformer cannot be used twice"); - IsPointer = true; + TM |= TypeModifier::Pointer; break; case 'C': - if (IsConstant) - llvm_unreachable("'C' transformer cannot be used twice"); - IsConstant = true; + TM |= TypeModifier::Const; break; case 'K': - IsImmediate = true; + TM |= TypeModifier::Immediate; break; case 'U': - ScalarType = ScalarTypeKind::UnsignedInteger; + TM |= TypeModifier::UnsignedInteger; break; case 'I': - ScalarType = ScalarTypeKind::SignedInteger; + TM |= TypeModifier::SignedInteger; break; case 'F': - ScalarType = ScalarTypeKind::Float; + TM |= TypeModifier::Float; break; case 'S': + TM |= TypeModifier::LMUL1; + break; + default: + llvm_unreachable("Illegal non-primitive type transformer!"); + } + } + TP.TM = static_cast(TM); + + return TP; +} + +void RVVType::applyModifier(const TypeProfile &Transformer) { + // Handle primitive type transformer + switch (static_cast(Transformer.PT)) { + case PrimitiveType::Scalar: + Scale = 0; + break; + case PrimitiveType::Vector: + Scale = LMUL.getScale(ElementBitwidth); + break; + case PrimitiveType::Widening2XVector: + ElementBitwidth *= 2; + LMUL *= 2; + Scale = LMUL.getScale(ElementBitwidth); + break; + case PrimitiveType::Widening4XVector: + ElementBitwidth *= 4; + LMUL *= 4; + Scale = LMUL.getScale(ElementBitwidth); + break; + case PrimitiveType::Widening8XVector: + ElementBitwidth *= 8; + LMUL *= 8; + Scale = LMUL.getScale(ElementBitwidth); + break; + case PrimitiveType::MaskVector: + ScalarType = ScalarTypeKind::Boolean; + Scale = LMUL.getScale(ElementBitwidth); + ElementBitwidth = 1; + break; + case PrimitiveType::Void: + ScalarType = ScalarTypeKind::Void; + break; + case PrimitiveType::SizeT: + ScalarType = ScalarTypeKind::Size_t; + break; + case PrimitiveType::Ptrdiff: + ScalarType = ScalarTypeKind::Ptrdiff_t; + break; + case PrimitiveType::UnsignedLong: + ScalarType = ScalarTypeKind::UnsignedLong; + break; + case PrimitiveType::SignedLong: + ScalarType = ScalarTypeKind::SignedLong; + break; + case PrimitiveType::Invalid: + ScalarType = ScalarTypeKind::Invalid; + return; + default: + llvm_unreachable("Illegal primitive type transformers!"); + } + + switch (static_cast(Transformer.VTM)) { + case VectorTypeModifier::Log2EEW3: + applyLog2EEW(3); + break; + case VectorTypeModifier::Log2EEW4: + applyLog2EEW(4); + break; + case VectorTypeModifier::Log2EEW5: + applyLog2EEW(5); + break; + case VectorTypeModifier::Log2EEW6: + applyLog2EEW(6); + break; + case VectorTypeModifier::FixedSEW8: + applyFixedSEW(8); + break; + case VectorTypeModifier::FixedSEW16: + applyFixedSEW(16); + break; + case VectorTypeModifier::FixedSEW32: + applyFixedSEW(32); + break; + case VectorTypeModifier::FixedSEW64: + applyFixedSEW(64); + break; + case VectorTypeModifier::LFixedLog2LMULN3: + applyFixedLog2LMUL(-3, /* LargerThan= */ true); + break; + case VectorTypeModifier::LFixedLog2LMULN2: + applyFixedLog2LMUL(-2, /* LargerThan= */ true); + break; + case VectorTypeModifier::LFixedLog2LMULN1: + applyFixedLog2LMUL(-1, /* LargerThan= */ true); + break; + case VectorTypeModifier::LFixedLog2LMUL0: + applyFixedLog2LMUL(0, /* LargerThan= */ true); + break; + case VectorTypeModifier::LFixedLog2LMUL1: + applyFixedLog2LMUL(1, /* LargerThan= */ true); + break; + case VectorTypeModifier::LFixedLog2LMUL2: + applyFixedLog2LMUL(2, /* LargerThan= */ true); + break; + case VectorTypeModifier::LFixedLog2LMUL3: + applyFixedLog2LMUL(3, /* LargerThan= */ true); + break; + case VectorTypeModifier::SFixedLog2LMULN3: + applyFixedLog2LMUL(-3, /* LargerThan= */ false); + break; + case VectorTypeModifier::SFixedLog2LMULN2: + applyFixedLog2LMUL(-2, /* LargerThan= */ false); + break; + case VectorTypeModifier::SFixedLog2LMULN1: + applyFixedLog2LMUL(-1, /* LargerThan= */ false); + break; + case VectorTypeModifier::SFixedLog2LMUL0: + applyFixedLog2LMUL(0, /* LargerThan= */ false); + break; + case VectorTypeModifier::SFixedLog2LMUL1: + applyFixedLog2LMUL(1, /* LargerThan= */ false); + break; + case VectorTypeModifier::SFixedLog2LMUL2: + applyFixedLog2LMUL(2, /* LargerThan= */ false); + break; + case VectorTypeModifier::SFixedLog2LMUL3: + applyFixedLog2LMUL(3, /* LargerThan= */ false); + break; + case VectorTypeModifier::NoModifier: + break; + default: + llvm_unreachable("Illegal vector type modifier!"); + } + + for (unsigned TypeModifierMaskShift = 0; + TypeModifierMaskShift <= static_cast(TypeModifier::MaxOffset); + ++TypeModifierMaskShift) { + unsigned TypeModifierMask = 1 << TypeModifierMaskShift; + if ((static_cast(Transformer.TM) & TypeModifierMask) != + TypeModifierMask) + continue; + switch (static_cast(TypeModifierMask)) { + case TypeModifier::Pointer: + IsPointer = true; + break; + case TypeModifier::Const: + IsConstant = true; + break; + case TypeModifier::Immediate: + IsImmediate = true; + IsConstant = true; + break; + case TypeModifier::UnsignedInteger: + ScalarType = ScalarTypeKind::UnsignedInteger; + break; + case TypeModifier::SignedInteger: + ScalarType = ScalarTypeKind::SignedInteger; + break; + case TypeModifier::Float: + ScalarType = ScalarTypeKind::Float; + break; + case TypeModifier::LMUL1: LMUL = LMULType(0); // Update ElementBitwidth need to update Scale too. Scale = LMUL.getScale(ElementBitwidth); break; default: - llvm_unreachable("Illegal non-primitive type transformer!"); + llvm_unreachable("Unknown type modifier mask!"); } } } +void RVVType::applyLog2EEW(unsigned Log2EEW) { + // update new elmul = (eew/sew) * lmul + LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); + // update new eew + ElementBitwidth = 1 << Log2EEW; + ScalarType = ScalarTypeKind::SignedInteger; + Scale = LMUL.getScale(ElementBitwidth); +} + +void RVVType::applyFixedSEW(unsigned NewSEW) { + // Set invalid type if src and dst SEW are same. + if (ElementBitwidth == NewSEW) { + ScalarType = ScalarTypeKind::Invalid; + return; + } + // Update new SEW + ElementBitwidth = NewSEW; + Scale = LMUL.getScale(ElementBitwidth); +} + +void RVVType::applyFixedLog2LMUL(int Log2LMUL, bool LargerThan) { + if (LargerThan) { + if (Log2LMUL < LMUL.Log2LMUL) { + ScalarType = ScalarTypeKind::Invalid; + return; + } + } else { + if (Log2LMUL > LMUL.Log2LMUL) { + ScalarType = ScalarTypeKind::Invalid; + return; + } + } + // Update new LMUL + LMUL = LMULType(Log2LMUL); + Scale = LMUL.getScale(ElementBitwidth); +} + +Optional RVVType::computeTypes(BasicType BT, int Log2LMUL, + unsigned NF, + ArrayRef PrototypeSeq) { + // LMUL x NF must be less than or equal to 8. + if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) + return llvm::None; + + RVVTypes Types; + for (const TypeProfile &Proto : PrototypeSeq) { + auto T = computeType(BT, Log2LMUL, Proto); + if (!T.hasValue()) + return llvm::None; + // Record legal type index + Types.push_back(T.getValue()); + } + return Types; +} + +Optional RVVType::computeType(BasicType BT, int Log2LMUL, + TypeProfile Proto) { + std::string Idx = + Twine(Twine(static_cast(BT)) + Twine(Log2LMUL) + Proto.IndexStr()) + .str(); + // Search first + auto It = LegalTypes.find(Idx); + if (It != LegalTypes.end()) + return &(It->second); + if (IllegalTypes.count(Idx)) + return llvm::None; + // Compute type and record the result. + RVVType T(BT, Log2LMUL, Proto); + if (T.isValid()) { + // Record legal type index and value. + LegalTypes.insert({Idx, T}); + return &(LegalTypes[Idx]); + } + // Record illegal type index. + IllegalTypes.insert(Idx); + return llvm::None; +} + //===----------------------------------------------------------------------===// // RVVIntrinsic implementation //===----------------------------------------------------------------------===// @@ -593,5 +899,36 @@ return S; } +std::string +RVVIntrinsic::getSuffixStr(BasicType Type, int Log2LMUL, + const llvm::SmallVector &TypeProfiles) { + SmallVector SuffixStrs; + for (auto TP : TypeProfiles) { + auto T = RVVType::computeType(Type, Log2LMUL, TP); + SuffixStrs.push_back(T.getValue()->getShortStr()); + } + return join(SuffixStrs, "_"); +} + +SmallVector parsePrototypes(StringRef Prototypes) { + SmallVector TypeProfiles; + const StringRef Primaries("evwqom0ztul"); + while (!Prototypes.empty()) { + size_t Idx = 0; + // Skip over complex prototype because it could contain primitive type + // character. + if (Prototypes[0] == '(') + Idx = Prototypes.find_first_of(')'); + Idx = Prototypes.find_first_of(Primaries, Idx); + assert(Idx != StringRef::npos); + auto TP = TypeProfile::parseTypeProfile(Prototypes.slice(0, Idx + 1)); + if (!TP) + llvm_unreachable("Error during parsing prototype."); + TypeProfiles.push_back(*TP); + Prototypes = Prototypes.drop_front(Idx + 1); + } + return std::move(TypeProfiles); +} + } // end namespace RISCV } // end namespace clang diff --git a/clang/utils/TableGen/RISCVVEmitter.cpp b/clang/utils/TableGen/RISCVVEmitter.cpp --- a/clang/utils/TableGen/RISCVVEmitter.cpp +++ b/clang/utils/TableGen/RISCVVEmitter.cpp @@ -32,9 +32,6 @@ class RVVEmitter { private: RecordKeeper &Records; - // Concat BasicType, LMUL and Proto as key - StringMap LegalTypes; - StringSet<> IllegalTypes; public: RVVEmitter(RecordKeeper &R) : Records(R) {} @@ -48,20 +45,11 @@ /// Emit all the information needed to map builtin -> LLVM IR intrinsic. void createCodeGen(raw_ostream &o); - std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes); - private: /// Create all intrinsics and add them to \p Out void createRVVIntrinsics(std::vector> &Out); /// Print HeaderCode in RVVHeader Record to \p Out void printHeaderCode(raw_ostream &OS); - /// Compute output and input types by applying different config (basic type - /// and LMUL with type transformers). It also record result of type in legal - /// or illegal set to avoid compute the same config again. The result maybe - /// have illegal RVVType. - Optional computeTypes(BasicType BT, int Log2LMUL, unsigned NF, - ArrayRef PrototypeSeq); - Optional computeType(BasicType BT, int Log2LMUL, StringRef Proto); /// Emit Acrh predecessor definitions and body, assume the element of Defs are /// sorted by extension. @@ -73,14 +61,39 @@ // non-empty string. bool emitMacroRestrictionStr(RISCVPredefinedMacroT PredefinedMacros, raw_ostream &o); - // Slice Prototypes string into sub prototype string and process each sub - // prototype string individually in the Handler. - void parsePrototypes(StringRef Prototypes, - std::function Handler); }; } // namespace +static BasicType ParseBasicType(char c) { + switch (c) { + case 'c': + return BasicType::Int8; + break; + case 's': + return BasicType::Int16; + break; + case 'i': + return BasicType::Int32; + break; + case 'l': + return BasicType::Int64; + break; + case 'x': + return BasicType::Float16; + break; + case 'f': + return BasicType::Float32; + break; + case 'd': + return BasicType::Float64; + break; + + default: + return BasicType::Unknown; + } +} + void emitCodeGenSwitchBody(const RVVIntrinsic *RVVI, raw_ostream &OS) { if (!RVVI->getIRName().empty()) OS << " ID = Intrinsic::riscv_" + RVVI->getIRName() + ";\n"; @@ -202,24 +215,28 @@ constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3}; // Print RVV boolean types. for (int Log2LMUL : Log2LMULs) { - auto T = computeType('c', Log2LMUL, "m"); + auto T = RVVType::computeType(BasicType::Int8, Log2LMUL, TypeProfile::Mask); if (T.hasValue()) printType(T.getValue()); } // Print RVV int/float types. for (char I : StringRef("csil")) { + BasicType BT = ParseBasicType(I); for (int Log2LMUL : Log2LMULs) { - auto T = computeType(I, Log2LMUL, "v"); + auto T = RVVType::computeType(BT, Log2LMUL, TypeProfile::Vector); if (T.hasValue()) { printType(T.getValue()); - auto UT = computeType(I, Log2LMUL, "Uv"); + auto UT = RVVType::computeType( + BT, Log2LMUL, + TypeProfile(PrimitiveType::Vector, TypeModifier::UnsignedInteger)); printType(UT.getValue()); } } } OS << "#if defined(__riscv_zvfh)\n"; for (int Log2LMUL : Log2LMULs) { - auto T = computeType('x', Log2LMUL, "v"); + auto T = + RVVType::computeType(BasicType::Float16, Log2LMUL, TypeProfile::Vector); if (T.hasValue()) printType(T.getValue()); } @@ -227,7 +244,8 @@ OS << "#if defined(__riscv_f)\n"; for (int Log2LMUL : Log2LMULs) { - auto T = computeType('f', Log2LMUL, "v"); + auto T = + RVVType::computeType(BasicType::Float32, Log2LMUL, TypeProfile::Vector); if (T.hasValue()) printType(T.getValue()); } @@ -235,7 +253,8 @@ OS << "#if defined(__riscv_d)\n"; for (int Log2LMUL : Log2LMULs) { - auto T = computeType('d', Log2LMUL, "v"); + auto T = + RVVType::computeType(BasicType::Float64, Log2LMUL, TypeProfile::Vector); if (T.hasValue()) printType(T.getValue()); } @@ -359,32 +378,6 @@ OS << "\n"; } -void RVVEmitter::parsePrototypes(StringRef Prototypes, - std::function Handler) { - const StringRef Primaries("evwqom0ztul"); - while (!Prototypes.empty()) { - size_t Idx = 0; - // Skip over complex prototype because it could contain primitive type - // character. - if (Prototypes[0] == '(') - Idx = Prototypes.find_first_of(')'); - Idx = Prototypes.find_first_of(Primaries, Idx); - assert(Idx != StringRef::npos); - Handler(Prototypes.slice(0, Idx + 1)); - Prototypes = Prototypes.drop_front(Idx + 1); - } -} - -std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL, - StringRef Prototypes) { - SmallVector SuffixStrs; - parsePrototypes(Prototypes, [&](StringRef Proto) { - auto T = computeType(Type, Log2LMUL, Proto); - SuffixStrs.push_back(T.getValue()->getShortStr()); - }); - return join(SuffixStrs, "_"); -} - void RVVEmitter::createRVVIntrinsics( std::vector> &Out) { std::vector RV = Records.getAllDerivedDefinitions("RVVBuiltin"); @@ -419,13 +412,14 @@ // Parse prototype and create a list of primitive type with transformers // (operand) in ProtoSeq. ProtoSeq[0] is output operand. - SmallVector ProtoSeq; - parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) { - ProtoSeq.push_back(Proto.str()); - }); + SmallVector ProtoSeq = parsePrototypes(Prototypes); + + SmallVector SuffixProtoSeq = parsePrototypes(SuffixProto); + SmallVector MangledSuffixProtoSeq = + parsePrototypes(MangledSuffixProto); // Compute Builtin types - SmallVector ProtoMaskSeq = ProtoSeq; + SmallVector ProtoMaskSeq = ProtoSeq; if (HasMasked) { // If HasMaskedOffOperand, insert result type as first input operand. if (HasMaskedOffOperand) { @@ -436,10 +430,10 @@ // (void, op0 address, op1 address, ...) // to // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) + TypeProfile MaskoffType = ProtoSeq[1]; + MaskoffType.TM &= ~static_cast(TypeModifier::Pointer); for (unsigned I = 0; I < NF; ++I) - ProtoMaskSeq.insert( - ProtoMaskSeq.begin() + NF + 1, - ProtoSeq[1].substr(1)); // Use substr(1) to skip '*' + ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, MaskoffType); } } if (HasMaskedOffOperand && NF > 1) { @@ -448,28 +442,32 @@ // to // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1, // ...) - ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m"); + ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, TypeProfile::Mask); } else { - // If HasMasked, insert 'm' as first input operand. - ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m"); + // If HasMasked, insert TypeProfile:Mask as first input operand. + ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, TypeProfile::Mask); } } - // If HasVL, append 'z' to last operand + // If HasVL, append TypeProfile:VL to last operand if (HasVL) { - ProtoSeq.push_back("z"); - ProtoMaskSeq.push_back("z"); + ProtoSeq.push_back(TypeProfile::VL); + ProtoMaskSeq.push_back(TypeProfile::VL); } // Create Intrinsics for each type and LMUL. for (char I : TypeRange) { for (int Log2LMUL : Log2LMULList) { - Optional Types = computeTypes(I, Log2LMUL, NF, ProtoSeq); + BasicType BT = ParseBasicType(I); + Optional Types = + RVVType::computeTypes(BT, Log2LMUL, NF, ProtoSeq); // Ignored to create new intrinsic if there are any illegal types. if (!Types.hasValue()) continue; - auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto); - auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto); + auto SuffixStr = + RVVIntrinsic::getSuffixStr(BT, Log2LMUL, SuffixProtoSeq); + auto MangledSuffixStr = + RVVIntrinsic::getSuffixStr(BT, Log2LMUL, MangledSuffixProtoSeq); // Create a unmasked intrinsic Out.push_back(std::make_unique( Name, SuffixStr, MangledName, MangledSuffixStr, IRName, @@ -480,7 +478,7 @@ if (HasMasked) { // Create a masked intrinsic Optional MaskTypes = - computeTypes(I, Log2LMUL, NF, ProtoMaskSeq); + RVVType::computeTypes(BT, Log2LMUL, NF, ProtoMaskSeq); Out.push_back(std::make_unique( Name, SuffixStr, MangledName, MangledSuffixStr, MaskedIRName, /*IsMasked=*/true, HasMaskedOffOperand, HasVL, MaskedPolicy, @@ -501,45 +499,6 @@ } } -Optional -RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, - ArrayRef PrototypeSeq) { - // LMUL x NF must be less than or equal to 8. - if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) - return llvm::None; - - RVVTypes Types; - for (const std::string &Proto : PrototypeSeq) { - auto T = computeType(BT, Log2LMUL, Proto); - if (!T.hasValue()) - return llvm::None; - // Record legal type index - Types.push_back(T.getValue()); - } - return Types; -} - -Optional RVVEmitter::computeType(BasicType BT, int Log2LMUL, - StringRef Proto) { - std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str(); - // Search first - auto It = LegalTypes.find(Idx); - if (It != LegalTypes.end()) - return &(It->second); - if (IllegalTypes.count(Idx)) - return llvm::None; - // Compute type and record the result. - RVVType T(BT, Log2LMUL, Proto); - if (T.isValid()) { - // Record legal type index and value. - LegalTypes.insert({Idx, T}); - return &(LegalTypes[Idx]); - } - // Record illegal type index. - IllegalTypes.insert(Idx); - return llvm::None; -} - void RVVEmitter::emitArchMacroAndBody( std::vector> &Defs, raw_ostream &OS, std::function PrintBody) {