diff --git a/clang/include/clang/Support/RISCVVIntrinsicUtils.h b/clang/include/clang/Support/RISCVVIntrinsicUtils.h --- a/clang/include/clang/Support/RISCVVIntrinsicUtils.h +++ b/clang/include/clang/Support/RISCVVIntrinsicUtils.h @@ -9,7 +9,10 @@ #ifndef CLANG_SUPPORT_RISCVVINTRINSICUTILS_H #define CLANG_SUPPORT_RISCVVINTRINSICUTILS_H +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include #include @@ -18,9 +21,133 @@ namespace clang { namespace RISCV { -using BasicType = char; using VScaleVal = llvm::Optional; +// Modifier for vector type. +enum class VectorTypeModifier : uint8_t { + NoModifier, + Widening2XVector, + Widening4XVector, + Widening8XVector, + MaskVector, + Log2EEW3, + Log2EEW4, + Log2EEW5, + Log2EEW6, + FixedSEW8, + FixedSEW16, + FixedSEW32, + FixedSEW64, + LFixedLog2LMULN3, + LFixedLog2LMULN2, + LFixedLog2LMULN1, + LFixedLog2LMUL0, + LFixedLog2LMUL1, + LFixedLog2LMUL2, + LFixedLog2LMUL3, + SFixedLog2LMULN3, + SFixedLog2LMULN2, + SFixedLog2LMULN1, + SFixedLog2LMUL0, + SFixedLog2LMUL1, + SFixedLog2LMUL2, + SFixedLog2LMUL3, +}; + +// Similar to basic type but used to describe what's kind of type related to +// basic vector type, used to compute type info of arguments. +enum class BaseTypeModifier : uint8_t { + Invalid, + Scalar, + Vector, + Void, + SizeT, + Ptrdiff, + UnsignedLong, + SignedLong, +}; + +// Modifier for type, used for both scalar and vector types. +enum class TypeModifier : uint8_t { + NoModifier = 0, + Pointer = 1 << 0, + Const = 1 << 1, + Immediate = 1 << 2, + UnsignedInteger = 1 << 3, + SignedInteger = 1 << 4, + Float = 1 << 5, + // LMUL1 should be kind of VectorTypeModifier, but that might come with + // Widening2XVector for widening reduction. + // However that might require VectorTypeModifier become bitmask rather than + // simple enum, so we decide keek LMUL1 in TypeModifier for code size + // optimization of clang binary size. + LMUL1 = 1 << 6, + MaxOffset = 6, + LLVM_MARK_AS_BITMASK_ENUM(LMUL1), +}; + +// PrototypeDescriptor is used to compute type info of arguments or return +// value. +struct PrototypeDescriptor { + constexpr PrototypeDescriptor() = default; + constexpr PrototypeDescriptor( + BaseTypeModifier PT, + VectorTypeModifier VTM = VectorTypeModifier::NoModifier, + TypeModifier TM = TypeModifier::NoModifier) + : PT(static_cast(PT)), VTM(static_cast(VTM)), + TM(static_cast(TM)) {} + constexpr PrototypeDescriptor(uint8_t PT, uint8_t VTM, uint8_t TM) + : PT(PT), VTM(VTM), TM(TM) {} + + uint8_t PT = static_cast(BaseTypeModifier::Invalid); + uint8_t VTM = static_cast(VectorTypeModifier::NoModifier); + uint8_t TM = static_cast(TypeModifier::NoModifier); + + bool operator!=(const PrototypeDescriptor &PD) const { + return PD.PT != PT || PD.VTM != VTM || PD.TM != TM; + } + bool operator>(const PrototypeDescriptor &PD) const { + return !(PD.PT <= PT && PD.VTM <= VTM && PD.TM <= TM); + } + + static const PrototypeDescriptor Mask; + static const PrototypeDescriptor Vector; + static const PrototypeDescriptor VL; + static llvm::Optional + parsePrototypeDescriptor(llvm::StringRef PrototypeStr); +}; + +llvm::SmallVector +parsePrototypes(llvm::StringRef Prototypes); + +// Basic type of vector type. +enum class BasicType : uint8_t { + Unknown = 0, + Int8 = 1 << 0, + Int16 = 1 << 1, + Int32 = 1 << 2, + Int64 = 1 << 3, + Float16 = 1 << 4, + Float32 = 1 << 5, + Float64 = 1 << 6, + MaxOffset = 6, + LLVM_MARK_AS_BITMASK_ENUM(Float64), +}; + +// Type of vector type. +enum ScalarTypeKind : uint8_t { + Void, + Size_t, + Ptrdiff_t, + UnsignedLong, + SignedLong, + Boolean, + SignedInteger, + UnsignedInteger, + Float, + Invalid, +}; + // Exponential LMUL struct LMULType { int Log2LMUL; @@ -32,20 +159,12 @@ LMULType &operator*=(uint32_t RHS); }; +class RVVType; +using RVVTypePtr = RVVType *; +using RVVTypes = std::vector; + // This class is compact representation of a valid and invalid RVVType. class RVVType { - enum ScalarTypeKind : uint32_t { - Void, - Size_t, - Ptrdiff_t, - UnsignedLong, - SignedLong, - Boolean, - SignedInteger, - UnsignedInteger, - Float, - Invalid, - }; BasicType BT; ScalarTypeKind ScalarType = Invalid; LMULType LMUL; @@ -63,9 +182,11 @@ std::string Str; std::string ShortStr; + enum class FixedLMULType { LargerThan, SmallerThan }; + public: - RVVType() : RVVType(BasicType(), 0, llvm::StringRef()) {} - RVVType(BasicType BT, int Log2LMUL, llvm::StringRef prototype); + RVVType() : BT(BasicType::Unknown), LMUL(0), Valid(false) {} + RVVType(BasicType BT, int Log2LMUL, const PrototypeDescriptor &Profile); // Return the string representation of a type, which is an encoded string for // passing to the BUILTIN() macro in Builtins.def. @@ -114,7 +235,11 @@ // Applies a prototype modifier to the current type. The result maybe an // invalid type. - void applyModifier(llvm::StringRef prototype); + void applyModifier(const PrototypeDescriptor &prototype); + + void applyLog2EEW(unsigned Log2EEW); + void applyFixedSEW(unsigned NewSEW); + void applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type); // Compute and record a string for legal type. void initBuiltinStr(); @@ -124,10 +249,19 @@ void initTypeStr(); // Compute and record a short name of a type for C/C++ name suffix. void initShortStr(); + +public: + /// Compute output and input types by applying different config (basic type + /// and LMUL with type transformers). It also record result of type in legal + /// or illegal set to avoid compute the same config again. The result maybe + /// have illegal RVVType. + static llvm::Optional + computeTypes(BasicType BT, int Log2LMUL, unsigned NF, + llvm::ArrayRef PrototypeSeq); + static llvm::Optional computeType(BasicType BT, int Log2LMUL, + PrototypeDescriptor Proto); }; -using RVVTypePtr = RVVType *; -using RVVTypes = std::vector; using RISCVPredefinedMacroT = uint8_t; enum RISCVPredefinedMacro : RISCVPredefinedMacroT { @@ -206,6 +340,10 @@ // Return the type string for a BUILTIN() macro in Builtins.def. std::string getBuiltinTypeStr() const; + + static std::string getSuffixStr( + BasicType Type, int Log2LMUL, + const llvm::SmallVector &PrototypeDescriptors); }; } // end namespace RISCV diff --git a/clang/lib/Support/RISCVVIntrinsicUtils.cpp b/clang/lib/Support/RISCVVIntrinsicUtils.cpp --- a/clang/lib/Support/RISCVVIntrinsicUtils.cpp +++ b/clang/lib/Support/RISCVVIntrinsicUtils.cpp @@ -16,12 +16,25 @@ #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" #include +#include +#include using namespace llvm; namespace clang { namespace RISCV { +const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor( + BaseTypeModifier::Vector, VectorTypeModifier::MaskVector); +const PrototypeDescriptor PrototypeDescriptor::VL = + PrototypeDescriptor(BaseTypeModifier::SizeT); +const PrototypeDescriptor PrototypeDescriptor::Vector = + PrototypeDescriptor(BaseTypeModifier::Vector); + +// Concat BasicType, LMUL and Proto as key +static std::unordered_map LegalTypes; +static std::set IllegalTypes; + //===----------------------------------------------------------------------===// // Type implementation //===----------------------------------------------------------------------===// @@ -70,7 +83,8 @@ return *this; } -RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype) +RVVType::RVVType(BasicType BT, int Log2LMUL, + const PrototypeDescriptor &prototype) : BT(BT), LMUL(LMULType(Log2LMUL)) { applyBasicType(); applyModifier(prototype); @@ -326,31 +340,31 @@ void RVVType::applyBasicType() { switch (BT) { - case 'c': + case BasicType::Int8: ElementBitwidth = 8; ScalarType = ScalarTypeKind::SignedInteger; break; - case 's': + case BasicType::Int16: ElementBitwidth = 16; ScalarType = ScalarTypeKind::SignedInteger; break; - case 'i': + case BasicType::Int32: ElementBitwidth = 32; ScalarType = ScalarTypeKind::SignedInteger; break; - case 'l': + case BasicType::Int64: ElementBitwidth = 64; ScalarType = ScalarTypeKind::SignedInteger; break; - case 'x': + case BasicType::Float16: ElementBitwidth = 16; ScalarType = ScalarTypeKind::Float; break; - case 'f': + case BasicType::Float32: ElementBitwidth = 32; ScalarType = ScalarTypeKind::Float; break; - case 'd': + case BasicType::Float64: ElementBitwidth = 64; ScalarType = ScalarTypeKind::Float; break; @@ -360,160 +374,481 @@ assert(ElementBitwidth != 0 && "Bad element bitwidth!"); } -void RVVType::applyModifier(StringRef Transformer) { - if (Transformer.empty()) - return; - // Handle primitive type transformer - auto PType = Transformer.back(); +Optional PrototypeDescriptor::parsePrototypeDescriptor( + llvm::StringRef PrototypeDescriptorStr) { + PrototypeDescriptor PD; + BaseTypeModifier PT = BaseTypeModifier::Invalid; + VectorTypeModifier VTM = VectorTypeModifier::NoModifier; + + if (PrototypeDescriptorStr.empty()) + return PD; + + // Handle base type modifier + auto PType = PrototypeDescriptorStr.back(); switch (PType) { case 'e': - Scale = 0; + PT = BaseTypeModifier::Scalar; break; case 'v': - Scale = LMUL.getScale(ElementBitwidth); + PT = BaseTypeModifier::Vector; break; case 'w': - ElementBitwidth *= 2; - LMUL *= 2; - Scale = LMUL.getScale(ElementBitwidth); + PT = BaseTypeModifier::Vector; + VTM = VectorTypeModifier::Widening2XVector; break; case 'q': - ElementBitwidth *= 4; - LMUL *= 4; - Scale = LMUL.getScale(ElementBitwidth); + PT = BaseTypeModifier::Vector; + VTM = VectorTypeModifier::Widening4XVector; break; case 'o': - ElementBitwidth *= 8; - LMUL *= 8; - Scale = LMUL.getScale(ElementBitwidth); + PT = BaseTypeModifier::Vector; + VTM = VectorTypeModifier::Widening8XVector; break; case 'm': - ScalarType = ScalarTypeKind::Boolean; - Scale = LMUL.getScale(ElementBitwidth); - ElementBitwidth = 1; + PT = BaseTypeModifier::Vector; + VTM = VectorTypeModifier::MaskVector; break; case '0': - ScalarType = ScalarTypeKind::Void; + PT = BaseTypeModifier::Void; break; case 'z': - ScalarType = ScalarTypeKind::Size_t; + PT = BaseTypeModifier::SizeT; break; case 't': - ScalarType = ScalarTypeKind::Ptrdiff_t; + PT = BaseTypeModifier::Ptrdiff; break; case 'u': - ScalarType = ScalarTypeKind::UnsignedLong; + PT = BaseTypeModifier::UnsignedLong; break; case 'l': - ScalarType = ScalarTypeKind::SignedLong; + PT = BaseTypeModifier::SignedLong; break; default: llvm_unreachable("Illegal primitive type transformers!"); } - Transformer = Transformer.drop_back(); - - // Extract and compute complex type transformer. It can only appear one time. - if (Transformer.startswith("(")) { - size_t Idx = Transformer.find(')'); + PD.PT = static_cast(PT); + PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back(); + + // Compute the vector type transformers, it can only appear one time. + if (PrototypeDescriptorStr.startswith("(")) { + assert(VTM == VectorTypeModifier::NoModifier && + "VectorTypeModifier should only have one modifier"); + size_t Idx = PrototypeDescriptorStr.find(')'); assert(Idx != StringRef::npos); - StringRef ComplexType = Transformer.slice(1, Idx); - Transformer = Transformer.drop_front(Idx + 1); - assert(!Transformer.contains('(') && - "Only allow one complex type transformer"); + StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx); + PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1); + assert(!PrototypeDescriptorStr.contains('(') && + "Only allow one vector type modifier"); - auto UpdateAndCheckComplexProto = [&]() { - Scale = LMUL.getScale(ElementBitwidth); - const StringRef VectorPrototypes("vwqom"); - if (!VectorPrototypes.contains(PType)) - llvm_unreachable("Complex type transformer only supports vector type!"); - if (Transformer.find_first_of("PCKWS") != StringRef::npos) - llvm_unreachable( - "Illegal type transformer for Complex type transformer"); - }; - auto ComputeFixedLog2LMUL = - [&](StringRef Value, - std::function Compare) { - int32_t Log2LMUL; - Value.getAsInteger(10, Log2LMUL); - if (!Compare(Log2LMUL, LMUL.Log2LMUL)) { - ScalarType = Invalid; - return false; - } - // Update new LMUL - LMUL = LMULType(Log2LMUL); - UpdateAndCheckComplexProto(); - return true; - }; auto ComplexTT = ComplexType.split(":"); if (ComplexTT.first == "Log2EEW") { uint32_t Log2EEW; - ComplexTT.second.getAsInteger(10, Log2EEW); - // update new elmul = (eew/sew) * lmul - LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); - // update new eew - ElementBitwidth = 1 << Log2EEW; - ScalarType = ScalarTypeKind::SignedInteger; - UpdateAndCheckComplexProto(); + if (ComplexTT.second.getAsInteger(10, Log2EEW)) { + llvm_unreachable("Invalid Log2EEW value!"); + return None; + } + switch (Log2EEW) { + case 3: + VTM = VectorTypeModifier::Log2EEW3; + break; + case 4: + VTM = VectorTypeModifier::Log2EEW4; + break; + case 5: + VTM = VectorTypeModifier::Log2EEW5; + break; + case 6: + VTM = VectorTypeModifier::Log2EEW6; + break; + default: + llvm_unreachable("Invalid Log2EEW value, should be [3-6]"); + return None; + } } else if (ComplexTT.first == "FixedSEW") { uint32_t NewSEW; - ComplexTT.second.getAsInteger(10, NewSEW); - // Set invalid type if src and dst SEW are same. - if (ElementBitwidth == NewSEW) { - ScalarType = Invalid; - return; + if (ComplexTT.second.getAsInteger(10, NewSEW)) { + llvm_unreachable("Invalid FixedSEW value!"); + return None; + } + switch (NewSEW) { + case 8: + VTM = VectorTypeModifier::FixedSEW8; + break; + case 16: + VTM = VectorTypeModifier::FixedSEW16; + break; + case 32: + VTM = VectorTypeModifier::FixedSEW32; + break; + case 64: + VTM = VectorTypeModifier::FixedSEW64; + break; + default: + llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64"); + return None; } - // Update new SEW - ElementBitwidth = NewSEW; - UpdateAndCheckComplexProto(); } else if (ComplexTT.first == "LFixedLog2LMUL") { - // New LMUL should be larger than old - if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater())) - return; + int32_t Log2LMUL; + if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { + llvm_unreachable("Invalid LFixedLog2LMUL value!"); + return None; + } + switch (Log2LMUL) { + case -3: + VTM = VectorTypeModifier::LFixedLog2LMULN3; + break; + case -2: + VTM = VectorTypeModifier::LFixedLog2LMULN2; + break; + case -1: + VTM = VectorTypeModifier::LFixedLog2LMULN1; + break; + case 0: + VTM = VectorTypeModifier::LFixedLog2LMUL0; + break; + case 1: + VTM = VectorTypeModifier::LFixedLog2LMUL1; + break; + case 2: + VTM = VectorTypeModifier::LFixedLog2LMUL2; + break; + case 3: + VTM = VectorTypeModifier::LFixedLog2LMUL3; + break; + default: + llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); + return None; + } } else if (ComplexTT.first == "SFixedLog2LMUL") { - // New LMUL should be smaller than old - if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less())) - return; + int32_t Log2LMUL; + if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { + llvm_unreachable("Invalid SFixedLog2LMUL value!"); + return None; + } + switch (Log2LMUL) { + case -3: + VTM = VectorTypeModifier::SFixedLog2LMULN3; + break; + case -2: + VTM = VectorTypeModifier::SFixedLog2LMULN2; + break; + case -1: + VTM = VectorTypeModifier::SFixedLog2LMULN1; + break; + case 0: + VTM = VectorTypeModifier::SFixedLog2LMUL0; + break; + case 1: + VTM = VectorTypeModifier::SFixedLog2LMUL1; + break; + case 2: + VTM = VectorTypeModifier::SFixedLog2LMUL2; + break; + case 3: + VTM = VectorTypeModifier::SFixedLog2LMUL3; + break; + default: + llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); + return None; + } + } else { llvm_unreachable("Illegal complex type transformers!"); } } + PD.VTM = static_cast(VTM); // Compute the remain type transformers - for (char I : Transformer) { + TypeModifier TM = TypeModifier::NoModifier; + for (char I : PrototypeDescriptorStr) { switch (I) { case 'P': - if (IsConstant) + if ((TM & TypeModifier::Const) == TypeModifier::Const) llvm_unreachable("'P' transformer cannot be used after 'C'"); - if (IsPointer) + if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer) llvm_unreachable("'P' transformer cannot be used twice"); - IsPointer = true; + TM |= TypeModifier::Pointer; break; case 'C': - if (IsConstant) - llvm_unreachable("'C' transformer cannot be used twice"); - IsConstant = true; + TM |= TypeModifier::Const; break; case 'K': - IsImmediate = true; + TM |= TypeModifier::Immediate; break; case 'U': - ScalarType = ScalarTypeKind::UnsignedInteger; + TM |= TypeModifier::UnsignedInteger; break; case 'I': - ScalarType = ScalarTypeKind::SignedInteger; + TM |= TypeModifier::SignedInteger; break; case 'F': - ScalarType = ScalarTypeKind::Float; + TM |= TypeModifier::Float; break; case 'S': + TM |= TypeModifier::LMUL1; + break; + default: + llvm_unreachable("Illegal non-primitive type transformer!"); + } + } + PD.TM = static_cast(TM); + + return PD; +} + +void RVVType::applyModifier(const PrototypeDescriptor &Transformer) { + // Handle primitive type transformer + switch (static_cast(Transformer.PT)) { + case BaseTypeModifier::Scalar: + Scale = 0; + break; + case BaseTypeModifier::Vector: + Scale = LMUL.getScale(ElementBitwidth); + break; + case BaseTypeModifier::Void: + ScalarType = ScalarTypeKind::Void; + break; + case BaseTypeModifier::SizeT: + ScalarType = ScalarTypeKind::Size_t; + break; + case BaseTypeModifier::Ptrdiff: + ScalarType = ScalarTypeKind::Ptrdiff_t; + break; + case BaseTypeModifier::UnsignedLong: + ScalarType = ScalarTypeKind::UnsignedLong; + break; + case BaseTypeModifier::SignedLong: + ScalarType = ScalarTypeKind::SignedLong; + break; + case BaseTypeModifier::Invalid: + ScalarType = ScalarTypeKind::Invalid; + return; + } + + switch (static_cast(Transformer.VTM)) { + case VectorTypeModifier::Widening2XVector: + ElementBitwidth *= 2; + LMUL *= 2; + Scale = LMUL.getScale(ElementBitwidth); + break; + case VectorTypeModifier::Widening4XVector: + ElementBitwidth *= 4; + LMUL *= 4; + Scale = LMUL.getScale(ElementBitwidth); + break; + case VectorTypeModifier::Widening8XVector: + ElementBitwidth *= 8; + LMUL *= 8; + Scale = LMUL.getScale(ElementBitwidth); + break; + case VectorTypeModifier::MaskVector: + ScalarType = ScalarTypeKind::Boolean; + Scale = LMUL.getScale(ElementBitwidth); + ElementBitwidth = 1; + break; + case VectorTypeModifier::Log2EEW3: + applyLog2EEW(3); + break; + case VectorTypeModifier::Log2EEW4: + applyLog2EEW(4); + break; + case VectorTypeModifier::Log2EEW5: + applyLog2EEW(5); + break; + case VectorTypeModifier::Log2EEW6: + applyLog2EEW(6); + break; + case VectorTypeModifier::FixedSEW8: + applyFixedSEW(8); + break; + case VectorTypeModifier::FixedSEW16: + applyFixedSEW(16); + break; + case VectorTypeModifier::FixedSEW32: + applyFixedSEW(32); + break; + case VectorTypeModifier::FixedSEW64: + applyFixedSEW(64); + break; + case VectorTypeModifier::LFixedLog2LMULN3: + applyFixedLog2LMUL(-3, FixedLMULType::LargerThan); + break; + case VectorTypeModifier::LFixedLog2LMULN2: + applyFixedLog2LMUL(-2, FixedLMULType::LargerThan); + break; + case VectorTypeModifier::LFixedLog2LMULN1: + applyFixedLog2LMUL(-1, FixedLMULType::LargerThan); + break; + case VectorTypeModifier::LFixedLog2LMUL0: + applyFixedLog2LMUL(0, FixedLMULType::LargerThan); + break; + case VectorTypeModifier::LFixedLog2LMUL1: + applyFixedLog2LMUL(1, FixedLMULType::LargerThan); + break; + case VectorTypeModifier::LFixedLog2LMUL2: + applyFixedLog2LMUL(2, FixedLMULType::LargerThan); + break; + case VectorTypeModifier::LFixedLog2LMUL3: + applyFixedLog2LMUL(3, FixedLMULType::LargerThan); + break; + case VectorTypeModifier::SFixedLog2LMULN3: + applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan); + break; + case VectorTypeModifier::SFixedLog2LMULN2: + applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan); + break; + case VectorTypeModifier::SFixedLog2LMULN1: + applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan); + break; + case VectorTypeModifier::SFixedLog2LMUL0: + applyFixedLog2LMUL(0, FixedLMULType::SmallerThan); + break; + case VectorTypeModifier::SFixedLog2LMUL1: + applyFixedLog2LMUL(1, FixedLMULType::SmallerThan); + break; + case VectorTypeModifier::SFixedLog2LMUL2: + applyFixedLog2LMUL(2, FixedLMULType::SmallerThan); + break; + case VectorTypeModifier::SFixedLog2LMUL3: + applyFixedLog2LMUL(3, FixedLMULType::SmallerThan); + break; + case VectorTypeModifier::NoModifier: + break; + } + + for (unsigned TypeModifierMaskShift = 0; + TypeModifierMaskShift <= static_cast(TypeModifier::MaxOffset); + ++TypeModifierMaskShift) { + unsigned TypeModifierMask = 1 << TypeModifierMaskShift; + if ((static_cast(Transformer.TM) & TypeModifierMask) != + TypeModifierMask) + continue; + switch (static_cast(TypeModifierMask)) { + case TypeModifier::Pointer: + IsPointer = true; + break; + case TypeModifier::Const: + IsConstant = true; + break; + case TypeModifier::Immediate: + IsImmediate = true; + IsConstant = true; + break; + case TypeModifier::UnsignedInteger: + ScalarType = ScalarTypeKind::UnsignedInteger; + break; + case TypeModifier::SignedInteger: + ScalarType = ScalarTypeKind::SignedInteger; + break; + case TypeModifier::Float: + ScalarType = ScalarTypeKind::Float; + break; + case TypeModifier::LMUL1: LMUL = LMULType(0); // Update ElementBitwidth need to update Scale too. Scale = LMUL.getScale(ElementBitwidth); break; default: - llvm_unreachable("Illegal non-primitive type transformer!"); + llvm_unreachable("Unknown type modifier mask!"); + } + } +} + +void RVVType::applyLog2EEW(unsigned Log2EEW) { + // update new elmul = (eew/sew) * lmul + LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); + // update new eew + ElementBitwidth = 1 << Log2EEW; + ScalarType = ScalarTypeKind::SignedInteger; + Scale = LMUL.getScale(ElementBitwidth); +} + +void RVVType::applyFixedSEW(unsigned NewSEW) { + // Set invalid type if src and dst SEW are same. + if (ElementBitwidth == NewSEW) { + ScalarType = ScalarTypeKind::Invalid; + return; + } + // Update new SEW + ElementBitwidth = NewSEW; + Scale = LMUL.getScale(ElementBitwidth); +} + +void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) { + switch (Type) { + case FixedLMULType::LargerThan: + if (Log2LMUL < LMUL.Log2LMUL) { + ScalarType = ScalarTypeKind::Invalid; + return; + } + break; + case FixedLMULType::SmallerThan: + if (Log2LMUL > LMUL.Log2LMUL) { + ScalarType = ScalarTypeKind::Invalid; + return; } + break; + default: + llvm_unreachable("Unknown FixedLMULType??"); } + + // Update new LMUL + LMUL = LMULType(Log2LMUL); + Scale = LMUL.getScale(ElementBitwidth); +} + +Optional +RVVType::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, + ArrayRef PrototypeSeq) { + // LMUL x NF must be less than or equal to 8. + if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) + return llvm::None; + + RVVTypes Types; + for (const PrototypeDescriptor &Proto : PrototypeSeq) { + auto T = computeType(BT, Log2LMUL, Proto); + if (!T.hasValue()) + return llvm::None; + // Record legal type index + Types.push_back(T.getValue()); + } + return Types; +} + +// Compute the hash value of RVVType, used for cache the result of computeType. +static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL, + PrototypeDescriptor Proto) { + // Layout of hash value: + // 0 8 16 24 32 40 + // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM | + assert(Log2LMUL >= -3 && Log2LMUL <= 3); + return (Log2LMUL + 3) | (static_cast(BT) & 0xff) << 8 | + ((uint64_t)(Proto.PT & 0xff) << 16) | + ((uint64_t)(Proto.TM & 0xff) << 24) | + ((uint64_t)(Proto.VTM & 0xff) << 32); +} + +Optional RVVType::computeType(BasicType BT, int Log2LMUL, + PrototypeDescriptor Proto) { + uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto); + // Search first + auto It = LegalTypes.find(Idx); + if (It != LegalTypes.end()) + return &(It->second); + + if (IllegalTypes.count(Idx)) + return llvm::None; + + // Compute type and record the result. + RVVType T(BT, Log2LMUL, Proto); + if (T.isValid()) { + // Record legal type index and value. + LegalTypes.insert({Idx, T}); + return &(LegalTypes[Idx]); + } + // Record illegal type index. + IllegalTypes.insert(Idx); + return llvm::None; } //===----------------------------------------------------------------------===// @@ -593,5 +928,37 @@ return S; } +std::string RVVIntrinsic::getSuffixStr( + BasicType Type, int Log2LMUL, + const llvm::SmallVector &PrototypeDescriptors) { + SmallVector SuffixStrs; + for (auto PD : PrototypeDescriptors) { + auto T = RVVType::computeType(Type, Log2LMUL, PD); + SuffixStrs.push_back(T.getValue()->getShortStr()); + } + return join(SuffixStrs, "_"); +} + +SmallVector parsePrototypes(StringRef Prototypes) { + SmallVector PrototypeDescriptors; + const StringRef Primaries("evwqom0ztul"); + while (!Prototypes.empty()) { + size_t Idx = 0; + // Skip over complex prototype because it could contain primitive type + // character. + if (Prototypes[0] == '(') + Idx = Prototypes.find_first_of(')'); + Idx = Prototypes.find_first_of(Primaries, Idx); + assert(Idx != StringRef::npos); + auto PD = PrototypeDescriptor::parsePrototypeDescriptor( + Prototypes.slice(0, Idx + 1)); + if (!PD) + llvm_unreachable("Error during parsing prototype."); + PrototypeDescriptors.push_back(*PD); + Prototypes = Prototypes.drop_front(Idx + 1); + } + return PrototypeDescriptors; +} + } // end namespace RISCV } // end namespace clang diff --git a/clang/utils/TableGen/RISCVVEmitter.cpp b/clang/utils/TableGen/RISCVVEmitter.cpp --- a/clang/utils/TableGen/RISCVVEmitter.cpp +++ b/clang/utils/TableGen/RISCVVEmitter.cpp @@ -32,9 +32,6 @@ class RVVEmitter { private: RecordKeeper &Records; - // Concat BasicType, LMUL and Proto as key - StringMap LegalTypes; - StringSet<> IllegalTypes; public: RVVEmitter(RecordKeeper &R) : Records(R) {} @@ -48,20 +45,11 @@ /// Emit all the information needed to map builtin -> LLVM IR intrinsic. void createCodeGen(raw_ostream &o); - std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes); - private: /// Create all intrinsics and add them to \p Out void createRVVIntrinsics(std::vector> &Out); /// Print HeaderCode in RVVHeader Record to \p Out void printHeaderCode(raw_ostream &OS); - /// Compute output and input types by applying different config (basic type - /// and LMUL with type transformers). It also record result of type in legal - /// or illegal set to avoid compute the same config again. The result maybe - /// have illegal RVVType. - Optional computeTypes(BasicType BT, int Log2LMUL, unsigned NF, - ArrayRef PrototypeSeq); - Optional computeType(BasicType BT, int Log2LMUL, StringRef Proto); /// Emit Acrh predecessor definitions and body, assume the element of Defs are /// sorted by extension. @@ -73,14 +61,39 @@ // non-empty string. bool emitMacroRestrictionStr(RISCVPredefinedMacroT PredefinedMacros, raw_ostream &o); - // Slice Prototypes string into sub prototype string and process each sub - // prototype string individually in the Handler. - void parsePrototypes(StringRef Prototypes, - std::function Handler); }; } // namespace +static BasicType ParseBasicType(char c) { + switch (c) { + case 'c': + return BasicType::Int8; + break; + case 's': + return BasicType::Int16; + break; + case 'i': + return BasicType::Int32; + break; + case 'l': + return BasicType::Int64; + break; + case 'x': + return BasicType::Float16; + break; + case 'f': + return BasicType::Float32; + break; + case 'd': + return BasicType::Float64; + break; + + default: + return BasicType::Unknown; + } +} + void emitCodeGenSwitchBody(const RVVIntrinsic *RVVI, raw_ostream &OS) { if (!RVVI->getIRName().empty()) OS << " ID = Intrinsic::riscv_" + RVVI->getIRName() + ";\n"; @@ -202,24 +215,31 @@ constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3}; // Print RVV boolean types. for (int Log2LMUL : Log2LMULs) { - auto T = computeType('c', Log2LMUL, "m"); + auto T = RVVType::computeType(BasicType::Int8, Log2LMUL, + PrototypeDescriptor::Mask); if (T.hasValue()) printType(T.getValue()); } // Print RVV int/float types. for (char I : StringRef("csil")) { + BasicType BT = ParseBasicType(I); for (int Log2LMUL : Log2LMULs) { - auto T = computeType(I, Log2LMUL, "v"); + auto T = RVVType::computeType(BT, Log2LMUL, PrototypeDescriptor::Vector); if (T.hasValue()) { printType(T.getValue()); - auto UT = computeType(I, Log2LMUL, "Uv"); + auto UT = RVVType::computeType( + BT, Log2LMUL, + PrototypeDescriptor(BaseTypeModifier::Vector, + VectorTypeModifier::NoModifier, + TypeModifier::UnsignedInteger)); printType(UT.getValue()); } } } OS << "#if defined(__riscv_zvfh)\n"; for (int Log2LMUL : Log2LMULs) { - auto T = computeType('x', Log2LMUL, "v"); + auto T = RVVType::computeType(BasicType::Float16, Log2LMUL, + PrototypeDescriptor::Vector); if (T.hasValue()) printType(T.getValue()); } @@ -227,7 +247,8 @@ OS << "#if defined(__riscv_f)\n"; for (int Log2LMUL : Log2LMULs) { - auto T = computeType('f', Log2LMUL, "v"); + auto T = RVVType::computeType(BasicType::Float32, Log2LMUL, + PrototypeDescriptor::Vector); if (T.hasValue()) printType(T.getValue()); } @@ -235,7 +256,8 @@ OS << "#if defined(__riscv_d)\n"; for (int Log2LMUL : Log2LMULs) { - auto T = computeType('d', Log2LMUL, "v"); + auto T = RVVType::computeType(BasicType::Float64, Log2LMUL, + PrototypeDescriptor::Vector); if (T.hasValue()) printType(T.getValue()); } @@ -359,32 +381,6 @@ OS << "\n"; } -void RVVEmitter::parsePrototypes(StringRef Prototypes, - std::function Handler) { - const StringRef Primaries("evwqom0ztul"); - while (!Prototypes.empty()) { - size_t Idx = 0; - // Skip over complex prototype because it could contain primitive type - // character. - if (Prototypes[0] == '(') - Idx = Prototypes.find_first_of(')'); - Idx = Prototypes.find_first_of(Primaries, Idx); - assert(Idx != StringRef::npos); - Handler(Prototypes.slice(0, Idx + 1)); - Prototypes = Prototypes.drop_front(Idx + 1); - } -} - -std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL, - StringRef Prototypes) { - SmallVector SuffixStrs; - parsePrototypes(Prototypes, [&](StringRef Proto) { - auto T = computeType(Type, Log2LMUL, Proto); - SuffixStrs.push_back(T.getValue()->getShortStr()); - }); - return join(SuffixStrs, "_"); -} - void RVVEmitter::createRVVIntrinsics( std::vector> &Out) { std::vector RV = Records.getAllDerivedDefinitions("RVVBuiltin"); @@ -419,13 +415,15 @@ // Parse prototype and create a list of primitive type with transformers // (operand) in ProtoSeq. ProtoSeq[0] is output operand. - SmallVector ProtoSeq; - parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) { - ProtoSeq.push_back(Proto.str()); - }); + SmallVector ProtoSeq = parsePrototypes(Prototypes); + + SmallVector SuffixProtoSeq = + parsePrototypes(SuffixProto); + SmallVector MangledSuffixProtoSeq = + parsePrototypes(MangledSuffixProto); // Compute Builtin types - SmallVector ProtoMaskSeq = ProtoSeq; + SmallVector ProtoMaskSeq = ProtoSeq; if (HasMasked) { // If HasMaskedOffOperand, insert result type as first input operand. if (HasMaskedOffOperand) { @@ -436,10 +434,10 @@ // (void, op0 address, op1 address, ...) // to // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) + PrototypeDescriptor MaskoffType = ProtoSeq[1]; + MaskoffType.TM &= ~static_cast(TypeModifier::Pointer); for (unsigned I = 0; I < NF; ++I) - ProtoMaskSeq.insert( - ProtoMaskSeq.begin() + NF + 1, - ProtoSeq[1].substr(1)); // Use substr(1) to skip '*' + ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, MaskoffType); } } if (HasMaskedOffOperand && NF > 1) { @@ -448,28 +446,34 @@ // to // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1, // ...) - ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m"); + ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, + PrototypeDescriptor::Mask); } else { - // If HasMasked, insert 'm' as first input operand. - ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m"); + // If HasMasked, insert PrototypeDescriptor:Mask as first input operand. + ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, + PrototypeDescriptor::Mask); } } - // If HasVL, append 'z' to last operand + // If HasVL, append PrototypeDescriptor:VL to last operand if (HasVL) { - ProtoSeq.push_back("z"); - ProtoMaskSeq.push_back("z"); + ProtoSeq.push_back(PrototypeDescriptor::VL); + ProtoMaskSeq.push_back(PrototypeDescriptor::VL); } // Create Intrinsics for each type and LMUL. for (char I : TypeRange) { for (int Log2LMUL : Log2LMULList) { - Optional Types = computeTypes(I, Log2LMUL, NF, ProtoSeq); + BasicType BT = ParseBasicType(I); + Optional Types = + RVVType::computeTypes(BT, Log2LMUL, NF, ProtoSeq); // Ignored to create new intrinsic if there are any illegal types. if (!Types.hasValue()) continue; - auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto); - auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto); + auto SuffixStr = + RVVIntrinsic::getSuffixStr(BT, Log2LMUL, SuffixProtoSeq); + auto MangledSuffixStr = + RVVIntrinsic::getSuffixStr(BT, Log2LMUL, MangledSuffixProtoSeq); // Create a unmasked intrinsic Out.push_back(std::make_unique( Name, SuffixStr, MangledName, MangledSuffixStr, IRName, @@ -480,7 +484,7 @@ if (HasMasked) { // Create a masked intrinsic Optional MaskTypes = - computeTypes(I, Log2LMUL, NF, ProtoMaskSeq); + RVVType::computeTypes(BT, Log2LMUL, NF, ProtoMaskSeq); Out.push_back(std::make_unique( Name, SuffixStr, MangledName, MangledSuffixStr, MaskedIRName, /*IsMasked=*/true, HasMaskedOffOperand, HasVL, MaskedPolicy, @@ -501,45 +505,6 @@ } } -Optional -RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, - ArrayRef PrototypeSeq) { - // LMUL x NF must be less than or equal to 8. - if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) - return llvm::None; - - RVVTypes Types; - for (const std::string &Proto : PrototypeSeq) { - auto T = computeType(BT, Log2LMUL, Proto); - if (!T.hasValue()) - return llvm::None; - // Record legal type index - Types.push_back(T.getValue()); - } - return Types; -} - -Optional RVVEmitter::computeType(BasicType BT, int Log2LMUL, - StringRef Proto) { - std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str(); - // Search first - auto It = LegalTypes.find(Idx); - if (It != LegalTypes.end()) - return &(It->second); - if (IllegalTypes.count(Idx)) - return llvm::None; - // Compute type and record the result. - RVVType T(BT, Log2LMUL, Proto); - if (T.isValid()) { - // Record legal type index and value. - LegalTypes.insert({Idx, T}); - return &(LegalTypes[Idx]); - } - // Record illegal type index. - IllegalTypes.insert(Idx); - return llvm::None; -} - void RVVEmitter::emitArchMacroAndBody( std::vector> &Defs, raw_ostream &OS, std::function PrintBody) {