diff --git a/llvm/include/llvm/Support/LowLevelTypeImpl.h b/llvm/include/llvm/Support/LowLevelTypeImpl.h --- a/llvm/include/llvm/Support/LowLevelTypeImpl.h +++ b/llvm/include/llvm/Support/LowLevelTypeImpl.h @@ -42,31 +42,37 @@ /// Get a low-level scalar or aggregate "bag of bits". static LLT scalar(unsigned SizeInBits) { assert(SizeInBits > 0 && "invalid scalar size"); - return LLT{/*isPointer=*/false, /*isVector=*/false, /*NumElements=*/0, - SizeInBits, /*AddressSpace=*/0}; + return LLT{/*isPointer=*/false, /*isVector=*/false, + ElementCount::getFixed(0), SizeInBits, + /*AddressSpace=*/0}; } /// Get a low-level pointer in the given address space. static LLT pointer(unsigned AddressSpace, unsigned SizeInBits) { assert(SizeInBits > 0 && "invalid pointer size"); - return LLT{/*isPointer=*/true, /*isVector=*/false, /*NumElements=*/0, - SizeInBits, AddressSpace}; + return LLT{/*isPointer=*/true, /*isVector=*/false, + ElementCount::getFixed(0), SizeInBits, AddressSpace}; } /// Get a low-level vector of some number of elements and element width. /// \p NumElements must be at least 2. - static LLT vector(uint16_t NumElements, unsigned ScalarSizeInBits) { - assert(NumElements > 1 && "invalid number of vector elements"); + static LLT vector(uint16_t NumElements, unsigned ScalarSizeInBits, + bool Scalable = false) { + assert(((!Scalable && NumElements > 1) || NumElements > 0) && + "invalid number of vector elements"); assert(ScalarSizeInBits > 0 && "invalid vector element size"); - return LLT{/*isPointer=*/false, /*isVector=*/true, NumElements, - ScalarSizeInBits, /*AddressSpace=*/0}; + return LLT{/*isPointer=*/false, /*isVector=*/true, + ElementCount::get(NumElements, Scalable), ScalarSizeInBits, + /*AddressSpace=*/0}; } /// Get a low-level vector of some number of elements and element type. - static LLT vector(uint16_t NumElements, LLT ScalarTy) { - assert(NumElements > 1 && "invalid number of vector elements"); + static LLT vector(uint16_t NumElements, LLT ScalarTy, bool Scalable = false) { + assert(((!Scalable && NumElements > 1) || NumElements > 0) && + "invalid number of vector elements"); assert(!ScalarTy.isVector() && "invalid vector element type"); - return LLT{ScalarTy.isPointer(), /*isVector=*/true, NumElements, + return LLT{ScalarTy.isPointer(), /*isVector=*/true, + ElementCount::get(NumElements, Scalable), ScalarTy.getSizeInBits(), ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0}; } @@ -79,9 +85,9 @@ return scalarOrVector(NumElements, LLT::scalar(ScalarSize)); } - explicit LLT(bool isPointer, bool isVector, uint16_t NumElements, + explicit LLT(bool isPointer, bool isVector, ElementCount EC, unsigned SizeInBits, unsigned AddressSpace) { - init(isPointer, isVector, NumElements, SizeInBits, AddressSpace); + init(isPointer, isVector, EC, SizeInBits, AddressSpace); } explicit LLT() : IsPointer(false), IsVector(false), RawData(0) {} @@ -98,18 +104,37 @@ /// Returns the number of elements in a vector LLT. Must only be called on /// vector types. uint16_t getNumElements() const { + if (isScalable()) + llvm::reportInvalidSizeRequest( + "Possible incorrect use of LLT::getNumElements() for " + "scalable vector. Scalable flag may be dropped, use " + "LLT::getElementCount() instead"); + return getElementCount().getKnownMinValue(); + } + + /// Returns true if the LLT is a scalable vector. Must only be called on + /// vector types. + bool isScalable() const { + assert(isVector() && "Expected a vector type"); + return IsPointer ? getFieldValue(PointerVectorScalableFieldInfo) + : getFieldValue(VectorScalableFieldInfo); + } + + ElementCount getElementCount() const { assert(IsVector && "cannot get number of elements on scalar/aggregate"); - if (!IsPointer) - return getFieldValue(VectorElementsFieldInfo); - else - return getFieldValue(PointerVectorElementsFieldInfo); + return ElementCount::get(IsPointer + ? getFieldValue(PointerVectorElementsFieldInfo) + : getFieldValue(VectorElementsFieldInfo), + isScalable()); } /// Returns the total size of the type. Must only be called on sized types. unsigned getSizeInBits() const { if (isPointer() || isScalar()) return getScalarSizeInBits(); - return getScalarSizeInBits() * getNumElements(); + // FIXME: This should return a TypeSize in order to work for scalable + // vectors. + return getScalarSizeInBits() * getElementCount().getKnownMinValue(); } /// Returns the total size of the type in bytes, i.e. number of whole bytes @@ -125,7 +150,9 @@ /// If this type is a vector, return a vector with the same number of elements /// but the new element type. Otherwise, return the new element type. LLT changeElementType(LLT NewEltTy) const { - return isVector() ? LLT::vector(getNumElements(), NewEltTy) : NewEltTy; + return isVector() ? LLT::vector(getElementCount().getKnownMinValue(), + NewEltTy, isScalable()) + : NewEltTy; } /// If this type is a vector, return a vector with the same number of elements @@ -134,13 +161,16 @@ LLT changeElementSize(unsigned NewEltSize) const { assert(!getScalarType().isPointer() && "invalid to directly change element size for pointers"); - return isVector() ? LLT::vector(getNumElements(), NewEltSize) + return isVector() ? LLT::vector(getElementCount().getKnownMinValue(), + NewEltSize, isScalable()) : LLT::scalar(NewEltSize); } /// Return a vector or scalar with the same element type and the new number of /// elements. LLT changeNumElements(unsigned NewNumElts) const { + assert((!isVector() || !isScalable()) && + "Cannot use changeNumElements on a scalable vector"); return LLT::scalarOrVector(NewNumElts, getScalarType()); } @@ -237,22 +267,37 @@ static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 0}; static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{ 24, PointerSizeFieldInfo[0] + PointerSizeFieldInfo[1]}; + static_assert((PointerAddressSpaceFieldInfo[0] + + PointerAddressSpaceFieldInfo[1]) <= 62, + "Insufficient bits to encode all data"); /// * Vector-of-non-pointer (isPointer == 0 && isVector == 1): /// NumElements: 16; /// SizeOfElement: 32; + /// Scalable: 1; static const constexpr BitFieldInfo VectorElementsFieldInfo{16, 0}; static const constexpr BitFieldInfo VectorSizeFieldInfo{ 32, VectorElementsFieldInfo[0] + VectorElementsFieldInfo[1]}; + static const constexpr BitFieldInfo VectorScalableFieldInfo{ + 1, VectorSizeFieldInfo[0] + VectorSizeFieldInfo[1]}; + static_assert((VectorSizeFieldInfo[0] + VectorSizeFieldInfo[1]) <= 62, + "Insufficient bits to encode all data"); /// * Vector-of-pointer (isPointer == 1 && isVector == 1): /// NumElements: 16; /// SizeOfElement: 16; /// AddressSpace: 24; + /// Scalable: 1; static const constexpr BitFieldInfo PointerVectorElementsFieldInfo{16, 0}; static const constexpr BitFieldInfo PointerVectorSizeFieldInfo{ 16, PointerVectorElementsFieldInfo[1] + PointerVectorElementsFieldInfo[0]}; static const constexpr BitFieldInfo PointerVectorAddressSpaceFieldInfo{ 24, PointerVectorSizeFieldInfo[1] + PointerVectorSizeFieldInfo[0]}; + static const constexpr BitFieldInfo PointerVectorScalableFieldInfo{ + 1, PointerVectorAddressSpaceFieldInfo[0] + + PointerVectorAddressSpaceFieldInfo[1]}; + static_assert((PointerVectorAddressSpaceFieldInfo[0] + + PointerVectorAddressSpaceFieldInfo[1]) <= 62, + "Insufficient bits to encode all data"); uint64_t IsPointer : 1; uint64_t IsVector : 1; @@ -273,8 +318,8 @@ return getMask(FieldInfo) & (RawData >> FieldInfo[1]); } - void init(bool IsPointer, bool IsVector, uint16_t NumElements, - unsigned SizeInBits, unsigned AddressSpace) { + void init(bool IsPointer, bool IsVector, ElementCount EC, unsigned SizeInBits, + unsigned AddressSpace) { this->IsPointer = IsPointer; this->IsVector = IsVector; if (!IsVector) { @@ -284,15 +329,20 @@ RawData = maskAndShift(SizeInBits, PointerSizeFieldInfo) | maskAndShift(AddressSpace, PointerAddressSpaceFieldInfo); } else { - assert(NumElements > 1 && "invalid number of vector elements"); + assert(EC.isVector() && "invalid number of vector elements"); if (!IsPointer) - RawData = maskAndShift(NumElements, VectorElementsFieldInfo) | - maskAndShift(SizeInBits, VectorSizeFieldInfo); + RawData = + maskAndShift(EC.getKnownMinValue(), VectorElementsFieldInfo) | + maskAndShift(SizeInBits, VectorSizeFieldInfo) | + maskAndShift(EC.isScalable() ? 1 : 0, VectorScalableFieldInfo); else RawData = - maskAndShift(NumElements, PointerVectorElementsFieldInfo) | + maskAndShift(EC.getKnownMinValue(), + PointerVectorElementsFieldInfo) | maskAndShift(SizeInBits, PointerVectorSizeFieldInfo) | - maskAndShift(AddressSpace, PointerVectorAddressSpaceFieldInfo); + maskAndShift(AddressSpace, PointerVectorAddressSpaceFieldInfo) | + maskAndShift(EC.isScalable() ? 1 : 0, + PointerVectorScalableFieldInfo); } } diff --git a/llvm/lib/CodeGen/LowLevelType.cpp b/llvm/lib/CodeGen/LowLevelType.cpp --- a/llvm/lib/CodeGen/LowLevelType.cpp +++ b/llvm/lib/CodeGen/LowLevelType.cpp @@ -20,11 +20,11 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) { if (auto VTy = dyn_cast(&Ty)) { - auto NumElements = cast(VTy)->getNumElements(); + auto EC = VTy->getElementCount(); LLT ScalarTy = getLLTForType(*VTy->getElementType(), DL); - if (NumElements == 1) + if (EC.isScalar()) return ScalarTy; - return LLT::vector(NumElements, ScalarTy); + return LLT::vector(EC.getKnownMinValue(), ScalarTy, EC.isScalable()); } if (auto PTy = dyn_cast(&Ty)) { diff --git a/llvm/lib/Support/LowLevelType.cpp b/llvm/lib/Support/LowLevelType.cpp --- a/llvm/lib/Support/LowLevelType.cpp +++ b/llvm/lib/Support/LowLevelType.cpp @@ -18,13 +18,13 @@ LLT::LLT(MVT VT) { if (VT.isVector()) { init(/*IsPointer=*/false, VT.getVectorNumElements() > 1, - VT.getVectorNumElements(), VT.getVectorElementType().getSizeInBits(), + VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(), /*AddressSpace=*/0); } else if (VT.isValid()) { // Aggregates are no different from real scalars as far as GlobalISel is // concerned. assert(VT.getSizeInBits().isNonZero() && "invalid zero-sized type"); - init(/*IsPointer=*/false, /*IsVector=*/false, /*NumElements=*/0, + init(/*IsPointer=*/false, /*IsVector=*/false, ElementCount::getFixed(0), VT.getSizeInBits(), /*AddressSpace=*/0); } else { IsPointer = false; @@ -34,9 +34,10 @@ } void LLT::print(raw_ostream &OS) const { - if (isVector()) - OS << "<" << getNumElements() << " x " << getElementType() << ">"; - else if (isPointer()) + if (isVector()) { + OS << "<"; + OS << getElementCount() << " x " << getElementType() << ">"; + } else if (isPointer()) OS << "p" << getAddressSpace(); else if (isValid()) { assert(isScalar() && "unexpected type"); @@ -49,7 +50,9 @@ const constexpr LLT::BitFieldInfo LLT::PointerSizeFieldInfo; const constexpr LLT::BitFieldInfo LLT::PointerAddressSpaceFieldInfo; const constexpr LLT::BitFieldInfo LLT::VectorElementsFieldInfo; +const constexpr LLT::BitFieldInfo LLT::VectorScalableFieldInfo; const constexpr LLT::BitFieldInfo LLT::VectorSizeFieldInfo; const constexpr LLT::BitFieldInfo LLT::PointerVectorElementsFieldInfo; +const constexpr LLT::BitFieldInfo LLT::PointerVectorScalableFieldInfo; const constexpr LLT::BitFieldInfo LLT::PointerVectorSizeFieldInfo; const constexpr LLT::BitFieldInfo LLT::PointerVectorAddressSpaceFieldInfo; diff --git a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp --- a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp +++ b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp @@ -11,6 +11,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Type.h" +#include "llvm/Support/TypeSize.h" #include "gtest/gtest.h" using namespace llvm; @@ -50,13 +51,19 @@ DataLayout DL(""); for (unsigned S : {1U, 17U, 32U, 64U, 0xfffU}) { - for (uint16_t Elts : {2U, 3U, 4U, 32U, 0xffU}) { + for (auto EC : + {ElementCount::getFixed(2), ElementCount::getFixed(3), + ElementCount::getFixed(4), ElementCount::getFixed(32), + ElementCount::getFixed(0xff), ElementCount::getScalable(2), + ElementCount::getScalable(3), ElementCount::getScalable(4), + ElementCount::getScalable(32), ElementCount::getScalable(0xff)}) { const LLT STy = LLT::scalar(S); - const LLT VTy = LLT::vector(Elts, S); + const LLT VTy = LLT::vector(EC.getKnownMinValue(), S, EC.isScalable()); // Test the alternative vector(). { - const LLT VSTy = LLT::vector(Elts, STy); + const LLT VSTy = + LLT::vector(EC.getKnownMinValue(), STy, EC.isScalable()); EXPECT_EQ(VTy, VSTy); } @@ -71,9 +78,10 @@ ASSERT_FALSE(VTy.isPointer()); // Test sizes. - EXPECT_EQ(S * Elts, VTy.getSizeInBits()); EXPECT_EQ(S, VTy.getScalarSizeInBits()); - EXPECT_EQ(Elts, VTy.getNumElements()); + EXPECT_EQ(EC, VTy.getElementCount()); + if (!EC.isScalable()) + EXPECT_EQ(S * EC.getFixedValue(), VTy.getSizeInBits()); // Test equality operators. EXPECT_TRUE(VTy == VTy); @@ -85,7 +93,7 @@ // Test Type->LLT conversion. Type *IRSTy = IntegerType::get(C, S); - Type *IRTy = FixedVectorType::get(IRSTy, Elts); + Type *IRTy = VectorType::get(IRSTy, EC); EXPECT_EQ(VTy, getLLTForType(*IRTy, DL)); } } @@ -136,6 +144,22 @@ EXPECT_EQ(V2P1, V2P0.changeElementType(P1)); EXPECT_EQ(V2S32, V2P0.changeElementType(S32)); + + // Similar tests for for scalable vectors. + const LLT NXV2S32 = LLT::vector(2, 32, true); + const LLT NXV2S64 = LLT::vector(2, 64, true); + + const LLT NXV2P0 = LLT::vector(2, P0, true); + const LLT NXV2P1 = LLT::vector(2, P1, true); + + EXPECT_EQ(NXV2S64, NXV2S32.changeElementType(S64)); + EXPECT_EQ(NXV2S32, NXV2S64.changeElementType(S32)); + + EXPECT_EQ(NXV2S64, NXV2S32.changeElementSize(64)); + EXPECT_EQ(NXV2S32, NXV2S64.changeElementSize(32)); + + EXPECT_EQ(NXV2P1, NXV2P0.changeElementType(P1)); + EXPECT_EQ(NXV2S32, NXV2P0.changeElementType(S32)); } TEST(LowLevelTypeTest, ChangeNumElements) { @@ -191,9 +215,14 @@ for (unsigned AS : {0U, 1U, 127U, 0xffffU, static_cast(maxUIntN(23)), static_cast(maxUIntN(24))}) { - for (unsigned NumElts : {2, 3, 4, 256, 65535}) { + for (ElementCount EC : + {ElementCount::getFixed(2), ElementCount::getFixed(3), + ElementCount::getFixed(4), ElementCount::getFixed(256), + ElementCount::getFixed(65535), ElementCount::getScalable(2), + ElementCount::getScalable(3), ElementCount::getScalable(4), + ElementCount::getScalable(256), ElementCount::getScalable(65535)}) { const LLT Ty = LLT::pointer(AS, DL.getPointerSizeInBits(AS)); - const LLT VTy = LLT::vector(NumElts, Ty); + const LLT VTy = LLT::vector(EC.getKnownMinValue(), Ty, EC.isScalable()); // Test kind. ASSERT_TRUE(Ty.isValid()); @@ -222,8 +251,8 @@ // Test Type->LLT conversion. Type *IRTy = PointerType::get(IntegerType::get(C, 8), AS); EXPECT_EQ(Ty, getLLTForType(*IRTy, DL)); - Type *IRVTy = FixedVectorType::get( - PointerType::get(IntegerType::get(C, 8), AS), NumElts); + Type *IRVTy = + VectorType::get(PointerType::get(IntegerType::get(C, 8), AS), EC); EXPECT_EQ(VTy, getLLTForType(*IRVTy, DL)); } } diff --git a/llvm/utils/TableGen/GlobalISelEmitter.cpp b/llvm/utils/TableGen/GlobalISelEmitter.cpp --- a/llvm/utils/TableGen/GlobalISelEmitter.cpp +++ b/llvm/utils/TableGen/GlobalISelEmitter.cpp @@ -118,7 +118,9 @@ return; } if (Ty.isVector()) { - OS << "GILLT_v" << Ty.getNumElements() << "s" << Ty.getScalarSizeInBits(); + OS << (Ty.isScalable() ? "GILLT_nxv" : "GILLT_v") + << Ty.getElementCount().getKnownMinValue() << "s" + << Ty.getScalarSizeInBits(); return; } if (Ty.isPointer()) { @@ -136,8 +138,8 @@ return; } if (Ty.isVector()) { - OS << "LLT::vector(" << Ty.getNumElements() << ", " - << Ty.getScalarSizeInBits() << ")"; + OS << "LLT::vector(" << Ty.getElementCount().getKnownMinValue() << ", " + << Ty.getScalarSizeInBits() << ", " << Ty.isScalable() << ")"; return; } if (Ty.isPointer() && Ty.getSizeInBits() > 0) { @@ -169,9 +171,14 @@ if (Ty.isPointer() && Ty.getAddressSpace() != Other.Ty.getAddressSpace()) return Ty.getAddressSpace() < Other.Ty.getAddressSpace(); - if (Ty.isVector() && Ty.getNumElements() != Other.Ty.getNumElements()) - return Ty.getNumElements() < Other.Ty.getNumElements(); + if (Ty.isVector() && Ty.getElementCount() != Other.Ty.getElementCount()) + return std::make_tuple(Ty.isScalable(), + Ty.getElementCount().getKnownMinValue()) < + std::make_tuple(Other.Ty.isScalable(), + Other.Ty.getElementCount().getKnownMinValue()); + assert((!Ty.isVector() || Ty.isScalable() == Other.Ty.isScalable()) && + "Unexpected mismatch of scalable property"); return Ty.getSizeInBits() < Other.Ty.getSizeInBits(); } @@ -187,12 +194,10 @@ static Optional MVTToLLT(MVT::SimpleValueType SVT) { MVT VT(SVT); - if (VT.isScalableVector()) - return None; - - if (VT.isFixedLengthVector() && VT.getVectorNumElements() != 1) - return LLTCodeGen( - LLT::vector(VT.getVectorNumElements(), VT.getScalarSizeInBits())); + if (VT.isVector() && !VT.getVectorElementCount().isScalar()) + return LLTCodeGen(LLT::vector(VT.getVectorNumElements(), + VT.getScalarSizeInBits(), + VT.isScalableVector())); if (VT.isInteger() || VT.isFloatingPoint()) return LLTCodeGen(LLT::scalar(VT.getSizeInBits()));