diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h --- a/llvm/include/llvm-c/Core.h +++ b/llvm/include/llvm-c/Core.h @@ -160,7 +160,9 @@ LLVMVectorTypeKind, /**< SIMD 'packed' format, or other vector type */ LLVMMetadataTypeKind, /**< Metadata */ LLVMX86_MMXTypeKind, /**< X86 MMX */ - LLVMTokenTypeKind /**< Tokens */ + LLVMTokenTypeKind, /**< Tokens */ + LLVMFixedVectorTypeKind, /**< Fixed width SIMD vector type */ + LLVMScalableVectorTypeKind /**< Scalable SIMD vector type */ } LLVMTypeKind; typedef enum { diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h --- a/llvm/include/llvm/IR/DerivedTypes.h +++ b/llvm/include/llvm/IR/DerivedTypes.h @@ -386,7 +386,7 @@ return cast(this)->getNumElements(); } -/// Class to represent vector types. +/// Base class of all SIMD vecctor types class VectorType : public Type { /// A fully specified VectorType is of the form . 'n' is the /// minimum number of elements of type Ty contained within the vector, and @@ -403,16 +403,12 @@ /// The element type of the vector. Type *ContainedType; - /// Minumum number of elements in the vector. - uint64_t NumElements; - VectorType(Type *ElType, unsigned NumEl, bool Scalable = false); - VectorType(Type *ElType, ElementCount EC); + /// The element count of this vector + ElementCount EC; - // If true, the total number of elements is an unknown multiple of the - // minimum 'NumElements'. Otherwise the total number of elements is exactly - // equal to 'NumElements'. - bool Scalable; +protected: + VectorType(Type *ElType, ElementCount EC, Type::TypeID TID); public: VectorType(const VectorType &) = delete; @@ -420,7 +416,7 @@ /// For scalable vectors, this will return the minimum number of elements /// in the vector. - uint64_t getNumElements() const { return NumElements; } + uint64_t getNumElements() const { return EC.Min; } Type *getElementType() const { return ContainedType; } /// This static method is the primary way to construct an VectorType. @@ -430,6 +426,15 @@ return VectorType::get(ElementType, {NumElements, Scalable}); } + /// Construct a VectorType that has the same shape as some other VectorType + static VectorType *get(Type *ElementType, VectorType *Other) { + return VectorType::get(ElementType, Other->getElementCount()); + } + + static VectorType *get(Type *ElementType, const VectorType *Other) { + return VectorType::get(ElementType, Other->getElementCount()); + } + /// This static method gets a VectorType with the same number of elements as /// the input type, and the element type is an integer type of the same width /// as the input element type. @@ -510,14 +515,12 @@ ElementCount getElementCount() const { uint64_t MinimumEltCnt = getNumElements(); assert(MinimumEltCnt <= UINT_MAX && "Too many elements in vector"); - return { (unsigned)MinimumEltCnt, Scalable }; + return EC; } /// Returns whether or not this is a scalable vector (meaning the total /// element count is a multiple of the minimum). - bool isScalable() const { - return Scalable; - } + bool isScalable() const { return EC.Scalable; } /// Return the minimum number of bits in the Vector type. /// Returns zero when the vector is a vector of pointers. @@ -527,12 +530,47 @@ /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Type *T) { - return T->getTypeID() == VectorTyID; + switch (T->getTypeID()) { + case VectorTyID: + case FixedVectorTyID: + case ScalableVectorTyID: + return true; + default: + return false; + } } }; bool Type::isVectorTy() const { return isa(this); } +/// Class to represent fixed width SIMD vectors +class FixedVectorType : public VectorType { +protected: + FixedVectorType(Type *ElTy, unsigned NumElts) + : VectorType(ElTy, {NumElts, false}, FixedVectorTyID) {} + +public: + static FixedVectorType *get(Type *ElementType, unsigned NumElts); + + static bool classof(const Type *T) { + return T->getTypeID() == FixedVectorTyID; + } +}; + +/// Class to represent scalable SIMD vectors +class ScalableVectorType : public VectorType { +protected: + ScalableVectorType(Type *ElTy, unsigned MinNumElts) + : VectorType(ElTy, {MinNumElts, true}, ScalableVectorTyID) {} + +public: + static ScalableVectorType *get(Type *ElementType, unsigned MinNumElts); + + static bool classof(const Type *T) { + return T->getTypeID() == ScalableVectorTyID; + } +}; + /// Class to represent pointers. class PointerType : public Type { explicit PointerType(Type *ElType, unsigned AddrSpace); diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h --- a/llvm/include/llvm/IR/Type.h +++ b/llvm/include/llvm/IR/Type.h @@ -54,26 +54,28 @@ /// enum TypeID { // PrimitiveTypes - make sure LastPrimitiveTyID stays up to date. - VoidTyID = 0, ///< 0: type with no size - HalfTyID, ///< 1: 16-bit floating point type - FloatTyID, ///< 2: 32-bit floating point type - DoubleTyID, ///< 3: 64-bit floating point type - X86_FP80TyID, ///< 4: 80-bit floating point type (X87) - FP128TyID, ///< 5: 128-bit floating point type (112-bit mantissa) - PPC_FP128TyID, ///< 6: 128-bit floating point type (two 64-bits, PowerPC) - LabelTyID, ///< 7: Labels - MetadataTyID, ///< 8: Metadata - X86_MMXTyID, ///< 9: MMX vectors (64 bits, X86 specific) - TokenTyID, ///< 10: Tokens + VoidTyID = 0, ///< 0: type with no size + HalfTyID, ///< 1: 16-bit floating point type + FloatTyID, ///< 2: 32-bit floating point type + DoubleTyID, ///< 3: 64-bit floating point type + X86_FP80TyID, ///< 4: 80-bit floating point type (X87) + FP128TyID, ///< 5: 128-bit floating point type (112-bit mantissa) + PPC_FP128TyID, ///< 6: 128-bit floating point type (two 64-bits, PowerPC) + LabelTyID, ///< 7: Labels + MetadataTyID, ///< 8: Metadata + X86_MMXTyID, ///< 9: MMX vectors (64 bits, X86 specific) + TokenTyID, ///< 10: Tokens // Derived types... see DerivedTypes.h file. // Make sure FirstDerivedTyID stays up to date! - IntegerTyID, ///< 11: Arbitrary bit width integers - FunctionTyID, ///< 12: Functions - StructTyID, ///< 13: Structures - ArrayTyID, ///< 14: Arrays - PointerTyID, ///< 15: Pointers - VectorTyID ///< 16: SIMD 'packed' format, or other vector type + IntegerTyID, ///< 11: Arbitrary bit width integers + FunctionTyID, ///< 12: Functions + StructTyID, ///< 13: Structures + ArrayTyID, ///< 14: Arrays + PointerTyID, ///< 15: Pointers + VectorTyID, ///< 16: SIMD 'packed' format, or other vector type + FixedVectorTyID, ///< 17: Fixed width SIMD vector type + ScalableVectorTyID ///< 18: Scalable SIMD vector type }; private: diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -650,7 +650,9 @@ OS << ']'; return; } - case Type::VectorTyID: { + case Type::VectorTyID: + case Type::FixedVectorTyID: + case Type::ScalableVectorTyID: { VectorType *PTy = cast(Ty); OS << "<"; if (PTy->isScalable()) diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp --- a/llvm/lib/IR/Core.cpp +++ b/llvm/lib/IR/Core.cpp @@ -507,6 +507,10 @@ return LLVMX86_MMXTypeKind; case Type::TokenTyID: return LLVMTokenTypeKind; + case Type::FixedVectorTyID: + return LLVMFixedVectorTypeKind; + case Type::ScalableVectorTyID: + return LLVMScalableVectorTypeKind; } llvm_unreachable("Unhandled TypeID."); } diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -123,9 +123,11 @@ case Type::X86_MMXTyID: return TypeSize::Fixed(64); case Type::IntegerTyID: return TypeSize::Fixed(cast(this)->getBitWidth()); - case Type::VectorTyID: { + case Type::VectorTyID: + case Type::FixedVectorTyID: + case Type::ScalableVectorTyID: { const VectorType *VTy = cast(this); - return TypeSize(VTy->getBitWidth(), VTy->isScalable()); + return TypeSize(VTy->getBitWidth(), isa(VTy)); } default: return TypeSize::Fixed(0); } @@ -589,30 +591,65 @@ // VectorType Implementation //===----------------------------------------------------------------------===// -VectorType::VectorType(Type *ElType, ElementCount EC) - : Type(ElType->getContext(), VectorTyID), ContainedType(ElType), - NumElements(EC.Min), Scalable(EC.Scalable) { +VectorType::VectorType(Type *ElType, ElementCount EC, Type::TypeID TID) + : Type(ElType->getContext(), TID), ContainedType(ElType), EC(EC) { ContainedTys = &ContainedType; NumContainedTys = 1; } VectorType *VectorType::get(Type *ElementType, ElementCount EC) { - assert(EC.Min > 0 && "#Elements of a VectorType must be greater than 0"); + if (EC.Scalable) + return ScalableVectorType::get(ElementType, EC.Min); + else + return FixedVectorType::get(ElementType, EC.Min); +} + +bool VectorType::isValidElementType(Type *ElemTy) { + return ElemTy->isIntegerTy() || ElemTy->isFloatingPointTy() || + ElemTy->isPointerTy(); +} + +//===----------------------------------------------------------------------===// +// FixedVectorType Implementation +//===----------------------------------------------------------------------===// + +FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) { + assert(NumElts > 0 && "#Elements of a VectorType must be greater than 0"); assert(isValidElementType(ElementType) && "Element type of a VectorType must " "be an integer, floating point, or " "pointer type."); + ElementCount EC(NumElts, false); + LLVMContextImpl *pImpl = ElementType->getContext().pImpl; - VectorType *&Entry = ElementType->getContext().pImpl - ->VectorTypes[std::make_pair(ElementType, EC)]; + VectorType *&Entry = ElementType->getContext() + .pImpl->VectorTypes[std::make_pair(ElementType, EC)]; + if (!Entry) - Entry = new (pImpl->Alloc) VectorType(ElementType, EC); - return Entry; + Entry = new (pImpl->Alloc) FixedVectorType(ElementType, NumElts); + return cast(Entry); } -bool VectorType::isValidElementType(Type *ElemTy) { - return ElemTy->isIntegerTy() || ElemTy->isFloatingPointTy() || - ElemTy->isPointerTy(); +//===----------------------------------------------------------------------===// +// ScalableVectorType Implementation +//===----------------------------------------------------------------------===// + +ScalableVectorType *ScalableVectorType::get(Type *ElementType, + unsigned MinNumElts) { + assert(MinNumElts > 0 && "#Elements of a VectorType must be greater than 0"); + assert(isValidElementType(ElementType) && "Element type of a VectorType must " + "be an integer, floating point, or " + "pointer type."); + + ElementCount EC(MinNumElts, true); + + LLVMContextImpl *pImpl = ElementType->getContext().pImpl; + VectorType *&Entry = ElementType->getContext() + .pImpl->VectorTypes[std::make_pair(ElementType, EC)]; + + if (!Entry) + Entry = new (pImpl->Alloc) ScalableVectorType(ElementType, MinNumElts); + return cast(Entry); } //===----------------------------------------------------------------------===//