Index: docs/LangRef.rst =================================================================== --- docs/LangRef.rst +++ docs/LangRef.rst @@ -2480,30 +2480,36 @@ A vector type is a simple derived type that represents a vector of elements. Vector types are used when multiple primitive data are operated in parallel using a single instruction (SIMD). A vector type -requires a size (number of elements) and an underlying primitive data -type. Vector types are considered :ref:`first class `. +requires a size (number of elements), an underlying primitive data type, +and a scalable property to represent vectors where the exact hardware +vector length is unknown at compile time. Vector types are considered +:ref:`first class `. :Syntax: :: - < <# elements> x > + < <# elements> x > ; Fixed-length vector + < m x <# elements> x > ; Scalable vector The number of elements is a constant integer value larger than 0; elementtype may be any integer, floating point or pointer type. Vectors -of size zero are not allowed. +of size zero are not allowed. For scalable vectors, the number of +elements is an unknown integer multiple of the number of elements. :Examples: -+-------------------+--------------------------------------------------+ -| ``<4 x i32>`` | Vector of 4 32-bit integer values. | -+-------------------+--------------------------------------------------+ -| ``<8 x float>`` | Vector of 8 32-bit floating-point values. | -+-------------------+--------------------------------------------------+ -| ``<2 x i64>`` | Vector of 2 64-bit integer values. | -+-------------------+--------------------------------------------------+ -| ``<4 x i64*>`` | Vector of 4 pointers to 64-bit integer values. | -+-------------------+--------------------------------------------------+ ++-------------------+----------------------------------------------------+ +| ``<4 x i32>`` | Vector of 4 32-bit integer values. | ++-------------------+----------------------------------------------------+ +| ``<8 x float>`` | Vector of 8 32-bit floating-point values. | ++-------------------+----------------------------------------------------+ +| ``<2 x i64>`` | Vector of 2 64-bit integer values. | ++-------------------+----------------------------------------------------+ +| ``<4 x i64*>`` | Vector of 4 pointers to 64-bit integer values. | ++-------------------+----------------------------------------------------+ +| ```` | Vector with a multiple of 4 32-bit integer values. | ++-------------------+----------------------------------------------------+ .. _t_label: Index: include/llvm/IR/DerivedTypes.h =================================================================== --- include/llvm/IR/DerivedTypes.h +++ include/llvm/IR/DerivedTypes.h @@ -367,14 +367,55 @@ /// Class to represent vector types. class VectorType : public SequentialType { - VectorType(Type *ElType, unsigned NumEl); +public: + /// A fully specified VectorType is of the form . 'n' is the + /// minimum number of elements of type Ty contained within the vector, 'm' + /// indicates multiples of the minimum and the total element count is + /// the result of 'm' * 'n'. However, for all targets 'm' is expected to be + /// either statically unknown at compile time or guaranteed to be one. + /// If 'm' is known to be 1, then the extra term is discarded in textual IR: + /// + /// <4 x i32> - a vector containing 4 i32s + /// - a vector containing an unknown integer multiple of 4 i32s + class ElementCount { + public: + unsigned Min; // Minimum number of vector elements. + bool Scalable; // if true, NumElements is an unknown multiple of 'Min' + + ElementCount(unsigned Min, bool Scalable) + : Min(Min), Scalable(Scalable) {} + + ElementCount operator*(unsigned RHS) { + return { Min * RHS, Scalable }; + } + ElementCount operator/(unsigned RHS) { + return { Min / RHS, Scalable }; + } + + bool operator==(const ElementCount& RHS) { + return Min == RHS.Min && Scalable == RHS.Scalable; + } + }; + +private: + VectorType(Type *ElType, unsigned NumEl, bool Scalable = false); + VectorType(Type *ElType, ElementCount EC); + + // If true, the total number of elements is an unknown multiple of the + // minimum 'NumElements' from SequentialType. Otherwise the total number + // of elements is exactly equal to 'NumElements' + bool Scalable; public: VectorType(const VectorType &) = delete; VectorType &operator=(const VectorType &) = delete; /// This static method is the primary way to construct an VectorType. - static VectorType *get(Type *ElementType, unsigned NumElements); + static VectorType *get(Type *ElementType, ElementCount EC); + static VectorType *get(Type *ElementType, unsigned NumElements, + bool Scalable = false) { + return VectorType::get(ElementType, {NumElements, Scalable}); + } /// This static method gets a VectorType with the same number of elements as /// the input type, and the element type is an integer type of the same width @@ -383,7 +424,7 @@ unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits(); assert(EltBits && "Element size must be of a non-zero size"); Type *EltTy = IntegerType::get(VTy->getContext(), EltBits); - return VectorType::get(EltTy, VTy->getNumElements()); + return VectorType::get(EltTy, VTy->getElementCount()); } /// This static method is like getInteger except that the element types are @@ -391,7 +432,7 @@ static VectorType *getExtendedElementVectorType(VectorType *VTy) { unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits(); Type *EltTy = IntegerType::get(VTy->getContext(), EltBits * 2); - return VectorType::get(EltTy, VTy->getNumElements()); + return VectorType::get(EltTy, VTy->getElementCount()); } /// This static method is like getInteger except that the element types are @@ -401,29 +442,43 @@ assert((EltBits & 1) == 0 && "Cannot truncate vector element with odd bit-width"); Type *EltTy = IntegerType::get(VTy->getContext(), EltBits / 2); - return VectorType::get(EltTy, VTy->getNumElements()); + return VectorType::get(EltTy, VTy->getElementCount()); } /// This static method returns a VectorType with half as many elements as the /// input type and the same element type. static VectorType *getHalfElementsVectorType(VectorType *VTy) { - unsigned NumElts = VTy->getNumElements(); - assert ((NumElts & 1) == 0 && + auto EltCnt = VTy->getElementCount(); + assert ((EltCnt.Min & 1) == 0 && "Cannot halve vector with odd number of elements."); - return VectorType::get(VTy->getElementType(), NumElts/2); + return VectorType::get(VTy->getElementType(), EltCnt/2); } /// This static method returns a VectorType with twice as many elements as the /// input type and the same element type. static VectorType *getDoubleElementsVectorType(VectorType *VTy) { - unsigned NumElts = VTy->getNumElements(); - return VectorType::get(VTy->getElementType(), NumElts*2); + auto EltCnt = VTy->getElementCount(); + return VectorType::get(VTy->getElementType(), EltCnt*2); } /// Return true if the specified type is valid as a element type. static bool isValidElementType(Type *ElemTy); - /// Return the number of bits in the Vector type. + /// Return an ElementCount instance to represent the (possibly scalable) + /// number of elements in the vector + ElementCount getElementCount() const { + uint64_t MinimumEltCnt = getNumElements(); + assert(MinimumEltCnt <= UINT_MAX && "Too many elements in vector"); + return { (unsigned)MinimumEltCnt, Scalable }; + } + + /// Returns whether or not this is a scalable vector (meaning the total + /// element count is a multiple of the minimum). + bool isScalable() const { + return Scalable; + } + + /// Return the minimum number of bits in the Vector type. /// Returns zero when the vector is a vector of pointers. unsigned getBitWidth() const { return getNumElements() * getElementType()->getPrimitiveSizeInBits(); @@ -439,6 +494,10 @@ return cast(this)->getNumElements(); } +bool Type::getVectorIsScalable() const { + return cast(this)->isScalable(); +} + /// Class to represent pointers. class PointerType : public Type { explicit PointerType(Type *ElType, unsigned AddrSpace); Index: include/llvm/IR/Type.h =================================================================== --- include/llvm/IR/Type.h +++ include/llvm/IR/Type.h @@ -353,6 +353,7 @@ return ContainedTys[0]; } + inline bool getVectorIsScalable() const; inline unsigned getVectorNumElements() const; Type *getVectorElementType() const { assert(getTypeID() == VectorTyID); Index: lib/AsmParser/LLLexer.cpp =================================================================== --- lib/AsmParser/LLLexer.cpp +++ lib/AsmParser/LLLexer.cpp @@ -682,6 +682,7 @@ KEYWORD(xchg); KEYWORD(nand); KEYWORD(max); KEYWORD(min); KEYWORD(umax); KEYWORD(umin); + KEYWORD(m); KEYWORD(x); KEYWORD(blockaddress); Index: lib/AsmParser/LLParser.cpp =================================================================== --- lib/AsmParser/LLParser.cpp +++ lib/AsmParser/LLParser.cpp @@ -2429,7 +2429,19 @@ /// Type /// ::= '[' APSINTVAL 'x' Types ']' /// ::= '<' APSINTVAL 'x' Types '>' +/// ::= '<' 'm' 'x' APSINTVAL 'x' Types '>' bool LLParser::ParseArrayVectorType(Type *&Result, bool isVector) { + bool Scalable = false; + + if (isVector && Lex.getKind() == lltok::kw_m) { + Lex.Lex(); // consume the 'm' + + if (ParseToken(lltok::kw_x, "expected 'x' after scalable vector specifier")) + return true; + + Scalable = true; + } + if (Lex.getKind() != lltok::APSInt || Lex.getAPSIntVal().isSigned() || Lex.getAPSIntVal().getBitWidth() > 64) return TokError("expected number in address space"); @@ -2456,7 +2468,7 @@ return Error(SizeLoc, "size too large for vector"); if (!VectorType::isValidElementType(EltTy)) return Error(TypeLoc, "invalid vector element type"); - Result = VectorType::get(EltTy, unsigned(Size)); + Result = VectorType::get(EltTy, unsigned(Size), Scalable); } else { if (!ArrayType::isValidElementType(EltTy)) return Error(TypeLoc, "invalid array element type"); Index: lib/AsmParser/LLToken.h =================================================================== --- lib/AsmParser/LLToken.h +++ lib/AsmParser/LLToken.h @@ -37,6 +37,7 @@ exclaim, // ! bar, // | + kw_m, kw_x, kw_true, kw_false, Index: lib/Bitcode/Reader/BitcodeReader.cpp =================================================================== --- lib/Bitcode/Reader/BitcodeReader.cpp +++ lib/Bitcode/Reader/BitcodeReader.cpp @@ -1672,7 +1672,7 @@ return error("Invalid type"); ResultTy = ArrayType::get(ResultTy, Record[0]); break; - case bitc::TYPE_CODE_VECTOR: // VECTOR: [numelts, eltty] + case bitc::TYPE_CODE_VECTOR: // VECTOR: [numelts, eltty, scalable] if (Record.size() < 2) return error("Invalid record"); if (Record[0] == 0) @@ -1680,7 +1680,8 @@ ResultTy = getTypeByID(Record[1]); if (!ResultTy || !StructType::isValidElementType(ResultTy)) return error("Invalid type"); - ResultTy = VectorType::get(ResultTy, Record[0]); + bool Scalable = Record.size() > 2 ? Record[2] : false; + ResultTy = VectorType::get(ResultTy, Record[0], Scalable); break; } Index: lib/Bitcode/Writer/BitcodeWriter.cpp =================================================================== --- lib/Bitcode/Writer/BitcodeWriter.cpp +++ lib/Bitcode/Writer/BitcodeWriter.cpp @@ -933,10 +933,11 @@ } case Type::VectorTyID: { VectorType *VT = cast(T); - // VECTOR [numelts, eltty] + // VECTOR [numelts, eltty, scalable] Code = bitc::TYPE_CODE_VECTOR; TypeVals.push_back(VT->getNumElements()); TypeVals.push_back(VE.getTypeID(VT->getElementType())); + TypeVals.push_back(VT->isScalable()); break; } } Index: lib/IR/AsmWriter.cpp =================================================================== --- lib/IR/AsmWriter.cpp +++ lib/IR/AsmWriter.cpp @@ -537,7 +537,10 @@ } case Type::VectorTyID: { VectorType *PTy = cast(Ty); - OS << "<" << PTy->getNumElements() << " x "; + OS << "<"; + if (PTy->isScalable()) + OS << "m x "; + OS << PTy->getNumElements() << " x "; print(PTy->getElementType(), OS); OS << '>'; return; Index: lib/IR/LLVMContextImpl.h =================================================================== --- lib/IR/LLVMContextImpl.h +++ lib/IR/LLVMContextImpl.h @@ -1189,7 +1189,7 @@ unsigned NamedStructTypesUniqueID; DenseMap, ArrayType*> ArrayTypes; - DenseMap, VectorType*> VectorTypes; + DenseMap, VectorType*> VectorTypes; DenseMap PointerTypes; // Pointers in AddrSpace = 0 DenseMap, PointerType*> ASPointerTypes; Index: lib/IR/Type.cpp =================================================================== --- lib/IR/Type.cpp +++ lib/IR/Type.cpp @@ -619,21 +619,25 @@ // VectorType Implementation //===----------------------------------------------------------------------===// -VectorType::VectorType(Type *ElType, unsigned NumEl) - : SequentialType(VectorTyID, ElType, NumEl) {} +VectorType::VectorType(Type *ElType, ElementCount EC) + : SequentialType(VectorTyID, ElType, EC.Min) { Scalable = EC.Scalable; } -VectorType *VectorType::get(Type *ElementType, unsigned NumElements) { - assert(NumElements > 0 && "#Elements of a VectorType must be greater than 0"); +VectorType *VectorType::get(Type *ElementType, ElementCount EC ) { + assert(EC.Min > 0 && "#Elements of a VectorType must be greater than 0"); assert(isValidElementType(ElementType) && "Element type of a VectorType must " "be an integer, floating point, or " "pointer type."); + + uint64_t LookupEC = (uint64_t)EC.Min | + (EC.Scalable ? (1ull << 63) : 0ull); + LLVMContextImpl *pImpl = ElementType->getContext().pImpl; VectorType *&Entry = ElementType->getContext().pImpl - ->VectorTypes[std::make_pair(ElementType, NumElements)]; + ->VectorTypes[std::make_pair(ElementType, LookupEC)]; if (!Entry) - Entry = new (pImpl->TypeAllocator) VectorType(ElementType, NumElements); + Entry = new (pImpl->TypeAllocator) VectorType(ElementType, EC); return Entry; } Index: test/Bitcode/compatibility.ll =================================================================== --- test/Bitcode/compatibility.ll +++ test/Bitcode/compatibility.ll @@ -823,6 +823,10 @@ ; CHECK: %t7 = alloca x86_mmx %t8 = alloca %opaquety* ; CHECK: %t8 = alloca %opaquety* + %t9 = alloca <4 x i32> + ; CHECK: %t9 = alloca <4 x i32> + %t10 = alloca + ; CHECK: %t10 = alloca ret void } Index: unittests/IR/CMakeLists.txt =================================================================== --- unittests/IR/CMakeLists.txt +++ unittests/IR/CMakeLists.txt @@ -30,6 +30,7 @@ ValueHandleTest.cpp ValueMapTest.cpp ValueTest.cpp + VectorTypesTest.cpp VerifierTest.cpp WaymarkTest.cpp ) Index: unittests/IR/VectorTypesTest.cpp =================================================================== --- /dev/null +++ unittests/IR/VectorTypesTest.cpp @@ -0,0 +1,90 @@ +//===--- llvm/unittest/IR/VectorTypesTest.cpp - vector types unit tests ---===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/LLVMContext.h" +#include "gtest/gtest.h" +using namespace llvm; + +namespace { +TEST(VectorTypesTest, FixedLength) { + LLVMContext Ctx; + + Type *Int16Ty = Type::getInt16Ty(Ctx); + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *Int64Ty = Type::getInt64Ty(Ctx); + Type *Float64Ty = Type::getDoubleTy(Ctx); + + VectorType *V8Int32Ty = VectorType::get(Int32Ty, 8); + ASSERT_FALSE(V8Int32Ty->isScalable()); + VectorType *V8Int16Ty = VectorType::get(Int16Ty, {8, false}); + ASSERT_FALSE(V8Int16Ty->isScalable()); + + VectorType::ElementCount EltCnt(4, false); + VectorType *V4Int64Ty = VectorType::get(Int64Ty, EltCnt); + ASSERT_FALSE(V4Int64Ty->isScalable()); + VectorType *V2Int64Ty = VectorType::get(Int64Ty, EltCnt/2); + ASSERT_FALSE(V2Int64Ty->isScalable()); + VectorType *V8Int64Ty = VectorType::get(Int64Ty, EltCnt*2); + ASSERT_FALSE(V8Int64Ty->isScalable()); + VectorType *V4Float64Ty = VectorType::get(Float64Ty, EltCnt); + ASSERT_FALSE(V4Float64Ty->isScalable()); + + EXPECT_EQ(VectorType::getExtendedElementVectorType(V8Int16Ty), V8Int32Ty); + EXPECT_EQ(VectorType::getTruncatedElementVectorType(V8Int32Ty), V8Int16Ty); + + EXPECT_EQ(VectorType::getHalfElementsVectorType(V4Int64Ty), V2Int64Ty); + EXPECT_EQ(VectorType::getDoubleElementsVectorType(V4Int64Ty), V8Int64Ty); + + EXPECT_EQ(VectorType::getInteger(V4Float64Ty), V4Int64Ty); + + EltCnt = V8Int64Ty->getElementCount(); + EXPECT_EQ(EltCnt.Min, 8U); + ASSERT_FALSE(EltCnt.Scalable); +} + +TEST(VectorTypesTest, Scalable) { + LLVMContext Ctx; + + Type *Int16Ty = Type::getInt16Ty(Ctx); + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *Int64Ty = Type::getInt64Ty(Ctx); + Type *Float64Ty = Type::getDoubleTy(Ctx); + + VectorType *ScV8Int32Ty = VectorType::get(Int32Ty, 8, true); + ASSERT_TRUE(ScV8Int32Ty->isScalable()); + VectorType *ScV8Int16Ty = VectorType::get(Int16Ty, {8, true}); + ASSERT_TRUE(ScV8Int16Ty->isScalable()); + + VectorType::ElementCount EltCnt(4, true); + VectorType *ScV4Int64Ty = VectorType::get(Int64Ty, EltCnt); + ASSERT_TRUE(ScV4Int64Ty->isScalable()); + VectorType *ScV2Int64Ty = VectorType::get(Int64Ty, EltCnt/2); + ASSERT_TRUE(ScV2Int64Ty->isScalable()); + VectorType *ScV8Int64Ty = VectorType::get(Int64Ty, EltCnt*2); + ASSERT_TRUE(ScV8Int64Ty->isScalable()); + VectorType *ScV4Float64Ty = VectorType::get(Float64Ty, EltCnt); + ASSERT_TRUE(ScV4Float64Ty->isScalable()); + + + EXPECT_EQ(VectorType::getExtendedElementVectorType(ScV8Int16Ty), ScV8Int32Ty); + EXPECT_EQ(VectorType::getTruncatedElementVectorType(ScV8Int32Ty), + ScV8Int16Ty); + + EXPECT_EQ(VectorType::getHalfElementsVectorType(ScV4Int64Ty), ScV2Int64Ty); + EXPECT_EQ(VectorType::getDoubleElementsVectorType(ScV4Int64Ty), ScV8Int64Ty); + + EXPECT_EQ(VectorType::getInteger(ScV4Float64Ty), ScV4Int64Ty); + + EltCnt = ScV8Int64Ty->getElementCount(); + EXPECT_EQ(EltCnt.Min, 8U); + ASSERT_TRUE(EltCnt.Scalable); +} + +}