diff --git a/llvm/include/llvm/Analysis/MemoryLocation.h b/llvm/include/llvm/Analysis/MemoryLocation.h --- a/llvm/include/llvm/Analysis/MemoryLocation.h +++ b/llvm/include/llvm/Analysis/MemoryLocation.h @@ -19,6 +19,7 @@ #include "llvm/ADT/Optional.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Metadata.h" +#include "llvm/Support/ScalableSize.h" namespace llvm { @@ -89,6 +90,10 @@ static LocationSize precise(uint64_t Value) { return LocationSize(Value); } + static LocationSize precise(ScalableSize SSize) { + return LocationSize(SSize.Scalable ? Unknown : SSize.MinSize); + } + static LocationSize upperBound(uint64_t Value) { // You can't go lower than 0, so give a precise result. if (LLVM_UNLIKELY(Value == 0)) @@ -245,6 +250,14 @@ const AAMDNodes &AATags = AAMDNodes()) : Ptr(Ptr), Size(Size), AATags(AATags) {} + explicit MemoryLocation(const Value *Ptr, + ScalableSize SSize, + const AAMDNodes &AATags = AAMDNodes()) + : Ptr(Ptr), Size(SSize.MinSize), AATags(AATags) { + if (SSize.Scalable) + Size = UnknownSize; + } + MemoryLocation getWithNewPtr(const Value *NewPtr) const { MemoryLocation Copy(*this); Copy.Ptr = NewPtr; diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -746,7 +746,7 @@ uint64_t Field = ConstIdx->getZExtValue(); BaseOffset += DL.getStructLayout(STy)->getElementOffset(Field); } else { - int64_t ElementSize = DL.getTypeAllocSize(GTI.getIndexedType()); + int64_t ElementSize = DL.getMinTypeAllocSize(GTI.getIndexedType()); if (ConstIdx) { BaseOffset += ConstIdx->getValue().sextOrTrunc(PtrSizeBits) * ElementSize; diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -654,7 +654,7 @@ // Check for NOOP conversions. if (SrcLT.first == DstLT.first && - SrcLT.second.getSizeInBits() == DstLT.second.getSizeInBits()) { + SrcLT.second.getScalableSizeInBits() == DstLT.second.getScalableSizeInBits()) { // Bitcast between types that are legalized to the same type are free. if (Opcode == Instruction::BitCast || Opcode == Instruction::Trunc) @@ -838,7 +838,7 @@ unsigned Cost = LT.first; if (Src->isVectorTy() && - Src->getPrimitiveSizeInBits() < LT.second.getSizeInBits()) { + Src->getScalableSizeInBits() < LT.second.getScalableSizeInBits()) { // This is a vector load that legalizes to a larger type than the vector // itself. Unless the corresponding extending load or truncating store is // legal, then this will scalarize. diff --git a/llvm/include/llvm/CodeGen/MachineFunction.h b/llvm/include/llvm/CodeGen/MachineFunction.h --- a/llvm/include/llvm/CodeGen/MachineFunction.h +++ b/llvm/include/llvm/CodeGen/MachineFunction.h @@ -746,6 +746,15 @@ AtomicOrdering Ordering = AtomicOrdering::NotAtomic, AtomicOrdering FailureOrdering = AtomicOrdering::NotAtomic); + MachineMemOperand *getMachineMemOperand( + MachinePointerInfo PtrInfo, MachineMemOperand::Flags f, + ScalableSize s, unsigned base_alignment, + const AAMDNodes &AAInfo = AAMDNodes(), + const MDNode *Ranges = nullptr, + SyncScope::ID SSID = SyncScope::System, + AtomicOrdering Ordering = AtomicOrdering::NotAtomic, + AtomicOrdering FailureOrdering = AtomicOrdering::NotAtomic); + /// getMachineMemOperand - Allocate a new MachineMemOperand by copying /// an existing one, adjusting by an offset and using the given size. /// MachineMemOperands are owned by the MachineFunction and need not be diff --git a/llvm/include/llvm/CodeGen/MachineMemOperand.h b/llvm/include/llvm/CodeGen/MachineMemOperand.h --- a/llvm/include/llvm/CodeGen/MachineMemOperand.h +++ b/llvm/include/llvm/CodeGen/MachineMemOperand.h @@ -21,6 +21,7 @@ #include "llvm/IR/Value.h" // PointerLikeTypeTraits #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/DataTypes.h" +#include "llvm/Support/ScalableSize.h" namespace llvm { @@ -167,7 +168,7 @@ }; MachinePointerInfo PtrInfo; - uint64_t Size; + ScalableSize Size; Flags FlagVals; uint16_t BaseAlignLog2; // log_2(base_alignment) + 1 MachineAtomicInfo AtomicInfo; @@ -188,6 +189,14 @@ AtomicOrdering Ordering = AtomicOrdering::NotAtomic, AtomicOrdering FailureOrdering = AtomicOrdering::NotAtomic); + MachineMemOperand(MachinePointerInfo PtrInfo, Flags flags, + ScalableSize s, uint64_t a, + const AAMDNodes &AAInfo = AAMDNodes(), + const MDNode *Ranges = nullptr, + SyncScope::ID SSID = SyncScope::System, + AtomicOrdering Ordering = AtomicOrdering::NotAtomic, + AtomicOrdering FailureOrdering = AtomicOrdering::NotAtomic); + const MachinePointerInfo &getPointerInfo() const { return PtrInfo; } /// Return the base address of the memory access. This may either be a normal @@ -218,10 +227,16 @@ unsigned getAddrSpace() const { return PtrInfo.getAddrSpace(); } /// Return the size in bytes of the memory reference. - uint64_t getSize() const { return Size; } + uint64_t getSize() const { return Size.getFixedSize(); } + + /// Return the size in (potentially scalable) bytes + ScalableSize getScalableSize() const { return Size; } /// Return the size in bits of the memory reference. - uint64_t getSizeInBits() const { return Size * 8; } + uint64_t getSizeInBits() const { return Size.getFixedSize() * 8; } + + /// Return the size in (potentially scalable) bits + ScalableSize getScalableSizeInBits() const { return Size * 8; } /// Return the minimum known alignment in bytes of the actual memory /// reference. diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -42,6 +42,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MachineValueType.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -174,6 +175,14 @@ return getValueType().getSizeInBits(); } + ScalableSize getValueScalableSizeInBits() const { + return getValueType().getScalableSizeInBits(); + } + + unsigned getValueMinSizeInBits() const { + return getValueType().getMinSizeInBits(); + } + unsigned getScalarValueSizeInBits() const { return getValueType().getScalarType().getSizeInBits(); } @@ -1015,6 +1024,14 @@ return getValueType(ResNo).getSizeInBits(); } + ScalableSize getValueScalableSizeInBits(unsigned ResNo) const { + return getValueType(ResNo).getScalableSizeInBits(); + } + + unsigned getValueMinSizeInBits(unsigned ResNo) const { + return getValueType(ResNo).getMinSizeInBits(); + } + using value_iterator = const EVT *; value_iterator value_begin() const { return ValueList; } diff --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h --- a/llvm/include/llvm/CodeGen/ValueTypes.h +++ b/llvm/include/llvm/CodeGen/ValueTypes.h @@ -18,6 +18,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/MachineValueType.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -209,7 +210,8 @@ /// Return true if the bit size is a multiple of 8. bool isByteSized() const { - return (getSizeInBits() & 7) == 0; + ScalableSize Bits = getScalableSizeInBits(); + return (Bits.MinSize & 7) == 0; } /// Return true if the size is a power-of-two number of bytes. @@ -221,31 +223,31 @@ /// Return true if this has the same number of bits as VT. bool bitsEq(EVT VT) const { if (EVT::operator==(VT)) return true; - return getSizeInBits() == VT.getSizeInBits(); + return getScalableSizeInBits() == VT.getScalableSizeInBits(); } /// Return true if this has more bits than VT. bool bitsGT(EVT VT) const { if (EVT::operator==(VT)) return false; - return getSizeInBits() > VT.getSizeInBits(); + return getScalableSizeInBits() > VT.getScalableSizeInBits(); } /// Return true if this has no less bits than VT. bool bitsGE(EVT VT) const { if (EVT::operator==(VT)) return true; - return getSizeInBits() >= VT.getSizeInBits(); + return getScalableSizeInBits() >= VT.getScalableSizeInBits(); } /// Return true if this has less bits than VT. bool bitsLT(EVT VT) const { if (EVT::operator==(VT)) return false; - return getSizeInBits() < VT.getSizeInBits(); + return getScalableSizeInBits() < VT.getScalableSizeInBits(); } /// Return true if this has no more bits than VT. bool bitsLE(EVT VT) const { if (EVT::operator==(VT)) return true; - return getSizeInBits() <= VT.getSizeInBits(); + return getScalableSizeInBits() <= VT.getScalableSizeInBits(); } /// Return the SimpleValueType held in the specified simple EVT. @@ -294,6 +296,18 @@ return getExtendedSizeInBits(); } + ScalableSize getScalableSizeInBits() const { + if (isSimple()) + return V.getScalableSizeInBits(); + return getScalableExtendedSizeInBits(); + } + + unsigned getMinSizeInBits() const { + if (isSimple()) + return V.getMinSizeInBits(); + return getMinExtendedSizeInBits(); + } + unsigned getScalarSizeInBits() const { return getScalarType().getSizeInBits(); } @@ -304,12 +318,29 @@ return (getSizeInBits() + 7) / 8; } + ScalableSize getScalableStoreSize() const { + ScalableSize SizeInBits = getScalableSizeInBits(); + return { (SizeInBits.MinSize + 7) / 8, SizeInBits.Scalable }; + } + + unsigned getMinStoreSize() const { + return (getMinSizeInBits() + 7) / 8; + } + /// Return the number of bits overwritten by a store of the specified value /// type. unsigned getStoreSizeInBits() const { return getStoreSize() * 8; } + ScalableSize getScalableStoreSizeInBits() const { + return getScalableStoreSize() * 8; + } + + unsigned getMinStoreSizeInBits() const { + return getMinStoreSize() * 8; + } + /// Rounds the bit-width of the given integer EVT up to the nearest power of /// two (and at least to eight), and returns the integer EVT with that /// number of bits. @@ -429,6 +460,8 @@ EVT getExtendedVectorElementType() const; unsigned getExtendedVectorNumElements() const LLVM_READONLY; unsigned getExtendedSizeInBits() const LLVM_READONLY; + ScalableSize getScalableExtendedSizeInBits() const LLVM_READONLY; + unsigned getMinExtendedSizeInBits() const LLVM_READONLY; }; } // end namespace llvm diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h --- a/llvm/include/llvm/IR/DataLayout.h +++ b/llvm/include/llvm/IR/DataLayout.h @@ -29,6 +29,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -437,6 +438,9 @@ /// have a size (Type::isSized() must return true). uint64_t getTypeSizeInBits(Type *Ty) const; + ScalableSize getScalableTypeSizeInBits(Type *Ty) const; + uint64_t getMinTypeSizeInBits(Type *Ty) const; + /// Returns the maximum number of bytes that may be overwritten by /// storing the specified type. /// @@ -445,6 +449,16 @@ return (getTypeSizeInBits(Ty) + 7) / 8; } + ScalableSize getScalableTypeStoreSize(Type *Ty) const { + // Is overloading bits/bytes wise? + auto Bits = getScalableTypeSizeInBits(Ty); + return ScalableSize((Bits.MinSize+7)/8, Bits.Scalable); + } + + uint64_t getMinTypeStoreSize(Type *Ty) const { + return (getScalableTypeSizeInBits(Ty).getMinSize()+7)/8; + } + /// Returns the maximum number of bits that may be overwritten by /// storing the specified type; always a multiple of 8. /// @@ -453,12 +467,21 @@ return 8 * getTypeStoreSize(Ty); } + ScalableSize getScalableTypeStoreSizeInBits(Type *Ty) const { + auto Bytes = getScalableTypeStoreSize(Ty); + return {Bytes.MinSize * 8, Bytes.Scalable}; + } + + uint64_t getMinTypeStoreSizeInBits(Type *Ty) const { + return 8 * getMinTypeStoreSize(Ty); + } + /// Returns true if no extra padding bits are needed when storing the /// specified type. /// /// For example, returns false for i19 that has a 24-bit store size. bool typeSizeEqualsStoreSize(Type *Ty) const { - return getTypeSizeInBits(Ty) == getTypeStoreSizeInBits(Ty); + return getScalableTypeSizeInBits(Ty) == getScalableTypeStoreSizeInBits(Ty); } /// Returns the offset in bytes between successive objects of the @@ -471,6 +494,17 @@ return alignTo(getTypeStoreSize(Ty), getABITypeAlignment(Ty)); } + ScalableSize getScalableTypeAllocSize(Type *Ty) const { + auto Bytes = getScalableTypeStoreSize(Ty); + Bytes.MinSize = alignTo(Bytes.MinSize, getABITypeAlignment(Ty)); + + return Bytes; + } + + uint64_t getMinTypeAllocSize(Type *Ty) const { + return alignTo(getMinTypeStoreSize(Ty), getABITypeAlignment(Ty)); + } + /// Returns the offset in bits between successive objects of the /// specified type, including alignment padding; always a multiple of 8. /// @@ -480,6 +514,15 @@ return 8 * getTypeAllocSize(Ty); } + ScalableSize getScalableTypeAllocSizeInBits(Type *Ty) const { + auto Bytes = getScalableTypeAllocSize(Ty); + return {Bytes.MinSize * 8, Bytes.Scalable}; + } + + uint64_t getMinTypeAllocSizeInBits(Type *Ty) const { + return 8 * getMinTypeAllocSize(Ty); + } + /// Returns the minimum ABI-required alignment for the specified type. unsigned getABITypeAlignment(Type *Ty) const; @@ -631,6 +674,8 @@ return 80; case Type::VectorTyID: { VectorType *VTy = cast(Ty); + assert(!VTy->isScalable() && + "Scalable vector sizes cannot be represented by a scalar"); return VTy->getNumElements() * getTypeSizeInBits(VTy->getElementType()); } default: @@ -638,6 +683,23 @@ } } +inline ScalableSize DataLayout::getScalableTypeSizeInBits(Type *Ty) const { + switch(Ty->getTypeID()) { + default: + return {getTypeSizeInBits(Ty), false}; + case Type::VectorTyID: { + VectorType *VTy = cast(Ty); + auto EltCnt = VTy->getElementCount(); + uint64_t MinBits = EltCnt.Min * getTypeSizeInBits(VTy->getElementType()); + return {MinBits, EltCnt.Scalable}; + } + } +} + +inline uint64_t DataLayout::getMinTypeSizeInBits(Type *Ty) const { + return getScalableTypeSizeInBits(Ty).getMinSize(); +} + } // end namespace llvm #endif // LLVM_IR_DATALAYOUT_H diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h --- a/llvm/include/llvm/IR/Type.h +++ b/llvm/include/llvm/IR/Type.h @@ -21,6 +21,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -288,6 +289,19 @@ /// unsigned getPrimitiveSizeInBits() const LLVM_READONLY; + // Returns a ScalableSize for the type in question. This should be used in + // place of getPrimitiveSizeInBits in places where the type may be a + // VectorType with the Scalable flag set. + ScalableSize getScalableSizeInBits() const LLVM_READONLY; + + /// Returns the minimum known size in bits, ignoring whether the type might + /// be a scalable vector. + unsigned getMinSizeInBits() const LLVM_READONLY; + + /// Returns the minimum known size in bits, asserting if called on a scalable + /// vector type. + unsigned getFixedSizeInBits() const LLVM_READONLY; + /// If this is a vector type, return the getPrimitiveSizeInBits value for the /// element type. Otherwise return the getPrimitiveSizeInBits value for this /// type. diff --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h --- a/llvm/include/llvm/Support/MachineValueType.h +++ b/llvm/include/llvm/Support/MachineValueType.h @@ -17,6 +17,7 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/ScalableSize.h" #include namespace llvm { @@ -668,7 +669,60 @@ return { getVectorNumElements(), isScalableVector() }; } + ScalableSize getScalableSizeInBits() const { + switch (SimpleTy) { + default: return { getSizeInBits(), false }; + case nxv1i1: return { 1U, true }; + case nxv2i1: return { 2U, true }; + case nxv4i1: return { 4U, true }; + case nxv1i8: + case nxv8i1: return { 8U, true }; + case nxv16i1: + case nxv2i8: + case nxv1i16: return { 16U, true }; + case nxv32i1: + case nxv4i8: + case nxv2i16: + case nxv1i32: + case nxv2f16: + case nxv1f32: return { 32U, true }; + case nxv8i8: + case nxv4i16: + case nxv2i32: + case nxv1i64: + case nxv4f16: + case nxv2f32: + case nxv1f64: return { 64U, true }; + case nxv16i8: + case nxv8i16: + case nxv4i32: + case nxv2i64: + case nxv8f16: + case nxv4f32: + case nxv2f64: return { 128U, true }; + case nxv32i8: + case nxv16i16: + case nxv8i32: + case nxv4i64: + case nxv8f32: + case nxv4f64: return { 256U, true }; + case nxv32i16: + case nxv16i32: + case nxv8i64: + case nxv16f32: + case nxv8f64: return { 512U, true }; + case nxv32i32: + case nxv16i64: return { 1024U, true }; + case nxv32i64: return { 2048U, true }; + } + } + + unsigned getMinSizeInBits() const { + return getScalableSizeInBits().getMinSize(); + } + unsigned getSizeInBits() const { + assert(!isScalableVector() && "getSizeInBits called on scalable vector"); switch (SimpleTy) { default: llvm_unreachable("getSizeInBits called on extended MVT."); @@ -688,25 +742,17 @@ case Metadata: llvm_unreachable("Value type is metadata."); case i1: - case v1i1: - case nxv1i1: return 1; - case v2i1: - case nxv2i1: return 2; - case v4i1: - case nxv4i1: return 4; + case v1i1: return 1; + case v2i1: return 2; + case v4i1: return 4; case i8 : case v1i8: - case v8i1: - case nxv1i8: - case nxv8i1: return 8; + case v8i1: return 8; case i16 : case f16: case v16i1: case v2i8: - case v1i16: - case nxv16i1: - case nxv2i8: - case nxv1i16: return 16; + case v1i16: return 16; case f32 : case i32 : case v32i1: @@ -714,13 +760,7 @@ case v2i16: case v2f16: case v1f32: - case v1i32: - case nxv32i1: - case nxv4i8: - case nxv2i16: - case nxv1i32: - case nxv2f16: - case nxv1f32: return 32; + case v1i32: return 32; case x86mmx: case f64 : case i64 : @@ -731,14 +771,7 @@ case v1i64: case v4f16: case v2f32: - case v1f64: - case nxv8i8: - case nxv4i16: - case nxv2i32: - case nxv1i64: - case nxv4f16: - case nxv2f32: - case nxv1f64: return 64; + case v1f64: return 64; case f80 : return 80; case v3i32: case v3f32: return 96; @@ -753,14 +786,7 @@ case v1i128: case v8f16: case v4f32: - case v2f64: - case nxv16i8: - case nxv8i16: - case nxv4i32: - case nxv2i64: - case nxv8f16: - case nxv4f32: - case nxv2f64: return 128; + case v2f64: return 128; case v5i32: case v5f32: return 160; case v32i8: @@ -768,39 +794,25 @@ case v8i32: case v4i64: case v8f32: - case v4f64: - case nxv32i8: - case nxv16i16: - case nxv8i32: - case nxv4i64: - case nxv8f32: - case nxv4f64: return 256; + case v4f64: return 256; case v512i1: case v64i8: case v32i16: case v16i32: case v8i64: case v16f32: - case v8f64: - case nxv32i16: - case nxv16i32: - case nxv8i64: - case nxv16f32: - case nxv8f64: return 512; + case v8f64: return 512; case v1024i1: case v128i8: case v64i16: case v32i32: case v16i64: - case v32f32: - case nxv32i32: - case nxv16i64: return 1024; + case v32f32: return 1024; case v256i8: case v128i16: case v64i32: case v32i64: - case v64f32: - case nxv32i64: return 2048; + case v64f32: return 2048; case v128i32: case v128f32: return 4096; case v256i32: @@ -825,30 +837,48 @@ return (getSizeInBits() + 7) / 8; } + ScalableSize getScalableStoreSize() const { + ScalableSize SizeInBits = getScalableSizeInBits(); + return { (SizeInBits.MinSize + 7) / 8, SizeInBits.Scalable }; + } + + unsigned getMinStoreSize() const { + return getScalableStoreSize().MinSize; + } + /// Return the number of bits overwritten by a store of the specified value /// type. unsigned getStoreSizeInBits() const { return getStoreSize() * 8; } + ScalableSize getScalableStoreSizeInBits() const { + ScalableSize SizeInBytes = getScalableStoreSize(); + return { SizeInBytes.MinSize * 8, SizeInBytes.Scalable }; + } + + unsigned getMinStoreSizeInBits() const { + return getScalableStoreSizeInBits().MinSize; + } + /// Return true if this has more bits than VT. bool bitsGT(MVT VT) const { - return getSizeInBits() > VT.getSizeInBits(); + return getScalableSizeInBits() > VT.getScalableSizeInBits(); } /// Return true if this has no less bits than VT. bool bitsGE(MVT VT) const { - return getSizeInBits() >= VT.getSizeInBits(); + return getScalableSizeInBits() >= VT.getScalableSizeInBits(); } /// Return true if this has less bits than VT. bool bitsLT(MVT VT) const { - return getSizeInBits() < VT.getSizeInBits(); + return getScalableSizeInBits() < VT.getScalableSizeInBits(); } /// Return true if this has no more bits than VT. bool bitsLE(MVT VT) const { - return getSizeInBits() <= VT.getSizeInBits(); + return getScalableSizeInBits() <= VT.getScalableSizeInBits(); } static MVT getFloatingPointVT(unsigned BitWidth) { diff --git a/llvm/include/llvm/Support/ScalableSize.h b/llvm/include/llvm/Support/ScalableSize.h --- a/llvm/include/llvm/Support/ScalableSize.h +++ b/llvm/include/llvm/Support/ScalableSize.h @@ -38,6 +38,79 @@ } }; +struct ScalableSize { + uint64_t MinSize; + bool Scalable; + + constexpr ScalableSize(uint64_t MinSize, bool Scalable) + : MinSize(MinSize), Scalable(Scalable) {} + + ScalableSize() = delete; + + bool operator==(const ScalableSize& RHS) const { + if (Scalable == RHS.Scalable) + return MinSize == RHS.MinSize; + + return false; + } + + bool operator!=(const ScalableSize& RHS) const { + if (Scalable == RHS.Scalable) + return MinSize != RHS.MinSize; + + return true; + } + + bool operator<(const ScalableSize& RHS) const { + if (Scalable == RHS.Scalable) + return MinSize < RHS.MinSize; + + llvm_unreachable("Size comparison of scalable and fixed types"); + } + + bool operator<=(const ScalableSize& RHS) const { + if (Scalable == RHS.Scalable) + return MinSize <= RHS.MinSize; + + llvm_unreachable("Size comparison of scalable and fixed types"); + } + + bool operator>(const ScalableSize& RHS) const { + if (Scalable == RHS.Scalable) + return MinSize > RHS.MinSize; + + llvm_unreachable("Size comparison of scalable and fixed types"); + } + + bool operator>=(const ScalableSize& RHS) const { + if (Scalable == RHS.Scalable) + return MinSize >= RHS.MinSize; + + llvm_unreachable("Size comparison of scalable and fixed types"); + } + + ScalableSize operator*(unsigned RHS) const { + return { MinSize * RHS, Scalable }; + } + + ScalableSize operator/(unsigned RHS) const { + return { MinSize / RHS, Scalable }; + } + + uint64_t getFixedSize() const { + assert(!Scalable && "Request for a fixed size on a scalable object"); + return MinSize; + } + + uint64_t getMinSize() const { + return MinSize; + } + + bool isScalable() const { + return Scalable; + } +}; + } // end namespace llvm #endif // LLVM_SUPPORT_SCALABLESIZE_H diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp --- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp +++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp @@ -510,6 +510,14 @@ return false; } + // Don't try to decompose pointers based on scalable vectors for now, since + // the offset won't take vscale into account. + if (auto *VecTy = dyn_cast(GEPOp->getSourceElementType())) + if (VecTy->isScalable()) { + Decomposed.Base = V; + return false; + } + unsigned AS = GEPOp->getPointerAddressSpace(); // Walk the indices of the GEP, accumulating them into BaseOff/VarIndices. gep_type_iterator GTI = gep_type_begin(GEPOp); diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -327,7 +327,8 @@ // If the type sizes are the same and a cast is legal, just directly // cast the constant. - if (DL.getTypeSizeInBits(DestTy) == DL.getTypeSizeInBits(SrcTy)) { + if (DL.getScalableTypeSizeInBits(DestTy) == + DL.getScalableTypeSizeInBits(SrcTy)) { Instruction::CastOps Cast = Instruction::BitCast; // If we are going from a pointer to int or vice versa, we spell the cast // differently. diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -3095,7 +3095,7 @@ // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input // if the integer type is the same size as the pointer type. if (MaxRecurse && isa(LI) && - Q.DL.getTypeSizeInBits(SrcTy) == DstTy->getPrimitiveSizeInBits()) { + Q.DL.getScalableTypeSizeInBits(SrcTy) == DstTy->getScalableSizeInBits()) { if (Constant *RHSC = dyn_cast(RHS)) { // Transfer the cast to the constant. if (Value *V = SimplifyICmpInst(Pred, SrcOp, @@ -3854,6 +3854,10 @@ return Ops[0]; Type *Ty = SrcTy; + if (auto *VecTy = dyn_cast(Ty)) + if (VecTy->isScalable()) + return nullptr; + if (Ty->isSized()) { Value *P; uint64_t C; @@ -3902,12 +3906,13 @@ } } - if (Q.DL.getTypeAllocSize(LastType) == 1 && + if (Q.DL.getScalableTypeAllocSize(LastType) == ScalableSize(1U, false) && all_of(Ops.slice(1).drop_back(1), [](Value *Idx) { return match(Idx, m_Zero()); })) { unsigned IdxWidth = Q.DL.getIndexSizeInBits(Ops[0]->getType()->getPointerAddressSpace()); - if (Q.DL.getTypeSizeInBits(Ops.back()->getType()) == IdxWidth) { + if (Q.DL.getScalableTypeSizeInBits(Ops.back()->getType()) == + ScalableSize(IdxWidth, false)) { APInt BasePtrOffset(IdxWidth, 0); Value *StrippedBasePtr = Ops[0]->stripAndAccumulateInBoundsConstantOffsets(Q.DL, diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -142,6 +142,10 @@ if (!Ty->isSized()) return false; + if (auto *VecTy = dyn_cast(Ty)) + if (VecTy->isScalable()) + return false; + SmallPtrSet Visited; return ::isDereferenceableAndAlignedPointer( V, Align, @@ -361,7 +365,8 @@ const DataLayout &DL = ScanBB->getModule()->getDataLayout(); // Try to get the store size for the type. - auto AccessSize = LocationSize::precise(DL.getTypeStoreSize(AccessTy)); + auto AccessSize = + LocationSize::precise(DL.getScalableTypeStoreSize(AccessTy)); Value *StrippedPtr = Ptr->stripPointerCasts(); diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp --- a/llvm/lib/Analysis/MemoryLocation.cpp +++ b/llvm/lib/Analysis/MemoryLocation.cpp @@ -38,7 +38,8 @@ return MemoryLocation( LI->getPointerOperand(), - LocationSize::precise(DL.getTypeStoreSize(LI->getType())), AATags); + LocationSize::precise(DL.getScalableTypeStoreSize(LI->getType())), + AATags); } MemoryLocation MemoryLocation::get(const StoreInst *SI) { @@ -47,7 +48,7 @@ const auto &DL = SI->getModule()->getDataLayout(); return MemoryLocation(SI->getPointerOperand(), - LocationSize::precise(DL.getTypeStoreSize( + LocationSize::precise(DL.getScalableTypeStoreSize( SI->getValueOperand()->getType())), AATags); } diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -65,6 +65,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -1325,7 +1326,7 @@ break; } unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits(); - uint64_t TypeSize = Q.DL.getTypeAllocSize(IndexedTy); + uint64_t TypeSize = Q.DL.getMinTypeAllocSize(IndexedTy); LocalKnown.Zero = LocalKnown.One = APInt(GEPOpiBits, 0); computeKnownBits(Index, LocalKnown, Depth + 1, Q); TrailZ = std::min(TrailZ, @@ -1866,7 +1867,7 @@ } // If we have a zero-sized type, the index doesn't matter. Keep looping. - if (Q.DL.getTypeAllocSize(GTI.getIndexedType()) == 0) + if (Q.DL.getMinTypeAllocSize(GTI.getIndexedType()) == 0) continue; // Fast path the constant operand case both for efficiency and so we don't diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -84,6 +84,7 @@ #include "llvm/Support/MachineValueType.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Support/ScalableSize.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -4157,7 +4158,12 @@ cast(AddrInst->getOperand(i))->getZExtValue(); ConstantOffset += SL->getElementOffset(Idx); } else { - uint64_t TypeSize = DL.getTypeAllocSize(GTI.getIndexedType()); + // TODO: Think about this one a bit more... + auto Size = DL.getScalableTypeAllocSize(GTI.getIndexedType()); + // Offsets cannot be fully calculated at compile time for scalable types + if (Size.isScalable()) + return false; + uint64_t TypeSize = Size.getMinSize(); if (ConstantInt *CI = dyn_cast(AddrInst->getOperand(i))) { const APInt &CVal = CI->getValue(); if (CVal.getMinSignedBits() <= 64) { @@ -6678,8 +6684,14 @@ /// during code expansion. static bool splitMergedValStore(StoreInst &SI, const DataLayout &DL, const TargetLowering &TLI) { - // Handle simple but common cases only. Type *StoreType = SI.getValueOperand()->getType(); + + // Don't try this with scalable vectors for now + if (auto *VTy = dyn_cast(StoreType)) + if (VTy->isScalable()) + return false; + + // Handle simple but common cases only. if (!DL.typeSizeEqualsStoreSize(StoreType) || DL.getTypeSizeInBits(StoreType) == 0) return false; diff --git a/llvm/lib/CodeGen/MachineFunction.cpp b/llvm/lib/CodeGen/MachineFunction.cpp --- a/llvm/lib/CodeGen/MachineFunction.cpp +++ b/llvm/lib/CodeGen/MachineFunction.cpp @@ -406,6 +406,16 @@ SSID, Ordering, FailureOrdering); } +MachineMemOperand *MachineFunction::getMachineMemOperand( + MachinePointerInfo PtrInfo, MachineMemOperand::Flags f, + ScalableSize s, unsigned base_alignment, const AAMDNodes &AAInfo, + const MDNode *Ranges, SyncScope::ID SSID, AtomicOrdering Ordering, + AtomicOrdering FailureOrdering) { + return new (Allocator) + MachineMemOperand(PtrInfo, f, s, base_alignment, AAInfo, Ranges, + SSID, Ordering, FailureOrdering); +} + MachineMemOperand * MachineFunction::getMachineMemOperand(const MachineMemOperand *MMO, int64_t Offset, uint64_t Size) { @@ -431,7 +441,7 @@ MachinePointerInfo(MMO->getPseudoValue(), MMO->getOffset()); return new (Allocator) - MachineMemOperand(MPI, MMO->getFlags(), MMO->getSize(), + MachineMemOperand(MPI, MMO->getFlags(), MMO->getScalableSize(), MMO->getBaseAlignment(), AAInfo, MMO->getRanges(), MMO->getSyncScopeID(), MMO->getOrdering(), MMO->getFailureOrdering()); diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp --- a/llvm/lib/CodeGen/MachineInstr.cpp +++ b/llvm/lib/CodeGen/MachineInstr.cpp @@ -1236,8 +1236,12 @@ int64_t OffsetB = MMOb->getOffset(); int64_t MinOffset = std::min(OffsetA, OffsetB); - uint64_t WidthA = MMOa->getSize(); - uint64_t WidthB = MMOb->getSize(); + ScalableSize SizeA = MMOa->getScalableSize(); + ScalableSize SizeB = MMOb->getScalableSize(); + if (SizeA.isScalable() || SizeB.isScalable()) + return true; + uint64_t WidthA = SizeA.getFixedSize(); + uint64_t WidthB = SizeB.getFixedSize(); bool KnownWidthA = WidthA != MemoryLocation::UnknownSize; bool KnownWidthB = WidthB != MemoryLocation::UnknownSize; @@ -1343,7 +1347,7 @@ // If we have an AliasAnalysis, ask it whether the memory is constant. if (AA && AA->pointsToConstantMemory( - MemoryLocation(V, MMO->getSize(), MMO->getAAInfo()))) + MemoryLocation(V, MMO->getScalableSize(), MMO->getAAInfo()))) continue; } diff --git a/llvm/lib/CodeGen/MachineOperand.cpp b/llvm/lib/CodeGen/MachineOperand.cpp --- a/llvm/lib/CodeGen/MachineOperand.cpp +++ b/llvm/lib/CodeGen/MachineOperand.cpp @@ -1002,8 +1002,30 @@ const MDNode *Ranges, SyncScope::ID SSID, AtomicOrdering Ordering, AtomicOrdering FailureOrdering) - : PtrInfo(ptrinfo), Size(s), FlagVals(f), BaseAlignLog2(Log2_32(a) + 1), - AAInfo(AAInfo), Ranges(Ranges) { + : PtrInfo(ptrinfo), Size(s, false), FlagVals(f), + BaseAlignLog2(Log2_32(a) + 1), AAInfo(AAInfo), Ranges(Ranges) { + assert((PtrInfo.V.isNull() || PtrInfo.V.is() || + isa(PtrInfo.V.get()->getType())) && + "invalid pointer value"); + assert(getBaseAlignment() == a && "Alignment is not a power of 2!"); + assert((isLoad() || isStore()) && "Not a load/store!"); + + AtomicInfo.SSID = static_cast(SSID); + assert(getSyncScopeID() == SSID && "Value truncated"); + AtomicInfo.Ordering = static_cast(Ordering); + assert(getOrdering() == Ordering && "Value truncated"); + AtomicInfo.FailureOrdering = static_cast(FailureOrdering); + assert(getFailureOrdering() == FailureOrdering && "Value truncated"); +} + +MachineMemOperand::MachineMemOperand(MachinePointerInfo ptrinfo, Flags f, + ScalableSize s, uint64_t a, + const AAMDNodes &AAInfo, + const MDNode *Ranges, SyncScope::ID SSID, + AtomicOrdering Ordering, + AtomicOrdering FailureOrdering) + : PtrInfo(ptrinfo), Size(s), FlagVals(f), + BaseAlignLog2(Log2_32(a) + 1), AAInfo(AAInfo), Ranges(Ranges) { assert((PtrInfo.V.isNull() || PtrInfo.V.is() || isa(PtrInfo.V.get()->getType())) && "invalid pointer value"); @@ -1022,7 +1044,8 @@ /// void MachineMemOperand::Profile(FoldingSetNodeID &ID) const { ID.AddInteger(getOffset()); - ID.AddInteger(Size); + ID.AddInteger(Size.getMinSize()); + ID.AddBoolean(Size.isScalable()); ID.AddPointer(getOpaqueValue()); ID.AddInteger(getFlags()); ID.AddInteger(getBaseAlignment()); @@ -1032,7 +1055,7 @@ // The Value and Offset may differ due to CSE. But the flags and size // should be the same. assert(MMO->getFlags() == getFlags() && "Flags mismatch!"); - assert(MMO->getSize() == getSize() && "Size mismatch!"); + assert(MMO->getScalableSize() == getScalableSize() && "Size mismatch!"); if (MMO->getBaseAlignment() >= getBaseAlignment()) { // Update the alignment value. @@ -1098,10 +1121,14 @@ if (getFailureOrdering() != AtomicOrdering::NotAtomic) OS << toIRString(getFailureOrdering()) << ' '; - if (getSize() == MemoryLocation::UnknownSize) + auto Size = getScalableSize(); + if (Size.getMinSize() == MemoryLocation::UnknownSize) OS << "unknown-size"; - else - OS << getSize(); + else { + OS << Size.getMinSize(); + if (Size.isScalable()) + OS << "(scalable)"; + } if (const Value *Val = getValue()) { OS << ((isLoad() && isStore()) ? " on " : isLoad() ? " from " : " into "); @@ -1149,7 +1176,7 @@ } } MachineOperand::printOperandOffset(OS, getOffset()); - if (getBaseAlignment() != getSize()) + if (Size.isScalable() || (getBaseAlignment() != Size.getMinSize())) OS << ", align " << getBaseAlignment(); auto AAInfo = getAAInfo(); if (AAInfo.TBAA) { diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -206,8 +206,8 @@ for (MVT VT : MVT::all_valuetypes()) if (EVT(VT).isSimple() && VT != MVT::Other && TLI.isTypeLegal(EVT(VT)) && - VT.getSizeInBits() >= MaximumLegalStoreInBits) - MaximumLegalStoreInBits = VT.getSizeInBits(); + VT.getMinSizeInBits() >= MaximumLegalStoreInBits) + MaximumLegalStoreInBits = VT.getMinSizeInBits(); } void ConsiderForPruning(SDNode *N) { @@ -9365,7 +9365,7 @@ // Does the setcc have the same vector size as the casted select? SDValue SetCC = VSel.getOperand(0); EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType()); - if (SetCCVT.getSizeInBits() != VT.getSizeInBits()) + if (SetCCVT.getScalableSizeInBits() != VT.getScalableSizeInBits()) return SDValue(); // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B) @@ -9637,7 +9637,7 @@ // for that matter). Check to see that they are the same size. If so, // we know that the element size of the sext'd result matches the // element size of the compare operands. - if (VT.getSizeInBits() == SVT.getSizeInBits()) + if (VT.getScalableSizeInBits() == SVT.getScalableSizeInBits()) return DAG.getSetCC(DL, VT, N00, N01, CC); // If the desired elements are smaller or larger than the source @@ -10629,11 +10629,13 @@ EVT ExTy = N0.getValueType(); EVT TrTy = N->getValueType(0); - unsigned NumElem = VecTy.getVectorNumElements(); + auto EltCnt = VecTy.getVectorElementCount(); unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits(); - EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, SizeRatio * NumElem); - assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size"); + + EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, EltCnt * SizeRatio); + assert(NVT.getScalableSizeInBits() == VecTy.getScalableSizeInBits() && + "Invalid Size"); SDValue EltNo = N0->getOperand(1); if (isa(EltNo) && isTypeLegal(NVT)) { @@ -14670,6 +14672,9 @@ !LD->getValueType(0).isInteger()) return false; + if (LD->getValueType(0).isScalableVector()) + return false; + // Keep track of already used bits to detect overlapping values. // In that case, we will just abort the transformation. APInt UsedBits(LD->getValueSizeInBits(0), 0); @@ -15029,6 +15034,8 @@ Value.hasOneUse()) { LoadSDNode *LD = cast(Value); EVT VT = LD->getMemoryVT(); + if (VT.isScalableVector()) + return SDValue(); if (!VT.isFloatingPoint() || VT != ST->getMemoryVT() || LD->isNonTemporal() || @@ -15520,6 +15527,9 @@ return false; EVT MemVT = St->getMemoryVT(); + if (MemVT.isScalableVector()) + return false; + int64_t ElementSizeBytes = MemVT.getStoreSize(); unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1; @@ -17693,6 +17703,9 @@ if (TLI.isTypeLegal(OpVT)) return SDValue(); + if (OpVT.isScalableVector()) + return SDValue(); + SDLoc DL(N); EVT VT = N->getValueType(0); SmallVector Ops; @@ -17791,7 +17804,7 @@ // Ensure that we are extracting a subvector from a vector the same // size as the result. - if (ExtVT.getSizeInBits() != VT.getSizeInBits()) + if (ExtVT.getScalableSizeInBits() != VT.getScalableSizeInBits()) return SDValue(); // Scale the subvector index to account for any bitcast. @@ -18067,8 +18080,8 @@ "Extract index is not a multiple of the vector length."); // Bail out if this is not a proper multiple width extraction. - unsigned WideWidth = WideBVT.getSizeInBits(); - unsigned NarrowWidth = VT.getSizeInBits(); + unsigned WideWidth = WideBVT.getMinSizeInBits(); + unsigned NarrowWidth = VT.getMinSizeInBits(); if (WideWidth % NarrowWidth != 0) return SDValue(); @@ -20417,6 +20430,15 @@ /// Return true if there is any possibility that the two addresses overlap. bool DAGCombiner::isAlias(SDNode *Op0, SDNode *Op1) const { + // Be very conservative and say there's a possible alias if either Op is + // a scalable vector. + if (auto MemOp0 = dyn_cast(Op0)) + if (MemOp0->getMemoryVT().isScalableVector()) + return true; + + if (auto MemOp1 = dyn_cast(Op1)) + if (MemOp1->getMemoryVT().isScalableVector()) + return true; struct MemUseCharacteristics { bool IsVolatile; diff --git a/llvm/lib/CodeGen/SelectionDAG/FunctionLoweringInfo.cpp b/llvm/lib/CodeGen/SelectionDAG/FunctionLoweringInfo.cpp --- a/llvm/lib/CodeGen/SelectionDAG/FunctionLoweringInfo.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/FunctionLoweringInfo.cpp @@ -143,7 +143,9 @@ if (AI->isStaticAlloca() && (TFI->isStackRealignable() || (Align <= StackAlign))) { const ConstantInt *CUI = cast(AI->getArraySize()); - uint64_t TySize = MF->getDataLayout().getTypeAllocSize(Ty); + // Scalable vector stack slots are handled later, so just use + // the minimum size here. + uint64_t TySize = MF->getDataLayout().getMinTypeAllocSize(Ty); TySize *= CUI->getZExtValue(); // Get total allocated size. if (TySize == 0) TySize = 1; // Don't create zero-sized stack objects. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -527,21 +527,22 @@ LLVM_DEBUG(dbgs() << "Legalizing truncating store operations\n"); SDValue Value = ST->getValue(); EVT StVT = ST->getMemoryVT(); - unsigned StWidth = StVT.getSizeInBits(); + unsigned StWidth = StVT.getMinSizeInBits(); + bool Scalable = StVT.getScalableSizeInBits().isScalable(); auto &DL = DAG.getDataLayout(); - if (StWidth != StVT.getStoreSizeInBits()) { + if (StWidth != StVT.getMinStoreSizeInBits() && !Scalable) { // Promote to a byte-sized store with upper bits zero if not // storing an integral number of bytes. For example, promote // TRUNCSTORE:i1 X -> TRUNCSTORE:i8 (and X, 1) EVT NVT = EVT::getIntegerVT(*DAG.getContext(), - StVT.getStoreSizeInBits()); + StVT.getMinStoreSizeInBits()); Value = DAG.getZeroExtendInReg(Value, dl, StVT); SDValue Result = DAG.getTruncStore(Chain, dl, Value, Ptr, ST->getPointerInfo(), NVT, Alignment, MMOFlags, AAInfo); ReplaceNode(SDValue(Node, 0), Result); - } else if (StWidth & (StWidth - 1)) { + } else if ((StWidth & (StWidth - 1)) && !Scalable) { // If not storing a power-of-2 number of bits, expand as two stores. assert(!StVT.isVector() && "Unsupported truncstore!"); unsigned LogStWidth = Log2_32(StWidth); @@ -708,12 +709,12 @@ LLVM_DEBUG(dbgs() << "Legalizing extending load operation\n"); EVT SrcVT = LD->getMemoryVT(); - unsigned SrcWidth = SrcVT.getSizeInBits(); + unsigned SrcWidth = SrcVT.getMinSizeInBits(); unsigned Alignment = LD->getAlignment(); MachineMemOperand::Flags MMOFlags = LD->getMemOperand()->getFlags(); AAMDNodes AAInfo = LD->getAAInfo(); - if (SrcWidth != SrcVT.getStoreSizeInBits() && + if (SrcWidth != SrcVT.getMinStoreSizeInBits() && // Some targets pretend to have an i1 loading operation, and actually // load an i8. This trick is correct for ZEXTLOAD because the top 7 // bits are guaranteed to be zero; it helps the optimizers understand @@ -726,7 +727,7 @@ TargetLowering::Promote)) { // Promote to a byte-sized load if not loading an integral number of // bytes. For example, promote EXTLOAD:i20 -> EXTLOAD:i24. - unsigned NewWidth = SrcVT.getStoreSizeInBits(); + unsigned NewWidth = SrcVT.getMinStoreSizeInBits(); EVT NVT = EVT::getIntegerVT(*DAG.getContext(), NewWidth); SDValue Ch; @@ -4399,7 +4400,8 @@ // (i32 (extract_vector_elt castx, (2 * y + 1))) // - assert(NVT.isVector() && OVT.getSizeInBits() == NVT.getSizeInBits() && + assert(NVT.isVector() && + OVT.getScalableSizeInBits() == NVT.getScalableSizeInBits() && "Invalid promote type for extract_vector_elt"); assert(NewEltVT.bitsLT(EltVT) && "not handled"); @@ -4445,7 +4447,8 @@ // (extract_vector_elt casty, 0), 2 * z), // (extract_vector_elt casty, 1), (2 * z + 1)) - assert(NVT.isVector() && OVT.getSizeInBits() == NVT.getSizeInBits() && + assert(NVT.isVector() && + OVT.getScalableSizeInBits() == NVT.getScalableSizeInBits() && "Invalid promote type for insert_vector_elt"); assert(NewEltVT.bitsLT(EltVT) && "not handled"); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -23,6 +23,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Support/ScalableSize.h" using namespace llvm; #define DEBUG_TYPE "legalize-types" @@ -1577,7 +1578,8 @@ MachineMemOperand *MMO = DAG.getMachineFunction(). getMachineMemOperand(MLD->getPointerInfo(), - MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), + MachineMemOperand::MOLoad, + LoMemVT.getScalableStoreSize(), Alignment, MLD->getAAInfo(), MLD->getRanges()); Lo = DAG.getMaskedLoad(LoVT, dl, Ch, Ptr, MaskLo, PassThruLo, LoMemVT, MMO, @@ -1646,7 +1648,8 @@ MachineMemOperand *MMO = DAG.getMachineFunction(). getMachineMemOperand(MGT->getPointerInfo(), - MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), + MachineMemOperand::MOLoad, + LoMemVT.getScalableStoreSize(), Alignment, MGT->getAAInfo(), MGT->getRanges()); SDValue OpsLo[] = {Ch, PassThruLo, MaskLo, Ptr, IndexLo, Scale}; @@ -2288,7 +2291,8 @@ MachineMemOperand *MMO = DAG.getMachineFunction(). getMachineMemOperand(MGT->getPointerInfo(), - MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), + MachineMemOperand::MOLoad, + LoMemVT.getScalableStoreSize(), Alignment, MGT->getAAInfo(), MGT->getRanges()); SDValue OpsLo[] = {Ch, PassThruLo, MaskLo, Ptr, IndexLo, Scale}; @@ -2297,7 +2301,8 @@ MMO = DAG.getMachineFunction(). getMachineMemOperand(MGT->getPointerInfo(), - MachineMemOperand::MOLoad, HiMemVT.getStoreSize(), + MachineMemOperand::MOLoad, + HiMemVT.getScalableStoreSize(), Alignment, MGT->getAAInfo(), MGT->getRanges()); @@ -2413,7 +2418,8 @@ SDValue Lo; MachineMemOperand *MMO = DAG.getMachineFunction(). getMachineMemOperand(N->getPointerInfo(), - MachineMemOperand::MOStore, LoMemVT.getStoreSize(), + MachineMemOperand::MOStore, + LoMemVT.getScalableStoreSize(), Alignment, N->getAAInfo(), N->getRanges()); SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo, Scale}; @@ -2422,7 +2428,8 @@ MMO = DAG.getMachineFunction(). getMachineMemOperand(N->getPointerInfo(), - MachineMemOperand::MOStore, HiMemVT.getStoreSize(), + MachineMemOperand::MOStore, + HiMemVT.getScalableStoreSize(), Alignment, N->getAAInfo(), N->getRanges()); // The order of the Scatter operation after split is well defined. The "Hi" @@ -2455,7 +2462,7 @@ if (!LoMemVT.isByteSized() || !HiMemVT.isByteSized()) return TLI.scalarizeVectorStore(N, DAG); - unsigned IncrementSize = LoMemVT.getSizeInBits()/8; + unsigned IncrementSize = LoMemVT.getMinSizeInBits()/8; if (isTruncating) Lo = DAG.getTruncStore(Ch, DL, Lo, Ptr, N->getPointerInfo(), LoMemVT, @@ -3845,7 +3852,7 @@ EVT VSelVT = N->getValueType(0); // Only handle vector types which are a power of 2. - if (!isPowerOf2_64(VSelVT.getSizeInBits())) + if (!isPowerOf2_64(VSelVT.getMinSizeInBits())) return SDValue(); // Don't touch if this will be scalarized. @@ -4597,7 +4604,7 @@ unsigned Width, EVT WidenVT, unsigned Align = 0, unsigned WidenEx = 0) { EVT WidenEltVT = WidenVT.getVectorElementType(); - unsigned WidenWidth = WidenVT.getSizeInBits(); + unsigned WidenWidth = WidenVT.getMinSizeInBits(); unsigned WidenEltWidth = WidenEltVT.getSizeInBits(); unsigned AlignInBits = Align*8; @@ -4633,7 +4640,9 @@ for (VT = (unsigned)MVT::LAST_VECTOR_VALUETYPE; VT >= (unsigned)MVT::FIRST_VECTOR_VALUETYPE; --VT) { EVT MemVT = (MVT::SimpleValueType) VT; - unsigned MemVTWidth = MemVT.getSizeInBits(); + if (MemVT.isScalableVector() != WidenVT.isScalableVector()) + continue; + unsigned MemVTWidth = MemVT.getMinSizeInBits(); auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT); if ((Action == TargetLowering::TypeLegal || Action == TargetLowering::TypePromoteInteger) && diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1944,7 +1944,8 @@ SDValue SelectionDAG::CreateStackTemporary(EVT VT, unsigned minAlign) { MachineFrameInfo &MFI = getMachineFunction().getFrameInfo(); - unsigned ByteSize = VT.getStoreSize(); + // Stack ID handled elsewhere for now... + unsigned ByteSize = VT.getMinStoreSize(); Type *Ty = VT.getTypeForEVT(*getContext()); unsigned StackAlign = std::max((unsigned)getDataLayout().getPrefTypeAlignment(Ty), minAlign); @@ -4615,7 +4616,7 @@ break; case ISD::BITCAST: // Basic sanity checking. - assert(VT.getSizeInBits() == Operand.getValueSizeInBits() && + assert(VT.getScalableSizeInBits() == Operand.getValueScalableSizeInBits() && "Cannot BITCAST between types of different sizes!"); if (VT == Operand.getValueType()) return Operand; // noop conversion. if (OpOpcode == ISD::BITCAST) // bitconv(bitconv(x)) -> bitconv(x) @@ -5145,7 +5146,8 @@ // amounts. This catches things like trying to shift an i1024 value by an // i8, which is easy to fall into in generic code that uses // TLI.getShiftAmount(). - assert(N2.getValueSizeInBits() >= Log2_32_Ceil(N1.getValueSizeInBits()) && + assert(N2.getValueMinSizeInBits() >= + Log2_32_Ceil(N1.getValueMinSizeInBits()) && "Invalid use of small shift amount with oversized value!"); // Always fold shifts of i1 values so the code generator doesn't need to @@ -6712,7 +6714,8 @@ MachineFunction &MF = getMachineFunction(); MachineMemOperand *MMO = MF.getMachineMemOperand( - PtrInfo, MMOFlags, MemVT.getStoreSize(), Alignment, AAInfo, Ranges); + PtrInfo, MMOFlags, MemVT.getScalableStoreSize(), Alignment, + AAInfo, Ranges); return getLoad(AM, ExtType, VT, dl, Chain, Ptr, Offset, MemVT, MMO); } @@ -6833,7 +6836,8 @@ MachineFunction &MF = getMachineFunction(); MachineMemOperand *MMO = MF.getMachineMemOperand( - PtrInfo, MMOFlags, Val.getValueType().getStoreSize(), Alignment, AAInfo); + PtrInfo, MMOFlags, Val.getValueType().getScalableStoreSize(), + Alignment, AAInfo); return getStore(Chain, dl, Val, Ptr, MMO); } @@ -6885,7 +6889,7 @@ MachineFunction &MF = getMachineFunction(); MachineMemOperand *MMO = MF.getMachineMemOperand( - PtrInfo, MMOFlags, SVT.getStoreSize(), Alignment, AAInfo); + PtrInfo, MMOFlags, SVT.getScalableStoreSize(), Alignment, AAInfo); return getTruncStore(Chain, dl, Val, Ptr, SVT, MMO); } @@ -8834,7 +8838,8 @@ // We check here that the size of the memory operand fits within the size of // the MMO. This is because the MMO might indicate only a possible address // range instead of specifying the affected memory addresses precisely. - assert(memvt.getStoreSize() <= MMO->getSize() && "Size mismatch!"); + assert(memvt.getScalableStoreSize() <= MMO->getScalableSize() && + "Size mismatch!"); } /// Profile - Gather unique data for the node. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -385,8 +385,8 @@ assert(NumRegs == NumParts && "Part count doesn't match vector breakdown!"); NumParts = NumRegs; // Silence a compiler warning. assert(RegisterVT == PartVT && "Part type doesn't match vector breakdown!"); - assert(RegisterVT.getSizeInBits() == - Parts[0].getSimpleValueType().getSizeInBits() && + assert(RegisterVT.getScalableSizeInBits() == + Parts[0].getSimpleValueType().getScalableSizeInBits() && "Part type sizes don't match!"); // Assemble the parts into intermediate operands. @@ -440,10 +440,10 @@ } // Vector/Vector bitcast. - if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits()) + if (ValueVT.getScalableSizeInBits() == PartEVT.getScalableSizeInBits()) return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); - assert(PartEVT.getVectorNumElements() == ValueVT.getVectorNumElements() && + assert(PartEVT.getVectorElementCount() == ValueVT.getVectorElementCount() && "Cannot handle this kind of promotion"); // Promoted vector extract return DAG.getAnyExtOrTrunc(Val, DL, ValueVT); @@ -668,7 +668,7 @@ EVT PartEVT = PartVT; if (PartEVT == ValueVT) { // Nothing to do. - } else if (PartVT.getSizeInBits() == ValueVT.getSizeInBits()) { + } else if (PartVT.getScalableSizeInBits() == ValueVT.getScalableSizeInBits()) { // Bitconvert vector->vector case. Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val); } else if (SDValue Widened = widenVectorToPartType(DAG, Val, DL, PartVT)) { @@ -4285,7 +4285,8 @@ MachineMemOperand *MMO = DAG.getMachineFunction(). getMachineMemOperand(MachinePointerInfo(PtrOperand), - MachineMemOperand::MOStore, VT.getStoreSize(), + MachineMemOperand::MOStore, + VT.getScalableStoreSize(), Alignment, AAInfo); SDValue StoreNode = DAG.getMaskedStore(getRoot(), sdl, Src0, Ptr, Mask, VT, MMO, false /* Truncating */, @@ -4380,7 +4381,8 @@ const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr; MachineMemOperand *MMO = DAG.getMachineFunction(). getMachineMemOperand(MachinePointerInfo(MemOpBasePtr), - MachineMemOperand::MOStore, VT.getStoreSize(), + MachineMemOperand::MOStore, + VT.getScalableStoreSize(), Alignment, AAInfo); if (!UniformBase) { Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); @@ -4494,7 +4496,8 @@ MachineMemOperand *MMO = DAG.getMachineFunction(). getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : nullptr), - MachineMemOperand::MOLoad, VT.getStoreSize(), + MachineMemOperand::MOLoad, + VT.getScalableStoreSize(), Alignment, AAInfo, Ranges); if (!UniformBase) { @@ -9112,7 +9115,7 @@ // if it isn't first piece, alignment must be 1 ISD::OutputArg MyFlags(Flags, Parts[j].getValueType(), VT, i < CLI.NumFixedArgs, - i, j*Parts[j].getValueType().getStoreSize()); + i, j*Parts[j].getValueType().getMinStoreSize()); if (NumParts > 1 && j == 0) MyFlags.Flags.setSplit(); else if (j != 0) { @@ -9358,8 +9361,8 @@ const auto *Arg = dyn_cast(Val); if (!Arg || Arg->hasInAllocaAttr() || Arg->hasByValAttr() || Arg->getType()->isEmptyTy() || - DL.getTypeStoreSize(Arg->getType()) != - DL.getTypeAllocSize(AI->getAllocatedType()) || + DL.getScalableTypeStoreSize(Arg->getType()) != + DL.getScalableTypeAllocSize(AI->getAllocatedType()) || ArgCopyElisionCandidates.count(Arg)) { *Info = StaticAllocaInfo::Clobbered; continue; @@ -9580,7 +9583,7 @@ *CurDAG->getContext(), F.getCallingConv(), VT); for (unsigned i = 0; i != NumRegs; ++i) { ISD::InputArg MyFlags(Flags, RegisterVT, VT, isArgValueUsed, - ArgNo, PartBase+i*RegisterVT.getStoreSize()); + ArgNo, PartBase+i*RegisterVT.getMinStoreSize()); if (NumRegs > 1 && i == 0) MyFlags.Flags.setSplit(); // if it isn't first piece, alignment must be 1 @@ -9593,7 +9596,7 @@ } if (NeedsRegBlock && Value == NumValues - 1) Ins[Ins.size() - 1].Flags.setInConsecutiveRegsLast(); - PartBase += VT.getStoreSize(); + PartBase += VT.getMinStoreSize(); } } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -3539,8 +3539,8 @@ assert((NodeToMatch->getValueType(i) == Res.getValueType() || NodeToMatch->getValueType(i) == MVT::iPTR || Res.getValueType() == MVT::iPTR || - NodeToMatch->getValueType(i).getSizeInBits() == - Res.getValueSizeInBits()) && + NodeToMatch->getValueType(i).getScalableSizeInBits() == + Res.getValueScalableSizeInBits()) && "invalid replacement"); ReplaceUses(SDValue(NodeToMatch, i), Res); } diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -978,7 +978,7 @@ NewVT = EltTy; IntermediateVT = NewVT; - unsigned NewVTSize = NewVT.getSizeInBits(); + unsigned NewVTSize = NewVT.getMinSizeInBits(); // Convert sizes such as i33 to i64. if (!isPowerOf2_32(NewVTSize)) @@ -987,7 +987,7 @@ MVT DestVT = TLI->getRegisterType(NewVT); RegisterVT = DestVT; if (EVT(DestVT).bitsLT(NewVT)) // Value is expanded, e.g. i64 -> i16. - return NumVectorRegs*(NewVTSize/DestVT.getSizeInBits()); + return NumVectorRegs*(NewVTSize/DestVT.getMinSizeInBits()); // Otherwise, promotion or legal types use the same number of registers as // the vector decimated to the appropriate level. @@ -1413,14 +1413,14 @@ MVT DestVT = getRegisterType(Context, NewVT); RegisterVT = DestVT; - unsigned NewVTSize = NewVT.getSizeInBits(); + unsigned NewVTSize = NewVT.getMinSizeInBits(); // Convert sizes such as i33 to i64. if (!isPowerOf2_32(NewVTSize)) NewVTSize = NextPowerOf2(NewVTSize); if (EVT(DestVT).bitsLT(NewVT)) // Value is expanded, e.g. i64 -> i16. - return NumVectorRegs*(NewVTSize/DestVT.getSizeInBits()); + return NumVectorRegs*(NewVTSize/DestVT.getMinSizeInBits()); // Otherwise, promotion or legal types use the same number of registers as // the vector decimated to the appropriate level. diff --git a/llvm/lib/CodeGen/ValueTypes.cpp b/llvm/lib/CodeGen/ValueTypes.cpp --- a/llvm/lib/CodeGen/ValueTypes.cpp +++ b/llvm/lib/CodeGen/ValueTypes.cpp @@ -110,6 +110,18 @@ llvm_unreachable("Unrecognized extended type!"); } +ScalableSize EVT::getScalableExtendedSizeInBits() const { + assert(isExtended() && "Type is not extended!"); + if (VectorType *VTy = dyn_cast(LLVMTy)) + return VTy->getScalableSizeInBits(); + return { getExtendedSizeInBits(), false }; +} + +unsigned EVT::getMinExtendedSizeInBits() const { + assert(isExtended() && "Type is not extended!"); + return getScalableExtendedSizeInBits().getMinSize(); +} + /// getEVTString - This function returns value type as a string, e.g. "i32". std::string EVT::getEVTString() const { switch (V.SimpleTy) { diff --git a/llvm/lib/IR/DataLayout.cpp b/llvm/lib/IR/DataLayout.cpp --- a/llvm/lib/IR/DataLayout.cpp +++ b/llvm/lib/IR/DataLayout.cpp @@ -740,7 +740,8 @@ llvm_unreachable("Bad type for getAlignment!!!"); } - return getAlignmentInfo(AlignType, getTypeSizeInBits(Ty), abi_or_pref, Ty); + // We only care about the minimum size for alignment + return getAlignmentInfo(AlignType, getMinTypeSizeInBits(Ty), abi_or_pref, Ty); } unsigned DataLayout::getABITypeAlignment(Type *Ty) const { diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -2968,8 +2968,8 @@ } // Get the bit sizes, we'll need these - unsigned SrcBits = SrcTy->getPrimitiveSizeInBits(); // 0 for ptr - unsigned DestBits = DestTy->getPrimitiveSizeInBits(); // 0 for ptr + auto SrcBits = SrcTy->getScalableSizeInBits(); // 0 for ptr + auto DestBits = DestTy->getScalableSizeInBits(); // 0 for ptr // Run through the possibilities ... if (DestTy->isIntegerTy()) { // Casting to integral @@ -3030,12 +3030,12 @@ } } - unsigned SrcBits = SrcTy->getPrimitiveSizeInBits(); // 0 for ptr - unsigned DestBits = DestTy->getPrimitiveSizeInBits(); // 0 for ptr + auto SrcBits = SrcTy->getScalableSizeInBits(); // 0 for ptr + auto DestBits = DestTy->getScalableSizeInBits(); // 0 for ptr // Could still have vectors of pointers if the number of elements doesn't // match - if (SrcBits == 0 || DestBits == 0) + if (SrcBits.MinSize == 0 || DestBits.MinSize == 0) return false; if (SrcBits != DestBits) @@ -3245,7 +3245,7 @@ // For non-pointer cases, the cast is okay if the source and destination bit // widths are identical. if (!SrcPtrTy) - return SrcTy->getPrimitiveSizeInBits() == DstTy->getPrimitiveSizeInBits(); + return SrcTy->getScalableSizeInBits() == DstTy->getScalableSizeInBits(); // If both are pointers then the address spaces must match. if (SrcPtrTy->getAddressSpace() != DstPtrTy->getAddressSpace()) diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -121,7 +121,11 @@ case Type::PPC_FP128TyID: return 128; case Type::X86_MMXTyID: return 64; case Type::IntegerTyID: return cast(this)->getBitWidth(); - case Type::VectorTyID: return cast(this)->getBitWidth(); + case Type::VectorTyID: { + const VectorType *VTy = cast(this); + assert(!VTy->isScalable() && "Scalable vectors are not a primitive type"); + return VTy->getBitWidth(); + } default: return 0; } } @@ -130,6 +134,23 @@ return getScalarType()->getPrimitiveSizeInBits(); } +ScalableSize Type::getScalableSizeInBits() const { + if (auto *VTy = dyn_cast(this)) + return {VTy->getBitWidth(), VTy->isScalable()}; + + return {getPrimitiveSizeInBits(), false}; +} + +unsigned Type::getMinSizeInBits() const { + return getScalableSizeInBits().MinSize; +} + +unsigned Type::getFixedSizeInBits() const { + auto Size = getScalableSizeInBits(); + assert(!Size.Scalable && "Request for a fixed size on a scalable vector"); + return Size.MinSize; +} + int Type::getFPMantissaWidth() const { if (auto *VTy = dyn_cast(this)) return VTy->getElementType()->getFPMantissaWidth(); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3771,7 +3771,7 @@ // common case. It should also work for fundamental types too. uint32_t BEAlign = 0; unsigned OpSize = Flags.isByVal() ? Flags.getByValSize() * 8 - : VA.getValVT().getSizeInBits(); + : VA.getValVT().getMinSizeInBits(); OpSize = (OpSize + 7) / 8; if (!Subtarget->isLittleEndian() && !Flags.isByVal() && !Flags.isInConsecutiveRegs()) { @@ -8777,6 +8777,11 @@ if (AM.HasBaseReg && AM.BaseOffs && AM.Scale) return false; + // For now, just allow base reg only addressing for SVE + if (auto *VTy = dyn_cast(Ty)) + if (VTy->isScalable()) + return AM.HasBaseReg && !(AM.BaseOffs || AM.Scale); + // check reg + imm case: // i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12 uint64_t NumBytes = 0; @@ -9133,7 +9138,8 @@ EVT VT = N->getValueType(0); if (!VT.isVector() || N->getOperand(0)->getOpcode() != ISD::AND || N->getOperand(0)->getOperand(0)->getOpcode() != ISD::SETCC || - VT.getSizeInBits() != N->getOperand(0)->getValueType(0).getSizeInBits()) + VT.getScalableSizeInBits() != + N->getOperand(0)->getValueType(0).getScalableSizeInBits()) return SDValue(); // Now check that the other operand of the AND is a constant. We could @@ -9560,7 +9566,7 @@ // Only interested in 64-bit vectors as the ultimate result. EVT VT = N->getValueType(0); - if (!VT.isVector()) + if (!VT.isVector() || VT.isScalableVector()) return SDValue(); if (VT.getSimpleVT().getSizeInBits() != 64) return SDValue(); @@ -10224,7 +10230,7 @@ // If the source VT is a 64-bit vector, we can play games and get the // better results we want. - if (SrcVT.getSizeInBits() != 64) + if (SrcVT.getScalableSizeInBits() != ScalableSize(64U, false)) return SDValue(); unsigned SrcEltSize = SrcVT.getScalarSizeInBits(); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -16,6 +16,7 @@ #include "llvm/CodeGen/TargetLowering.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ScalableSize.h" #include using namespace llvm; @@ -594,6 +595,11 @@ // We don't lower some vector selects well that are wider than the register // width. if (ValTy->isVectorTy() && ISD == ISD::SELECT) { + if (ST->hasSVE()) { + EVT SelValTy = TLI->getValueType(DL, ValTy); + if (SelValTy.isScalableVector()) + return SelValTy.getMinSizeInBits() / 128; + } // We would need this many instructions to hide the scalarization happening. const int AmortizationCost = 20; static const TypeConversionCostTblEntry diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp --- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp +++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp @@ -1435,6 +1435,9 @@ }; for (MVT VT : MVT::vector_valuetypes()) { + if (VT.isScalableVector()) + continue; + for (unsigned VectExpOp : VectExpOps) setOperationAction(VectExpOp, VT, Expand); @@ -1844,7 +1847,7 @@ TargetLoweringBase::LegalizeTypeAction HexagonTargetLowering::getPreferredVectorAction(MVT VT) const { - if (VT.getVectorNumElements() == 1) + if (VT.getVectorNumElements() == 1 || VT.isScalableVector()) return TargetLoweringBase::TypeScalarizeVector; // Always widen vectors of i1. diff --git a/llvm/lib/Target/Hexagon/HexagonSubtarget.h b/llvm/lib/Target/Hexagon/HexagonSubtarget.h --- a/llvm/lib/Target/Hexagon/HexagonSubtarget.h +++ b/llvm/lib/Target/Hexagon/HexagonSubtarget.h @@ -228,7 +228,7 @@ } bool isHVXVectorType(MVT VecTy, bool IncludeBool = false) const { - if (!VecTy.isVector() || !useHVXOps()) + if (!VecTy.isVector() || !useHVXOps() || VecTy.isScalableVector()) return false; MVT ElemTy = VecTy.getVectorElementType(); if (!IncludeBool && ElemTy == MVT::i1) diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -45,6 +45,8 @@ bool HexagonTTIImpl::isTypeForHVX(Type *VecTy) const { assert(VecTy->isVectorTy()); + if (cast(VecTy)->isScalable()) + return false; // Avoid types like <2 x i32*>. if (!cast(VecTy)->getElementType()->isIntegerTy()) return false; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -629,6 +629,7 @@ // Do not perform canonicalization if minmax pattern is found (to avoid // infinite loop). if (!Ty->isIntegerTy() && Ty->isSized() && + !DL.getScalableTypeStoreSize(Ty).Scalable && DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) && DL.typeSizeEqualsStoreSize(Ty) && !DL.isNonIntegralPointerType(Ty) && diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1658,7 +1658,7 @@ // If the element type has zero size then any index over it is equivalent // to an index of zero, so replace it with zero if it is not zero already. Type *EltTy = GTI.getIndexedType(); - if (EltTy->isSized() && DL.getTypeAllocSize(EltTy) == 0) + if (EltTy->isSized() && DL.getMinTypeAllocSize(EltTy) == 0) if (!isa(*I) || !cast(*I)->isNullValue()) { *I = Constant::getNullValue(NewIndexType); MadeChange = true; @@ -1906,7 +1906,7 @@ unsigned AS = GEP.getPointerAddressSpace(); if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == DL.getIndexSizeInBits(AS)) { - uint64_t TyAllocSize = DL.getTypeAllocSize(GEPEltType); + uint64_t TyAllocSize = DL.getMinTypeAllocSize(GEPEltType); bool Matched = false; uint64_t C; @@ -2047,8 +2047,8 @@ if (GEPEltType->isSized() && StrippedPtrEltTy->isSized()) { // Check that changing the type amounts to dividing the index by a scale // factor. - uint64_t ResSize = DL.getTypeAllocSize(GEPEltType); - uint64_t SrcSize = DL.getTypeAllocSize(StrippedPtrEltTy); + uint64_t ResSize = DL.getMinTypeAllocSize(GEPEltType); + uint64_t SrcSize = DL.getMinTypeAllocSize(StrippedPtrEltTy); if (ResSize && SrcSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -662,7 +662,7 @@ public: SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS) : PtrUseVisitor(DL), - AllocSize(DL.getTypeAllocSize(AI.getAllocatedType())), AS(AS) {} + AllocSize(DL.getMinTypeAllocSize(AI.getAllocatedType())), AS(AS) {} private: void markAsDead(Instruction &I) { @@ -752,7 +752,7 @@ // type. APInt Index = OpC->getValue().sextOrTrunc(Offset.getBitWidth()); GEPOffset += Index * APInt(Offset.getBitWidth(), - DL.getTypeAllocSize(GTI.getIndexedType())); + DL.getMinTypeAllocSize(GTI.getIndexedType())); } // If this index has computed an intermediate pointer which is not @@ -787,7 +787,7 @@ LI.getPointerAddressSpace() != DL.getAllocaAddrSpace()) return PI.setAborted(&LI); - uint64_t Size = DL.getTypeStoreSize(LI.getType()); + uint64_t Size = DL.getMinTypeStoreSize(LI.getType()); return handleLoadOrStore(LI.getType(), LI, Offset, Size, LI.isVolatile()); } @@ -802,7 +802,7 @@ SI.getPointerAddressSpace() != DL.getAllocaAddrSpace()) return PI.setAborted(&SI); - uint64_t Size = DL.getTypeStoreSize(ValOp->getType()); + uint64_t Size = DL.getMinTypeStoreSize(ValOp->getType()); // If this memory access can be shown to *statically* extend outside the // bounds of the allocation, it's behavior is undefined, so simply @@ -1493,7 +1493,8 @@ if (ArrayType *ArrTy = dyn_cast(Ty)) { Type *ElementTy = ArrTy->getElementType(); - APInt ElementSize(Offset.getBitWidth(), DL.getTypeAllocSize(ElementTy)); + APInt ElementSize(Offset.getBitWidth(), + DL.getTypeAllocSize(ElementTy)); APInt NumSkippedElements = Offset.sdiv(ElementSize); if (NumSkippedElements.ugt(ArrTy->getNumElements())) return nullptr; @@ -1515,7 +1516,7 @@ unsigned Index = SL->getElementContainingOffset(StructOffset); Offset -= APInt(Offset.getBitWidth(), SL->getElementOffset(Index)); Type *ElementTy = STy->getElementType(Index); - if (Offset.uge(DL.getTypeAllocSize(ElementTy))) + if (Offset.uge(DL.getMinTypeAllocSize(ElementTy))) return nullptr; // The offset points into alignment padding. Indices.push_back(IRB.getInt32(Index)); @@ -1547,7 +1548,7 @@ Type *ElementTy = Ty->getElementType(); if (!ElementTy->isSized()) return nullptr; // We can't GEP through an unsized element. - APInt ElementSize(Offset.getBitWidth(), DL.getTypeAllocSize(ElementTy)); + APInt ElementSize(Offset.getBitWidth(), DL.getMinTypeAllocSize(ElementTy)); if (ElementSize == 0) return nullptr; // Zero-length arrays can't help us build a natural GEP. APInt NumSkippedElements = Offset.sdiv(ElementSize); @@ -1718,7 +1719,8 @@ return false; } - if (DL.getTypeSizeInBits(NewTy) != DL.getTypeSizeInBits(OldTy)) + if (DL.getScalableTypeSizeInBits(NewTy) != + DL.getScalableTypeSizeInBits(OldTy)) return false; if (!NewTy->isSingleValueType() || !OldTy->isSingleValueType()) return false; @@ -1963,7 +1965,7 @@ // that aren't byte sized. if (ElementSize % 8) return false; - assert((DL.getTypeSizeInBits(VTy) % 8) == 0 && + assert((DL.getMinTypeSizeInBits(VTy) % 8) == 0 && "vector size not a multiple of element size?"); ElementSize /= 8; @@ -2077,7 +2079,11 @@ /// promote the resulting alloca. static bool isIntegerWideningViable(Partition &P, Type *AllocaTy, const DataLayout &DL) { - uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy); + ScalableSize AllocaSize = DL.getScalableTypeSizeInBits(AllocaTy); + if (AllocaSize.Scalable) + return false; + + uint64_t SizeInBits = AllocaSize.MinSize; // Don't create integer types larger than the maximum bitwidth. if (SizeInBits > IntegerType::MAX_INT_BITS) return false; @@ -2492,7 +2498,7 @@ Type *TargetTy = IsSplit ? Type::getIntNTy(LI.getContext(), SliceSize * 8) : LI.getType(); - const bool IsLoadPastEnd = DL.getTypeStoreSize(TargetTy) > SliceSize; + const bool IsLoadPastEnd = DL.getMinTypeStoreSize(TargetTy) > SliceSize; bool IsPtrAdjusted = false; Value *V; if (VecTy) { @@ -2667,7 +2673,8 @@ if (IntTy && V->getType()->isIntegerTy()) return rewriteIntegerStore(V, SI, AATags); - const bool IsStorePastEnd = DL.getTypeStoreSize(V->getType()) > SliceSize; + const bool IsStorePastEnd = + DL.getMinTypeStoreSize(V->getType()) > SliceSize; StoreInst *NewSI; if (NewBeginOffset == NewAllocaBeginOffset && NewEndOffset == NewAllocaEndOffset && @@ -3458,8 +3465,8 @@ if (Ty->isSingleValueType()) return Ty; - uint64_t AllocSize = DL.getTypeAllocSize(Ty); - uint64_t TypeSize = DL.getTypeSizeInBits(Ty); + uint64_t AllocSize = DL.getMinTypeAllocSize(Ty); + uint64_t TypeSize = DL.getMinTypeSizeInBits(Ty); Type *InnerTy; if (ArrayType *ArrTy = dyn_cast(Ty)) { @@ -3472,8 +3479,8 @@ return Ty; } - if (AllocSize > DL.getTypeAllocSize(InnerTy) || - TypeSize > DL.getTypeSizeInBits(InnerTy)) + if (AllocSize > DL.getMinTypeAllocSize(InnerTy) || + TypeSize > DL.getMinTypeSizeInBits(InnerTy)) return Ty; return stripAggregateTypeWrapping(DL, InnerTy); @@ -3494,10 +3501,10 @@ /// return a type if necessary. static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, uint64_t Size) { - if (Offset == 0 && DL.getTypeAllocSize(Ty) == Size) + if (Offset == 0 && DL.getMinTypeAllocSize(Ty) == Size) return stripAggregateTypeWrapping(DL, Ty); - if (Offset > DL.getTypeAllocSize(Ty) || - (DL.getTypeAllocSize(Ty) - Offset) < Size) + if (Offset > DL.getMinTypeAllocSize(Ty) || + (DL.getMinTypeAllocSize(Ty) - Offset) < Size) return nullptr; if (SequentialType *SeqTy = dyn_cast(Ty)) { @@ -3542,7 +3549,7 @@ Offset -= SL->getElementOffset(Index); Type *ElementTy = STy->getElementType(Index); - uint64_t ElementSize = DL.getTypeAllocSize(ElementTy); + uint64_t ElementSize = DL.getMinTypeAllocSize(ElementTy); if (Offset >= ElementSize) return nullptr; // The offset points into alignment padding. @@ -4108,7 +4115,7 @@ Type *SliceTy = nullptr; const DataLayout &DL = AI.getModule()->getDataLayout(); if (Type *CommonUseTy = findCommonType(P.begin(), P.end(), P.endOffset())) - if (DL.getTypeAllocSize(CommonUseTy) >= P.size()) + if (DL.getMinTypeAllocSize(CommonUseTy) >= P.size()) SliceTy = CommonUseTy; if (!SliceTy) if (Type *TypePartitionTy = getTypePartition(DL, AI.getAllocatedType(), @@ -4262,7 +4269,12 @@ // to be rewritten into a partition. bool IsSorted = true; - uint64_t AllocaSize = DL.getTypeAllocSize(AI.getAllocatedType()); + uint64_t AllocaSize = DL.getMinTypeAllocSize(AI.getAllocatedType()); + + // Don't think we can do this properly for scalable types yet + if (DL.getScalableTypeAllocSize(AI.getAllocatedType()).Scalable) + return false; + const uint64_t MaxBitVectorSize = 1024; if (AllocaSize <= MaxBitVectorSize) { // If a byte boundary is included in any load or store, a slice starting or @@ -4326,7 +4338,8 @@ Changed = true; if (NewAI != &AI) { uint64_t SizeOfByte = 8; - uint64_t AllocaSize = DL.getTypeSizeInBits(NewAI->getAllocatedType()); + uint64_t AllocaSize = + DL.getMinTypeSizeInBits(NewAI->getAllocatedType()); // Don't include any padding. uint64_t Size = std::min(AllocaSize, P.size() * SizeOfByte); Fragments.push_back(Fragment(NewAI, P.beginOffset() * SizeOfByte, Size)); @@ -4346,7 +4359,7 @@ auto *Expr = DbgDeclares.front()->getExpression(); auto VarSize = Var->getSizeInBits(); DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false); - uint64_t AllocaSize = DL.getTypeSizeInBits(AI.getAllocatedType()); + uint64_t AllocaSize = DL.getMinTypeSizeInBits(AI.getAllocatedType()); for (auto Fragment : Fragments) { // Create a fragment expression describing the new partition or reuse AI's // expression if there is only one partition. @@ -4435,7 +4448,8 @@ // Skip alloca forms that this analysis can't handle. if (AI.isArrayAllocation() || !AI.getAllocatedType()->isSized() || - DL.getTypeAllocSize(AI.getAllocatedType()) == 0) + DL.getMinTypeAllocSize(AI.getAllocatedType()) == 0 || + DL.getScalableTypeAllocSize(AI.getAllocatedType()).Scalable) return false; bool Changed = false; diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -1274,7 +1274,7 @@ /// least n bits. static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) { const DataLayout &DL = DII->getModule()->getDataLayout(); - uint64_t ValueSize = DL.getTypeAllocSizeInBits(ValTy); + uint64_t ValueSize = DL.getMinTypeAllocSizeInBits(ValTy); if (auto FragmentSize = DII->getFragmentSizeInBits()) return ValueSize >= *FragmentSize; // We can't always calculate the size of the DI variable (e.g. if it is a diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -4693,7 +4693,7 @@ // If V is a store, just return the width of the stored value without // traversing the expression tree. This is the common case. if (auto *Store = dyn_cast(V)) - return DL->getTypeSizeInBits(Store->getValueOperand()->getType()); + return DL->getMinTypeSizeInBits(Store->getValueOperand()->getType()); // If V is not a store, we can traverse the expression tree to find loads // that feed it. The type of the loaded value may indicate a more suitable @@ -4742,7 +4742,7 @@ // If we didn't encounter a memory access in the expression tree, or if we // gave up for some reason, just return the width of V. if (!MaxWidth || FoundUnknownInst) - return DL->getTypeSizeInBits(V->getType()); + return DL->getMinTypeSizeInBits(V->getType()); // Otherwise, return the maximum width we found. return MaxWidth; diff --git a/llvm/unittests/IR/VectorTypesTest.cpp b/llvm/unittests/IR/VectorTypesTest.cpp --- a/llvm/unittests/IR/VectorTypesTest.cpp +++ b/llvm/unittests/IR/VectorTypesTest.cpp @@ -160,5 +160,78 @@ EXPECT_EQ(EltCnt.Min, 8U); ASSERT_TRUE(EltCnt.Scalable); } +TEST(VectorTypesTest, FixedLenComparisons) { + LLVMContext Ctx; + + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *Int64Ty = Type::getInt64Ty(Ctx); + + VectorType *V2Int32Ty = VectorType::get(Int32Ty, 2); + VectorType *V4Int32Ty = VectorType::get(Int32Ty, 4); + + VectorType *V2Int64Ty = VectorType::get(Int64Ty, 2); + + ScalableSize V2I32Len = V2Int32Ty->getScalableSizeInBits(); + EXPECT_EQ(V2I32Len.MinSize, 64U); + EXPECT_FALSE(V2I32Len.Scalable); + + EXPECT_LT(V2Int32Ty->getScalableSizeInBits(), + V4Int32Ty->getScalableSizeInBits()); + EXPECT_GT(V2Int64Ty->getScalableSizeInBits(), + V2Int32Ty->getScalableSizeInBits()); + EXPECT_EQ(V4Int32Ty->getScalableSizeInBits(), + V2Int64Ty->getScalableSizeInBits()); + EXPECT_NE(V2Int32Ty->getScalableSizeInBits(), + V2Int64Ty->getScalableSizeInBits()); + + // Check that a fixed-only comparison works for fixed size vectors. + EXPECT_EQ(V2Int64Ty->getFixedSizeInBits(), + V4Int32Ty->getFixedSizeInBits()); +} + +TEST(VectorTypesTest, ScalableComparisons) { + LLVMContext Ctx; + + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *Int64Ty = Type::getInt64Ty(Ctx); + + VectorType *ScV2Int32Ty = VectorType::get(Int32Ty, {2, true}); + VectorType *ScV4Int32Ty = VectorType::get(Int32Ty, {4, true}); + + VectorType *ScV2Int64Ty = VectorType::get(Int64Ty, {2, true}); + + ScalableSize ScV2I32Len = ScV2Int32Ty->getScalableSizeInBits(); + EXPECT_EQ(ScV2I32Len.MinSize, 64U); + EXPECT_TRUE(ScV2I32Len.Scalable); + + EXPECT_LT(ScV2Int32Ty->getScalableSizeInBits(), + ScV4Int32Ty->getScalableSizeInBits()); + EXPECT_GT(ScV2Int64Ty->getScalableSizeInBits(), + ScV2Int32Ty->getScalableSizeInBits()); + EXPECT_EQ(ScV4Int32Ty->getScalableSizeInBits(), + ScV2Int64Ty->getScalableSizeInBits()); + EXPECT_NE(ScV2Int32Ty->getScalableSizeInBits(), + ScV2Int64Ty->getScalableSizeInBits()); +} + +TEST(VectorTypesTest, CrossComparisons) { + LLVMContext Ctx; + + Type *Int32Ty = Type::getInt32Ty(Ctx); + + VectorType *V4Int32Ty = VectorType::get(Int32Ty, {4, false}); + VectorType *ScV4Int32Ty = VectorType::get(Int32Ty, {4, true}); + + // Even though the minimum size is the same, a scalable vector could be + // larger so we don't consider them to be the same size. + EXPECT_NE(V4Int32Ty->getScalableSizeInBits(), + ScV4Int32Ty->getScalableSizeInBits()); + // If we are only checking the minimum, then they are the same size. + EXPECT_EQ(V4Int32Ty->getMinSizeInBits(), + ScV4Int32Ty->getMinSizeInBits()); + + // We can't use ordering comparisons (<,<=,>,>=) between scalable and + // non-scalable vector sizes. +} } // end anonymous namespace diff --git a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp --- a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp +++ b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp @@ -503,9 +503,16 @@ } auto LT = [](MVT A, MVT B) -> bool { + // Always treat non-scalable MVTs as smaller than scalable MVTs for the + // purposes of ordering. + if (A.isScalableVector() && !B.isScalableVector()) + return false; + if (!A.isScalableVector() && B.isScalableVector()) + return true; + return A.getScalarSizeInBits() < B.getScalarSizeInBits() || (A.getScalarSizeInBits() == B.getScalarSizeInBits() && - A.getSizeInBits() < B.getSizeInBits()); + A.getScalableSizeInBits() < B.getScalableSizeInBits()); }; auto LE = [<](MVT A, MVT B) -> bool { // This function is used when removing elements: when a vector is compared @@ -513,8 +520,13 @@ if (A.isVector() != B.isVector()) return false; + // We also don't want to remove elements when they're both vectors with the + // same minimum number of lanes, but one is scalable and the other not. + if (A.isScalableVector() != B.isScalableVector()) + return false; + return LT(A, B) || (A.getScalarSizeInBits() == B.getScalarSizeInBits() && - A.getSizeInBits() == B.getSizeInBits()); + A.getScalableSizeInBits() == B.getScalableSizeInBits()); }; for (unsigned M : Modes) {