Index: docs/LangRef.rst =================================================================== --- docs/LangRef.rst +++ docs/LangRef.rst @@ -675,6 +675,9 @@ Variables and aliases can have a :ref:`Thread Local Storage Model `. +:ref:`Scalable vectors ` cannot be global variables or members of +structs or arrays because their size is unknown at compile time. + Syntax:: @ = [Linkage] [PreemptionSpecifier] [Visibility] @@ -2730,30 +2733,40 @@ A vector type is a simple derived type that represents a vector of elements. Vector types are used when multiple primitive data are operated in parallel using a single instruction (SIMD). A vector type -requires a size (number of elements) and an underlying primitive data -type. Vector types are considered :ref:`first class `. +requires a size (number of elements), an underlying primitive data type, +and a scalable property to represent vectors where the exact hardware +vector length is unknown at compile time. Vector types are considered +:ref:`first class `. :Syntax: :: - < <# elements> x > + < <# elements> x > ; Fixed-length vector + < scalable <# elements> x > ; Scalable vector The number of elements is a constant integer value larger than 0; elementtype may be any integer, floating-point or pointer type. Vectors -of size zero are not allowed. +of size zero are not allowed. For scalable vectors, the total number of +elements is a constant multiple (called vscale) of the specified number +of elements; vscale is a positive integer that is unknown at compile time +and the same hardware-dependent constant for all scalable vectors at run +time. The size of a specific scalable vector type is thus constant within +IR, even if the exact size in bytes cannot be determined until run time. :Examples: -+-------------------+--------------------------------------------------+ -| ``<4 x i32>`` | Vector of 4 32-bit integer values. | -+-------------------+--------------------------------------------------+ -| ``<8 x float>`` | Vector of 8 32-bit floating-point values. | -+-------------------+--------------------------------------------------+ -| ``<2 x i64>`` | Vector of 2 64-bit integer values. | -+-------------------+--------------------------------------------------+ -| ``<4 x i64*>`` | Vector of 4 pointers to 64-bit integer values. | -+-------------------+--------------------------------------------------+ ++------------------------+----------------------------------------------------+ +| ``<4 x i32>`` | Vector of 4 32-bit integer values. | ++------------------------+----------------------------------------------------+ +| ``<8 x float>`` | Vector of 8 32-bit floating-point values. | ++------------------------+----------------------------------------------------+ +| ``<2 x i64>`` | Vector of 2 64-bit integer values. | ++------------------------+----------------------------------------------------+ +| ``<4 x i64*>`` | Vector of 4 pointers to 64-bit integer values. | ++------------------------+----------------------------------------------------+ +| ```` | Vector with a multiple of 4 32-bit integer values. | ++------------------------+----------------------------------------------------+ .. _t_label: @@ -8113,6 +8126,7 @@ :: = extractelement > , ; yields + = extractelement > , ; yields Overview: """"""""" @@ -8133,8 +8147,11 @@ The result is a scalar of the same type as the element type of ``val``. Its value is the value at position ``idx`` of ``val``. If ``idx`` -exceeds the length of ``val``, the result is a -:ref:`poison value `. +exceeds the length of ``val`` for a fixed-length vector, the result is a +:ref:`poison value `. For a scalable vector, if the value of +``idx`` might exceed the runtime length of the vector, the result in IR is +a value of the appropriate type but the runtime behaviour of the code is +implementation defined. Example: """""""" @@ -8154,6 +8171,7 @@ :: = insertelement > , , ; yields > + = insertelement > , , ; yields > Overview: """"""""" @@ -8175,8 +8193,11 @@ The result is a vector of the same type as ``val``. Its element values are those of ``val`` except at position ``idx``, where it gets the value -``elt``. If ``idx`` exceeds the length of ``val``, the result -is a :ref:`poison value `. +``elt``. If ``idx`` exceeds the length of ``val`` for a fixed-length vector, +the result is a :ref:`poison value `. For a scalable vector, +if the value of ``idx`` might exceed the runtime length of the vector, the +result in IR is a value of the appropriate type but the runtime behaviour +of the code is implementation defined. Example: """""""" Index: include/llvm/ADT/DenseMapInfo.h =================================================================== --- include/llvm/ADT/DenseMapInfo.h +++ include/llvm/ADT/DenseMapInfo.h @@ -17,6 +17,7 @@ #include "llvm/ADT/Hashing.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/PointerLikeTypeTraits.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -268,6 +269,21 @@ static bool isEqual(hash_code LHS, hash_code RHS) { return LHS == RHS; } }; +template <> struct DenseMapInfo { + static inline ElementCount getEmptyKey() { return {~0U, true}; } + static inline ElementCount getTombstoneKey() { return {~0U - 1, false}; } + static unsigned getHashValue(const ElementCount& EltCnt) { + if (EltCnt.Scalable) + return (EltCnt.Min * 37U) - 1U; + + return EltCnt.Min * 37U; + } + + static bool isEqual(const ElementCount& LHS, const ElementCount& RHS) { + return LHS == RHS; + } +}; + } // end namespace llvm #endif // LLVM_ADT_DENSEMAPINFO_H Index: include/llvm/IR/DerivedTypes.h =================================================================== --- include/llvm/IR/DerivedTypes.h +++ include/llvm/IR/DerivedTypes.h @@ -23,6 +23,7 @@ #include "llvm/IR/Type.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Compiler.h" +#include "llvm/Support/ScalableSize.h" #include #include @@ -387,6 +388,8 @@ SequentialType(const SequentialType &) = delete; SequentialType &operator=(const SequentialType &) = delete; + /// For scalable vectors, this will return the minimum number of elements + /// in the vector. uint64_t getNumElements() const { return NumElements; } Type *getElementType() const { return ContainedType; } @@ -422,14 +425,37 @@ /// Class to represent vector types. class VectorType : public SequentialType { - VectorType(Type *ElType, unsigned NumEl); + /// A fully specified VectorType is of the form . 'n' is the + /// minimum number of elements of type Ty contained within the vector, and + /// 'scalable' indicates that the total element count is an integer multiple + /// of 'n', where the multiple is either guaranteed to be one, or is + /// statically unknown at compile time. + /// + /// If the multiple is known to be 1, then the extra term is discarded in + /// textual IR: + /// + /// <4 x i32> - a vector containing 4 i32s + /// - a vector containing an unknown integer multiple + /// of 4 i32s + + VectorType(Type *ElType, unsigned NumEl, bool Scalable = false); + VectorType(Type *ElType, ElementCount EC); + + // If true, the total number of elements is an unknown multiple of the + // minimum 'NumElements' from SequentialType. Otherwise the total number + // of elements is exactly equal to 'NumElements'. + bool Scalable; public: VectorType(const VectorType &) = delete; VectorType &operator=(const VectorType &) = delete; /// This static method is the primary way to construct an VectorType. - static VectorType *get(Type *ElementType, unsigned NumElements); + static VectorType *get(Type *ElementType, ElementCount EC); + static VectorType *get(Type *ElementType, unsigned NumElements, + bool Scalable = false) { + return VectorType::get(ElementType, {NumElements, Scalable}); + } /// This static method gets a VectorType with the same number of elements as /// the input type, and the element type is an integer type of the same width @@ -438,7 +464,7 @@ unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits(); assert(EltBits && "Element size must be of a non-zero size"); Type *EltTy = IntegerType::get(VTy->getContext(), EltBits); - return VectorType::get(EltTy, VTy->getNumElements()); + return VectorType::get(EltTy, VTy->getElementCount()); } /// This static method is like getInteger except that the element types are @@ -446,7 +472,7 @@ static VectorType *getExtendedElementVectorType(VectorType *VTy) { unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits(); Type *EltTy = IntegerType::get(VTy->getContext(), EltBits * 2); - return VectorType::get(EltTy, VTy->getNumElements()); + return VectorType::get(EltTy, VTy->getElementCount()); } /// This static method is like getInteger except that the element types are @@ -456,29 +482,45 @@ assert((EltBits & 1) == 0 && "Cannot truncate vector element with odd bit-width"); Type *EltTy = IntegerType::get(VTy->getContext(), EltBits / 2); - return VectorType::get(EltTy, VTy->getNumElements()); + return VectorType::get(EltTy, VTy->getElementCount()); } /// This static method returns a VectorType with half as many elements as the /// input type and the same element type. static VectorType *getHalfElementsVectorType(VectorType *VTy) { - unsigned NumElts = VTy->getNumElements(); - assert ((NumElts & 1) == 0 && + auto EltCnt = VTy->getElementCount(); + assert ((EltCnt.Min & 1) == 0 && "Cannot halve vector with odd number of elements."); - return VectorType::get(VTy->getElementType(), NumElts/2); + return VectorType::get(VTy->getElementType(), EltCnt/2); } /// This static method returns a VectorType with twice as many elements as the /// input type and the same element type. static VectorType *getDoubleElementsVectorType(VectorType *VTy) { - unsigned NumElts = VTy->getNumElements(); - return VectorType::get(VTy->getElementType(), NumElts*2); + auto EltCnt = VTy->getElementCount(); + assert((VTy->getNumElements() * 2ull) <= UINT_MAX && + "Too many elements in vector"); + return VectorType::get(VTy->getElementType(), EltCnt*2); } /// Return true if the specified type is valid as a element type. static bool isValidElementType(Type *ElemTy); - /// Return the number of bits in the Vector type. + /// Return an ElementCount instance to represent the (possibly scalable) + /// number of elements in the vector. + ElementCount getElementCount() const { + uint64_t MinimumEltCnt = getNumElements(); + assert(MinimumEltCnt <= UINT_MAX && "Too many elements in vector"); + return { (unsigned)MinimumEltCnt, Scalable }; + } + + /// Returns whether or not this is a scalable vector (meaning the total + /// element count is a multiple of the minimum). + bool isScalable() const { + return Scalable; + } + + /// Return the minimum number of bits in the Vector type. /// Returns zero when the vector is a vector of pointers. unsigned getBitWidth() const { return getNumElements() * getElementType()->getPrimitiveSizeInBits(); @@ -494,6 +536,10 @@ return cast(this)->getNumElements(); } +bool Type::getVectorIsScalable() const { + return cast(this)->isScalable(); +} + /// Class to represent pointers. class PointerType : public Type { explicit PointerType(Type *ElType, unsigned AddrSpace); Index: include/llvm/IR/Type.h =================================================================== --- include/llvm/IR/Type.h +++ include/llvm/IR/Type.h @@ -366,6 +366,7 @@ return ContainedTys[0]; } + inline bool getVectorIsScalable() const; inline unsigned getVectorNumElements() const; Type *getVectorElementType() const { assert(getTypeID() == VectorTyID); Index: include/llvm/Support/ScalableSize.h =================================================================== --- /dev/null +++ include/llvm/Support/ScalableSize.h @@ -0,0 +1,43 @@ +//===- ScalableSize.h - Scalable vector size info ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides a struct that can be used to query the size of IR types +// which may be scalable vectors. It provides convenience operators so that +// it can be used in much the same way as a single scalar value. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_SUPPORT_SCALABLESIZE_H +#define LLVM_SUPPORT_SCALABLESIZE_H + +namespace llvm { + +class ElementCount { +public: + unsigned Min; // Minimum number of vector elements. + bool Scalable; // If true, NumElements is a multiple of 'Min' determined + // at runtime rather than compile time. + + ElementCount(unsigned Min, bool Scalable) + : Min(Min), Scalable(Scalable) {} + + ElementCount operator*(unsigned RHS) { + return { Min * RHS, Scalable }; + } + ElementCount operator/(unsigned RHS) { + return { Min / RHS, Scalable }; + } + + bool operator==(const ElementCount& RHS) const { + return Min == RHS.Min && Scalable == RHS.Scalable; + } +}; + +} // end namespace llvm + +#endif // LLVM_SUPPORT_SCALABLESIZE_H Index: lib/AsmParser/LLLexer.cpp =================================================================== --- lib/AsmParser/LLLexer.cpp +++ lib/AsmParser/LLLexer.cpp @@ -706,6 +706,7 @@ KEYWORD(xchg); KEYWORD(nand); KEYWORD(max); KEYWORD(min); KEYWORD(umax); KEYWORD(umin); + KEYWORD(scalable); KEYWORD(x); KEYWORD(blockaddress); Index: lib/AsmParser/LLParser.cpp =================================================================== --- lib/AsmParser/LLParser.cpp +++ lib/AsmParser/LLParser.cpp @@ -2694,7 +2694,16 @@ /// Type /// ::= '[' APSINTVAL 'x' Types ']' /// ::= '<' APSINTVAL 'x' Types '>' +/// ::= '<' 'scalable' APSINTVAL 'x' Types '>' bool LLParser::ParseArrayVectorType(Type *&Result, bool isVector) { + bool Scalable = false; + + if (isVector && Lex.getKind() == lltok::kw_scalable) { + Lex.Lex(); // consume the 'scalable' + + Scalable = true; + } + if (Lex.getKind() != lltok::APSInt || Lex.getAPSIntVal().isSigned() || Lex.getAPSIntVal().getBitWidth() > 64) return TokError("expected number in address space"); @@ -2721,7 +2730,7 @@ return Error(SizeLoc, "size too large for vector"); if (!VectorType::isValidElementType(EltTy)) return Error(TypeLoc, "invalid vector element type"); - Result = VectorType::get(EltTy, unsigned(Size)); + Result = VectorType::get(EltTy, unsigned(Size), Scalable); } else { if (!ArrayType::isValidElementType(EltTy)) return Error(TypeLoc, "invalid array element type"); Index: lib/AsmParser/LLToken.h =================================================================== --- lib/AsmParser/LLToken.h +++ lib/AsmParser/LLToken.h @@ -37,6 +37,7 @@ bar, // | colon, // : + kw_scalable, kw_x, kw_true, kw_false, Index: lib/Bitcode/Reader/BitcodeReader.cpp =================================================================== --- lib/Bitcode/Reader/BitcodeReader.cpp +++ lib/Bitcode/Reader/BitcodeReader.cpp @@ -1757,7 +1757,8 @@ return error("Invalid type"); ResultTy = ArrayType::get(ResultTy, Record[0]); break; - case bitc::TYPE_CODE_VECTOR: // VECTOR: [numelts, eltty] + case bitc::TYPE_CODE_VECTOR: // VECTOR: [numelts, eltty] or + // [numelts, eltty, scalable] if (Record.size() < 2) return error("Invalid record"); if (Record[0] == 0) @@ -1765,7 +1766,8 @@ ResultTy = getTypeByID(Record[1]); if (!ResultTy || !StructType::isValidElementType(ResultTy)) return error("Invalid type"); - ResultTy = VectorType::get(ResultTy, Record[0]); + bool Scalable = Record.size() > 2 ? Record[2] : false; + ResultTy = VectorType::get(ResultTy, Record[0], Scalable); break; } Index: lib/Bitcode/Writer/BitcodeWriter.cpp =================================================================== --- lib/Bitcode/Writer/BitcodeWriter.cpp +++ lib/Bitcode/Writer/BitcodeWriter.cpp @@ -931,10 +931,13 @@ } case Type::VectorTyID: { VectorType *VT = cast(T); - // VECTOR [numelts, eltty] + // VECTOR [numelts, eltty] or + // [numelts, eltty, scalable] Code = bitc::TYPE_CODE_VECTOR; TypeVals.push_back(VT->getNumElements()); TypeVals.push_back(VE.getTypeID(VT->getElementType())); + if (VT->isScalable()) + TypeVals.push_back(VT->isScalable()); break; } } Index: lib/IR/AsmWriter.cpp =================================================================== --- lib/IR/AsmWriter.cpp +++ lib/IR/AsmWriter.cpp @@ -620,7 +620,10 @@ } case Type::VectorTyID: { VectorType *PTy = cast(Ty); - OS << "<" << PTy->getNumElements() << " x "; + OS << "<"; + if (PTy->isScalable()) + OS << "scalable "; + OS << PTy->getNumElements() << " x "; print(PTy->getElementType(), OS); OS << '>'; return; Index: lib/IR/LLVMContextImpl.h =================================================================== --- lib/IR/LLVMContextImpl.h +++ lib/IR/LLVMContextImpl.h @@ -1335,7 +1335,7 @@ unsigned NamedStructTypesUniqueID = 0; DenseMap, ArrayType*> ArrayTypes; - DenseMap, VectorType*> VectorTypes; + DenseMap, VectorType*> VectorTypes; DenseMap PointerTypes; // Pointers in AddrSpace = 0 DenseMap, PointerType*> ASPointerTypes; Index: lib/IR/Type.cpp =================================================================== --- lib/IR/Type.cpp +++ lib/IR/Type.cpp @@ -599,21 +599,21 @@ // VectorType Implementation //===----------------------------------------------------------------------===// -VectorType::VectorType(Type *ElType, unsigned NumEl) - : SequentialType(VectorTyID, ElType, NumEl) {} +VectorType::VectorType(Type *ElType, ElementCount EC) + : SequentialType(VectorTyID, ElType, EC.Min), Scalable(EC.Scalable) {} -VectorType *VectorType::get(Type *ElementType, unsigned NumElements) { - assert(NumElements > 0 && "#Elements of a VectorType must be greater than 0"); +VectorType *VectorType::get(Type *ElementType, ElementCount EC ) { + assert(EC.Min > 0 && "#Elements of a VectorType must be greater than 0"); assert(isValidElementType(ElementType) && "Element type of a VectorType must " "be an integer, floating point, or " "pointer type."); LLVMContextImpl *pImpl = ElementType->getContext().pImpl; VectorType *&Entry = ElementType->getContext().pImpl - ->VectorTypes[std::make_pair(ElementType, NumElements)]; + ->VectorTypes[std::make_pair(ElementType, EC)]; if (!Entry) - Entry = new (pImpl->TypeAllocator) VectorType(ElementType, NumElements); + Entry = new (pImpl->TypeAllocator) VectorType(ElementType, EC); return Entry; } Index: lib/IR/Verifier.cpp =================================================================== --- lib/IR/Verifier.cpp +++ lib/IR/Verifier.cpp @@ -43,6 +43,7 @@ // //===----------------------------------------------------------------------===// +#include "LLVMContextImpl.h" #include "llvm/IR/Verifier.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" @@ -307,6 +308,7 @@ TBAAVerifier TBAAVerifyHelper; void checkAtomicMemAccessSize(Type *Ty, const Instruction *I); + static bool containsScalableVectorValue(const Type *Ty); public: explicit Verifier(raw_ostream *OS, bool ShouldTreatBrokenDebugInfoAsError, @@ -318,6 +320,33 @@ bool hasBrokenDebugInfo() const { return BrokenDebugInfo; } + bool verifyTypes(const Module &M) { + LLVMContext &Ctx = M.getContext(); + for (auto &Entry : Ctx.pImpl->ArrayTypes) { + ArrayType *ATy = Entry.second; + if (containsScalableVectorValue(ATy)) { + CheckFailed("Arrays cannot contain scalable vectors", ATy, &M); + Broken = true; + } + } + + for (StructType* STy : Ctx.pImpl->AnonStructTypes) + if (containsScalableVectorValue(STy)) { + CheckFailed("Structs cannot contain scalable vectors", STy, &M); + Broken = true; + } + + for (auto &Entry : Ctx.pImpl->NamedStructTypes) { + StructType *STy = Entry.second; + if (containsScalableVectorValue(STy)) { + CheckFailed("Structs cannot contain scalable vectors", STy, &M); + Broken = true; + } + } + + return !Broken; + } + bool verify(const Function &F) { assert(F.getParent() == &M && "An instance of this class only works with a specific module!"); @@ -387,6 +416,8 @@ verifyCompileUnits(); + verifyTypes(M); + verifyDeoptimizeCallingConvs(); DISubprogramAttachments.clear(); return !Broken; @@ -613,6 +644,35 @@ }); } +// Check for a scalable vector type, making sure to look through arrays and +// structs. Pointers to scalable vectors don't count, since we know what the +// size of a pointer is. +static bool containsScalableVectorValueRecursive(const Type *Ty, + SmallVectorImpl &Visited) { + if (is_contained(Visited, Ty)) + return false; + + Visited.push_back(Ty); + + if (auto *VTy = dyn_cast(Ty)) + return VTy->isScalable(); + + if (auto *ATy = dyn_cast(Ty)) + return containsScalableVectorValueRecursive(ATy->getElementType(), Visited); + + if (auto *STy = dyn_cast(Ty)) + for (Type *EltTy : STy->elements()) + if (containsScalableVectorValueRecursive(EltTy, Visited)) + return true; + + return false; +} + +bool Verifier::containsScalableVectorValue(const Type *Ty) { + SmallVector VisitedList = {}; + return containsScalableVectorValueRecursive(Ty, VisitedList); +} + void Verifier::visitGlobalVariable(const GlobalVariable &GV) { if (GV.hasInitializer()) { Assert(GV.getInitializer()->getType() == GV.getValueType(), @@ -691,6 +751,12 @@ "DIGlobalVariableExpression"); } + // Scalable vectors cannot be global variables, since we don't know + // the runtime size. Need to look inside structs/arrays to find the + // underlying element type as well. + if (containsScalableVectorValue(GV.getValueType())) + CheckFailed("Globals cannot contain scalable vectors", &GV); + if (!GV.hasInitializer()) { visitGlobalValue(GV); return; Index: test/Bitcode/compatibility.ll =================================================================== --- test/Bitcode/compatibility.ll +++ test/Bitcode/compatibility.ll @@ -871,6 +871,10 @@ ; CHECK: %t7 = alloca x86_mmx %t8 = alloca %opaquety* ; CHECK: %t8 = alloca %opaquety* + %t9 = alloca <4 x i32> + ; CHECK: %t9 = alloca <4 x i32> + %t10 = alloca + ; CHECK: %t10 = alloca ret void } Index: test/Verifier/scalable-aggregates.ll =================================================================== --- /dev/null +++ test/Verifier/scalable-aggregates.ll @@ -0,0 +1,31 @@ +; RUN: not opt -S -verify < %s 2>&1 | FileCheck %s + +;; Arrays and Structs cannot contain scalable vectors, since we don't +;; know the size at compile time and the container types need to have +;; a known size. + +; CHECK-DAG: Arrays cannot contain scalable vectors +; CHECK-DAG: [2 x { i32, }]; ModuleID = '' +; CHECK-DAG: Arrays cannot contain scalable vectors +; CHECK-DAG: [4 x ]; ModuleID = '' +; CHECK-DAG: Arrays cannot contain scalable vectors +; CHECK-DAG: [2 x ]; ModuleID = '' +; CHECK-DAG: Structs cannot contain scalable vectors +; CHECK-DAG: { i64, [4 x ] }; ModuleID = '' +; CHECK-DAG: Structs cannot contain scalable vectors +; CHECK-DAG: { i32, }; ModuleID = '' +; CHECK-DAG: Structs cannot contain scalable vectors +; CHECK-DAG: { , }; ModuleID = '' +; CHECK-DAG: Structs cannot contain scalable vectors +; CHECK-DAG: %sty = type { i64, }; ModuleID = '' + +%sty = type { i64, } + +define void @scalable_aggregates() { + %array = alloca [2 x ] + %struct = alloca { , } + %named_struct = alloca %sty + %s_in_a = alloca [2 x { i32, } ] + %a_in_s = alloca { i64, [4 x ] } + ret void +} \ No newline at end of file Index: test/Verifier/scalable-global-vars.ll =================================================================== --- /dev/null +++ test/Verifier/scalable-global-vars.ll @@ -0,0 +1,24 @@ +; RUN: not opt -S -verify < %s 2>&1 | FileCheck %s + +;; Global variables cannot be scalable vectors, since we don't +;; know the size at compile time. + +; CHECK: Globals cannot contain scalable vectors +; CHECK-NEXT: * @ScalableVecGlobal +@ScalableVecGlobal = global zeroinitializer + +; CHECK: Globals cannot contain scalable vectors +; CHECK-NEXT: [64 x ]* @ScalableVecGlobalArray +@ScalableVecGlobalArray = global [64 x ] zeroinitializer + +; CHECK: Globals cannot contain scalable vectors +; CHECK-NEXT: { , }* @ScalableVecGlobalStruct +@ScalableVecGlobalStruct = global { , } zeroinitializer + +; CHECK: Globals cannot contain scalable vectors +; CHECK-NEXT: { [4 x i32], [2 x { , }] }* @ScalableVecMixed +@ScalableVecMixed = global { [4 x i32], [2 x { , }]} zeroinitializer + +;; Global _pointers_ to scalable vectors are fine +; CHECK-NOT: Globals cannot contain scalable vectors +@ScalableVecPtr = global * zeroinitializer Index: unittests/IR/CMakeLists.txt =================================================================== --- unittests/IR/CMakeLists.txt +++ unittests/IR/CMakeLists.txt @@ -37,6 +37,7 @@ ValueHandleTest.cpp ValueMapTest.cpp ValueTest.cpp + VectorTypesTest.cpp VerifierTest.cpp WaymarkTest.cpp ) Index: unittests/IR/VectorTypesTest.cpp =================================================================== --- /dev/null +++ unittests/IR/VectorTypesTest.cpp @@ -0,0 +1,164 @@ +//===--- llvm/unittest/IR/VectorTypesTest.cpp - vector types unit tests ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/ScalableSize.h" +#include "gtest/gtest.h" +using namespace llvm; + +namespace { +TEST(VectorTypesTest, FixedLength) { + LLVMContext Ctx; + + Type *Int16Ty = Type::getInt16Ty(Ctx); + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *Int64Ty = Type::getInt64Ty(Ctx); + Type *Float64Ty = Type::getDoubleTy(Ctx); + + VectorType *V8Int32Ty = VectorType::get(Int32Ty, 8); + ASSERT_FALSE(V8Int32Ty->isScalable()); + EXPECT_EQ(V8Int32Ty->getNumElements(), 8U); + EXPECT_EQ(V8Int32Ty->getElementType()->getScalarSizeInBits(), 32U); + + VectorType *V8Int16Ty = VectorType::get(Int16Ty, {8, false}); + ASSERT_FALSE(V8Int16Ty->isScalable()); + EXPECT_EQ(V8Int16Ty->getNumElements(), 8U); + EXPECT_EQ(V8Int16Ty->getElementType()->getScalarSizeInBits(), 16U); + + ElementCount EltCnt(4, false); + VectorType *V4Int64Ty = VectorType::get(Int64Ty, EltCnt); + ASSERT_FALSE(V4Int64Ty->isScalable()); + EXPECT_EQ(V4Int64Ty->getNumElements(), 4U); + EXPECT_EQ(V4Int64Ty->getElementType()->getScalarSizeInBits(), 64U); + + VectorType *V2Int64Ty = VectorType::get(Int64Ty, EltCnt/2); + ASSERT_FALSE(V2Int64Ty->isScalable()); + EXPECT_EQ(V2Int64Ty->getNumElements(), 2U); + EXPECT_EQ(V2Int64Ty->getElementType()->getScalarSizeInBits(), 64U); + + VectorType *V8Int64Ty = VectorType::get(Int64Ty, EltCnt*2); + ASSERT_FALSE(V8Int64Ty->isScalable()); + EXPECT_EQ(V8Int64Ty->getNumElements(), 8U); + EXPECT_EQ(V8Int64Ty->getElementType()->getScalarSizeInBits(), 64U); + + VectorType *V4Float64Ty = VectorType::get(Float64Ty, EltCnt); + ASSERT_FALSE(V4Float64Ty->isScalable()); + EXPECT_EQ(V4Float64Ty->getNumElements(), 4U); + EXPECT_EQ(V4Float64Ty->getElementType()->getScalarSizeInBits(), 64U); + + VectorType *ExtTy = VectorType::getExtendedElementVectorType(V8Int16Ty); + EXPECT_EQ(ExtTy, V8Int32Ty); + ASSERT_FALSE(ExtTy->isScalable()); + EXPECT_EQ(ExtTy->getNumElements(), 8U); + EXPECT_EQ(ExtTy->getElementType()->getScalarSizeInBits(), 32U); + + VectorType *TruncTy = VectorType::getTruncatedElementVectorType(V8Int32Ty); + EXPECT_EQ(TruncTy, V8Int16Ty); + ASSERT_FALSE(TruncTy->isScalable()); + EXPECT_EQ(TruncTy->getNumElements(), 8U); + EXPECT_EQ(TruncTy->getElementType()->getScalarSizeInBits(), 16U); + + VectorType *HalvedTy = VectorType::getHalfElementsVectorType(V4Int64Ty); + EXPECT_EQ(HalvedTy, V2Int64Ty); + ASSERT_FALSE(HalvedTy->isScalable()); + EXPECT_EQ(HalvedTy->getNumElements(), 2U); + EXPECT_EQ(HalvedTy->getElementType()->getScalarSizeInBits(), 64U); + + VectorType *DoubledTy = VectorType::getDoubleElementsVectorType(V4Int64Ty); + EXPECT_EQ(DoubledTy, V8Int64Ty); + ASSERT_FALSE(DoubledTy->isScalable()); + EXPECT_EQ(DoubledTy->getNumElements(), 8U); + EXPECT_EQ(DoubledTy->getElementType()->getScalarSizeInBits(), 64U); + + VectorType *ConvTy = VectorType::getInteger(V4Float64Ty); + EXPECT_EQ(ConvTy, V4Int64Ty); + ASSERT_FALSE(ConvTy->isScalable()); + EXPECT_EQ(ConvTy->getNumElements(), 4U); + EXPECT_EQ(ConvTy->getElementType()->getScalarSizeInBits(), 64U); + + EltCnt = V8Int64Ty->getElementCount(); + EXPECT_EQ(EltCnt.Min, 8U); + ASSERT_FALSE(EltCnt.Scalable); +} + +TEST(VectorTypesTest, Scalable) { + LLVMContext Ctx; + + Type *Int16Ty = Type::getInt16Ty(Ctx); + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *Int64Ty = Type::getInt64Ty(Ctx); + Type *Float64Ty = Type::getDoubleTy(Ctx); + + VectorType *ScV8Int32Ty = VectorType::get(Int32Ty, 8, true); + ASSERT_TRUE(ScV8Int32Ty->isScalable()); + EXPECT_EQ(ScV8Int32Ty->getNumElements(), 8U); + EXPECT_EQ(ScV8Int32Ty->getElementType()->getScalarSizeInBits(), 32U); + + VectorType *ScV8Int16Ty = VectorType::get(Int16Ty, {8, true}); + ASSERT_TRUE(ScV8Int16Ty->isScalable()); + EXPECT_EQ(ScV8Int16Ty->getNumElements(), 8U); + EXPECT_EQ(ScV8Int16Ty->getElementType()->getScalarSizeInBits(), 16U); + + ElementCount EltCnt(4, true); + VectorType *ScV4Int64Ty = VectorType::get(Int64Ty, EltCnt); + ASSERT_TRUE(ScV4Int64Ty->isScalable()); + EXPECT_EQ(ScV4Int64Ty->getNumElements(), 4U); + EXPECT_EQ(ScV4Int64Ty->getElementType()->getScalarSizeInBits(), 64U); + + VectorType *ScV2Int64Ty = VectorType::get(Int64Ty, EltCnt/2); + ASSERT_TRUE(ScV2Int64Ty->isScalable()); + EXPECT_EQ(ScV2Int64Ty->getNumElements(), 2U); + EXPECT_EQ(ScV2Int64Ty->getElementType()->getScalarSizeInBits(), 64U); + + VectorType *ScV8Int64Ty = VectorType::get(Int64Ty, EltCnt*2); + ASSERT_TRUE(ScV8Int64Ty->isScalable()); + EXPECT_EQ(ScV8Int64Ty->getNumElements(), 8U); + EXPECT_EQ(ScV8Int64Ty->getElementType()->getScalarSizeInBits(), 64U); + + VectorType *ScV4Float64Ty = VectorType::get(Float64Ty, EltCnt); + ASSERT_TRUE(ScV4Float64Ty->isScalable()); + EXPECT_EQ(ScV4Float64Ty->getNumElements(), 4U); + EXPECT_EQ(ScV4Float64Ty->getElementType()->getScalarSizeInBits(), 64U); + + VectorType *ExtTy = VectorType::getExtendedElementVectorType(ScV8Int16Ty); + EXPECT_EQ(ExtTy, ScV8Int32Ty); + ASSERT_TRUE(ExtTy->isScalable()); + EXPECT_EQ(ExtTy->getNumElements(), 8U); + EXPECT_EQ(ExtTy->getElementType()->getScalarSizeInBits(), 32U); + + VectorType *TruncTy = VectorType::getTruncatedElementVectorType(ScV8Int32Ty); + EXPECT_EQ(TruncTy, ScV8Int16Ty); + ASSERT_TRUE(TruncTy->isScalable()); + EXPECT_EQ(TruncTy->getNumElements(), 8U); + EXPECT_EQ(TruncTy->getElementType()->getScalarSizeInBits(), 16U); + + VectorType *HalvedTy = VectorType::getHalfElementsVectorType(ScV4Int64Ty); + EXPECT_EQ(HalvedTy, ScV2Int64Ty); + ASSERT_TRUE(HalvedTy->isScalable()); + EXPECT_EQ(HalvedTy->getNumElements(), 2U); + EXPECT_EQ(HalvedTy->getElementType()->getScalarSizeInBits(), 64U); + + VectorType *DoubledTy = VectorType::getDoubleElementsVectorType(ScV4Int64Ty); + EXPECT_EQ(DoubledTy, ScV8Int64Ty); + ASSERT_TRUE(DoubledTy->isScalable()); + EXPECT_EQ(DoubledTy->getNumElements(), 8U); + EXPECT_EQ(DoubledTy->getElementType()->getScalarSizeInBits(), 64U); + + VectorType *ConvTy = VectorType::getInteger(ScV4Float64Ty); + EXPECT_EQ(ConvTy, ScV4Int64Ty); + ASSERT_TRUE(ConvTy->isScalable()); + EXPECT_EQ(ConvTy->getNumElements(), 4U); + EXPECT_EQ(ConvTy->getElementType()->getScalarSizeInBits(), 64U); + + EltCnt = ScV8Int64Ty->getElementCount(); + EXPECT_EQ(EltCnt.Min, 8U); + ASSERT_TRUE(EltCnt.Scalable); +} + +} // end anonymous namespace