Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -778,7 +778,7 @@ unsigned Cost = LT.first; if (Src->isVectorTy() && - Src->getPrimitiveSizeInBits() < LT.second.getSizeInBits()) { + Src->getScalableSizeInBits().MinSize < LT.second.getSizeInBits()) { // This is a vector load that legalizes to a larger type than the vector // itself. Unless the corresponding extending load or truncating store is // legal, then this will scalarize. Index: include/llvm/IR/DataLayout.h =================================================================== --- include/llvm/IR/DataLayout.h +++ include/llvm/IR/DataLayout.h @@ -404,6 +404,8 @@ /// have a size (Type::isSized() must return true). uint64_t getTypeSizeInBits(Type *Ty) const; + ScalableSize getScalableTypeSizeInBits(Type *Ty) const; + /// Returns the maximum number of bytes that may be overwritten by /// storing the specified type. /// @@ -412,6 +414,12 @@ return (getTypeSizeInBits(Ty) + 7) / 8; } + ScalableSize getScalableTypeStoreSize(Type *Ty) const { + // Is overloading bits/bytes wise? + auto Bits = getScalableTypeSizeInBits(Ty); + return ScalableSize((Bits.MinSize+7)/8, Bits.Scalable); + } + /// Returns the maximum number of bits that may be overwritten by /// storing the specified type; always a multiple of 8. /// @@ -420,6 +428,11 @@ return 8 * getTypeStoreSize(Ty); } + ScalableSize getScalableTypeStoreSizeInBits(Type *Ty) const { + auto Bytes = getScalableTypeStoreSize(Ty); + return {Bytes.MinSize * 8, Bytes.Scalable}; + } + /// Returns the offset in bytes between successive objects of the /// specified type, including alignment padding. /// @@ -430,6 +443,13 @@ return alignTo(getTypeStoreSize(Ty), getABITypeAlignment(Ty)); } + ScalableSize getScalableTypeAllocSize(Type *Ty) const { + auto Bytes = getScalableTypeStoreSize(Ty); + Bytes.MinSize = alignTo(Bytes.MinSize, getABITypeAlignment(Ty)); + + return Bytes; + } + /// Returns the offset in bits between successive objects of the /// specified type, including alignment padding; always a multiple of 8. /// @@ -439,6 +459,11 @@ return 8 * getTypeAllocSize(Ty); } + ScalableSize getScalableTypeAllocSizeInBits(Type *Ty) const { + auto Bytes = getScalableTypeAllocSize(Ty); + return {Bytes.MinSize * 8, Bytes.Scalable}; + } + /// Returns the minimum ABI-required alignment for the specified type. unsigned getABITypeAlignment(Type *Ty) const; @@ -590,6 +615,8 @@ return 80; case Type::VectorTyID: { VectorType *VTy = cast(Ty); + assert(!VTy->isScalable() && + "Scalable vector sizes cannot be represented by a scalar"); return VTy->getNumElements() * getTypeSizeInBits(VTy->getElementType()); } default: @@ -597,6 +624,20 @@ } } +inline ScalableSize +DataLayout::getScalableTypeSizeInBits(Type *Ty) const { + switch(Ty->getTypeID()) { + default: + return {getTypeSizeInBits(Ty), false}; + case Type::VectorTyID: { + VectorType *VTy = cast(Ty); + auto EltCnt = VTy->getElementCount(); + uint64_t MinBits = EltCnt.Min * getTypeSizeInBits(VTy->getElementType()); + return {MinBits, EltCnt.Scalable}; + } + } +} + } // end namespace llvm #endif // LLVM_IR_DATALAYOUT_H Index: include/llvm/IR/Type.h =================================================================== --- include/llvm/IR/Type.h +++ include/llvm/IR/Type.h @@ -22,6 +22,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -289,6 +290,11 @@ /// unsigned getPrimitiveSizeInBits() const LLVM_READONLY; + // Returns a ScalableSize for the type in question. This should be used in + // place of getPrimitiveSizeInBits in places where the type may be a + // VectorType with the Scalable flag set. + ScalableSize getScalableSizeInBits() const LLVM_READONLY; + /// If this is a vector type, return the getPrimitiveSizeInBits value for the /// element type. Otherwise return the getPrimitiveSizeInBits value for this /// type. Index: include/llvm/Support/ScalableSize.h =================================================================== --- include/llvm/Support/ScalableSize.h +++ include/llvm/Support/ScalableSize.h @@ -40,6 +40,58 @@ } }; +struct ScalableSize { + uint64_t MinSize; + bool Scalable; + + ScalableSize(uint64_t MinSize, bool Scalable) + : MinSize(MinSize), Scalable(Scalable) {} + + ScalableSize() {} + + bool operator==(const ScalableSize& RHS) const { + if (Scalable == RHS.Scalable) + return MinSize == RHS.MinSize; + + llvm_unreachable("Size comparison of scalable and fixed types"); + } + + bool operator!=(const ScalableSize& RHS) const { + if (Scalable == RHS.Scalable) + return MinSize != RHS.MinSize; + + llvm_unreachable("Size comparison of scalable and fixed types"); + } + + bool operator<(const ScalableSize& RHS) const { + if (Scalable == RHS.Scalable) + return MinSize < RHS.MinSize; + + llvm_unreachable("Size comparison of scalable and fixed types"); + } + + bool operator<=(const ScalableSize& RHS) const { + if (Scalable == RHS.Scalable) + return MinSize <= RHS.MinSize; + + llvm_unreachable("Size comparison of scalable and fixed types"); + } + + bool operator>(const ScalableSize& RHS) const { + if (Scalable == RHS.Scalable) + return MinSize > RHS.MinSize; + + llvm_unreachable("Size comparison of scalable and fixed types"); + } + + bool operator>=(const ScalableSize& RHS) const { + if (Scalable == RHS.Scalable) + return MinSize >= RHS.MinSize; + + llvm_unreachable("Size comparison of scalable and fixed types"); + } +}; + } // end namespace llvm #endif // LLVM_SCALABLESIZE_H Index: lib/Analysis/InstructionSimplify.cpp =================================================================== --- lib/Analysis/InstructionSimplify.cpp +++ lib/Analysis/InstructionSimplify.cpp @@ -3304,7 +3304,7 @@ // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input // if the integer type is the same size as the pointer type. if (MaxRecurse && isa(LI) && - Q.DL.getTypeSizeInBits(SrcTy) == DstTy->getPrimitiveSizeInBits()) { + Q.DL.getScalableTypeSizeInBits(SrcTy) == DstTy->getScalableSizeInBits()) { if (Constant *RHSC = dyn_cast(RHS)) { // Transfer the cast to the constant. if (Value *V = SimplifyICmpInst(Pred, SrcOp, Index: lib/CodeGen/CodeGenPrepare.cpp =================================================================== --- lib/CodeGen/CodeGenPrepare.cpp +++ lib/CodeGen/CodeGenPrepare.cpp @@ -6284,6 +6284,12 @@ const TargetLowering &TLI) { // Handle simple but common cases only. Type *StoreType = SI.getValueOperand()->getType(); + + // Don't try this with scalable vectors for now + if (auto *VTy = dyn_cast(StoreType)) + if (VTy->isScalable()) + return false; + if (DL.getTypeStoreSizeInBits(StoreType) != DL.getTypeSizeInBits(StoreType) || DL.getTypeSizeInBits(StoreType) == 0) return false; Index: lib/IR/DataLayout.cpp =================================================================== --- lib/IR/DataLayout.cpp +++ lib/IR/DataLayout.cpp @@ -716,7 +716,9 @@ llvm_unreachable("Bad type for getAlignment!!!"); } - return getAlignmentInfo(AlignType, getTypeSizeInBits(Ty), abi_or_pref, Ty); + // We only care about the minimum size for alignment + auto Size = getScalableTypeSizeInBits(Ty); + return getAlignmentInfo(AlignType, Size.MinSize, abi_or_pref, Ty); } unsigned DataLayout::getABITypeAlignment(Type *Ty) const { Index: lib/IR/Instructions.cpp =================================================================== --- lib/IR/Instructions.cpp +++ lib/IR/Instructions.cpp @@ -2684,8 +2684,8 @@ } // Get the bit sizes, we'll need these - unsigned SrcBits = SrcTy->getPrimitiveSizeInBits(); // 0 for ptr - unsigned DestBits = DestTy->getPrimitiveSizeInBits(); // 0 for ptr + auto SrcBits = SrcTy->getScalableSizeInBits(); // 0 for ptr + auto DestBits = DestTy->getScalableSizeInBits(); // 0 for ptr // Run through the possibilities ... if (DestTy->isIntegerTy()) { // Casting to integral @@ -2746,12 +2746,12 @@ } } - unsigned SrcBits = SrcTy->getPrimitiveSizeInBits(); // 0 for ptr - unsigned DestBits = DestTy->getPrimitiveSizeInBits(); // 0 for ptr + auto SrcBits = SrcTy->getScalableSizeInBits(); // 0 for ptr + auto DestBits = DestTy->getScalableSizeInBits(); // 0 for ptr // Could still have vectors of pointers if the number of elements doesn't // match - if (SrcBits == 0 || DestBits == 0) + if (SrcBits.MinSize == 0 || DestBits.MinSize == 0) return false; if (SrcBits != DestBits) @@ -2961,7 +2961,8 @@ // For non-pointer cases, the cast is okay if the source and destination bit // widths are identical. if (!SrcPtrTy) - return SrcTy->getPrimitiveSizeInBits() == DstTy->getPrimitiveSizeInBits(); + return SrcTy->getScalableSizeInBits() == + DstTy->getScalableSizeInBits(); // If both are pointers then the address spaces must match. if (SrcPtrTy->getAddressSpace() != DstPtrTy->getAddressSpace()) Index: lib/IR/Type.cpp =================================================================== --- lib/IR/Type.cpp +++ lib/IR/Type.cpp @@ -122,7 +122,11 @@ case Type::PPC_FP128TyID: return 128; case Type::X86_MMXTyID: return 64; case Type::IntegerTyID: return cast(this)->getBitWidth(); - case Type::VectorTyID: return cast(this)->getBitWidth(); + case Type::VectorTyID: { + const VectorType *VTy = cast(this); + assert(!VTy->isScalable() && "Scalable vectors are not a primitive type"); + return VTy->getBitWidth(); + } default: return 0; } } @@ -131,6 +135,13 @@ return getScalarType()->getPrimitiveSizeInBits(); } +ScalableSize Type::getScalableSizeInBits() const { + if (auto *VTy = dyn_cast(this)) + return {VTy->getBitWidth(), VTy->isScalable()}; + + return {getPrimitiveSizeInBits(), false}; +} + int Type::getFPMantissaWidth() const { if (auto *VTy = dyn_cast(this)) return VTy->getElementType()->getFPMantissaWidth(); Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -8682,6 +8682,11 @@ if (AM.HasBaseReg && AM.BaseOffs && AM.Scale) return false; + // For now, just allow base reg only addressing for SVE + if (auto *VTy = dyn_cast(Ty)) + if (VTy->isScalable()) + return AM.HasBaseReg && !(AM.BaseOffs || AM.Scale); + // check reg + imm case: // i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12 uint64_t NumBytes = 0; Index: lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCalls.cpp +++ lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -3004,8 +3004,8 @@ if (match(Mask, m_SExt(m_Value(BoolVec))) && BoolVec->getType()->isVectorTy() && BoolVec->getType()->getScalarSizeInBits() == 1) { - assert(Mask->getType()->getPrimitiveSizeInBits() == - II->getType()->getPrimitiveSizeInBits() && + assert(Mask->getType()->getScalableSizeInBits() == + II->getType()->getScalableSizeInBits()&& "Not expecting mask and operands with different sizes"); unsigned NumMaskElts = Mask->getType()->getVectorNumElements(); Index: unittests/IR/VectorTypesTest.cpp =================================================================== --- unittests/IR/VectorTypesTest.cpp +++ unittests/IR/VectorTypesTest.cpp @@ -7,8 +7,10 @@ // //===----------------------------------------------------------------------===// +#include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Type.h" #include "llvm/Support/ScalableSize.h" #include "gtest/gtest.h" using namespace llvm; @@ -88,4 +90,54 @@ ASSERT_TRUE(EltCnt.Scalable); } +TEST(VectorTypesTest, FixedLenComparisons) { + LLVMContext Ctx; + + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *Int64Ty = Type::getInt64Ty(Ctx); + + VectorType *V2Int32Ty = VectorType::get(Int32Ty, 2); + VectorType *V4Int32Ty = VectorType::get(Int32Ty, 4); + + VectorType *V2Int64Ty = VectorType::get(Int64Ty, 2); + + ScalableSize V2I32Len = V2Int32Ty->getScalableSizeInBits(); + EXPECT_EQ(V2I32Len.MinSize, 64U); + EXPECT_FALSE(V2I32Len.Scalable); + + EXPECT_LT(V2Int32Ty->getScalableSizeInBits(), + V4Int32Ty->getScalableSizeInBits()); + EXPECT_GT(V2Int64Ty->getScalableSizeInBits(), + V2Int32Ty->getScalableSizeInBits()); + EXPECT_EQ(V4Int32Ty->getScalableSizeInBits(), + V2Int64Ty->getScalableSizeInBits()); + EXPECT_NE(V2Int32Ty->getScalableSizeInBits(), + V2Int64Ty->getScalableSizeInBits()); +} + +TEST(VectorTypesTest, ScalableComparisons) { + LLVMContext Ctx; + + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *Int64Ty = Type::getInt64Ty(Ctx); + + VectorType *ScV2Int32Ty = VectorType::get(Int32Ty, {2, true}); + VectorType *ScV4Int32Ty = VectorType::get(Int32Ty, {4, true}); + + VectorType *ScV2Int64Ty = VectorType::get(Int64Ty, {2, true}); + + ScalableSize ScV2I32Len = ScV2Int32Ty->getScalableSizeInBits(); + EXPECT_EQ(ScV2I32Len.MinSize, 64U); + EXPECT_TRUE(ScV2I32Len.Scalable); + + EXPECT_LT(ScV2Int32Ty->getScalableSizeInBits(), + ScV4Int32Ty->getScalableSizeInBits()); + EXPECT_GT(ScV2Int64Ty->getScalableSizeInBits(), + ScV2Int32Ty->getScalableSizeInBits()); + EXPECT_EQ(ScV4Int32Ty->getScalableSizeInBits(), + ScV2Int64Ty->getScalableSizeInBits()); + EXPECT_NE(ScV2Int32Ty->getScalableSizeInBits(), + ScV2Int64Ty->getScalableSizeInBits()); +} + }