diff --git a/clang/include/clang/Basic/CMakeLists.txt b/clang/include/clang/Basic/CMakeLists.txt --- a/clang/include/clang/Basic/CMakeLists.txt +++ b/clang/include/clang/Basic/CMakeLists.txt @@ -90,3 +90,6 @@ clang_tablegen(riscv_vector_builtin_cg.inc -gen-riscv-vector-builtin-codegen SOURCE riscv_vector.td TARGET ClangRISCVVectorBuiltinCG) +clang_tablegen(riscv_vector_builtin_sema.inc -gen-riscv-vector-builtin-sema + SOURCE riscv_vector.td + TARGET ClangRISCVVectorBuiltinSema) diff --git a/clang/include/clang/Basic/TokenKinds.def b/clang/include/clang/Basic/TokenKinds.def --- a/clang/include/clang/Basic/TokenKinds.def +++ b/clang/include/clang/Basic/TokenKinds.def @@ -888,6 +888,9 @@ // Annotation for the attribute pragma directives - #pragma clang attribute ... PRAGMA_ANNOTATION(pragma_attribute) +// Annotation for the riscv pragma directives - #pragma clang riscv intrinsic ... +PRAGMA_ANNOTATION(pragma_riscv) + // Annotations for module import translated from #include etc. ANNOTATION(module_include) ANNOTATION(module_begin) diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -212,6 +212,7 @@ std::unique_ptr AttributePragmaHandler; std::unique_ptr MaxTokensHerePragmaHandler; std::unique_ptr MaxTokensTotalPragmaHandler; + std::unique_ptr RISCVPragmaHandler; std::unique_ptr CommentSemaHandler; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -1567,6 +1567,9 @@ Optional> CachedDarwinSDKInfo; + /// Indicate RVV builtin funtions enabled or not. + bool DeclareRVVBuiltins = false; + public: Sema(Preprocessor &pp, ASTContext &ctxt, ASTConsumer &consumer, TranslationUnitKind TUKind = TU_Complete, @@ -13215,6 +13218,9 @@ llvm::StringRef StackSlotLabel, AlignPackInfo Value); +bool GetRVVBuiltinInfo(Sema &S, LookupResult &LR, IdentifierInfo *II, + Preprocessor &PP); + } // end namespace clang namespace llvm { diff --git a/clang/lib/Parse/ParsePragma.cpp b/clang/lib/Parse/ParsePragma.cpp --- a/clang/lib/Parse/ParsePragma.cpp +++ b/clang/lib/Parse/ParsePragma.cpp @@ -356,6 +356,16 @@ Token &FirstToken) override; }; +struct PragmaRISCVHandler : public PragmaHandler { + PragmaRISCVHandler(Sema &Actions) + : PragmaHandler("riscv"), Actions(Actions) {} + void HandlePragma(Preprocessor &PP, PragmaIntroducer Introducer, + Token &FirstToken) override; + +private: + Sema &Actions; +}; + void markAsReinjectedForRelexing(llvm::MutableArrayRef Toks) { for (auto &T : Toks) T.setFlag(clang::Token::IsReinjected); @@ -495,6 +505,11 @@ MaxTokensTotalPragmaHandler = std::make_unique(); PP.AddPragmaHandler("clang", MaxTokensTotalPragmaHandler.get()); + + if (getTargetInfo().getTriple().isRISCV()) { + RISCVPragmaHandler = std::make_unique(Actions); + PP.AddPragmaHandler("clang", RISCVPragmaHandler.get()); + } } void Parser::resetPragmaHandlers() { @@ -615,6 +630,11 @@ PP.RemovePragmaHandler("clang", MaxTokensTotalPragmaHandler.get()); MaxTokensTotalPragmaHandler.reset(); + + if (getTargetInfo().getTriple().isRISCV()) { + PP.RemovePragmaHandler("clang", RISCVPragmaHandler.get()); + RISCVPragmaHandler.reset(); + } } /// Handle the annotation token produced for #pragma unused(...) @@ -3798,3 +3818,34 @@ PP.overrideMaxTokens(MaxTokens, Loc); } + +// Handle '#pragma clang riscv intrinsic vector'. +void PragmaRISCVHandler::HandlePragma(Preprocessor &PP, + PragmaIntroducer Introducer, + Token &FirstToken) { + Token Tok; + PP.Lex(Tok); + IdentifierInfo *II = Tok.getIdentifierInfo(); + if (!II || (!II->isStr("intrinsic"))) { + PP.Diag(Tok.getLocation(), diag::warn_pragma_invalid_argument) + << PP.getSpelling(Tok) << "riscv" << /*Expected=*/true << "'intrinsic'"; + return; + } + + PP.Lex(Tok); + II = Tok.getIdentifierInfo(); + if (!II || (!II->isStr("vector"))) { + PP.Diag(Tok.getLocation(), diag::warn_pragma_invalid_argument) + << PP.getSpelling(Tok) << "riscv" << /*Expected=*/true << "'vector'"; + return; + } + + PP.Lex(Tok); + if (Tok.isNot(tok::eod)) { + PP.Diag(Tok.getLocation(), diag::warn_pragma_extra_tokens_at_eol) + << "clang riscv intrinsic"; + return; + } + + Actions.DeclareRVVBuiltins = true; +} diff --git a/clang/lib/Sema/CMakeLists.txt b/clang/lib/Sema/CMakeLists.txt --- a/clang/lib/Sema/CMakeLists.txt +++ b/clang/lib/Sema/CMakeLists.txt @@ -46,6 +46,7 @@ SemaInit.cpp SemaLambda.cpp SemaLookup.cpp + SemaRVVLookup.cpp SemaModule.cpp SemaObjCProperty.cpp SemaOpenMP.cpp diff --git a/clang/lib/Sema/SemaLookup.cpp b/clang/lib/Sema/SemaLookup.cpp --- a/clang/lib/Sema/SemaLookup.cpp +++ b/clang/lib/Sema/SemaLookup.cpp @@ -23,6 +23,8 @@ #include "clang/Basic/Builtins.h" #include "clang/Basic/FileManager.h" #include "clang/Basic/LangOptions.h" +#include "clang/Basic/TargetBuiltins.h" +#include "clang/Basic/TargetInfo.h" #include "clang/Lex/HeaderSearch.h" #include "clang/Lex/ModuleLoader.h" #include "clang/Lex/Preprocessor.h" @@ -928,6 +930,12 @@ } } + if (DeclareRVVBuiltins) { + if (GetRVVBuiltinInfo(*this, R, II, PP)) { + return true; + } + } + // If this is a builtin on this (or all) targets, create the decl. if (unsigned BuiltinID = II->getBuiltinID()) { // In C++ and OpenCL (spec v1.2 s6.9.f), we don't have any predefined diff --git a/clang/lib/Sema/SemaRVVLookup.cpp b/clang/lib/Sema/SemaRVVLookup.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/Sema/SemaRVVLookup.cpp @@ -0,0 +1,415 @@ +//===-- SemaRVVLookup.cpp - Name Lookup for RISC-V Vector Intrinsic -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements name lookup for C, C++, Objective-C, and +// Objective-C++. +// +//===----------------------------------------------------------------------===// + +#include "clang/AST/ASTContext.h" +#include "clang/AST/CXXInheritance.h" +#include "clang/AST/Decl.h" +#include "clang/AST/DeclCXX.h" +#include "clang/AST/DeclLookups.h" +#include "clang/AST/DeclObjC.h" +#include "clang/AST/DeclTemplate.h" +#include "clang/AST/Expr.h" +#include "clang/AST/ExprCXX.h" +#include "clang/Basic/Builtins.h" +#include "clang/Basic/FileManager.h" +#include "clang/Basic/LangOptions.h" +#include "clang/Basic/TargetBuiltins.h" +#include "clang/Basic/TargetInfo.h" +#include "clang/Lex/HeaderSearch.h" +#include "clang/Lex/ModuleLoader.h" +#include "clang/Lex/Preprocessor.h" +#include "clang/Sema/DeclSpec.h" +#include "clang/Sema/Lookup.h" +#include "clang/Sema/Overload.h" +#include "clang/Sema/Scope.h" +#include "clang/Sema/ScopeInfo.h" +#include "clang/Sema/Sema.h" +#include "clang/Sema/SemaInternal.h" +#include "clang/Sema/TemplateDeduction.h" +#include "clang/Sema/TypoCorrection.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/TinyPtrVector.h" +#include "llvm/ADT/edit_distance.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/RISCVVIntrinsicUtils.h" +#include +#include +#include +#include +#include +#include + +using namespace llvm; +using namespace clang; +using namespace llvm::RISCV; + +enum RVVRequireExtensionMask { + RVVRequireZvlsseg = 1 << 0, +}; + +struct RVVIntrinsicRecord { + const char *Name; + const char *MangledName; + uint16_t ProtoSeqIndex; + uint16_t ProtoMaskSeqIndex; + uint16_t SuffixProtoIndex; + uint16_t MangledSuffixProtoIndex; + uint8_t ProtoSeqSize; + uint8_t ProtoMaskSeqSize; + uint8_t SuffixProtoSize; + uint8_t MangledSuffixProtoSize; + uint8_t RequiredExtension; + uint8_t TypeRangeMask; + uint8_t Log2LMULMask; + uint8_t NF; +}; + +static const RVVTypeProfile RVVSignatureTable[] = { +#define DECL_SIGNATURE_TABLE +#include "clang/Basic/riscv_vector_builtin_sema.inc" +#undef DECL_SIGNATURE_TABLE +}; + +static const RVVIntrinsicRecord RVVIntrinsicRecords[] = { +#define DECL_INTRINSIC_RECORDS +#include "clang/Basic/riscv_vector_builtin_sema.inc" +#undef DECL_INTRINSIC_RECORDS +}; + +struct RVVIntrinsicDef { + std::string Name; + std::string GenericName; + std::string BuiltinName; + RVVTypes Signature; +}; + +struct RVVGenericIntrinsicDef { + SmallVector Indexs; +}; + +static Optional ComputeType(RVVBasicType BT, int Log2LMUL, + const RVVTypeProfile &Proto) { + static StringMap LegalTypes; + static StringSet<> IllegalTypes; + RVVType T(BT, Log2LMUL, Proto); + StringRef Idx = T.getMangledStr(); + auto It = LegalTypes.find(Idx); + if (It != LegalTypes.end()) + return &(It->second); + if (IllegalTypes.count(Idx)) + return llvm::None; + if (T.isValid()) { + // Record legal type index and value. + LegalTypes.insert({Idx, T}); + // TypeList.push_back(&(LegalTypes[Idx])); + return &(LegalTypes[Idx]); + } + // Record illegal type index. + IllegalTypes.insert(Idx); + return llvm::None; +} + +static Optional +ComputeTypes(RVVBasicType BT, int Log2LMUL, unsigned NF, + const ArrayRef &PrototypeSeq) { + // LMUL x NF must be less than or equal to 8. + if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) + return llvm::None; + + RVVTypes Types; + for (const auto &Proto : PrototypeSeq) { + auto T = ComputeType(BT, Log2LMUL, Proto); + if (!T.hasValue()) + return llvm::None; + // Record legal type index + Types.push_back(T.getValue()); + } + return Types; +} + +std::string GetSuffixStr(RVVBasicType Type, int Log2LMUL, + const ArrayRef &Prototypes) { + SmallVector SuffixStrs; + for (auto Proto : Prototypes) { + auto T = ComputeType(Type, Log2LMUL, Proto); + SuffixStrs.push_back(T.getValue()->getShortStr()); + } + return join(SuffixStrs, "_"); +} + +static ArrayRef ProtoSeq2ArrayRef(uint16_t Index, + uint8_t Length) { + return ArrayRef(&RVVSignatureTable[Index], Length); +} + +static QualType RVVType2Qual(ASTContext &Context, const RVVType *Type) { + QualType QT; + switch (Type->getScalarType()) { + case STK_Void: + QT = Context.VoidTy; + break; + case STK_Size_t: + QT = Context.getSizeType(); + break; + case STK_Ptrdiff_t: + QT = Context.getPointerDiffType(); + break; + case STK_UnsignedLong: + QT = Context.UnsignedLongTy; + break; + case STK_SignedLong: + QT = Context.LongTy; + break; + case STK_Boolean: + QT = Context.BoolTy; + break; + case STK_SignedInteger: + QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), true); + break; + case STK_UnsignedInteger: + QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false); + break; + case STK_Float: + switch (Type->getElementBitwidth()) { + case 64: + QT = Context.DoubleTy; + break; + case 32: + QT = Context.FloatTy; + break; + case 16: + QT = Context.Float16Ty; + break; + } + break; + default: + return QT; + } + if (Type->isVector()) + QT = Context.getScalableVectorType(QT, Type->getScale().getValue()); + + if (Type->isConstant()) { + QT = Context.getConstType(QT); + } + + // Transform the type to a pointer as the last step, if necessary. + if (Type->isPointer()) { + QT = Context.getPointerType(QT); + } + + return QT; +} + +static void Create(Sema &S, LookupResult &LR, IdentifierInfo *II, + Preprocessor &PP, unsigned Index, + std::vector &RVVIntrinsicList) { + ASTContext &Context = S.Context; + RVVIntrinsicDef &IDef = RVVIntrinsicList[Index]; + auto Sigs = IDef.Signature; + size_t SigLength = Sigs.size(); + auto ReturnType = Sigs[0]; + QualType RetType = RVVType2Qual(Context, ReturnType); + SmallVector ArgTypes; + QualType BuiltinFuncType; + for (size_t i = 1; i < SigLength; ++i) { + ArgTypes.push_back(RVVType2Qual(Context, Sigs[i])); + } + FunctionProtoType::ExtProtoInfo PI( + Context.getDefaultCallingConvention(false, false, true)); + PI.Variadic = false; + + SourceLocation Loc = LR.getNameLoc(); + BuiltinFuncType = Context.getFunctionType(RetType, ArgTypes, PI); + DeclContext *Parent = Context.getTranslationUnitDecl(); + + FunctionDecl *NewRVVBuiltin = FunctionDecl::Create( + Context, Parent, Loc, Loc, II, BuiltinFuncType, /*TInfo=*/nullptr, + SC_Extern, S.getCurFPFeatures().isFPConstrained(), false, + BuiltinFuncType->isFunctionProtoType()); + + NewRVVBuiltin->setImplicit(); + + // Create Decl objects for each parameter, adding them to the + // FunctionDecl. + const auto *FP = cast(BuiltinFuncType); + SmallVector ParmList; + for (unsigned IParm = 0, e = FP->getNumParams(); IParm != e; ++IParm) { + ParmVarDecl *Parm = ParmVarDecl::Create( + Context, NewRVVBuiltin, SourceLocation(), SourceLocation(), nullptr, + FP->getParamType(IParm), nullptr, SC_None, nullptr); + Parm->setScopeInfo(0, IParm); + ParmList.push_back(Parm); + } + NewRVVBuiltin->setParams(ParmList); + NewRVVBuiltin->addAttr(OverloadableAttr::CreateImplicit(Context)); + auto &IntrinsicII = PP.getIdentifierTable().get(IDef.BuiltinName); + NewRVVBuiltin->addAttr( + BuiltinAliasAttr::CreateImplicit(S.Context, &IntrinsicII)); + + LR.addDecl(NewRVVBuiltin); +} + +class RVVIntrinsicManager { +private: + std::vector IntrinsicList; + StringMap Intrinsics; + StringMap GenericIntrinsics; + + void InitIntrinsicList(); + ASTContext &Context; + +public: + RVVIntrinsicManager(ASTContext &Context) : Context(Context) { + InitIntrinsicList(); + } + bool CreateIntrinsicIfFound(Sema &S, LookupResult &LR, IdentifierInfo *II, + Preprocessor &PP); + + void InitRVVIntrinsic(const RVVIntrinsicRecord &Record, StringRef SuffixStr, + StringRef MangledSuffixStr, bool IsMask, + RVVTypes &Types); +}; + +void RVVIntrinsicManager::InitIntrinsicList() { + const TargetInfo &TI = Context.getTargetInfo(); + bool HasF = TI.hasFeature("f"); + bool HasD = TI.hasFeature("d"); + bool HasZfh = TI.hasFeature("experimental-zfh"); + bool HasZvlsseg = TI.hasFeature("experimental-zvlsseg"); + + for (auto &Record : RVVIntrinsicRecords) { + // Create Intrinsics for each type and LMUL. + RVVBasicType BaseType = RVVBasicTypeUnknown; + auto ProtoSeq = + ProtoSeq2ArrayRef(Record.ProtoSeqIndex, Record.ProtoSeqSize); + auto ProtoMaskSeq = + ProtoSeq2ArrayRef(Record.ProtoMaskSeqIndex, Record.ProtoMaskSeqSize); + auto SuffixProto = + ProtoSeq2ArrayRef(Record.SuffixProtoIndex, Record.SuffixProtoSize); + auto MangledSuffixProto = ProtoSeq2ArrayRef(Record.MangledSuffixProtoIndex, + Record.MangledSuffixProtoSize); + for (int TypeRangeMaskShift = 0; + TypeRangeMaskShift <= RVVBasicTypeMaxOffset; ++TypeRangeMaskShift) { + BaseType = static_cast(1 << TypeRangeMaskShift); + + if (!(BaseType & Record.TypeRangeMask)) + continue; + + // Check requirement. + if (BaseType == RVVBasicTypeFloat16 && !HasZfh) + continue; + + if (BaseType == RVVBasicTypeFloat32 && !HasF) + continue; + + if (BaseType == RVVBasicTypeFloat64 && !HasD) + continue; + + if ((Record.RequiredExtension & RVVRequireZvlsseg) && !HasZvlsseg) + continue; + + // TODO: Part of SEW=64 instructions are not avaliable on zve64*, but + // those extensions are not merged yet. + + for (int Log2LMUL = -3; Log2LMUL <= 3; Log2LMUL++) { + if (!(Record.Log2LMULMask & (1 << (Log2LMUL + 3)))) { + continue; + } + Optional Types = + ComputeTypes(BaseType, Log2LMUL, Record.NF, ProtoSeq); + // Ignored to create new intrinsic if there are any illegal types. + if (!Types.hasValue()) { + continue; + } + auto SuffixStr = GetSuffixStr(BaseType, Log2LMUL, SuffixProto); + auto MangledSuffixStr = + GetSuffixStr(BaseType, Log2LMUL, MangledSuffixProto); + InitRVVIntrinsic(Record, SuffixStr, MangledSuffixStr, false, *Types); + bool HasMask = Record.ProtoMaskSeqSize != 0; + if (HasMask) { + // Create a mask intrinsic + Optional MaskTypes = + ComputeTypes(BaseType, Log2LMUL, Record.NF, ProtoMaskSeq); + InitRVVIntrinsic(Record, SuffixStr, MangledSuffixStr, true, + *MaskTypes); + } + } + } + } +} + +void RVVIntrinsicManager::InitRVVIntrinsic(const RVVIntrinsicRecord &Record, + StringRef SuffixStr, + StringRef MangledSuffixStr, + bool IsMask, RVVTypes &Types) { + std::string Name = Record.Name; + std::string BuiltinName = "__builtin_rvv_" + std::string(Record.Name); + std::string MangledName; + if (!Record.MangledName) + MangledName = StringRef(Record.Name).split("_").first.str(); + else + MangledName = Record.MangledName; + if (!SuffixStr.empty()) + Name += "_" + SuffixStr.str(); + if (!MangledSuffixStr.empty()) + MangledName += "_" + MangledSuffixStr.str(); + + if (IsMask) { + BuiltinName += "_m"; + Name += "_m"; + } + + size_t Index = IntrinsicList.size(); + IntrinsicList.push_back({Name, MangledName, BuiltinName, Types}); + + Intrinsics.insert({Name, Index}); + + RVVGenericIntrinsicDef &GenericIntrinsicDef = GenericIntrinsics[MangledName]; + + GenericIntrinsicDef.Indexs.push_back(Index); +} + +bool RVVIntrinsicManager::CreateIntrinsicIfFound(Sema &S, LookupResult &LR, + IdentifierInfo *II, + Preprocessor &PP) { + StringRef Name = II->getName(); + + auto GIItr = GenericIntrinsics.find(Name); + if (GIItr != GenericIntrinsics.end()) { + auto GIntrinsicDef = GIItr->second; + for (auto Index : GIntrinsicDef.Indexs) { + Create(S, LR, II, PP, Index, IntrinsicList); + } + LR.resolveKind(); + return true; + } + + auto Itr = Intrinsics.find(Name); + if (Itr != Intrinsics.end()) { + Create(S, LR, II, PP, Itr->second, IntrinsicList); + return true; + } + + return false; +} + +namespace clang { +bool GetRVVBuiltinInfo(Sema &S, LookupResult &LR, IdentifierInfo *II, + Preprocessor &PP) { + static std::unique_ptr IntrinsicManager = + std::make_unique(S.Context); + + return IntrinsicManager->CreateIntrinsicIfFound(S, LR, II, PP); +} +} // namespace clang diff --git a/clang/utils/TableGen/RISCVVEmitter.cpp b/clang/utils/TableGen/RISCVVEmitter.cpp --- a/clang/utils/TableGen/RISCVVEmitter.cpp +++ b/clang/utils/TableGen/RISCVVEmitter.cpp @@ -19,122 +19,20 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/Twine.h" +#include "llvm/Support/RISCVVIntrinsicUtils.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" +#include "llvm/TableGen/StringMatcher.h" +#include "llvm/TableGen/TableGenBackend.h" #include using namespace llvm; -using BasicType = char; -using VScaleVal = Optional; +using namespace llvm::RISCV; namespace { -// Exponential LMUL -struct LMULType { - int Log2LMUL; - LMULType(int Log2LMUL); - // Return the C/C++ string representation of LMUL - std::string str() const; - Optional getScale(unsigned ElementBitwidth) const; - void MulLog2LMUL(int Log2LMUL); - LMULType &operator*=(uint32_t RHS); -}; - -// This class is compact representation of a valid and invalid RVVType. -class RVVType { - enum ScalarTypeKind : uint32_t { - Void, - Size_t, - Ptrdiff_t, - UnsignedLong, - SignedLong, - Boolean, - SignedInteger, - UnsignedInteger, - Float, - Invalid, - }; - BasicType BT; - ScalarTypeKind ScalarType = Invalid; - LMULType LMUL; - bool IsPointer = false; - // IsConstant indices are "int", but have the constant expression. - bool IsImmediate = false; - // Const qualifier for pointer to const object or object of const type. - bool IsConstant = false; - unsigned ElementBitwidth = 0; - VScaleVal Scale = 0; - bool Valid; - - std::string BuiltinStr; - std::string ClangBuiltinStr; - std::string Str; - std::string ShortStr; - -public: - RVVType() : RVVType(BasicType(), 0, StringRef()) {} - RVVType(BasicType BT, int Log2LMUL, StringRef prototype); - - // Return the string representation of a type, which is an encoded string for - // passing to the BUILTIN() macro in Builtins.def. - const std::string &getBuiltinStr() const { return BuiltinStr; } - - // Return the clang builtin type for RVV vector type which are used in the - // riscv_vector.h header file. - const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; } - - // Return the C/C++ string representation of a type for use in the - // riscv_vector.h header file. - const std::string &getTypeStr() const { return Str; } - - // Return the short name of a type for C/C++ name suffix. - const std::string &getShortStr() { - // Not all types are used in short name, so compute the short name by - // demanded. - if (ShortStr.empty()) - initShortStr(); - return ShortStr; - } - - bool isValid() const { return Valid; } - bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; } - bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; } - bool isFloat() const { return ScalarType == ScalarTypeKind::Float; } - bool isSignedInteger() const { - return ScalarType == ScalarTypeKind::SignedInteger; - } - bool isFloatVector(unsigned Width) const { - return isVector() && isFloat() && ElementBitwidth == Width; - } - bool isFloat(unsigned Width) const { - return isFloat() && ElementBitwidth == Width; - } - -private: - // Verify RVV vector type and set Valid. - bool verifyType() const; - - // Creates a type based on basic types of TypeRange - void applyBasicType(); - - // Applies a prototype modifier to the current type. The result maybe an - // invalid type. - void applyModifier(StringRef prototype); - - // Compute and record a string for legal type. - void initBuiltinStr(); - // Compute and record a builtin RVV vector type string. - void initClangBuiltinStr(); - // Compute and record a type string for used in the header. - void initTypeStr(); - // Compute and record a short name of a type for C/C++ name suffix. - void initShortStr(); -}; - -using RVVTypePtr = RVVType *; -using RVVTypes = std::vector; - enum RISCVExtension : uint8_t { Basic = 0, F = 1 << 1, @@ -200,24 +98,37 @@ // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should // init the RVVIntrinsic ID and IntrinsicTypes. void emitCodeGenSwitchBody(raw_ostream &o) const; +}; + +struct SemaRecord { + std::string Name; + std::string MangledName; + std::string TypeRange; + std::vector Log2LMULList; + std::string RequiredExtension; - // Emit the macros for mapping C/C++ intrinsic function to builtin functions. - void emitIntrinsicFuncDef(raw_ostream &o) const; + SmallVector ProtoSeq; + SmallVector ProtoMaskSeq; + SmallVector SuffixProto; + SmallVector MangledSuffixProto; - // Emit the mangled function definition. - void emitMangledFuncDef(raw_ostream &o) const; + unsigned NF; }; class RVVEmitter { private: RecordKeeper &Records; + std::vector> Defs; std::string HeaderCode; - // Concat BasicType, LMUL and Proto as key + // Concat RVVBasicType, LMUL and Proto as key StringMap LegalTypes; StringSet<> IllegalTypes; + std::vector SemaRecords; + std::vector SemaSignatureTable; + public: - RVVEmitter(RecordKeeper &R) : Records(R) {} + RVVEmitter(RecordKeeper &R) : Records(R) { createRVVIntrinsics(Defs); } /// Emit riscv_vector.h void createHeader(raw_ostream &o); @@ -228,529 +139,67 @@ /// Emit all the information needed to map builtin -> LLVM IR intrinsic. void createCodeGen(raw_ostream &o); - std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes); + /// Emit all the information needed by SemaLookup.cpp. + void createSema(raw_ostream &o); + + std::string getSuffixStr(RVVBasicType BT, int Log2LMUL, StringRef Prototypes); private: /// Create all intrinsics and add them to \p Out void createRVVIntrinsics(std::vector> &Out); /// Create Headers and add them to \p Out void createRVVHeaders(raw_ostream &OS); + /// + unsigned GetSemaSignatureIndex(const SmallVector &Signature); /// Compute output and input types by applying different config (basic type /// and LMUL with type transformers). It also record result of type in legal /// or illegal set to avoid compute the same config again. The result maybe /// have illegal RVVType. - Optional computeTypes(BasicType BT, int Log2LMUL, unsigned NF, + Optional computeTypes(RVVBasicType BT, int Log2LMUL, unsigned NF, ArrayRef PrototypeSeq); - Optional computeType(BasicType BT, int Log2LMUL, StringRef Proto); - - /// Emit Acrh predecessor definitions and body, assume the element of Defs are - /// sorted by extension. - void emitArchMacroAndBody( - std::vector> &Defs, raw_ostream &o, - std::function); + Optional computeType(RVVBasicType BT, int Log2LMUL, + StringRef Proto); - // Emit the architecture preprocessor definitions. Return true when emits - // non-empty string. - bool emitExtDefStr(uint8_t Extensions, raw_ostream &o); +#if 0 // Slice Prototypes string into sub prototype string and process each sub // prototype string individually in the Handler. void parsePrototypes(StringRef Prototypes, std::function Handler); +#endif + + void EmitSemaRecords(raw_ostream &OS); + void ConstructSemaSignatureTable(); + void EmitSemaSignatureTable(raw_ostream &OS); }; } // namespace -//===----------------------------------------------------------------------===// -// Type implementation -//===----------------------------------------------------------------------===// - -LMULType::LMULType(int NewLog2LMUL) { - // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 - assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); - Log2LMUL = NewLog2LMUL; -} - -std::string LMULType::str() const { - if (Log2LMUL < 0) - return "mf" + utostr(1ULL << (-Log2LMUL)); - return "m" + utostr(1ULL << Log2LMUL); -} - -VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { - int Log2ScaleResult = 0; - switch (ElementBitwidth) { - default: - break; - case 8: - Log2ScaleResult = Log2LMUL + 3; - break; - case 16: - Log2ScaleResult = Log2LMUL + 2; - break; - case 32: - Log2ScaleResult = Log2LMUL + 1; - break; - case 64: - Log2ScaleResult = Log2LMUL; - break; - } - // Illegal vscale result would be less than 1 - if (Log2ScaleResult < 0) - return None; - return 1 << Log2ScaleResult; -} - -void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } - -LMULType &LMULType::operator*=(uint32_t RHS) { - assert(isPowerOf2_32(RHS)); - this->Log2LMUL = this->Log2LMUL + Log2_32(RHS); - return *this; -} - -RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype) - : BT(BT), LMUL(LMULType(Log2LMUL)) { - applyBasicType(); - applyModifier(prototype); - Valid = verifyType(); - if (Valid) { - initBuiltinStr(); - initTypeStr(); - if (isVector()) { - initClangBuiltinStr(); - } - } -} - -// clang-format off -// boolean type are encoded the ratio of n (SEW/LMUL) -// SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 -// c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t -// IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 - -// type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 -// -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- -// i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 -// i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 -// i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 -// i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 -// double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 -// float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 -// half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 -// clang-format on - -bool RVVType::verifyType() const { - if (ScalarType == Invalid) - return false; - if (isScalar()) - return true; - if (!Scale.hasValue()) - return false; - if (isFloat() && ElementBitwidth == 8) - return false; - unsigned V = Scale.getValue(); - switch (ElementBitwidth) { - case 1: - case 8: - // Check Scale is 1,2,4,8,16,32,64 - return (V <= 64 && isPowerOf2_32(V)); - case 16: - // Check Scale is 1,2,4,8,16,32 - return (V <= 32 && isPowerOf2_32(V)); - case 32: - // Check Scale is 1,2,4,8,16 - return (V <= 16 && isPowerOf2_32(V)); - case 64: - // Check Scale is 1,2,4,8 - return (V <= 8 && isPowerOf2_32(V)); - } - return false; -} - -void RVVType::initBuiltinStr() { - assert(isValid() && "RVVType is invalid"); - switch (ScalarType) { - case ScalarTypeKind::Void: - BuiltinStr = "v"; - return; - case ScalarTypeKind::Size_t: - BuiltinStr = "z"; - if (IsImmediate) - BuiltinStr = "I" + BuiltinStr; - if (IsPointer) - BuiltinStr += "*"; - return; - case ScalarTypeKind::Ptrdiff_t: - BuiltinStr = "Y"; - return; - case ScalarTypeKind::UnsignedLong: - BuiltinStr = "ULi"; - return; - case ScalarTypeKind::SignedLong: - BuiltinStr = "Li"; - return; - case ScalarTypeKind::Boolean: - assert(ElementBitwidth == 1); - BuiltinStr += "b"; - break; - case ScalarTypeKind::SignedInteger: - case ScalarTypeKind::UnsignedInteger: - switch (ElementBitwidth) { - case 8: - BuiltinStr += "c"; - break; - case 16: - BuiltinStr += "s"; - break; - case 32: - BuiltinStr += "i"; - break; - case 64: - BuiltinStr += "Wi"; - break; - default: - llvm_unreachable("Unhandled ElementBitwidth!"); - } - if (isSignedInteger()) - BuiltinStr = "S" + BuiltinStr; - else - BuiltinStr = "U" + BuiltinStr; - break; - case ScalarTypeKind::Float: - switch (ElementBitwidth) { - case 16: - BuiltinStr += "x"; - break; - case 32: - BuiltinStr += "f"; - break; - case 64: - BuiltinStr += "d"; - break; - default: - llvm_unreachable("Unhandled ElementBitwidth!"); - } - break; - default: - llvm_unreachable("ScalarType is invalid!"); - } - if (IsImmediate) - BuiltinStr = "I" + BuiltinStr; - if (isScalar()) { - if (IsConstant) - BuiltinStr += "C"; - if (IsPointer) - BuiltinStr += "*"; - return; - } - BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr; - // Pointer to vector types. Defined for Zvlsseg load intrinsics. - // Zvlsseg load intrinsics have pointer type arguments to store the loaded - // vector values. - if (IsPointer) - BuiltinStr += "*"; -} - -void RVVType::initClangBuiltinStr() { - assert(isValid() && "RVVType is invalid"); - assert(isVector() && "Handle Vector type only"); - - ClangBuiltinStr = "__rvv_"; - switch (ScalarType) { - case ScalarTypeKind::Boolean: - ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t"; - return; - case ScalarTypeKind::Float: - ClangBuiltinStr += "float"; - break; - case ScalarTypeKind::SignedInteger: - ClangBuiltinStr += "int"; - break; - case ScalarTypeKind::UnsignedInteger: - ClangBuiltinStr += "uint"; - break; - default: - llvm_unreachable("ScalarTypeKind is invalid"); - } - ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t"; -} - -void RVVType::initTypeStr() { - assert(isValid() && "RVVType is invalid"); - - if (IsConstant) - Str += "const "; - - auto getTypeString = [&](StringRef TypeStr) { - if (isScalar()) - return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); - return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t") - .str(); - }; - - switch (ScalarType) { - case ScalarTypeKind::Void: - Str = "void"; - return; - case ScalarTypeKind::Size_t: - Str = "size_t"; - if (IsPointer) - Str += " *"; - return; - case ScalarTypeKind::Ptrdiff_t: - Str = "ptrdiff_t"; - return; - case ScalarTypeKind::UnsignedLong: - Str = "unsigned long"; - return; - case ScalarTypeKind::SignedLong: - Str = "long"; - return; - case ScalarTypeKind::Boolean: - if (isScalar()) - Str += "bool"; - else - // Vector bool is special case, the formulate is - // `vbool_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 - Str += "vbool" + utostr(64 / Scale.getValue()) + "_t"; - break; - case ScalarTypeKind::Float: - if (isScalar()) { - if (ElementBitwidth == 64) - Str += "double"; - else if (ElementBitwidth == 32) - Str += "float"; - else if (ElementBitwidth == 16) - Str += "_Float16"; - else - llvm_unreachable("Unhandled floating type."); - } else - Str += getTypeString("float"); - break; - case ScalarTypeKind::SignedInteger: - Str += getTypeString("int"); - break; - case ScalarTypeKind::UnsignedInteger: - Str += getTypeString("uint"); - break; - default: - llvm_unreachable("ScalarType is invalid!"); - } - if (IsPointer) - Str += " *"; -} - -void RVVType::initShortStr() { - switch (ScalarType) { - case ScalarTypeKind::Boolean: - assert(isVector()); - ShortStr = "b" + utostr(64 / Scale.getValue()); - return; - case ScalarTypeKind::Float: - ShortStr = "f" + utostr(ElementBitwidth); - break; - case ScalarTypeKind::SignedInteger: - ShortStr = "i" + utostr(ElementBitwidth); - break; - case ScalarTypeKind::UnsignedInteger: - ShortStr = "u" + utostr(ElementBitwidth); - break; - default: - PrintFatalError("Unhandled case!"); - } - if (isVector()) - ShortStr += LMUL.str(); -} - -void RVVType::applyBasicType() { - switch (BT) { +static RVVBasicType ParseRVVBasicType(char c) { + switch (c) { case 'c': - ElementBitwidth = 8; - ScalarType = ScalarTypeKind::SignedInteger; + return RVVBasicTypeInt8; break; case 's': - ElementBitwidth = 16; - ScalarType = ScalarTypeKind::SignedInteger; + return RVVBasicTypeInt16; break; case 'i': - ElementBitwidth = 32; - ScalarType = ScalarTypeKind::SignedInteger; + return RVVBasicTypeInt32; break; case 'l': - ElementBitwidth = 64; - ScalarType = ScalarTypeKind::SignedInteger; + return RVVBasicTypeInt64; break; case 'x': - ElementBitwidth = 16; - ScalarType = ScalarTypeKind::Float; + return RVVBasicTypeFloat16; break; case 'f': - ElementBitwidth = 32; - ScalarType = ScalarTypeKind::Float; + return RVVBasicTypeFloat32; break; case 'd': - ElementBitwidth = 64; - ScalarType = ScalarTypeKind::Float; + return RVVBasicTypeFloat64; break; - default: - PrintFatalError("Unhandled type code!"); - } - assert(ElementBitwidth != 0 && "Bad element bitwidth!"); -} -void RVVType::applyModifier(StringRef Transformer) { - if (Transformer.empty()) - return; - // Handle primitive type transformer - auto PType = Transformer.back(); - switch (PType) { - case 'e': - Scale = 0; - break; - case 'v': - Scale = LMUL.getScale(ElementBitwidth); - break; - case 'w': - ElementBitwidth *= 2; - LMUL *= 2; - Scale = LMUL.getScale(ElementBitwidth); - break; - case 'q': - ElementBitwidth *= 4; - LMUL *= 4; - Scale = LMUL.getScale(ElementBitwidth); - break; - case 'o': - ElementBitwidth *= 8; - LMUL *= 8; - Scale = LMUL.getScale(ElementBitwidth); - break; - case 'm': - ScalarType = ScalarTypeKind::Boolean; - Scale = LMUL.getScale(ElementBitwidth); - ElementBitwidth = 1; - break; - case '0': - ScalarType = ScalarTypeKind::Void; - break; - case 'z': - ScalarType = ScalarTypeKind::Size_t; - break; - case 't': - ScalarType = ScalarTypeKind::Ptrdiff_t; - break; - case 'u': - ScalarType = ScalarTypeKind::UnsignedLong; - break; - case 'l': - ScalarType = ScalarTypeKind::SignedLong; - break; default: - PrintFatalError("Illegal primitive type transformers!"); - } - Transformer = Transformer.drop_back(); - - // Extract and compute complex type transformer. It can only appear one time. - if (Transformer.startswith("(")) { - size_t Idx = Transformer.find(')'); - assert(Idx != StringRef::npos); - StringRef ComplexType = Transformer.slice(1, Idx); - Transformer = Transformer.drop_front(Idx + 1); - assert(!Transformer.contains('(') && - "Only allow one complex type transformer"); - - auto UpdateAndCheckComplexProto = [&]() { - Scale = LMUL.getScale(ElementBitwidth); - const StringRef VectorPrototypes("vwqom"); - if (!VectorPrototypes.contains(PType)) - PrintFatalError("Complex type transformer only supports vector type!"); - if (Transformer.find_first_of("PCKWS") != StringRef::npos) - PrintFatalError( - "Illegal type transformer for Complex type transformer"); - }; - auto ComputeFixedLog2LMUL = - [&](StringRef Value, - std::function Compare) { - int32_t Log2LMUL; - Value.getAsInteger(10, Log2LMUL); - if (!Compare(Log2LMUL, LMUL.Log2LMUL)) { - ScalarType = Invalid; - return false; - } - // Update new LMUL - LMUL = LMULType(Log2LMUL); - UpdateAndCheckComplexProto(); - return true; - }; - auto ComplexTT = ComplexType.split(":"); - if (ComplexTT.first == "Log2EEW") { - uint32_t Log2EEW; - ComplexTT.second.getAsInteger(10, Log2EEW); - // update new elmul = (eew/sew) * lmul - LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); - // update new eew - ElementBitwidth = 1 << Log2EEW; - ScalarType = ScalarTypeKind::SignedInteger; - UpdateAndCheckComplexProto(); - } else if (ComplexTT.first == "FixedSEW") { - uint32_t NewSEW; - ComplexTT.second.getAsInteger(10, NewSEW); - // Set invalid type if src and dst SEW are same. - if (ElementBitwidth == NewSEW) { - ScalarType = Invalid; - return; - } - // Update new SEW - ElementBitwidth = NewSEW; - UpdateAndCheckComplexProto(); - } else if (ComplexTT.first == "LFixedLog2LMUL") { - // New LMUL should be larger than old - if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater())) - return; - } else if (ComplexTT.first == "SFixedLog2LMUL") { - // New LMUL should be smaller than old - if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less())) - return; - } else { - PrintFatalError("Illegal complex type transformers!"); - } - } - - // Compute the remain type transformers - for (char I : Transformer) { - switch (I) { - case 'P': - if (IsConstant) - PrintFatalError("'P' transformer cannot be used after 'C'"); - if (IsPointer) - PrintFatalError("'P' transformer cannot be used twice"); - IsPointer = true; - break; - case 'C': - if (IsConstant) - PrintFatalError("'C' transformer cannot be used twice"); - IsConstant = true; - break; - case 'K': - IsImmediate = true; - break; - case 'U': - ScalarType = ScalarTypeKind::UnsignedInteger; - break; - case 'I': - ScalarType = ScalarTypeKind::SignedInteger; - break; - case 'F': - ScalarType = ScalarTypeKind::Float; - break; - case 'S': - LMUL = LMULType(0); - // Update ElementBitwidth need to update Scale too. - Scale = LMUL.getScale(ElementBitwidth); - break; - default: - PrintFatalError("Illegal non-primitive type transformer!"); - } + return RVVBasicTypeUnknown; } } @@ -837,7 +286,7 @@ OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n"; if (hasPolicy()) OS << " Ops.push_back(ConstantInt::get(Ops.back()->getType()," - " TAIL_UNDISTURBED));\n"; + " TAIL_UNDISTURBED));\n"; } else { OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n"; } @@ -860,32 +309,6 @@ OS << " break;\n"; } -void RVVIntrinsic::emitIntrinsicFuncDef(raw_ostream &OS) const { - OS << "__attribute__((__clang_builtin_alias__("; - OS << "__builtin_rvv_" << getBuiltinName() << ")))\n"; - OS << OutputType->getTypeStr() << " " << getName() << "("; - // Emit function arguments - if (!InputTypes.empty()) { - ListSeparator LS; - for (unsigned i = 0; i < InputTypes.size(); ++i) - OS << LS << InputTypes[i]->getTypeStr(); - } - OS << ");\n"; -} - -void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const { - OS << "__attribute__((__clang_builtin_alias__("; - OS << "__builtin_rvv_" << getBuiltinName() << ")))\n"; - OS << OutputType->getTypeStr() << " " << getMangledName() << "("; - // Emit function arguments - if (!InputTypes.empty()) { - ListSeparator LS; - for (unsigned i = 0; i < InputTypes.size(); ++i) - OS << LS << InputTypes[i]->getTypeStr(); - } - OS << ");\n"; -} - //===----------------------------------------------------------------------===// // RVVEmitter implementation //===----------------------------------------------------------------------===// @@ -917,12 +340,10 @@ OS << "#ifdef __cplusplus\n"; OS << "extern \"C\" {\n"; OS << "#endif\n\n"; + OS << "#pragma clang riscv intrinsic vector\n\n"; createRVVHeaders(OS); - std::vector> Defs; - createRVVIntrinsics(Defs); - // Print header code if (!HeaderCode.empty()) { OS << HeaderCode; @@ -936,24 +357,25 @@ constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3}; // Print RVV boolean types. for (int Log2LMUL : Log2LMULs) { - auto T = computeType('c', Log2LMUL, "m"); + auto T = computeType(RVVBasicTypeInt8, Log2LMUL, "m"); if (T.hasValue()) printType(T.getValue()); } // Print RVV int/float types. for (char I : StringRef("csil")) { + auto BT = ParseRVVBasicType(I); for (int Log2LMUL : Log2LMULs) { - auto T = computeType(I, Log2LMUL, "v"); + auto T = computeType(BT, Log2LMUL, "v"); if (T.hasValue()) { printType(T.getValue()); - auto UT = computeType(I, Log2LMUL, "Uv"); + auto UT = computeType(BT, Log2LMUL, "Uv"); printType(UT.getValue()); } } } OS << "#if defined(__riscv_zfh)\n"; for (int Log2LMUL : Log2LMULs) { - auto T = computeType('x', Log2LMUL, "v"); + auto T = computeType(RVVBasicTypeFloat16, Log2LMUL, "v"); if (T.hasValue()) printType(T.getValue()); } @@ -961,7 +383,7 @@ OS << "#if defined(__riscv_f)\n"; for (int Log2LMUL : Log2LMULs) { - auto T = computeType('f', Log2LMUL, "v"); + auto T = computeType(RVVBasicTypeFloat32, Log2LMUL, "v"); if (T.hasValue()) printType(T.getValue()); } @@ -969,7 +391,7 @@ OS << "#if defined(__riscv_d)\n"; for (int Log2LMUL : Log2LMULs) { - auto T = computeType('d', Log2LMUL, "v"); + auto T = computeType(RVVBasicTypeFloat64, Log2LMUL, "v"); if (T.hasValue()) printType(T.getValue()); } @@ -981,31 +403,8 @@ return A->getRISCVExtensions() < B->getRISCVExtensions(); }); - OS << "#define __rvv_ai static __inline__\n"; - - // Print intrinsic functions with macro - emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) { - OS << "__rvv_ai "; - Inst.emitIntrinsicFuncDef(OS); - }); - - OS << "#undef __rvv_ai\n\n"; - OS << "#define __riscv_v_intrinsic_overloading 1\n"; - // Print Overloaded APIs - OS << "#define __rvv_aio static __inline__ " - "__attribute__((__overloadable__))\n"; - - emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) { - if (!Inst.isMask() && !Inst.hasNoMaskedOverloaded()) - return; - OS << "__rvv_aio "; - Inst.emitMangledFuncDef(OS); - }); - - OS << "#undef __rvv_aio\n"; - OS << "\n#ifdef __cplusplus\n"; OS << "}\n"; OS << "#endif // __cplusplus\n"; @@ -1013,9 +412,6 @@ } void RVVEmitter::createBuiltins(raw_ostream &OS) { - std::vector> Defs; - createRVVIntrinsics(Defs); - // Map to keep track of which builtin names have already been emitted. StringMap BuiltinMap; @@ -1046,8 +442,6 @@ } void RVVEmitter::createCodeGen(raw_ostream &OS) { - std::vector> Defs; - createRVVIntrinsics(Defs); // IR name could be empty, use the stable sort preserves the relative order. llvm::stable_sort(Defs, [](const std::unique_ptr &A, const std::unique_ptr &B) { @@ -1095,27 +489,11 @@ OS << "\n"; } -void RVVEmitter::parsePrototypes(StringRef Prototypes, - std::function Handler) { - const StringRef Primaries("evwqom0ztul"); - while (!Prototypes.empty()) { - size_t Idx = 0; - // Skip over complex prototype because it could contain primitive type - // character. - if (Prototypes[0] == '(') - Idx = Prototypes.find_first_of(')'); - Idx = Prototypes.find_first_of(Primaries, Idx); - assert(Idx != StringRef::npos); - Handler(Prototypes.slice(0, Idx + 1)); - Prototypes = Prototypes.drop_front(Idx + 1); - } -} - -std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL, +std::string RVVEmitter::getSuffixStr(RVVBasicType BT, int Log2LMUL, StringRef Prototypes) { SmallVector SuffixStrs; - parsePrototypes(Prototypes, [&](StringRef Proto) { - auto T = computeType(Type, Log2LMUL, Proto); + RVVParsePrototypes(Prototypes, [&](StringRef Proto) { + auto T = computeType(BT, Log2LMUL, Proto); SuffixStrs.push_back(T.getValue()->getShortStr()); }); return join(SuffixStrs, "_"); @@ -1154,7 +532,7 @@ // Parse prototype and create a list of primitive type with transformers // (operand) in ProtoSeq. ProtoSeq[0] is output operand. SmallVector ProtoSeq; - parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) { + RVVParsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) { ProtoSeq.push_back(Proto.str()); }); @@ -1197,13 +575,14 @@ // Create Intrinsics for each type and LMUL. for (char I : TypeRange) { for (int Log2LMUL : Log2LMULList) { - Optional Types = computeTypes(I, Log2LMUL, NF, ProtoSeq); + auto BT = ParseRVVBasicType(I); + Optional Types = computeTypes(BT, Log2LMUL, NF, ProtoSeq); // Ignored to create new intrinsic if there are any illegal types. if (!Types.hasValue()) continue; - auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto); - auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto); + auto SuffixStr = getSuffixStr(BT, Log2LMUL, SuffixProto); + auto MangledSuffixStr = getSuffixStr(BT, Log2LMUL, MangledSuffixProto); // Create a non-mask intrinsic Out.push_back(std::make_unique( Name, SuffixStr, MangledName, MangledSuffixStr, IRName, @@ -1213,7 +592,7 @@ if (HasMask) { // Create a mask intrinsic Optional MaskTypes = - computeTypes(I, Log2LMUL, NF, ProtoMaskSeq); + computeTypes(BT, Log2LMUL, NF, ProtoMaskSeq); Out.push_back(std::make_unique( Name, SuffixStr, MangledName, MangledSuffixStr, IRNameMask, /*IsMask=*/true, HasMaskedOffOperand, HasVL, HasPolicy, @@ -1222,6 +601,32 @@ } } // end for Log2LMULList } // end for TypeRange + + // Create SemaRecord + SemaRecord SR; + + SR.Name = Name.str(); + SR.MangledName = MangledName.str(); + SR.TypeRange = TypeRange.str(); + SR.Log2LMULList = Log2LMULList; + SR.RequiredExtension = RequiredExtension.str(); + SR.NF = NF; + + SR.ProtoSeq = std::move(ProtoSeq); + + if (HasMask) + SR.ProtoMaskSeq = std::move(ProtoMaskSeq); + + auto InitSuffixProtoSeq = [&](SmallVectorImpl &PS, + StringRef Prototypes) { + RVVParsePrototypes(Prototypes, + [&](StringRef Proto) { PS.push_back(Proto.str()); }); + }; + + InitSuffixProtoSeq(SR.SuffixProto, SuffixProto); + InitSuffixProtoSeq(SR.MangledSuffixProto, MangledSuffixProto); + + SemaRecords.push_back(SR); } } @@ -1235,7 +640,7 @@ } Optional -RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, +RVVEmitter::computeTypes(RVVBasicType BT, int Log2LMUL, unsigned NF, ArrayRef PrototypeSeq) { // LMUL x NF must be less than or equal to 8. if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) @@ -1252,9 +657,10 @@ return Types; } -Optional RVVEmitter::computeType(BasicType BT, int Log2LMUL, +Optional RVVEmitter::computeType(RVVBasicType BT, int Log2LMUL, StringRef Proto) { - std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str(); + std::string Idx = + Twine(Twine(static_cast(BT)) + Twine(Log2LMUL) + Proto).str(); // Search first auto It = LegalTypes.find(Idx); if (It != LegalTypes.end()) @@ -1273,41 +679,275 @@ return llvm::None; } -void RVVEmitter::emitArchMacroAndBody( - std::vector> &Defs, raw_ostream &OS, - std::function PrintBody) { - uint8_t PrevExt = (*Defs.begin())->getRISCVExtensions(); - bool NeedEndif = emitExtDefStr(PrevExt, OS); - for (auto &Def : Defs) { - uint8_t CurExt = Def->getRISCVExtensions(); - if (CurExt != PrevExt) { - if (NeedEndif) - OS << "#endif\n\n"; - NeedEndif = emitExtDefStr(CurExt, OS); - PrevExt = CurExt; +void emitSemaPrototypeType(StringRef T, raw_ostream &OS) { + auto PType = T.back(); + StringRef PrimitiveType; + switch (PType) { + case 'e': + PrimitiveType = "RVVPrimTypeScalar"; + break; + case 'v': + PrimitiveType = "RVVPrimTypeVector"; + break; + case 'w': + PrimitiveType = "RVVPrimType2XWideningVector"; + break; + case 'q': + PrimitiveType = "RVVPrimType4XWideningVector"; + break; + case 'o': + PrimitiveType = "RVVPrimType8XWideningVector"; + break; + case 'm': + PrimitiveType = "RVVPrimTypeMaskVector"; + break; + case '0': + PrimitiveType = "RVVPrimTypeVoid"; + break; + case 'z': + PrimitiveType = "RVVPrimTypeSize"; + break; + case 't': + PrimitiveType = "RVVPrimTypePtrdiff"; + break; + case 'u': + PrimitiveType = "RVVPrimTypeUnsignedLong"; + break; + case 'l': + PrimitiveType = "RVVPrimTypeSignedLong"; + break; + default: + PrintFatalError("Illegal primitive type transformers!"); + } + OS << "{" << PrimitiveType << ", "; + T = T.drop_back(); + + // Extract and compute complex type transformer. It can only appear one time. + if (T.startswith("(")) { + size_t Idx = T.find(')'); + assert(Idx != StringRef::npos); + StringRef ComplexType = T.slice(1, Idx); + const char *ComplexTypeString = + StringSwitch(ComplexType) + .Case("Log2EEW:3", "RVVVecTypeLog2EEW3") + .Case("Log2EEW:4", "RVVVecTypeLog2EEW4") + .Case("Log2EEW:5", "RVVVecTypeLog2EEW5") + .Case("Log2EEW:6", "RVVVecTypeLog2EEW6") + .Case("FixedSEW:8", "RVVVecTypeFixedSEW8") + .Case("FixedSEW:16", "RVVVecTypeFixedSEW16") + .Case("FixedSEW:32", "RVVVecTypeFixedSEW32") + .Case("FixedSEW:64", "RVVVecTypeFixedSEW64") + .Case("LFixedLog2LMUL:-3", "RVVVecTypeLFixedLog2LMULN3") + .Case("LFixedLog2LMUL:-2", "RVVVecTypeLFixedLog2LMULN2") + .Case("LFixedLog2LMUL:-1", "RVVVecTypeLFixedLog2LMULN1") + .Case("LFixedLog2LMUL:-0", "RVVVecTypeLFixedLog2LMUL0") + .Case("LFixedLog2LMUL:1", "RVVVecTypeLFixedLog2LMUL1") + .Case("LFixedLog2LMUL:2", "RVVVecTypeLFixedLog2LMUL2") + .Case("LFixedLog2LMUL:3", "RVVVecTypeLFixedLog2LMUL3") + .Case("SFixedLog2LMUL:-3", "RVVVecTypeSFixedLog2LMULN3") + .Case("SFixedLog2LMUL:-2", "RVVVecTypeSFixedLog2LMULN2") + .Case("SFixedLog2LMUL:-1", "RVVVecTypeSFixedLog2LMULN1") + .Case("SFixedLog2LMUL:0", "RVVVecTypeSFixedLog2LMUL0") + .Case("SFixedLog2LMUL:1", "RVVVecTypeSFixedLog2LMUL1") + .Case("SFixedLog2LMUL:2", "RVVVecTypeSFixedLog2LMUL2") + .Case("SFixedLog2LMUL:3", "RVVVecTypeSFixedLog2LMUL3") + .Default(nullptr); + assert(ComplexTypeString != nullptr); + OS << ComplexTypeString << ", "; + T = T.drop_front(Idx + 1); + } else { + OS << "RVVVecTypeNoModifier, "; + } + + // Compute the remain type transformers + OS << "0"; + for (char I : T) { + OS << " | "; + switch (I) { + case 'P': + OS << "RVVTypeModifierPointer"; + break; + case 'C': + OS << "RVVTypeModifierConst"; + break; + case 'K': + OS << "RVVTypeModifierImmediate"; + break; + case 'U': + OS << "RVVTypeModifierUnsignedInteger"; + break; + case 'I': + OS << "RVVTypeModifierSignedInteger"; + break; + case 'F': + OS << "RVVTypeModifierFloat"; + break; + case 'S': + OS << "RVVTypeModifierLMUL1"; + break; + default: + PrintFatalError("Illegal non-primitive type transformer!"); } - if (Def->hasAutoDef()) - PrintBody(OS, *Def); } - if (NeedEndif) - OS << "#endif\n\n"; + OS << "}"; } -bool RVVEmitter::emitExtDefStr(uint8_t Extents, raw_ostream &OS) { - if (Extents == RISCVExtension::Basic) - return false; - OS << "#if "; - ListSeparator LS(" && "); - if (Extents & RISCVExtension::F) - OS << LS << "defined(__riscv_f)"; - if (Extents & RISCVExtension::D) - OS << LS << "defined(__riscv_d)"; - if (Extents & RISCVExtension::Zfh) - OS << LS << "defined(__riscv_zfh)"; - if (Extents & RISCVExtension::Zvlsseg) - OS << LS << "defined(__riscv_zvlsseg)"; - OS << "\n"; - return true; +unsigned +RVVEmitter::GetSemaSignatureIndex(const SmallVector &Signature) { + if (Signature.size() == 0) + return 0; + + // Checking Signature already in table or not. + if (Signature.size() < SemaSignatureTable.size()) { + size_t Bound = SemaSignatureTable.size() - Signature.size() + 1; + for (size_t Index = 0; Index < Bound; ++Index) { + bool Match = true; + for (size_t i = 0; i < Signature.size(); ++i) { + if (Signature[i] != SemaSignatureTable[Index + i]) { + Match = false; + break; + } + } + // Reuse if found in table. + if (Match) + return Index; + } + } + + // Insert Signature into SemaSignatureTable if not found in the table. + size_t Index = SemaSignatureTable.size(); + for (auto Type : Signature) { + SemaSignatureTable.push_back(Type); + } + return Index; +} + +void RVVEmitter::ConstructSemaSignatureTable() { + // Sort signature entries by length, let longer signature insert first, to + // make it more possible to reuse table entries, that can reduce ~10% table + // size. + struct Compare { + bool operator()(const SmallVector &A, + const SmallVector &B) { + if (A.size() != B.size()) + return A.size() > B.size(); + + size_t Len = A.size(); + for (size_t i = 0; i < Len; ++i) { + if (A[i] != B[i]) + return A[i] > B[i]; + } + + return false; + } + }; + + std::set, Compare> Signatures; + auto InsertToSignatureSet = [&](const SmallVector &Signature) { + if (Signature.empty()) + return; + + Signatures.insert(Signature); + }; + + for (auto SemaRecord : SemaRecords) { + InsertToSignatureSet(SemaRecord.ProtoSeq); + InsertToSignatureSet(SemaRecord.ProtoMaskSeq); + InsertToSignatureSet(SemaRecord.SuffixProto); + InsertToSignatureSet(SemaRecord.MangledSuffixProto); + } + + for (auto Sig : Signatures) { + GetSemaSignatureIndex(Sig); + } +} + +void RVVEmitter::EmitSemaSignatureTable(raw_ostream &OS) { + OS << "#ifdef DECL_SIGNATURE_TABLE\n"; + for (auto Sig : SemaSignatureTable) { + emitSemaPrototypeType(Sig, OS); + OS << ",\n"; + } + OS << "#endif\n"; +} + +void RVVEmitter::EmitSemaRecords(raw_ostream &OS) { + OS << "#ifdef DECL_INTRINSIC_RECORDS\n"; + for (auto SR : SemaRecords) { + OS << "{" + << "\"" << SR.Name << "\", "; + + if (SR.MangledName.empty()) + OS << "nullptr, "; + else + OS << "\"" << SR.MangledName << "\", "; + + OS << GetSemaSignatureIndex(SR.ProtoSeq) << ", "; + OS << GetSemaSignatureIndex(SR.ProtoMaskSeq) << ", "; + OS << GetSemaSignatureIndex(SR.SuffixProto) << ", "; + OS << GetSemaSignatureIndex(SR.MangledSuffixProto) << ", "; + + OS << SR.ProtoSeq.size() << ", "; + OS << SR.ProtoMaskSeq.size() << ", "; + OS << SR.SuffixProto.size() << ", "; + OS << SR.MangledSuffixProto.size() << ", "; + + if (SR.RequiredExtension.empty()) + OS << "0, "; + else if (SR.RequiredExtension == "Zvlsseg") + OS << "RVVRequireZvlsseg, "; + + OS << " /* Type Range Mask*/"; + ListSeparator TRLS(" | "); + for (auto T : SR.TypeRange) { + StringRef TypeMask; + switch (T) { + case 'c': + TypeMask = "RVVBasicTypeInt8"; + break; + case 's': + TypeMask = "RVVBasicTypeInt16"; + break; + case 'i': + TypeMask = "RVVBasicTypeInt32"; + break; + case 'l': + TypeMask = "RVVBasicTypeInt64"; + break; + case 'x': + TypeMask = "RVVBasicTypeFloat16"; + break; + case 'f': + TypeMask = "RVVBasicTypeFloat32"; + break; + case 'd': + TypeMask = "RVVBasicTypeFloat64"; + break; + default: + TypeMask = ""; + llvm_unreachable("Unknown TypeRang letter."); + } + OS << TRLS << TypeMask; + } + OS << ","; + OS << " /* LMUL Mask = */ "; + unsigned Log2LMULMask = 0; + for (int Log2LMUL : SR.Log2LMULList) { + Log2LMULMask |= 1 << (Log2LMUL + 3); + } + OS << Log2LMULMask << ", "; + + OS << SR.NF << "},\n"; + } + OS << "#endif\n"; +} + +void RVVEmitter::createSema(raw_ostream &OS) { + emitSourceFileHeader("RISC-V Vector Builtin handling", OS); + + ConstructSemaSignatureTable(); + EmitSemaSignatureTable(OS); + EmitSemaRecords(OS); } namespace clang { @@ -1323,4 +963,8 @@ RVVEmitter(Records).createCodeGen(OS); } +void EmitRVVBuiltinSema(RecordKeeper &Records, raw_ostream &OS) { + RVVEmitter(Records).createSema(OS); +} + } // End namespace clang diff --git a/clang/utils/TableGen/TableGen.cpp b/clang/utils/TableGen/TableGen.cpp --- a/clang/utils/TableGen/TableGen.cpp +++ b/clang/utils/TableGen/TableGen.cpp @@ -88,6 +88,7 @@ GenRISCVVectorHeader, GenRISCVVectorBuiltins, GenRISCVVectorBuiltinCG, + GenRISCVVectorBuiltinSema, GenAttrDocs, GenDiagDocs, GenOptDocs, @@ -243,6 +244,8 @@ "Generate riscv_vector_builtins.inc for clang"), clEnumValN(GenRISCVVectorBuiltinCG, "gen-riscv-vector-builtin-codegen", "Generate riscv_vector_builtin_cg.inc for clang"), + clEnumValN(GenRISCVVectorBuiltinSema, "gen-riscv-vector-builtin-sema", + "Generate riscv_vector_builtin_sema.inc for clang"), clEnumValN(GenAttrDocs, "gen-attr-docs", "Generate attribute documentation"), clEnumValN(GenDiagDocs, "gen-diag-docs", @@ -458,6 +461,9 @@ case GenRISCVVectorBuiltinCG: EmitRVVBuiltinCG(Records, OS); break; + case GenRISCVVectorBuiltinSema: + EmitRVVBuiltinSema(Records, OS); + break; case GenAttrDocs: EmitClangAttrDocs(Records, OS); break; diff --git a/clang/utils/TableGen/TableGenBackends.h b/clang/utils/TableGen/TableGenBackends.h --- a/clang/utils/TableGen/TableGenBackends.h +++ b/clang/utils/TableGen/TableGenBackends.h @@ -110,6 +110,7 @@ void EmitRVVHeader(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); void EmitRVVBuiltins(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); void EmitRVVBuiltinCG(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); +void EmitRVVBuiltinSema(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); void EmitCdeHeader(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); void EmitCdeBuiltinDef(llvm::RecordKeeper &Records, llvm::raw_ostream &OS); diff --git a/llvm/docs/CommandGuide/tblgen.rst b/llvm/docs/CommandGuide/tblgen.rst --- a/llvm/docs/CommandGuide/tblgen.rst +++ b/llvm/docs/CommandGuide/tblgen.rst @@ -348,6 +348,10 @@ Generate ``riscv_vector_builtin_cg.inc`` for Clang. +.. option:: -gen-riscv-vector-builtin-sema + + Generate ``riscv_vector_builtin_sema.inc`` for Clang. + .. option:: -gen-attr-docs Generate attribute documentation. diff --git a/llvm/include/llvm/Support/RISCVVIntrinsicUtils.h b/llvm/include/llvm/Support/RISCVVIntrinsicUtils.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Support/RISCVVIntrinsicUtils.h @@ -0,0 +1,221 @@ +//===-- RISCVVIntrinsicUtils.h - RISC-V Vector Intrinsic Utils --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_SUPPORT_RISCVVINTRINSICUTILS_H +#define LLVM_SUPPORT_RISCVVINTRINSICUTILS_H + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace llvm { + +namespace RISCV { + +using VScaleVal = Optional; + +enum RVVBasicType { + RVVBasicTypeUnknown = 0, + RVVBasicTypeInt8 = 1 << 0, + RVVBasicTypeInt16 = 1 << 1, + RVVBasicTypeInt32 = 1 << 2, + RVVBasicTypeInt64 = 1 << 3, + RVVBasicTypeFloat16 = 1 << 4, + RVVBasicTypeFloat32 = 1 << 5, + RVVBasicTypeFloat64 = 1 << 6, + RVVBasicTypeMaxOffset = 6, +}; + +enum RVVPrimitiveType { + RVVPrimTypeInvalid, + RVVPrimTypeScalar, + RVVPrimTypeVector, + RVVPrimType2XWideningVector, + RVVPrimType4XWideningVector, + RVVPrimType8XWideningVector, + RVVPrimTypeMaskVector, + RVVPrimTypeVoid, + RVVPrimTypeSize, + RVVPrimTypePtrdiff, + RVVPrimTypeUnsignedLong, + RVVPrimTypeSignedLong, +}; + +enum RVVVectorTypeModifier { + RVVVecTypeNoModifier, + RVVVecTypeLog2EEW3, + RVVVecTypeLog2EEW4, + RVVVecTypeLog2EEW5, + RVVVecTypeLog2EEW6, + RVVVecTypeFixedSEW8, + RVVVecTypeFixedSEW16, + RVVVecTypeFixedSEW32, + RVVVecTypeFixedSEW64, + RVVVecTypeLFixedLog2LMULN3, + RVVVecTypeLFixedLog2LMULN2, + RVVVecTypeLFixedLog2LMULN1, + RVVVecTypeLFixedLog2LMUL0, + RVVVecTypeLFixedLog2LMUL1, + RVVVecTypeLFixedLog2LMUL2, + RVVVecTypeLFixedLog2LMUL3, + RVVVecTypeSFixedLog2LMULN3, + RVVVecTypeSFixedLog2LMULN2, + RVVVecTypeSFixedLog2LMULN1, + RVVVecTypeSFixedLog2LMUL0, + RVVVecTypeSFixedLog2LMUL1, + RVVVecTypeSFixedLog2LMUL2, + RVVVecTypeSFixedLog2LMUL3, +}; + +enum RVVTypeModifier { + RVVTypeModifierPointer = 1 << 0, + RVVTypeModifierConst = 1 << 1, + RVVTypeModifierImmediate = 1 << 2, + RVVTypeModifierUnsignedInteger = 1 << 3, + RVVTypeModifierSignedInteger = 1 << 4, + RVVTypeModifierFloat = 1 << 5, + RVVTypeModifierLMUL1 = 1 << 6, + RVVTypeModifierMaskMax = RVVTypeModifierLMUL1, +}; + +// Exponential LMUL +struct LMULType { + int Log2LMUL; + LMULType(int Log2LMUL); + // Return the C/C++ string representation of LMUL + std::string str() const; + Optional getScale(unsigned ElementBitwidth) const; + void MulLog2LMUL(int Log2LMUL); + LMULType &operator*=(uint32_t RHS); +}; + +// This class is compact representation of a valid and invalid RVVType. +enum ScalarTypeKind { + STK_Void, + STK_Size_t, + STK_Ptrdiff_t, + STK_UnsignedLong, + STK_SignedLong, + STK_Boolean, + STK_SignedInteger, + STK_UnsignedInteger, + STK_Float, + STK_Invalid, +}; + +struct RVVTypeProfile { + uint8_t PrimitiveType; + uint8_t VectorTypeModifier; + uint8_t TypeModifierMask; +}; + +class RVVType { + ScalarTypeKind ScalarType = STK_Invalid; + LMULType LMUL; + bool IsPointer = false; + // IsConstant indices are "int", but have the constant expression. + bool IsImmediate = false; + // Const qualifier for pointer to const object or object of const type. + bool IsConstant = false; + unsigned ElementBitwidth = 0; + VScaleVal Scale = 0; + bool Valid; + + std::string BuiltinStr; + std::string ClangBuiltinStr; + std::string Str; + std::string ShortStr; + std::string MangledStr; + +public: + RVVType() : RVVType(RVVBasicTypeUnknown, 0, StringRef()) {} + RVVType(RVVBasicType BT, int Log2LMUL, StringRef prototype); + RVVType(RVVBasicType BT, int Log2LMUL, RVVTypeProfile Profile); + + // Return the string + StringRef getMangledStr(); + + // Return the string representation of a type, which is an encoded string for + // passing to the BUILTIN() macro in Builtins.def. + const std::string &getBuiltinStr() const { return BuiltinStr; } + + // Return the clang builtin type for RVV vector type which are used in the + // riscv_vector.h header file. + const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; } + + // Return the C/C++ string representation of a type for use in the + // riscv_vector.h header file. + const std::string &getTypeStr() const { return Str; } + + // Return the short name of a type for C/C++ name suffix. + const std::string &getShortStr() { + // Not all types are used in short name, so compute the short name by + // demanded. + if (ShortStr.empty()) + initShortStr(); + return ShortStr; + } + + bool isValid() const { return Valid; } + bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; } + bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; } + bool isFloat() const { return ScalarType == STK_Float; } + bool isSignedInteger() const { return ScalarType == STK_SignedInteger; } + bool isFloatVector(unsigned Width) const { + return isVector() && isFloat() && ElementBitwidth == Width; + } + bool isFloat(unsigned Width) const { + return isFloat() && ElementBitwidth == Width; + } + + bool isPointer() const { return IsPointer; } + bool isConstant() const { return IsConstant; } + unsigned getElementBitwidth() const { return ElementBitwidth; }; + VScaleVal getScale() const { return Scale; }; + + ScalarTypeKind getScalarType() const { return ScalarType; } + +private: + // Verify RVV vector type and set Valid. + bool verifyType() const; + + // Creates a type based on basic types of TypeRange + void applyBasicType(RVVBasicType BT); + + // Applies a prototype modifier to the current type. The result maybe an + // invalid type. + void applyModifier(StringRef prototype); + void applyModifier(const RVVTypeProfile &prototype); + + void applyLog2EEW(unsigned Log2EEW); + void applyFixedSEW(unsigned NewSEW); + void applyFixedLog2LMUL(int Log2LMUL, bool LargerThan); + + // Compute and record a string for legal type. + void initBuiltinStr(); + // Compute and record a builtin RVV vector type string. + void initClangBuiltinStr(); + // Compute and record a type string for used in the header. + void initTypeStr(); + // Compute and record a short name of a type for C/C++ name suffix. + void initShortStr(); +}; + +using RVVTypePtr = RVVType *; +using RVVTypes = std::vector; + +// Slice Prototypes string into sub prototype string and process each sub +// prototype string individually in the Handler. +void RVVParsePrototypes(StringRef Prototypes, + std::function Handler); + +} // namespace RISCV +} // namespace llvm + +#endif diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt --- a/llvm/lib/Support/CMakeLists.txt +++ b/llvm/lib/Support/CMakeLists.txt @@ -186,6 +186,7 @@ RISCVAttributes.cpp RISCVAttributeParser.cpp RISCVISAInfo.cpp + RISCVVIntrinsicUtils.cpp ScaledNumber.cpp ScopedPrinter.cpp SHA1.cpp diff --git a/llvm/lib/Support/RISCVVIntrinsicUtils.cpp b/llvm/lib/Support/RISCVVIntrinsicUtils.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Support/RISCVVIntrinsicUtils.cpp @@ -0,0 +1,748 @@ +//===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/RISCVVIntrinsicUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/Twine.h" +#include + +namespace llvm { +namespace RISCV { + +LMULType::LMULType(int NewLog2LMUL) { + // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 + assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); + Log2LMUL = NewLog2LMUL; +} + +std::string LMULType::str() const { + if (Log2LMUL < 0) + return "mf" + utostr(1ULL << (-Log2LMUL)); + return "m" + utostr(1ULL << Log2LMUL); +} + +VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { + int Log2ScaleResult = 0; + switch (ElementBitwidth) { + default: + break; + case 8: + Log2ScaleResult = Log2LMUL + 3; + break; + case 16: + Log2ScaleResult = Log2LMUL + 2; + break; + case 32: + Log2ScaleResult = Log2LMUL + 1; + break; + case 64: + Log2ScaleResult = Log2LMUL; + break; + } + // Illegal vscale result would be less than 1 + if (Log2ScaleResult < 0) + return None; + return 1 << Log2ScaleResult; +} + +void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } + +LMULType &LMULType::operator*=(uint32_t RHS) { + assert(isPowerOf2_32(RHS)); + this->Log2LMUL = this->Log2LMUL + Log2_32(RHS); + return *this; +} + +RVVType::RVVType(RVVBasicType BT, int Log2LMUL, StringRef prototype) + : LMUL(LMULType(Log2LMUL)) { + applyBasicType(BT); + applyModifier(prototype); + Valid = verifyType(); + if (Valid) { + initBuiltinStr(); + initTypeStr(); + if (isVector()) { + initClangBuiltinStr(); + } + } +} + +StringRef RVVType::getMangledStr() { + if (MangledStr.empty()) { + if (!Valid) + MangledStr = "invalid"; + else { + MangledStr = (Twine(ScalarType) + Twine(IsPointer) + Twine(IsImmediate) + + Twine(IsConstant) + Twine(ElementBitwidth)) + .str(); + if (!isScalar()) + MangledStr += (Twine(Scale.getValue()) + LMUL.str()).str(); + } + } + + return MangledStr; +} + +// clang-format off +// boolean type are encoded the ratio of n (SEW/LMUL) +// SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 +// c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t +// IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 + +// type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 +// -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- +// i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 +// i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 +// i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 +// i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 +// double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 +// float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 +// half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 +// clang-format on +bool RVVType::verifyType() const { + if (ScalarType == STK_Invalid) + return false; + if (isScalar()) + return true; + if (!Scale.hasValue()) + return false; + if (isFloat() && ElementBitwidth == 8) + return false; + unsigned V = Scale.getValue(); + switch (ElementBitwidth) { + case 1: + case 8: + // Check Scale is 1,2,4,8,16,32,64 + return (V <= 64 && isPowerOf2_32(V)); + case 16: + // Check Scale is 1,2,4,8,16,32 + return (V <= 32 && isPowerOf2_32(V)); + case 32: + // Check Scale is 1,2,4,8,16 + return (V <= 16 && isPowerOf2_32(V)); + case 64: + // Check Scale is 1,2,4,8 + return (V <= 8 && isPowerOf2_32(V)); + } + return false; +} + +void RVVType::initBuiltinStr() { + assert(isValid() && "RVVType is invalid"); + switch (ScalarType) { + case STK_Void: + BuiltinStr = "v"; + return; + case STK_Size_t: + BuiltinStr = "z"; + if (IsImmediate) + BuiltinStr = "I" + BuiltinStr; + if (IsPointer) + BuiltinStr += "*"; + return; + case STK_Ptrdiff_t: + BuiltinStr = "Y"; + return; + case STK_UnsignedLong: + BuiltinStr = "ULi"; + return; + case STK_SignedLong: + BuiltinStr = "Li"; + return; + case STK_Boolean: + assert(ElementBitwidth == 1); + BuiltinStr += "b"; + break; + case STK_SignedInteger: + case STK_UnsignedInteger: + switch (ElementBitwidth) { + case 8: + BuiltinStr += "c"; + break; + case 16: + BuiltinStr += "s"; + break; + case 32: + BuiltinStr += "i"; + break; + case 64: + BuiltinStr += "Wi"; + break; + default: + llvm_unreachable("Unhandled ElementBitwidth!"); + } + if (isSignedInteger()) + BuiltinStr = "S" + BuiltinStr; + else + BuiltinStr = "U" + BuiltinStr; + break; + case STK_Float: + switch (ElementBitwidth) { + case 16: + BuiltinStr += "x"; + break; + case 32: + BuiltinStr += "f"; + break; + case 64: + BuiltinStr += "d"; + break; + default: + llvm_unreachable("Unhandled ElementBitwidth!"); + } + break; + default: + llvm_unreachable("ScalarType is invalid!"); + } + if (IsImmediate) + BuiltinStr = "I" + BuiltinStr; + if (isScalar()) { + if (IsConstant) + BuiltinStr += "C"; + if (IsPointer) + BuiltinStr += "*"; + return; + } + BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr; + // Pointer to vector types. Defined for Zvlsseg load intrinsics. + // Zvlsseg load intrinsics have pointer type arguments to store the loaded + // vector values. + if (IsPointer) + BuiltinStr += "*"; +} +void RVVType::initClangBuiltinStr() { + assert(isValid() && "RVVType is invalid"); + assert(isVector() && "Handle Vector type only"); + + ClangBuiltinStr = "__rvv_"; + switch (ScalarType) { + case STK_Boolean: + ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t"; + return; + case STK_Float: + ClangBuiltinStr += "float"; + break; + case STK_SignedInteger: + ClangBuiltinStr += "int"; + break; + case STK_UnsignedInteger: + ClangBuiltinStr += "uint"; + break; + default: + llvm_unreachable("ScalarTypeKind is invalid"); + } + ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t"; +} + +void RVVType::initTypeStr() { + assert(isValid() && "RVVType is invalid"); + + if (IsConstant) + Str += "const "; + + auto getTypeString = [&](StringRef TypeStr) { + if (isScalar()) + return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); + return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t") + .str(); + }; + + switch (ScalarType) { + case STK_Void: + Str = "void"; + return; + case STK_Size_t: + Str = "size_t"; + if (IsPointer) + Str += " *"; + return; + case STK_Ptrdiff_t: + Str = "ptrdiff_t"; + return; + case STK_UnsignedLong: + Str = "unsigned long"; + return; + case STK_SignedLong: + Str = "long"; + return; + case STK_Boolean: + if (isScalar()) + Str += "bool"; + else + // Vector bool is special case, the formulate is + // `vbool_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 + Str += "vbool" + utostr(64 / Scale.getValue()) + "_t"; + break; + case STK_Float: + if (isScalar()) { + if (ElementBitwidth == 64) + Str += "double"; + else if (ElementBitwidth == 32) + Str += "float"; + else if (ElementBitwidth == 16) + Str += "_Float16"; + else + llvm_unreachable("Unhandled floating type."); + } else + Str += getTypeString("float"); + break; + case STK_SignedInteger: + Str += getTypeString("int"); + break; + case STK_UnsignedInteger: + Str += getTypeString("uint"); + break; + default: + llvm_unreachable("ScalarType is invalid!"); + } + if (IsPointer) + Str += " *"; +} +void RVVType::initShortStr() { + switch (ScalarType) { + case STK_Boolean: + assert(isVector()); + ShortStr = "b" + utostr(64 / Scale.getValue()); + return; + case STK_Float: + ShortStr = "f" + utostr(ElementBitwidth); + break; + case STK_SignedInteger: + ShortStr = "i" + utostr(ElementBitwidth); + break; + case STK_UnsignedInteger: + ShortStr = "u" + utostr(ElementBitwidth); + break; + default: + llvm_unreachable("Unhandled case!"); + } + if (isVector()) + ShortStr += LMUL.str(); +} + +void RVVType::applyBasicType(RVVBasicType BT) { + switch (BT) { + case RVVBasicTypeInt8: + ElementBitwidth = 8; + ScalarType = STK_SignedInteger; + break; + case RVVBasicTypeInt16: + ElementBitwidth = 16; + ScalarType = STK_SignedInteger; + break; + case RVVBasicTypeInt32: + ElementBitwidth = 32; + ScalarType = STK_SignedInteger; + break; + case RVVBasicTypeInt64: + ElementBitwidth = 64; + ScalarType = STK_SignedInteger; + break; + case RVVBasicTypeFloat16: + ElementBitwidth = 16; + ScalarType = STK_Float; + break; + case RVVBasicTypeFloat32: + ElementBitwidth = 32; + ScalarType = STK_Float; + break; + case RVVBasicTypeFloat64: + ElementBitwidth = 64; + ScalarType = STK_Float; + break; + default: + assert(false); + } + assert(ElementBitwidth != 0 && "Bad element bitwidth!"); +} + +void RVVType::applyModifier(StringRef Transformer) { + if (Transformer.empty()) + return; + // Handle primitive type transformer + auto PType = Transformer.back(); + switch (PType) { + case 'e': + Scale = 0; + break; + case 'v': + Scale = LMUL.getScale(ElementBitwidth); + break; + case 'w': + ElementBitwidth *= 2; + LMUL *= 2; + Scale = LMUL.getScale(ElementBitwidth); + break; + case 'q': + ElementBitwidth *= 4; + LMUL *= 4; + Scale = LMUL.getScale(ElementBitwidth); + break; + case 'o': + ElementBitwidth *= 8; + LMUL *= 8; + Scale = LMUL.getScale(ElementBitwidth); + break; + case 'm': + ScalarType = STK_Boolean; + Scale = LMUL.getScale(ElementBitwidth); + ElementBitwidth = 1; + break; + case '0': + ScalarType = STK_Void; + break; + case 'z': + ScalarType = STK_Size_t; + break; + case 't': + ScalarType = STK_Ptrdiff_t; + break; + case 'u': + ScalarType = STK_UnsignedLong; + break; + case 'l': + ScalarType = STK_SignedLong; + break; + default: + llvm_unreachable("Illegal primitive type transformers!"); + } + Transformer = Transformer.drop_back(); + + // Extract and compute complex type transformer. It can only appear one time. + if (Transformer.startswith("(")) { + size_t Idx = Transformer.find(')'); + assert(Idx != StringRef::npos); + StringRef ComplexType = Transformer.slice(1, Idx); + Transformer = Transformer.drop_front(Idx + 1); + assert(!Transformer.contains('(') && + "Only allow one complex type transformer"); + + auto UpdateAndCheckComplexProto = [&]() { + Scale = LMUL.getScale(ElementBitwidth); + const StringRef VectorPrototypes("vwqom"); + if (!VectorPrototypes.contains(PType)) + llvm_unreachable("Complex type transformer only supports vector type!"); + if (Transformer.find_first_of("PCKWS") != StringRef::npos) + llvm_unreachable( + "Illegal type transformer for Complex type transformer"); + }; + auto ComputeFixedLog2LMUL = + [&](StringRef Value, + std::function Compare) { + int32_t Log2LMUL; + Value.getAsInteger(10, Log2LMUL); + if (!Compare(Log2LMUL, LMUL.Log2LMUL)) { + ScalarType = STK_Invalid; + return false; + } + // Update new LMUL + LMUL = LMULType(Log2LMUL); + UpdateAndCheckComplexProto(); + return true; + }; + auto ComplexTT = ComplexType.split(":"); + if (ComplexTT.first == "Log2EEW") { + uint32_t Log2EEW; + ComplexTT.second.getAsInteger(10, Log2EEW); + // update new elmul = (eew/sew) * lmul + LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); + // update new eew + ElementBitwidth = 1 << Log2EEW; + ScalarType = STK_SignedInteger; + UpdateAndCheckComplexProto(); + } else if (ComplexTT.first == "FixedSEW") { + uint32_t NewSEW; + ComplexTT.second.getAsInteger(10, NewSEW); + // Set invalid type if src and dst SEW are same. + if (ElementBitwidth == NewSEW) { + ScalarType = STK_Invalid; + return; + } + // Update new SEW + ElementBitwidth = NewSEW; + UpdateAndCheckComplexProto(); + } else if (ComplexTT.first == "LFixedLog2LMUL") { + // New LMUL should be larger than old + if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater())) + return; + } else if (ComplexTT.first == "SFixedLog2LMUL") { + // New LMUL should be smaller than old + if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less())) + return; + } else { + llvm_unreachable("Illegal complex type transformers!"); + } + } + + // Compute the remain type transformers + for (char I : Transformer) { + switch (I) { + case 'P': + if (IsConstant) + llvm_unreachable("'P' transformer cannot be used after 'C'"); + if (IsPointer) + llvm_unreachable("'P' transformer cannot be used twice"); + IsPointer = true; + break; + case 'C': + if (IsConstant) + llvm_unreachable("'C' transformer cannot be used twice"); + IsConstant = true; + break; + case 'K': + IsImmediate = true; + IsConstant = true; + break; + case 'U': + ScalarType = STK_UnsignedInteger; + break; + case 'I': + ScalarType = STK_SignedInteger; + break; + case 'F': + ScalarType = STK_Float; + break; + case 'S': + LMUL = LMULType(0); + // Update ElementBitwidth need to update Scale too. + Scale = LMUL.getScale(ElementBitwidth); + break; + default: + llvm_unreachable("Illegal non-primitive type transformer!"); + } + } +} +RVVType::RVVType(RVVBasicType BT, int Log2LMUL, RVVTypeProfile prototype) + : LMUL(LMULType(Log2LMUL)) { + applyBasicType(BT); + applyModifier(prototype); + Valid = verifyType(); +} + +void RVVType::applyLog2EEW(unsigned Log2EEW) { + // update new elmul = (eew/sew) * lmul + LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); + // update new eew + ElementBitwidth = 1 << Log2EEW; + ScalarType = STK_SignedInteger; + Scale = LMUL.getScale(ElementBitwidth); +} + +void RVVType::applyFixedSEW(unsigned NewSEW) { + // Set invalid type if src and dst SEW are same. + if (ElementBitwidth == NewSEW) { + ScalarType = STK_Invalid; + return; + } + // Update new SEW + ElementBitwidth = NewSEW; + Scale = LMUL.getScale(ElementBitwidth); +} + +void RVVType::applyFixedLog2LMUL(int Log2LMUL, bool LargerThan) { + if (LargerThan) { + if (Log2LMUL < LMUL.Log2LMUL) { + ScalarType = STK_Invalid; + return; + } + } else { + if (Log2LMUL > LMUL.Log2LMUL) { + ScalarType = STK_Invalid; + return; + } + } + // Update new LMUL + LMUL = LMULType(Log2LMUL); + Scale = LMUL.getScale(ElementBitwidth); +} + +void RVVType::applyModifier(const RVVTypeProfile &Transformer) { + // Handle primitive type transformer + switch (Transformer.PrimitiveType) { + case RVVPrimTypeScalar: + Scale = 0; + break; + case RVVPrimTypeVector: + Scale = LMUL.getScale(ElementBitwidth); + break; + case RVVPrimType2XWideningVector: + ElementBitwidth *= 2; + LMUL *= 2; + Scale = LMUL.getScale(ElementBitwidth); + break; + case RVVPrimType4XWideningVector: + ElementBitwidth *= 4; + LMUL *= 4; + Scale = LMUL.getScale(ElementBitwidth); + break; + case RVVPrimType8XWideningVector: + ElementBitwidth *= 8; + LMUL *= 8; + Scale = LMUL.getScale(ElementBitwidth); + break; + case RVVPrimTypeMaskVector: + ScalarType = STK_Boolean; + Scale = LMUL.getScale(ElementBitwidth); + ElementBitwidth = 1; + break; + case RVVPrimTypeVoid: + ScalarType = STK_Void; + break; + case RVVPrimTypeSize: + ScalarType = STK_Size_t; + break; + case RVVPrimTypePtrdiff: + ScalarType = STK_Ptrdiff_t; + break; + case RVVPrimTypeUnsignedLong: + ScalarType = STK_UnsignedLong; + break; + case RVVPrimTypeSignedLong: + ScalarType = STK_SignedLong; + break; + case RVVPrimTypeInvalid: + ScalarType = STK_Invalid; + return; + default: + assert(false && "Illegal primitive type transformers!"); + } + + switch (Transformer.VectorTypeModifier) { + case RVVVecTypeLog2EEW3: + applyLog2EEW(3); + break; + case RVVVecTypeLog2EEW4: + applyLog2EEW(4); + break; + case RVVVecTypeLog2EEW5: + applyLog2EEW(5); + break; + case RVVVecTypeLog2EEW6: + applyLog2EEW(6); + break; + case RVVVecTypeFixedSEW8: + applyFixedSEW(8); + break; + case RVVVecTypeFixedSEW16: + applyFixedSEW(16); + break; + case RVVVecTypeFixedSEW32: + applyFixedSEW(32); + break; + case RVVVecTypeFixedSEW64: + applyFixedSEW(64); + break; + case RVVVecTypeLFixedLog2LMULN3: + applyFixedLog2LMUL(-3, /* LargerThan= */ true); + break; + case RVVVecTypeLFixedLog2LMULN2: + applyFixedLog2LMUL(-2, /* LargerThan= */ true); + break; + case RVVVecTypeLFixedLog2LMULN1: + applyFixedLog2LMUL(-1, /* LargerThan= */ true); + break; + case RVVVecTypeLFixedLog2LMUL0: + applyFixedLog2LMUL(0, /* LargerThan= */ true); + break; + case RVVVecTypeLFixedLog2LMUL1: + applyFixedLog2LMUL(1, /* LargerThan= */ true); + break; + case RVVVecTypeLFixedLog2LMUL2: + applyFixedLog2LMUL(2, /* LargerThan= */ true); + break; + case RVVVecTypeLFixedLog2LMUL3: + applyFixedLog2LMUL(3, /* LargerThan= */ true); + break; + case RVVVecTypeSFixedLog2LMULN3: + applyFixedLog2LMUL(-3, /* LargerThan= */ false); + break; + case RVVVecTypeSFixedLog2LMULN2: + applyFixedLog2LMUL(-2, /* LargerThan= */ false); + break; + case RVVVecTypeSFixedLog2LMULN1: + applyFixedLog2LMUL(-1, /* LargerThan= */ false); + break; + case RVVVecTypeSFixedLog2LMUL0: + applyFixedLog2LMUL(0, /* LargerThan= */ false); + break; + case RVVVecTypeSFixedLog2LMUL1: + applyFixedLog2LMUL(1, /* LargerThan= */ false); + break; + case RVVVecTypeSFixedLog2LMUL2: + applyFixedLog2LMUL(2, /* LargerThan= */ false); + break; + case RVVVecTypeSFixedLog2LMUL3: + applyFixedLog2LMUL(3, /* LargerThan= */ false); + break; + case RVVVecTypeNoModifier: + break; + default: + assert(false && "Illegal vector type modifier!"); + } + + for (unsigned TypeModifierMaskShift = 0; + TypeModifierMaskShift <= RVVTypeModifierMaskMax; + ++TypeModifierMaskShift) { + unsigned TypeModifierMask = 1 << TypeModifierMaskShift; + if (!(Transformer.TypeModifierMask & TypeModifierMask)) + continue; + switch (TypeModifierMask) { + case RVVTypeModifierPointer: + IsPointer = true; + break; + case RVVTypeModifierConst: + IsConstant = true; + break; + case RVVTypeModifierImmediate: + IsImmediate = true; + IsConstant = true; + break; + case RVVTypeModifierUnsignedInteger: + ScalarType = STK_UnsignedInteger; + break; + case RVVTypeModifierSignedInteger: + ScalarType = STK_SignedInteger; + break; + case RVVTypeModifierFloat: + ScalarType = STK_Float; + break; + case RVVTypeModifierLMUL1: + LMUL = LMULType(0); + // Update ElementBitwidth need to update Scale too. + Scale = LMUL.getScale(ElementBitwidth); + break; + default: + assert(false && "Unknown type modifier mask!"); + } + } +} + +void RVVParsePrototypes(StringRef Prototypes, + std::function Handler) { + const StringRef Primaries("evwqom0ztul"); + while (!Prototypes.empty()) { + size_t Idx = 0; + // Skip over complex prototype because it could contain primitive type + // character. + if (Prototypes[0] == '(') + Idx = Prototypes.find_first_of(')'); + Idx = Prototypes.find_first_of(Primaries, Idx); + assert(Idx != StringRef::npos); + Handler(Prototypes.slice(0, Idx + 1)); + Prototypes = Prototypes.drop_front(Idx + 1); + } +} + +} // namespace RISCV +} // namespace llvm