diff --git a/clang/include/clang/Support/RISCVVIntrinsicUtils.h b/clang/include/clang/Support/RISCVVIntrinsicUtils.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Support/RISCVVIntrinsicUtils.h @@ -0,0 +1,215 @@ +//===--- RISCVVIntrinsicUtils.h - RISC-V Vector Intrinsic Utils -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CLANG_SUPPORT_RISCVVINTRINSICUTILS_H +#define CLANG_SUPPORT_RISCVVINTRINSICUTILS_H + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include +#include +#include + +namespace clang { +namespace RISCV { + +using BasicType = char; +using VScaleVal = llvm::Optional; + +// Exponential LMUL +struct LMULType { + int Log2LMUL; + LMULType(int Log2LMUL); + // Return the C/C++ string representation of LMUL + std::string str() const; + llvm::Optional getScale(unsigned ElementBitwidth) const; + void MulLog2LMUL(int Log2LMUL); + LMULType &operator*=(uint32_t RHS); +}; + +// This class is compact representation of a valid and invalid RVVType. +class RVVType { + enum ScalarTypeKind : uint32_t { + Void, + Size_t, + Ptrdiff_t, + UnsignedLong, + SignedLong, + Boolean, + SignedInteger, + UnsignedInteger, + Float, + Invalid, + }; + BasicType BT; + ScalarTypeKind ScalarType = Invalid; + LMULType LMUL; + bool IsPointer = false; + // IsConstant indices are "int", but have the constant expression. + bool IsImmediate = false; + // Const qualifier for pointer to const object or object of const type. + bool IsConstant = false; + unsigned ElementBitwidth = 0; + VScaleVal Scale = 0; + bool Valid; + + std::string BuiltinStr; + std::string ClangBuiltinStr; + std::string Str; + std::string ShortStr; + +public: + RVVType() : RVVType(BasicType(), 0, llvm::StringRef()) {} + RVVType(BasicType BT, int Log2LMUL, llvm::StringRef prototype); + + // Return the string representation of a type, which is an encoded string for + // passing to the BUILTIN() macro in Builtins.def. + const std::string &getBuiltinStr() const { return BuiltinStr; } + + // Return the clang builtin type for RVV vector type which are used in the + // riscv_vector.h header file. + const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; } + + // Return the C/C++ string representation of a type for use in the + // riscv_vector.h header file. + const std::string &getTypeStr() const { return Str; } + + // Return the short name of a type for C/C++ name suffix. + const std::string &getShortStr() { + // Not all types are used in short name, so compute the short name by + // demanded. + if (ShortStr.empty()) + initShortStr(); + return ShortStr; + } + + bool isValid() const { return Valid; } + bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; } + bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; } + bool isVector(unsigned Width) const { + return isVector() && ElementBitwidth == Width; + } + bool isFloat() const { return ScalarType == ScalarTypeKind::Float; } + bool isSignedInteger() const { + return ScalarType == ScalarTypeKind::SignedInteger; + } + bool isFloatVector(unsigned Width) const { + return isVector() && isFloat() && ElementBitwidth == Width; + } + bool isFloat(unsigned Width) const { + return isFloat() && ElementBitwidth == Width; + } + +private: + // Verify RVV vector type and set Valid. + bool verifyType() const; + + // Creates a type based on basic types of TypeRange + void applyBasicType(); + + // Applies a prototype modifier to the current type. The result maybe an + // invalid type. + void applyModifier(llvm::StringRef prototype); + + // Compute and record a string for legal type. + void initBuiltinStr(); + // Compute and record a builtin RVV vector type string. + void initClangBuiltinStr(); + // Compute and record a type string for used in the header. + void initTypeStr(); + // Compute and record a short name of a type for C/C++ name suffix. + void initShortStr(); +}; + +using RVVTypePtr = RVVType *; +using RVVTypes = std::vector; +using RISCVPredefinedMacroT = uint8_t; + +enum RISCVPredefinedMacro : RISCVPredefinedMacroT { + Basic = 0, + V = 1 << 1, + Zvfh = 1 << 2, + RV64 = 1 << 3, + VectorMaxELen64 = 1 << 4, + VectorMaxELenFp32 = 1 << 5, + VectorMaxELenFp64 = 1 << 6, +}; + +enum PolicyScheme : uint8_t { + SchemeNone, + HasPassthruOperand, + HasPolicyOperand, +}; + +// TODO refactor RVVIntrinsic class design after support all intrinsic +// combination. This represents an instantiation of an intrinsic with a +// particular type and prototype +class RVVIntrinsic { + +private: + std::string BuiltinName; // Builtin name + std::string Name; // C intrinsic name. + std::string MangledName; + std::string IRName; + bool IsMasked; + bool HasVL; + PolicyScheme Scheme; + bool HasUnMaskedOverloaded; + bool HasBuiltinAlias; + std::string ManualCodegen; + RVVTypePtr OutputType; // Builtin output type + RVVTypes InputTypes; // Builtin input types + // The types we use to obtain the specific LLVM intrinsic. They are index of + // InputTypes. -1 means the return type. + std::vector IntrinsicTypes; + RISCVPredefinedMacroT RISCVPredefinedMacros = 0; + unsigned NF = 1; + +public: + RVVIntrinsic(llvm::StringRef Name, llvm::StringRef Suffix, llvm::StringRef MangledName, + llvm::StringRef MangledSuffix, llvm::StringRef IRName, bool IsMasked, + bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme, + bool HasUnMaskedOverloaded, bool HasBuiltinAlias, + llvm::StringRef ManualCodegen, const RVVTypes &Types, + const std::vector &IntrinsicTypes, + const std::vector &RequiredFeatures, unsigned NF); + ~RVVIntrinsic() = default; + + RVVTypePtr getOutputType() const { return OutputType; } + const RVVTypes &getInputTypes() const { return InputTypes; } + llvm::StringRef getBuiltinName() const { return BuiltinName; } + llvm::StringRef getName() const { return Name; } + llvm::StringRef getMangledName() const { return MangledName; } + bool hasVL() const { return HasVL; } + bool hasPolicy() const { return Scheme != SchemeNone; } + bool hasPassthruOperand() const { return Scheme == HasPassthruOperand; } + bool hasPolicyOperand() const { return Scheme == HasPolicyOperand; } + bool hasUnMaskedOverloaded() const { return HasUnMaskedOverloaded; } + bool hasBuiltinAlias() const { return HasBuiltinAlias; } + bool hasManualCodegen() const { return !ManualCodegen.empty(); } + bool isMasked() const { return IsMasked; } + llvm::StringRef getIRName() const { return IRName; } + llvm::StringRef getManualCodegen() const { return ManualCodegen; } + PolicyScheme getPolicyScheme() const { return Scheme; } + RISCVPredefinedMacroT getRISCVPredefinedMacros() const { + return RISCVPredefinedMacros; + } + unsigned getNF() const { return NF; } + const std::vector &getIntrinsicTypes() const { + return IntrinsicTypes; + } + + // Return the type string for a BUILTIN() macro in Builtins.def. + std::string getBuiltinTypeStr() const; +}; + +} // end namespace RISCV + +} // end namespace clang + +#endif // CLANG_SUPPORT_RISCVVINTRINSICUTILS_H diff --git a/clang/lib/CMakeLists.txt b/clang/lib/CMakeLists.txt --- a/clang/lib/CMakeLists.txt +++ b/clang/lib/CMakeLists.txt @@ -29,3 +29,4 @@ add_subdirectory(Testing) endif() add_subdirectory(Interpreter) +add_subdirectory(Support) diff --git a/clang/lib/Support/CMakeLists.txt b/clang/lib/Support/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/clang/lib/Support/CMakeLists.txt @@ -0,0 +1,16 @@ +set(LLVM_COMMON_DEPENDS_OLD ${LLVM_COMMON_DEPENDS}) + +# Drop clang-tablegen-targets from LLVM_COMMON_DEPENDS. +# so that we could use clangSupport within clang-tblgen and other clang +# component. +list(REMOVE_ITEM LLVM_COMMON_DEPENDS clang-tablegen-targets) + +set(LLVM_LINK_COMPONENTS + Support + ) + +add_clang_library(clangSupport + RISCVVIntrinsicUtils.cpp + ) + +set(LLVM_COMMON_DEPENDS ${LLVM_COMMON_DEPENDS_OLD}) diff --git a/clang/lib/Support/RISCVVIntrinsicUtils.cpp b/clang/lib/Support/RISCVVIntrinsicUtils.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/Support/RISCVVIntrinsicUtils.cpp @@ -0,0 +1,597 @@ +//===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Support/RISCVVIntrinsicUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace llvm; + +namespace clang { +namespace RISCV { + +//===----------------------------------------------------------------------===// +// Type implementation +//===----------------------------------------------------------------------===// + +LMULType::LMULType(int NewLog2LMUL) { + // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 + assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); + Log2LMUL = NewLog2LMUL; +} + +std::string LMULType::str() const { + if (Log2LMUL < 0) + return "mf" + utostr(1ULL << (-Log2LMUL)); + return "m" + utostr(1ULL << Log2LMUL); +} + +VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { + int Log2ScaleResult = 0; + switch (ElementBitwidth) { + default: + break; + case 8: + Log2ScaleResult = Log2LMUL + 3; + break; + case 16: + Log2ScaleResult = Log2LMUL + 2; + break; + case 32: + Log2ScaleResult = Log2LMUL + 1; + break; + case 64: + Log2ScaleResult = Log2LMUL; + break; + } + // Illegal vscale result would be less than 1 + if (Log2ScaleResult < 0) + return llvm::None; + return 1 << Log2ScaleResult; +} + +void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } + +LMULType &LMULType::operator*=(uint32_t RHS) { + assert(isPowerOf2_32(RHS)); + this->Log2LMUL = this->Log2LMUL + Log2_32(RHS); + return *this; +} + +RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype) + : BT(BT), LMUL(LMULType(Log2LMUL)) { + applyBasicType(); + applyModifier(prototype); + Valid = verifyType(); + if (Valid) { + initBuiltinStr(); + initTypeStr(); + if (isVector()) { + initClangBuiltinStr(); + } + } +} + +// clang-format off +// boolean type are encoded the ratio of n (SEW/LMUL) +// SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 +// c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t +// IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 + +// type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 +// -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- +// i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 +// i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 +// i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 +// i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 +// double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 +// float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 +// half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 +// clang-format on + +bool RVVType::verifyType() const { + if (ScalarType == Invalid) + return false; + if (isScalar()) + return true; + if (!Scale.hasValue()) + return false; + if (isFloat() && ElementBitwidth == 8) + return false; + unsigned V = Scale.getValue(); + switch (ElementBitwidth) { + case 1: + case 8: + // Check Scale is 1,2,4,8,16,32,64 + return (V <= 64 && isPowerOf2_32(V)); + case 16: + // Check Scale is 1,2,4,8,16,32 + return (V <= 32 && isPowerOf2_32(V)); + case 32: + // Check Scale is 1,2,4,8,16 + return (V <= 16 && isPowerOf2_32(V)); + case 64: + // Check Scale is 1,2,4,8 + return (V <= 8 && isPowerOf2_32(V)); + } + return false; +} + +void RVVType::initBuiltinStr() { + assert(isValid() && "RVVType is invalid"); + switch (ScalarType) { + case ScalarTypeKind::Void: + BuiltinStr = "v"; + return; + case ScalarTypeKind::Size_t: + BuiltinStr = "z"; + if (IsImmediate) + BuiltinStr = "I" + BuiltinStr; + if (IsPointer) + BuiltinStr += "*"; + return; + case ScalarTypeKind::Ptrdiff_t: + BuiltinStr = "Y"; + return; + case ScalarTypeKind::UnsignedLong: + BuiltinStr = "ULi"; + return; + case ScalarTypeKind::SignedLong: + BuiltinStr = "Li"; + return; + case ScalarTypeKind::Boolean: + assert(ElementBitwidth == 1); + BuiltinStr += "b"; + break; + case ScalarTypeKind::SignedInteger: + case ScalarTypeKind::UnsignedInteger: + switch (ElementBitwidth) { + case 8: + BuiltinStr += "c"; + break; + case 16: + BuiltinStr += "s"; + break; + case 32: + BuiltinStr += "i"; + break; + case 64: + BuiltinStr += "Wi"; + break; + default: + llvm_unreachable("Unhandled ElementBitwidth!"); + } + if (isSignedInteger()) + BuiltinStr = "S" + BuiltinStr; + else + BuiltinStr = "U" + BuiltinStr; + break; + case ScalarTypeKind::Float: + switch (ElementBitwidth) { + case 16: + BuiltinStr += "x"; + break; + case 32: + BuiltinStr += "f"; + break; + case 64: + BuiltinStr += "d"; + break; + default: + llvm_unreachable("Unhandled ElementBitwidth!"); + } + break; + default: + llvm_unreachable("ScalarType is invalid!"); + } + if (IsImmediate) + BuiltinStr = "I" + BuiltinStr; + if (isScalar()) { + if (IsConstant) + BuiltinStr += "C"; + if (IsPointer) + BuiltinStr += "*"; + return; + } + BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr; + // Pointer to vector types. Defined for segment load intrinsics. + // segment load intrinsics have pointer type arguments to store the loaded + // vector values. + if (IsPointer) + BuiltinStr += "*"; +} + +void RVVType::initClangBuiltinStr() { + assert(isValid() && "RVVType is invalid"); + assert(isVector() && "Handle Vector type only"); + + ClangBuiltinStr = "__rvv_"; + switch (ScalarType) { + case ScalarTypeKind::Boolean: + ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t"; + return; + case ScalarTypeKind::Float: + ClangBuiltinStr += "float"; + break; + case ScalarTypeKind::SignedInteger: + ClangBuiltinStr += "int"; + break; + case ScalarTypeKind::UnsignedInteger: + ClangBuiltinStr += "uint"; + break; + default: + llvm_unreachable("ScalarTypeKind is invalid"); + } + ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t"; +} + +void RVVType::initTypeStr() { + assert(isValid() && "RVVType is invalid"); + + if (IsConstant) + Str += "const "; + + auto getTypeString = [&](StringRef TypeStr) { + if (isScalar()) + return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); + return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t") + .str(); + }; + + switch (ScalarType) { + case ScalarTypeKind::Void: + Str = "void"; + return; + case ScalarTypeKind::Size_t: + Str = "size_t"; + if (IsPointer) + Str += " *"; + return; + case ScalarTypeKind::Ptrdiff_t: + Str = "ptrdiff_t"; + return; + case ScalarTypeKind::UnsignedLong: + Str = "unsigned long"; + return; + case ScalarTypeKind::SignedLong: + Str = "long"; + return; + case ScalarTypeKind::Boolean: + if (isScalar()) + Str += "bool"; + else + // Vector bool is special case, the formulate is + // `vbool_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 + Str += "vbool" + utostr(64 / Scale.getValue()) + "_t"; + break; + case ScalarTypeKind::Float: + if (isScalar()) { + if (ElementBitwidth == 64) + Str += "double"; + else if (ElementBitwidth == 32) + Str += "float"; + else if (ElementBitwidth == 16) + Str += "_Float16"; + else + llvm_unreachable("Unhandled floating type."); + } else + Str += getTypeString("float"); + break; + case ScalarTypeKind::SignedInteger: + Str += getTypeString("int"); + break; + case ScalarTypeKind::UnsignedInteger: + Str += getTypeString("uint"); + break; + default: + llvm_unreachable("ScalarType is invalid!"); + } + if (IsPointer) + Str += " *"; +} + +void RVVType::initShortStr() { + switch (ScalarType) { + case ScalarTypeKind::Boolean: + assert(isVector()); + ShortStr = "b" + utostr(64 / Scale.getValue()); + return; + case ScalarTypeKind::Float: + ShortStr = "f" + utostr(ElementBitwidth); + break; + case ScalarTypeKind::SignedInteger: + ShortStr = "i" + utostr(ElementBitwidth); + break; + case ScalarTypeKind::UnsignedInteger: + ShortStr = "u" + utostr(ElementBitwidth); + break; + default: + llvm_unreachable("Unhandled case!"); + } + if (isVector()) + ShortStr += LMUL.str(); +} + +void RVVType::applyBasicType() { + switch (BT) { + case 'c': + ElementBitwidth = 8; + ScalarType = ScalarTypeKind::SignedInteger; + break; + case 's': + ElementBitwidth = 16; + ScalarType = ScalarTypeKind::SignedInteger; + break; + case 'i': + ElementBitwidth = 32; + ScalarType = ScalarTypeKind::SignedInteger; + break; + case 'l': + ElementBitwidth = 64; + ScalarType = ScalarTypeKind::SignedInteger; + break; + case 'x': + ElementBitwidth = 16; + ScalarType = ScalarTypeKind::Float; + break; + case 'f': + ElementBitwidth = 32; + ScalarType = ScalarTypeKind::Float; + break; + case 'd': + ElementBitwidth = 64; + ScalarType = ScalarTypeKind::Float; + break; + default: + llvm_unreachable("Unhandled type code!"); + } + assert(ElementBitwidth != 0 && "Bad element bitwidth!"); +} + +void RVVType::applyModifier(StringRef Transformer) { + if (Transformer.empty()) + return; + // Handle primitive type transformer + auto PType = Transformer.back(); + switch (PType) { + case 'e': + Scale = 0; + break; + case 'v': + Scale = LMUL.getScale(ElementBitwidth); + break; + case 'w': + ElementBitwidth *= 2; + LMUL *= 2; + Scale = LMUL.getScale(ElementBitwidth); + break; + case 'q': + ElementBitwidth *= 4; + LMUL *= 4; + Scale = LMUL.getScale(ElementBitwidth); + break; + case 'o': + ElementBitwidth *= 8; + LMUL *= 8; + Scale = LMUL.getScale(ElementBitwidth); + break; + case 'm': + ScalarType = ScalarTypeKind::Boolean; + Scale = LMUL.getScale(ElementBitwidth); + ElementBitwidth = 1; + break; + case '0': + ScalarType = ScalarTypeKind::Void; + break; + case 'z': + ScalarType = ScalarTypeKind::Size_t; + break; + case 't': + ScalarType = ScalarTypeKind::Ptrdiff_t; + break; + case 'u': + ScalarType = ScalarTypeKind::UnsignedLong; + break; + case 'l': + ScalarType = ScalarTypeKind::SignedLong; + break; + default: + llvm_unreachable("Illegal primitive type transformers!"); + } + Transformer = Transformer.drop_back(); + + // Extract and compute complex type transformer. It can only appear one time. + if (Transformer.startswith("(")) { + size_t Idx = Transformer.find(')'); + assert(Idx != StringRef::npos); + StringRef ComplexType = Transformer.slice(1, Idx); + Transformer = Transformer.drop_front(Idx + 1); + assert(!Transformer.contains('(') && + "Only allow one complex type transformer"); + + auto UpdateAndCheckComplexProto = [&]() { + Scale = LMUL.getScale(ElementBitwidth); + const StringRef VectorPrototypes("vwqom"); + if (!VectorPrototypes.contains(PType)) + llvm_unreachable("Complex type transformer only supports vector type!"); + if (Transformer.find_first_of("PCKWS") != StringRef::npos) + llvm_unreachable( + "Illegal type transformer for Complex type transformer"); + }; + auto ComputeFixedLog2LMUL = + [&](StringRef Value, + std::function Compare) { + int32_t Log2LMUL; + Value.getAsInteger(10, Log2LMUL); + if (!Compare(Log2LMUL, LMUL.Log2LMUL)) { + ScalarType = Invalid; + return false; + } + // Update new LMUL + LMUL = LMULType(Log2LMUL); + UpdateAndCheckComplexProto(); + return true; + }; + auto ComplexTT = ComplexType.split(":"); + if (ComplexTT.first == "Log2EEW") { + uint32_t Log2EEW; + ComplexTT.second.getAsInteger(10, Log2EEW); + // update new elmul = (eew/sew) * lmul + LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); + // update new eew + ElementBitwidth = 1 << Log2EEW; + ScalarType = ScalarTypeKind::SignedInteger; + UpdateAndCheckComplexProto(); + } else if (ComplexTT.first == "FixedSEW") { + uint32_t NewSEW; + ComplexTT.second.getAsInteger(10, NewSEW); + // Set invalid type if src and dst SEW are same. + if (ElementBitwidth == NewSEW) { + ScalarType = Invalid; + return; + } + // Update new SEW + ElementBitwidth = NewSEW; + UpdateAndCheckComplexProto(); + } else if (ComplexTT.first == "LFixedLog2LMUL") { + // New LMUL should be larger than old + if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater())) + return; + } else if (ComplexTT.first == "SFixedLog2LMUL") { + // New LMUL should be smaller than old + if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less())) + return; + } else { + llvm_unreachable("Illegal complex type transformers!"); + } + } + + // Compute the remain type transformers + for (char I : Transformer) { + switch (I) { + case 'P': + if (IsConstant) + llvm_unreachable("'P' transformer cannot be used after 'C'"); + if (IsPointer) + llvm_unreachable("'P' transformer cannot be used twice"); + IsPointer = true; + break; + case 'C': + if (IsConstant) + llvm_unreachable("'C' transformer cannot be used twice"); + IsConstant = true; + break; + case 'K': + IsImmediate = true; + break; + case 'U': + ScalarType = ScalarTypeKind::UnsignedInteger; + break; + case 'I': + ScalarType = ScalarTypeKind::SignedInteger; + break; + case 'F': + ScalarType = ScalarTypeKind::Float; + break; + case 'S': + LMUL = LMULType(0); + // Update ElementBitwidth need to update Scale too. + Scale = LMUL.getScale(ElementBitwidth); + break; + default: + llvm_unreachable("Illegal non-primitive type transformer!"); + } + } +} + +//===----------------------------------------------------------------------===// +// RVVIntrinsic implementation +//===----------------------------------------------------------------------===// +RVVIntrinsic::RVVIntrinsic( + StringRef NewName, StringRef Suffix, StringRef NewMangledName, + StringRef MangledSuffix, StringRef IRName, bool IsMasked, + bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme, + bool HasUnMaskedOverloaded, bool HasBuiltinAlias, StringRef ManualCodegen, + const RVVTypes &OutInTypes, const std::vector &NewIntrinsicTypes, + const std::vector &RequiredFeatures, unsigned NF) + : IRName(IRName), IsMasked(IsMasked), HasVL(HasVL), Scheme(Scheme), + HasUnMaskedOverloaded(HasUnMaskedOverloaded), + HasBuiltinAlias(HasBuiltinAlias), ManualCodegen(ManualCodegen.str()), + NF(NF) { + + // Init BuiltinName, Name and MangledName + BuiltinName = NewName.str(); + Name = BuiltinName; + if (NewMangledName.empty()) + MangledName = NewName.split("_").first.str(); + else + MangledName = NewMangledName.str(); + if (!Suffix.empty()) + Name += "_" + Suffix.str(); + if (!MangledSuffix.empty()) + MangledName += "_" + MangledSuffix.str(); + if (IsMasked) { + BuiltinName += "_m"; + Name += "_m"; + } + + // Init RISC-V extensions + for (const auto &T : OutInTypes) { + if (T->isFloatVector(16) || T->isFloat(16)) + RISCVPredefinedMacros |= RISCVPredefinedMacro::Zvfh; + if (T->isFloatVector(32)) + RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp32; + if (T->isFloatVector(64)) + RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp64; + if (T->isVector(64)) + RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELen64; + } + for (auto Feature : RequiredFeatures) { + if (Feature == "RV64") + RISCVPredefinedMacros |= RISCVPredefinedMacro::RV64; + // Note: Full multiply instruction (mulh, mulhu, mulhsu, smul) for EEW=64 + // require V. + if (Feature == "FullMultiply" && + (RISCVPredefinedMacros & RISCVPredefinedMacro::VectorMaxELen64)) + RISCVPredefinedMacros |= RISCVPredefinedMacro::V; + } + + // Init OutputType and InputTypes + OutputType = OutInTypes[0]; + InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end()); + + // IntrinsicTypes is unmasked TA version index. Need to update it + // if there is merge operand (It is always in first operand). + IntrinsicTypes = NewIntrinsicTypes; + if ((IsMasked && HasMaskedOffOperand) || + (!IsMasked && hasPassthruOperand())) { + for (auto &I : IntrinsicTypes) { + if (I >= 0) + I += NF; + } + } +} + +std::string RVVIntrinsic::getBuiltinTypeStr() const { + std::string S; + S += OutputType->getBuiltinStr(); + for (const auto &T : InputTypes) { + S += T->getBuiltinStr(); + } + return S; +} + +} // end namespace RISCV +} // end namespace clang diff --git a/clang/utils/TableGen/CMakeLists.txt b/clang/utils/TableGen/CMakeLists.txt --- a/clang/utils/TableGen/CMakeLists.txt +++ b/clang/utils/TableGen/CMakeLists.txt @@ -22,4 +22,7 @@ SveEmitter.cpp TableGen.cpp ) + +target_link_libraries(clang-tblgen PRIVATE clangSupport) + set_target_properties(clang-tblgen PROPERTIES FOLDER "Clang tablegenning") diff --git a/clang/utils/TableGen/RISCVVEmitter.cpp b/clang/utils/TableGen/RISCVVEmitter.cpp --- a/clang/utils/TableGen/RISCVVEmitter.cpp +++ b/clang/utils/TableGen/RISCVVEmitter.cpp @@ -14,6 +14,7 @@ // //===----------------------------------------------------------------------===// +#include "clang/Support/RISCVVIntrinsicUtils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringExtras.h" @@ -25,206 +26,9 @@ #include using namespace llvm; -using BasicType = char; -using VScaleVal = Optional; +using namespace clang::RISCV; namespace { - -// Exponential LMUL -struct LMULType { - int Log2LMUL; - LMULType(int Log2LMUL); - // Return the C/C++ string representation of LMUL - std::string str() const; - Optional getScale(unsigned ElementBitwidth) const; - void MulLog2LMUL(int Log2LMUL); - LMULType &operator*=(uint32_t RHS); -}; - -// This class is compact representation of a valid and invalid RVVType. -class RVVType { - enum ScalarTypeKind : uint32_t { - Void, - Size_t, - Ptrdiff_t, - UnsignedLong, - SignedLong, - Boolean, - SignedInteger, - UnsignedInteger, - Float, - Invalid, - }; - BasicType BT; - ScalarTypeKind ScalarType = Invalid; - LMULType LMUL; - bool IsPointer = false; - // IsConstant indices are "int", but have the constant expression. - bool IsImmediate = false; - // Const qualifier for pointer to const object or object of const type. - bool IsConstant = false; - unsigned ElementBitwidth = 0; - VScaleVal Scale = 0; - bool Valid; - - std::string BuiltinStr; - std::string ClangBuiltinStr; - std::string Str; - std::string ShortStr; - -public: - RVVType() : RVVType(BasicType(), 0, StringRef()) {} - RVVType(BasicType BT, int Log2LMUL, StringRef prototype); - - // Return the string representation of a type, which is an encoded string for - // passing to the BUILTIN() macro in Builtins.def. - const std::string &getBuiltinStr() const { return BuiltinStr; } - - // Return the clang builtin type for RVV vector type which are used in the - // riscv_vector.h header file. - const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; } - - // Return the C/C++ string representation of a type for use in the - // riscv_vector.h header file. - const std::string &getTypeStr() const { return Str; } - - // Return the short name of a type for C/C++ name suffix. - const std::string &getShortStr() { - // Not all types are used in short name, so compute the short name by - // demanded. - if (ShortStr.empty()) - initShortStr(); - return ShortStr; - } - - bool isValid() const { return Valid; } - bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; } - bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; } - bool isVector(unsigned Width) const { - return isVector() && ElementBitwidth == Width; - } - bool isFloat() const { return ScalarType == ScalarTypeKind::Float; } - bool isSignedInteger() const { - return ScalarType == ScalarTypeKind::SignedInteger; - } - bool isFloatVector(unsigned Width) const { - return isVector() && isFloat() && ElementBitwidth == Width; - } - bool isFloat(unsigned Width) const { - return isFloat() && ElementBitwidth == Width; - } - -private: - // Verify RVV vector type and set Valid. - bool verifyType() const; - - // Creates a type based on basic types of TypeRange - void applyBasicType(); - - // Applies a prototype modifier to the current type. The result maybe an - // invalid type. - void applyModifier(StringRef prototype); - - // Compute and record a string for legal type. - void initBuiltinStr(); - // Compute and record a builtin RVV vector type string. - void initClangBuiltinStr(); - // Compute and record a type string for used in the header. - void initTypeStr(); - // Compute and record a short name of a type for C/C++ name suffix. - void initShortStr(); -}; - -using RVVTypePtr = RVVType *; -using RVVTypes = std::vector; -using RISCVPredefinedMacroT = uint8_t; - -enum RISCVPredefinedMacro : RISCVPredefinedMacroT { - Basic = 0, - V = 1 << 1, - Zvfh = 1 << 2, - RV64 = 1 << 3, - VectorMaxELen64 = 1 << 4, - VectorMaxELenFp32 = 1 << 5, - VectorMaxELenFp64 = 1 << 6, -}; - -enum PolicyScheme : uint8_t { - SchemeNone, - HasPassthruOperand, - HasPolicyOperand, -}; - -// TODO refactor RVVIntrinsic class design after support all intrinsic -// combination. This represents an instantiation of an intrinsic with a -// particular type and prototype -class RVVIntrinsic { - -private: - std::string BuiltinName; // Builtin name - std::string Name; // C intrinsic name. - std::string MangledName; - std::string IRName; - bool IsMasked; - bool HasVL; - PolicyScheme Scheme; - bool HasUnMaskedOverloaded; - bool HasBuiltinAlias; - std::string ManualCodegen; - RVVTypePtr OutputType; // Builtin output type - RVVTypes InputTypes; // Builtin input types - // The types we use to obtain the specific LLVM intrinsic. They are index of - // InputTypes. -1 means the return type. - std::vector IntrinsicTypes; - RISCVPredefinedMacroT RISCVPredefinedMacros = 0; - unsigned NF = 1; - -public: - RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName, - StringRef MangledSuffix, StringRef IRName, bool IsMasked, - bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme, - bool HasUnMaskedOverloaded, bool HasBuiltinAlias, - StringRef ManualCodegen, const RVVTypes &Types, - const std::vector &IntrinsicTypes, - const std::vector &RequiredFeatures, unsigned NF); - ~RVVIntrinsic() = default; - - StringRef getBuiltinName() const { return BuiltinName; } - StringRef getName() const { return Name; } - StringRef getMangledName() const { return MangledName; } - bool hasVL() const { return HasVL; } - bool hasPolicy() const { return Scheme != SchemeNone; } - bool hasPassthruOperand() const { return Scheme == HasPassthruOperand; } - bool hasPolicyOperand() const { return Scheme == HasPolicyOperand; } - bool hasUnMaskedOverloaded() const { return HasUnMaskedOverloaded; } - bool hasBuiltinAlias() const { return HasBuiltinAlias; } - bool hasManualCodegen() const { return !ManualCodegen.empty(); } - bool isMasked() const { return IsMasked; } - StringRef getIRName() const { return IRName; } - StringRef getManualCodegen() const { return ManualCodegen; } - PolicyScheme getPolicyScheme() const { return Scheme; } - RISCVPredefinedMacroT getRISCVPredefinedMacros() const { - return RISCVPredefinedMacros; - } - unsigned getNF() const { return NF; } - const std::vector &getIntrinsicTypes() const { - return IntrinsicTypes; - } - - // Return the type string for a BUILTIN() macro in Builtins.def. - std::string getBuiltinTypeStr() const; - - // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should - // init the RVVIntrinsic ID and IntrinsicTypes. - void emitCodeGenSwitchBody(raw_ostream &o) const; - - // Emit the macros for mapping C/C++ intrinsic function to builtin functions. - void emitIntrinsicFuncDef(raw_ostream &o) const; - - // Emit the mangled function definition. - void emitMangledFuncDef(raw_ostream &o) const; -}; - class RVVEmitter { private: RecordKeeper &Records; @@ -277,602 +81,31 @@ } // namespace -//===----------------------------------------------------------------------===// -// Type implementation -//===----------------------------------------------------------------------===// - -LMULType::LMULType(int NewLog2LMUL) { - // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 - assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); - Log2LMUL = NewLog2LMUL; -} - -std::string LMULType::str() const { - if (Log2LMUL < 0) - return "mf" + utostr(1ULL << (-Log2LMUL)); - return "m" + utostr(1ULL << Log2LMUL); -} - -VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { - int Log2ScaleResult = 0; - switch (ElementBitwidth) { - default: - break; - case 8: - Log2ScaleResult = Log2LMUL + 3; - break; - case 16: - Log2ScaleResult = Log2LMUL + 2; - break; - case 32: - Log2ScaleResult = Log2LMUL + 1; - break; - case 64: - Log2ScaleResult = Log2LMUL; - break; - } - // Illegal vscale result would be less than 1 - if (Log2ScaleResult < 0) - return llvm::None; - return 1 << Log2ScaleResult; -} - -void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } - -LMULType &LMULType::operator*=(uint32_t RHS) { - assert(isPowerOf2_32(RHS)); - this->Log2LMUL = this->Log2LMUL + Log2_32(RHS); - return *this; -} - -RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype) - : BT(BT), LMUL(LMULType(Log2LMUL)) { - applyBasicType(); - applyModifier(prototype); - Valid = verifyType(); - if (Valid) { - initBuiltinStr(); - initTypeStr(); - if (isVector()) { - initClangBuiltinStr(); - } - } -} - -// clang-format off -// boolean type are encoded the ratio of n (SEW/LMUL) -// SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 -// c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t -// IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 - -// type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 -// -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- -// i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 -// i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 -// i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 -// i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 -// double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 -// float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 -// half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 -// clang-format on - -bool RVVType::verifyType() const { - if (ScalarType == Invalid) - return false; - if (isScalar()) - return true; - if (!Scale.hasValue()) - return false; - if (isFloat() && ElementBitwidth == 8) - return false; - unsigned V = Scale.getValue(); - switch (ElementBitwidth) { - case 1: - case 8: - // Check Scale is 1,2,4,8,16,32,64 - return (V <= 64 && isPowerOf2_32(V)); - case 16: - // Check Scale is 1,2,4,8,16,32 - return (V <= 32 && isPowerOf2_32(V)); - case 32: - // Check Scale is 1,2,4,8,16 - return (V <= 16 && isPowerOf2_32(V)); - case 64: - // Check Scale is 1,2,4,8 - return (V <= 8 && isPowerOf2_32(V)); - } - return false; -} - -void RVVType::initBuiltinStr() { - assert(isValid() && "RVVType is invalid"); - switch (ScalarType) { - case ScalarTypeKind::Void: - BuiltinStr = "v"; - return; - case ScalarTypeKind::Size_t: - BuiltinStr = "z"; - if (IsImmediate) - BuiltinStr = "I" + BuiltinStr; - if (IsPointer) - BuiltinStr += "*"; - return; - case ScalarTypeKind::Ptrdiff_t: - BuiltinStr = "Y"; - return; - case ScalarTypeKind::UnsignedLong: - BuiltinStr = "ULi"; - return; - case ScalarTypeKind::SignedLong: - BuiltinStr = "Li"; - return; - case ScalarTypeKind::Boolean: - assert(ElementBitwidth == 1); - BuiltinStr += "b"; - break; - case ScalarTypeKind::SignedInteger: - case ScalarTypeKind::UnsignedInteger: - switch (ElementBitwidth) { - case 8: - BuiltinStr += "c"; - break; - case 16: - BuiltinStr += "s"; - break; - case 32: - BuiltinStr += "i"; - break; - case 64: - BuiltinStr += "Wi"; - break; - default: - llvm_unreachable("Unhandled ElementBitwidth!"); - } - if (isSignedInteger()) - BuiltinStr = "S" + BuiltinStr; - else - BuiltinStr = "U" + BuiltinStr; - break; - case ScalarTypeKind::Float: - switch (ElementBitwidth) { - case 16: - BuiltinStr += "x"; - break; - case 32: - BuiltinStr += "f"; - break; - case 64: - BuiltinStr += "d"; - break; - default: - llvm_unreachable("Unhandled ElementBitwidth!"); - } - break; - default: - llvm_unreachable("ScalarType is invalid!"); - } - if (IsImmediate) - BuiltinStr = "I" + BuiltinStr; - if (isScalar()) { - if (IsConstant) - BuiltinStr += "C"; - if (IsPointer) - BuiltinStr += "*"; - return; - } - BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr; - // Pointer to vector types. Defined for segment load intrinsics. - // segment load intrinsics have pointer type arguments to store the loaded - // vector values. - if (IsPointer) - BuiltinStr += "*"; -} - -void RVVType::initClangBuiltinStr() { - assert(isValid() && "RVVType is invalid"); - assert(isVector() && "Handle Vector type only"); - - ClangBuiltinStr = "__rvv_"; - switch (ScalarType) { - case ScalarTypeKind::Boolean: - ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t"; - return; - case ScalarTypeKind::Float: - ClangBuiltinStr += "float"; - break; - case ScalarTypeKind::SignedInteger: - ClangBuiltinStr += "int"; - break; - case ScalarTypeKind::UnsignedInteger: - ClangBuiltinStr += "uint"; - break; - default: - llvm_unreachable("ScalarTypeKind is invalid"); - } - ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t"; -} - -void RVVType::initTypeStr() { - assert(isValid() && "RVVType is invalid"); - - if (IsConstant) - Str += "const "; - - auto getTypeString = [&](StringRef TypeStr) { - if (isScalar()) - return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); - return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t") - .str(); - }; - - switch (ScalarType) { - case ScalarTypeKind::Void: - Str = "void"; - return; - case ScalarTypeKind::Size_t: - Str = "size_t"; - if (IsPointer) - Str += " *"; - return; - case ScalarTypeKind::Ptrdiff_t: - Str = "ptrdiff_t"; - return; - case ScalarTypeKind::UnsignedLong: - Str = "unsigned long"; - return; - case ScalarTypeKind::SignedLong: - Str = "long"; - return; - case ScalarTypeKind::Boolean: - if (isScalar()) - Str += "bool"; - else - // Vector bool is special case, the formulate is - // `vbool_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 - Str += "vbool" + utostr(64 / Scale.getValue()) + "_t"; - break; - case ScalarTypeKind::Float: - if (isScalar()) { - if (ElementBitwidth == 64) - Str += "double"; - else if (ElementBitwidth == 32) - Str += "float"; - else if (ElementBitwidth == 16) - Str += "_Float16"; - else - llvm_unreachable("Unhandled floating type."); - } else - Str += getTypeString("float"); - break; - case ScalarTypeKind::SignedInteger: - Str += getTypeString("int"); - break; - case ScalarTypeKind::UnsignedInteger: - Str += getTypeString("uint"); - break; - default: - llvm_unreachable("ScalarType is invalid!"); - } - if (IsPointer) - Str += " *"; -} - -void RVVType::initShortStr() { - switch (ScalarType) { - case ScalarTypeKind::Boolean: - assert(isVector()); - ShortStr = "b" + utostr(64 / Scale.getValue()); - return; - case ScalarTypeKind::Float: - ShortStr = "f" + utostr(ElementBitwidth); - break; - case ScalarTypeKind::SignedInteger: - ShortStr = "i" + utostr(ElementBitwidth); - break; - case ScalarTypeKind::UnsignedInteger: - ShortStr = "u" + utostr(ElementBitwidth); - break; - default: - PrintFatalError("Unhandled case!"); - } - if (isVector()) - ShortStr += LMUL.str(); -} - -void RVVType::applyBasicType() { - switch (BT) { - case 'c': - ElementBitwidth = 8; - ScalarType = ScalarTypeKind::SignedInteger; - break; - case 's': - ElementBitwidth = 16; - ScalarType = ScalarTypeKind::SignedInteger; - break; - case 'i': - ElementBitwidth = 32; - ScalarType = ScalarTypeKind::SignedInteger; - break; - case 'l': - ElementBitwidth = 64; - ScalarType = ScalarTypeKind::SignedInteger; - break; - case 'x': - ElementBitwidth = 16; - ScalarType = ScalarTypeKind::Float; - break; - case 'f': - ElementBitwidth = 32; - ScalarType = ScalarTypeKind::Float; - break; - case 'd': - ElementBitwidth = 64; - ScalarType = ScalarTypeKind::Float; - break; - default: - PrintFatalError("Unhandled type code!"); - } - assert(ElementBitwidth != 0 && "Bad element bitwidth!"); -} - -void RVVType::applyModifier(StringRef Transformer) { - if (Transformer.empty()) - return; - // Handle primitive type transformer - auto PType = Transformer.back(); - switch (PType) { - case 'e': - Scale = 0; - break; - case 'v': - Scale = LMUL.getScale(ElementBitwidth); - break; - case 'w': - ElementBitwidth *= 2; - LMUL *= 2; - Scale = LMUL.getScale(ElementBitwidth); - break; - case 'q': - ElementBitwidth *= 4; - LMUL *= 4; - Scale = LMUL.getScale(ElementBitwidth); - break; - case 'o': - ElementBitwidth *= 8; - LMUL *= 8; - Scale = LMUL.getScale(ElementBitwidth); - break; - case 'm': - ScalarType = ScalarTypeKind::Boolean; - Scale = LMUL.getScale(ElementBitwidth); - ElementBitwidth = 1; - break; - case '0': - ScalarType = ScalarTypeKind::Void; - break; - case 'z': - ScalarType = ScalarTypeKind::Size_t; - break; - case 't': - ScalarType = ScalarTypeKind::Ptrdiff_t; - break; - case 'u': - ScalarType = ScalarTypeKind::UnsignedLong; - break; - case 'l': - ScalarType = ScalarTypeKind::SignedLong; - break; - default: - PrintFatalError("Illegal primitive type transformers!"); - } - Transformer = Transformer.drop_back(); - - // Extract and compute complex type transformer. It can only appear one time. - if (Transformer.startswith("(")) { - size_t Idx = Transformer.find(')'); - assert(Idx != StringRef::npos); - StringRef ComplexType = Transformer.slice(1, Idx); - Transformer = Transformer.drop_front(Idx + 1); - assert(!Transformer.contains('(') && - "Only allow one complex type transformer"); - - auto UpdateAndCheckComplexProto = [&]() { - Scale = LMUL.getScale(ElementBitwidth); - const StringRef VectorPrototypes("vwqom"); - if (!VectorPrototypes.contains(PType)) - PrintFatalError("Complex type transformer only supports vector type!"); - if (Transformer.find_first_of("PCKWS") != StringRef::npos) - PrintFatalError( - "Illegal type transformer for Complex type transformer"); - }; - auto ComputeFixedLog2LMUL = - [&](StringRef Value, - std::function Compare) { - int32_t Log2LMUL; - Value.getAsInteger(10, Log2LMUL); - if (!Compare(Log2LMUL, LMUL.Log2LMUL)) { - ScalarType = Invalid; - return false; - } - // Update new LMUL - LMUL = LMULType(Log2LMUL); - UpdateAndCheckComplexProto(); - return true; - }; - auto ComplexTT = ComplexType.split(":"); - if (ComplexTT.first == "Log2EEW") { - uint32_t Log2EEW; - ComplexTT.second.getAsInteger(10, Log2EEW); - // update new elmul = (eew/sew) * lmul - LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); - // update new eew - ElementBitwidth = 1 << Log2EEW; - ScalarType = ScalarTypeKind::SignedInteger; - UpdateAndCheckComplexProto(); - } else if (ComplexTT.first == "FixedSEW") { - uint32_t NewSEW; - ComplexTT.second.getAsInteger(10, NewSEW); - // Set invalid type if src and dst SEW are same. - if (ElementBitwidth == NewSEW) { - ScalarType = Invalid; - return; - } - // Update new SEW - ElementBitwidth = NewSEW; - UpdateAndCheckComplexProto(); - } else if (ComplexTT.first == "LFixedLog2LMUL") { - // New LMUL should be larger than old - if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater())) - return; - } else if (ComplexTT.first == "SFixedLog2LMUL") { - // New LMUL should be smaller than old - if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less())) - return; - } else { - PrintFatalError("Illegal complex type transformers!"); - } - } - - // Compute the remain type transformers - for (char I : Transformer) { - switch (I) { - case 'P': - if (IsConstant) - PrintFatalError("'P' transformer cannot be used after 'C'"); - if (IsPointer) - PrintFatalError("'P' transformer cannot be used twice"); - IsPointer = true; - break; - case 'C': - if (IsConstant) - PrintFatalError("'C' transformer cannot be used twice"); - IsConstant = true; - break; - case 'K': - IsImmediate = true; - break; - case 'U': - ScalarType = ScalarTypeKind::UnsignedInteger; - break; - case 'I': - ScalarType = ScalarTypeKind::SignedInteger; - break; - case 'F': - ScalarType = ScalarTypeKind::Float; - break; - case 'S': - LMUL = LMULType(0); - // Update ElementBitwidth need to update Scale too. - Scale = LMUL.getScale(ElementBitwidth); - break; - default: - PrintFatalError("Illegal non-primitive type transformer!"); - } - } -} - -//===----------------------------------------------------------------------===// -// RVVIntrinsic implementation -//===----------------------------------------------------------------------===// -RVVIntrinsic::RVVIntrinsic( - StringRef NewName, StringRef Suffix, StringRef NewMangledName, - StringRef MangledSuffix, StringRef IRName, bool IsMasked, - bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme, - bool HasUnMaskedOverloaded, bool HasBuiltinAlias, StringRef ManualCodegen, - const RVVTypes &OutInTypes, const std::vector &NewIntrinsicTypes, - const std::vector &RequiredFeatures, unsigned NF) - : IRName(IRName), IsMasked(IsMasked), HasVL(HasVL), Scheme(Scheme), - HasUnMaskedOverloaded(HasUnMaskedOverloaded), - HasBuiltinAlias(HasBuiltinAlias), ManualCodegen(ManualCodegen.str()), - NF(NF) { - - // Init BuiltinName, Name and MangledName - BuiltinName = NewName.str(); - Name = BuiltinName; - if (NewMangledName.empty()) - MangledName = NewName.split("_").first.str(); - else - MangledName = NewMangledName.str(); - if (!Suffix.empty()) - Name += "_" + Suffix.str(); - if (!MangledSuffix.empty()) - MangledName += "_" + MangledSuffix.str(); - if (IsMasked) { - BuiltinName += "_m"; - Name += "_m"; - } - - // Init RISC-V extensions - for (const auto &T : OutInTypes) { - if (T->isFloatVector(16) || T->isFloat(16)) - RISCVPredefinedMacros |= RISCVPredefinedMacro::Zvfh; - if (T->isFloatVector(32)) - RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp32; - if (T->isFloatVector(64)) - RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp64; - if (T->isVector(64)) - RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELen64; - } - for (auto Feature : RequiredFeatures) { - if (Feature == "RV64") - RISCVPredefinedMacros |= RISCVPredefinedMacro::RV64; - // Note: Full multiply instruction (mulh, mulhu, mulhsu, smul) for EEW=64 - // require V. - if (Feature == "FullMultiply" && - (RISCVPredefinedMacros & RISCVPredefinedMacro::VectorMaxELen64)) - RISCVPredefinedMacros |= RISCVPredefinedMacro::V; - } - - // Init OutputType and InputTypes - OutputType = OutInTypes[0]; - InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end()); - - // IntrinsicTypes is unmasked TA version index. Need to update it - // if there is merge operand (It is always in first operand). - IntrinsicTypes = NewIntrinsicTypes; - if ((IsMasked && HasMaskedOffOperand) || - (!IsMasked && hasPassthruOperand())) { - for (auto &I : IntrinsicTypes) { - if (I >= 0) - I += NF; - } - } -} - -std::string RVVIntrinsic::getBuiltinTypeStr() const { - std::string S; - S += OutputType->getBuiltinStr(); - for (const auto &T : InputTypes) { - S += T->getBuiltinStr(); - } - return S; -} - -void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const { - if (!getIRName().empty()) - OS << " ID = Intrinsic::riscv_" + getIRName() + ";\n"; - if (NF >= 2) - OS << " NF = " + utostr(getNF()) + ";\n"; - if (hasManualCodegen()) { - OS << ManualCodegen; +void emitCodeGenSwitchBody(const RVVIntrinsic *RVVI, raw_ostream &OS) { + if (!RVVI->getIRName().empty()) + OS << " ID = Intrinsic::riscv_" + RVVI->getIRName() + ";\n"; + if (RVVI->getNF() >= 2) + OS << " NF = " + utostr(RVVI->getNF()) + ";\n"; + if (RVVI->hasManualCodegen()) { + OS << RVVI->getManualCodegen(); OS << "break;\n"; return; } - if (isMasked()) { - if (hasVL()) { + if (RVVI->isMasked()) { + if (RVVI->hasVL()) { OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n"; - if (hasPolicyOperand()) + if (RVVI->hasPolicyOperand()) OS << " Ops.push_back(ConstantInt::get(Ops.back()->getType()," " TAIL_UNDISTURBED));\n"; } else { OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n"; } } else { - if (hasPolicyOperand()) + if (RVVI->hasPolicyOperand()) OS << " Ops.push_back(ConstantInt::get(Ops.back()->getType(), " "TAIL_UNDISTURBED));\n"; - else if (hasPassthruOperand()) { + else if (RVVI->hasPassthruOperand()) { OS << " Ops.push_back(llvm::UndefValue::get(ResultType));\n"; OS << " std::rotate(Ops.rbegin(), Ops.rbegin() + 1, Ops.rend());\n"; } @@ -880,7 +113,7 @@ OS << " IntrinsicTypes = {"; ListSeparator LS; - for (const auto &Idx : IntrinsicTypes) { + for (const auto &Idx : RVVI->getIntrinsicTypes()) { if (Idx == -1) OS << LS << "ResultType"; else @@ -889,17 +122,18 @@ // VL could be i64 or i32, need to encode it in IntrinsicTypes. VL is // always last operand. - if (hasVL()) + if (RVVI->hasVL()) OS << ", Ops.back()->getType()"; OS << "};\n"; OS << " break;\n"; } -void RVVIntrinsic::emitIntrinsicFuncDef(raw_ostream &OS) const { +void emitIntrinsicFuncDef(const RVVIntrinsic &RVVI, raw_ostream &OS) { OS << "__attribute__((__clang_builtin_alias__("; - OS << "__builtin_rvv_" << getBuiltinName() << ")))\n"; - OS << OutputType->getTypeStr() << " " << getName() << "("; + OS << "__builtin_rvv_" << RVVI.getBuiltinName() << ")))\n"; + OS << RVVI.getOutputType()->getTypeStr() << " " << RVVI.getName() << "("; // Emit function arguments + const RVVTypes &InputTypes = RVVI.getInputTypes(); if (!InputTypes.empty()) { ListSeparator LS; for (unsigned i = 0; i < InputTypes.size(); ++i) @@ -908,11 +142,13 @@ OS << ");\n"; } -void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const { +void emitMangledFuncDef(const RVVIntrinsic &RVVI, raw_ostream &OS) { OS << "__attribute__((__clang_builtin_alias__("; - OS << "__builtin_rvv_" << getBuiltinName() << ")))\n"; - OS << OutputType->getTypeStr() << " " << getMangledName() << "("; + OS << "__builtin_rvv_" << RVVI.getBuiltinName() << ")))\n"; + OS << RVVI.getOutputType()->getTypeStr() << " " << RVVI.getMangledName() + << "("; // Emit function arguments + const RVVTypes &InputTypes = RVVI.getInputTypes(); if (!InputTypes.empty()) { ListSeparator LS; for (unsigned i = 0; i < InputTypes.size(); ++i) @@ -1016,7 +252,7 @@ // Print intrinsic functions with macro emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) { OS << "__rvv_ai "; - Inst.emitIntrinsicFuncDef(OS); + emitIntrinsicFuncDef(Inst, OS); }); OS << "#undef __rvv_ai\n\n"; @@ -1031,7 +267,7 @@ if (!Inst.isMasked() && !Inst.hasUnMaskedOverloaded()) return; OS << "__rvv_aio "; - Inst.emitMangledFuncDef(OS); + emitMangledFuncDef(Inst, OS); }); OS << "#undef __rvv_aio\n"; @@ -1092,7 +328,7 @@ StringRef CurIRName = Def->getIRName(); if (CurIRName != PrevDef->getIRName() || (Def->getManualCodegen() != PrevDef->getManualCodegen())) { - PrevDef->emitCodeGenSwitchBody(OS); + emitCodeGenSwitchBody(PrevDef, OS); } PrevDef = Def.get(); @@ -1119,7 +355,7 @@ else if (P.first->second->getIntrinsicTypes() != Def->getIntrinsicTypes()) PrintFatalError("Builtin with same name has different IntrinsicTypes"); } - Defs.back()->emitCodeGenSwitchBody(OS); + emitCodeGenSwitchBody(Defs.back().get(), OS); OS << "\n"; }