diff --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h --- a/llvm/include/llvm/CodeGen/ValueTypes.h +++ b/llvm/include/llvm/CodeGen/ValueTypes.h @@ -18,6 +18,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/MachineValueType.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -209,7 +210,7 @@ /// Return true if the bit size is a multiple of 8. bool isByteSized() const { - return (getSizeInBits() & 7) == 0; + return (getMinSizeInBits() & 7) == 0; } /// Return true if the size is a power-of-two number of bytes. @@ -221,31 +222,31 @@ /// Return true if this has the same number of bits as VT. bool bitsEq(EVT VT) const { if (EVT::operator==(VT)) return true; - return getSizeInBits() == VT.getSizeInBits(); + return getScalableSizeInBits() == VT.getScalableSizeInBits(); } /// Return true if this has more bits than VT. bool bitsGT(EVT VT) const { if (EVT::operator==(VT)) return false; - return getSizeInBits() > VT.getSizeInBits(); + return getScalableSizeInBits() > VT.getScalableSizeInBits(); } /// Return true if this has no less bits than VT. bool bitsGE(EVT VT) const { if (EVT::operator==(VT)) return true; - return getSizeInBits() >= VT.getSizeInBits(); + return getScalableSizeInBits() >= VT.getScalableSizeInBits(); } /// Return true if this has less bits than VT. bool bitsLT(EVT VT) const { if (EVT::operator==(VT)) return false; - return getSizeInBits() < VT.getSizeInBits(); + return getScalableSizeInBits() < VT.getScalableSizeInBits(); } /// Return true if this has no more bits than VT. bool bitsLE(EVT VT) const { if (EVT::operator==(VT)) return true; - return getSizeInBits() <= VT.getSizeInBits(); + return getScalableSizeInBits() <= VT.getScalableSizeInBits(); } /// Return the SimpleValueType held in the specified simple EVT. @@ -287,29 +288,79 @@ return {getExtendedVectorNumElements(), false}; } - /// Return the size of the specified value type in bits. + /// Return the size of the specified value type in bits. An assert will + /// occur if this is called on a scalable vector type. unsigned getSizeInBits() const { if (isSimple()) return V.getSizeInBits(); return getExtendedSizeInBits(); } + /// Returns the size of the specified value type as a minimum number of + /// bits and a boolean indicating whether the runtime size is exactly that + /// size (if false) or if it's an integer multiple of that minimum (true). + ScalableSize getScalableSizeInBits() const { + if (isSimple()) + return V.getScalableSizeInBits(); + return getScalableExtendedSizeInBits(); + } + + /// Returns the size of the type in bits. If the type is scalable, this + /// quantity represents the minimum size. If the type is not scalable, + /// it represents the exact size. + unsigned getMinSizeInBits() const { + if (isSimple()) + return V.getMinSizeInBits(); + return getMinExtendedSizeInBits(); + } + unsigned getScalarSizeInBits() const { return getScalarType().getSizeInBits(); } /// Return the number of bytes overwritten by a store of the specified value - /// type. + /// type. An assert will occur if this is called on a scalable vector type. unsigned getStoreSize() const { return (getSizeInBits() + 7) / 8; } - /// Return the number of bits overwritten by a store of the specified value - /// type. + /// Returns the minimum number of bytes overwritten by a store of the + /// specified value type, along with a boolean indicating whether the + /// runtime size written to is exactly that size (if false) or if it's an + /// integer multiple of that minimum (true). + ScalableSize getScalableStoreSize() const { + ScalableSize SizeInBits = getScalableSizeInBits(); + return { (SizeInBits.getMinSize() + 7) / 8, SizeInBits.isScalable() }; + } + + /// Returns the number of bytes overwritten by a store of the specified + /// value type. If the type is scalable, this quantity represents the + /// minimum size. If not scalable, it represents the exact size. + unsigned getMinStoreSize() const { + return (getMinSizeInBits() + 7) / 8; + } + + /// Returns the number of bits overwritten by a store of the specified value + /// type. An assert will occur if this is called on a scalable vector type. unsigned getStoreSizeInBits() const { return getStoreSize() * 8; } + /// Returns the minimum number of bits overwritten by a store of the + /// specified value type, along with a boolean indicating whether the + /// runtime size written to is exactly that size (if false) or if it's an + /// integer multiple of that minimum (true). + ScalableSize getScalableStoreSizeInBits() const { + return getScalableStoreSize() * 8; + } + + /// Returns the number of bits overwritten by a store of the specified value + /// type. If the type is scalable, this quantity represents the minimum + /// size. If not scalable, it represents the exact size. + unsigned getMinStoreSizeInBits() const { + return getMinStoreSize() * 8; + } + /// Rounds the bit-width of the given integer EVT up to the nearest power of /// two (and at least to eight), and returns the integer EVT with that /// number of bits. @@ -429,6 +480,8 @@ EVT getExtendedVectorElementType() const; unsigned getExtendedVectorNumElements() const LLVM_READONLY; unsigned getExtendedSizeInBits() const LLVM_READONLY; + ScalableSize getScalableExtendedSizeInBits() const LLVM_READONLY; + unsigned getMinExtendedSizeInBits() const LLVM_READONLY; }; } // end namespace llvm diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h --- a/llvm/include/llvm/IR/DataLayout.h +++ b/llvm/include/llvm/IR/DataLayout.h @@ -30,6 +30,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/Alignment.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -436,30 +437,77 @@ /// /// For example, returns 36 for i36 and 80 for x86_fp80. The type passed must /// have a size (Type::isSized() must return true). + /// + /// An assert will occur if this is called on a scalable vector type. uint64_t getTypeSizeInBits(Type *Ty) const; + /// Returns the minimum number of bits necessary to hold the specified type + /// and a boolean indicating whether the runtime size is exactly that size + /// (if false) or if it's an integer multiple of that minimum (true). + ScalableSize getScalableTypeSizeInBits(Type *Ty) const; + + /// Returns the size of the type in bits. If the type is scalable, this + /// quantity represents the minimum size. If the type is not scalable, + /// it represents the exact size. + uint64_t getMinTypeSizeInBits(Type *Ty) const; + /// Returns the maximum number of bytes that may be overwritten by /// storing the specified type. /// /// For example, returns 5 for i36 and 10 for x86_fp80. + /// + /// An assert will occur if this is called on a scalable vector type. uint64_t getTypeStoreSize(Type *Ty) const { return (getTypeSizeInBits(Ty) + 7) / 8; } + /// Returns the number of bytes overwritten by a store of the specified type, + /// along with a boolean indicating whether the runtime size written to is + /// exactly that size (if false) or if it's an integer multiple of that + /// size (true). + ScalableSize getScalableTypeStoreSize(Type *Ty) const { + auto Bits = getScalableTypeSizeInBits(Ty); + return ScalableSize((Bits.getMinSize()+7)/8, Bits.isScalable()); + } + + /// Returns the number of bytes overwritten by a store of the specified type. + /// If the type is scalable, this quantity represents the minimum size. If + /// not scalable, it represents the exact size. + uint64_t getMinTypeStoreSize(Type *Ty) const { + return (getScalableTypeSizeInBits(Ty).getMinSize()+7)/8; + } + /// Returns the maximum number of bits that may be overwritten by /// storing the specified type; always a multiple of 8. /// /// For example, returns 40 for i36 and 80 for x86_fp80. + /// + /// An assert will occur if this is called on a scalable vector type. uint64_t getTypeStoreSizeInBits(Type *Ty) const { return 8 * getTypeStoreSize(Ty); } + /// Returns the number of bits overwritten by a store of the specified type, + /// along with a boolean indicating whether the runtime size written to is + /// exactly that size (if false) or if it's an integer multiple of that + /// size (true). + ScalableSize getScalableTypeStoreSizeInBits(Type *Ty) const { + auto Bytes = getScalableTypeStoreSize(Ty); + return {Bytes.getMinSize() * 8, Bytes.isScalable()}; + } + + /// Returns the number of bits overwritten by a store of the specified type. + /// If the type is scalable, this quantity represents the minimum size. If + /// not scalable, it represents the exact size. + uint64_t getMinTypeStoreSizeInBits(Type *Ty) const { + return 8 * getMinTypeStoreSize(Ty); + } /// Returns true if no extra padding bits are needed when storing the /// specified type. /// /// For example, returns false for i19 that has a 24-bit store size. bool typeSizeEqualsStoreSize(Type *Ty) const { - return getTypeSizeInBits(Ty) == getTypeStoreSizeInBits(Ty); + return getScalableTypeSizeInBits(Ty) == getScalableTypeStoreSizeInBits(Ty); } /// Returns the offset in bytes between successive objects of the @@ -467,20 +515,60 @@ /// /// This is the amount that alloca reserves for this type. For example, /// returns 12 or 16 for x86_fp80, depending on alignment. + /// + /// An assert will occur if this is called on a scalable vector type. uint64_t getTypeAllocSize(Type *Ty) const { // Round up to the next alignment boundary. return alignTo(getTypeStoreSize(Ty), getABITypeAlignment(Ty)); } + /// Returns the offset in bytes between successive object of the specified + /// type (including alignment padding), along with a boolean indicating + /// whether the runtime size written to is exactly that size (if false) or if + /// it's an integer multiple of that size (true). + ScalableSize getScalableTypeAllocSize(Type *Ty) const { + auto Bytes = getScalableTypeStoreSize(Ty); + uint64_t MinAlignedSize = alignTo(Bytes.getMinSize(), + getABITypeAlignment(Ty)); + return ScalableSize(MinAlignedSize, Bytes.isScalable()); + } + + /// Returns the offset in bytes between successive objects of the + /// specified type, including alignment padding. + /// If the type is scalable, this quantity represents the minimum size. If + /// not scalable, it represents the exact size. + uint64_t getMinTypeAllocSize(Type *Ty) const { + return alignTo(getMinTypeStoreSize(Ty), getABITypeAlignment(Ty)); + } + /// Returns the offset in bits between successive objects of the /// specified type, including alignment padding; always a multiple of 8. /// /// This is the amount that alloca reserves for this type. For example, /// returns 96 or 128 for x86_fp80, depending on alignment. + /// + /// An assert will occur if this is called on a scalable vector type. uint64_t getTypeAllocSizeInBits(Type *Ty) const { return 8 * getTypeAllocSize(Ty); } + /// Returns the offset in bits between successive object of the specified + /// type (including alignment padding), along with a boolean indicating + /// whether the runtime size written to is exactly that size (if false) or if + /// it's an integer multiple of that size (true). + ScalableSize getScalableTypeAllocSizeInBits(Type *Ty) const { + auto Bytes = getScalableTypeAllocSize(Ty); + return {Bytes.getMinSize() * 8, Bytes.isScalable()}; + } + + /// Returns the offset in bits between successive objects of the + /// specified type, including alignment padding. + /// If the type is scalable, this quantity represents the minimum size. If + /// not scalable, it represents the exact size. + uint64_t getMinTypeAllocSizeInBits(Type *Ty) const { + return 8 * getMinTypeAllocSize(Ty); + } + /// Returns the minimum ABI-required alignment for the specified type. unsigned getABITypeAlignment(Type *Ty) const; @@ -632,6 +720,8 @@ return 80; case Type::VectorTyID: { VectorType *VTy = cast(Ty); + assert(!VTy->isScalable() && + "Scalable vector sizes cannot be represented by a scalar"); return VTy->getNumElements() * getTypeSizeInBits(VTy->getElementType()); } default: @@ -639,6 +729,23 @@ } } +inline ScalableSize DataLayout::getScalableTypeSizeInBits(Type *Ty) const { + switch(Ty->getTypeID()) { + default: + return {getTypeSizeInBits(Ty), false}; + case Type::VectorTyID: { + VectorType *VTy = cast(Ty); + auto EltCnt = VTy->getElementCount(); + uint64_t MinBits = EltCnt.Min * getTypeSizeInBits(VTy->getElementType()); + return {MinBits, EltCnt.Scalable}; + } + } +} + +inline uint64_t DataLayout::getMinTypeSizeInBits(Type *Ty) const { + return getScalableTypeSizeInBits(Ty).getMinSize(); +} + } // end namespace llvm #endif // LLVM_IR_DATALAYOUT_H diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -975,7 +975,7 @@ static Type* makeCmpResultType(Type* opnd_type) { if (VectorType* vt = dyn_cast(opnd_type)) { return VectorType::get(Type::getInt1Ty(opnd_type->getContext()), - vt->getNumElements()); + vt->getElementCount()); } return Type::getInt1Ty(opnd_type->getContext()); } diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h --- a/llvm/include/llvm/IR/Type.h +++ b/llvm/include/llvm/IR/Type.h @@ -21,6 +21,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -286,8 +287,22 @@ /// instance of the type is stored to memory. The DataLayout class provides /// additional query functions to provide this information. /// + /// An assert will occur if this is called on a scalable vector type. unsigned getPrimitiveSizeInBits() const LLVM_READONLY; + // Returns a ScalableSize for the type in question. This should be used in + // place of getPrimitiveSizeInBits in places where the type may be a + // VectorType with the Scalable flag set. + ScalableSize getScalableSizeInBits() const LLVM_READONLY; + + /// Returns the minimum known size in bits, ignoring whether the type might + /// be a scalable vector. + unsigned getMinSizeInBits() const LLVM_READONLY; + + /// Returns the minimum known size in bits, asserting if called on a scalable + /// vector type. + unsigned getFixedSizeInBits() const LLVM_READONLY; + /// If this is a vector type, return the getPrimitiveSizeInBits value for the /// element type. Otherwise return the getPrimitiveSizeInBits value for this /// type. diff --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h --- a/llvm/include/llvm/Support/MachineValueType.h +++ b/llvm/include/llvm/Support/MachineValueType.h @@ -634,7 +634,68 @@ return { getVectorNumElements(), isScalableVector() }; } + /// Returns the size of the specified MVT as a minimum number of bits and a + /// boolean indicating whether the runtime size is exactly that + /// size (if false) or if it's an integer multiple of that minimum (true). + ScalableSize getScalableSizeInBits() const { + switch (SimpleTy) { + default: return { getSizeInBits(), false }; + case nxv1i1: return { 1U, true }; + case nxv2i1: return { 2U, true }; + case nxv4i1: return { 4U, true }; + case nxv1i8: + case nxv8i1: return { 8U, true }; + case nxv16i1: + case nxv2i8: + case nxv1i16: return { 16U, true }; + case nxv32i1: + case nxv4i8: + case nxv2i16: + case nxv1i32: + case nxv2f16: + case nxv1f32: return { 32U, true }; + case nxv8i8: + case nxv4i16: + case nxv2i32: + case nxv1i64: + case nxv4f16: + case nxv2f32: + case nxv1f64: return { 64U, true }; + case nxv16i8: + case nxv8i16: + case nxv4i32: + case nxv2i64: + case nxv8f16: + case nxv4f32: + case nxv2f64: return { 128U, true }; + case nxv32i8: + case nxv16i16: + case nxv8i32: + case nxv4i64: + case nxv8f32: + case nxv4f64: return { 256U, true }; + case nxv32i16: + case nxv16i32: + case nxv8i64: + case nxv16f32: + case nxv8f64: return { 512U, true }; + case nxv32i32: + case nxv16i64: return { 1024U, true }; + case nxv32i64: return { 2048U, true }; + } + } + + /// Returns the size of the MVT in bits. If the type is scalable, this + /// quantity represents the minimum size. If the type is not scalable, + /// it represents the exact size. + unsigned getMinSizeInBits() const { + return getScalableSizeInBits().getMinSize(); + } + + /// Returns the size of the specified MVT in bits. + /// An assert will occur if this is called on a scalable vector type. unsigned getSizeInBits() const { + assert(!isScalableVector() && "getSizeInBits called on scalable vector"); switch (SimpleTy) { default: llvm_unreachable("getSizeInBits called on extended MVT."); @@ -654,25 +715,17 @@ case Metadata: llvm_unreachable("Value type is metadata."); case i1: - case v1i1: - case nxv1i1: return 1; - case v2i1: - case nxv2i1: return 2; - case v4i1: - case nxv4i1: return 4; + case v1i1: return 1; + case v2i1: return 2; + case v4i1: return 4; case i8 : case v1i8: - case v8i1: - case nxv1i8: - case nxv8i1: return 8; + case v8i1: return 8; case i16 : case f16: case v16i1: case v2i8: - case v1i16: - case nxv16i1: - case nxv2i8: - case nxv1i16: return 16; + case v1i16: return 16; case f32 : case i32 : case v32i1: @@ -680,13 +733,7 @@ case v2i16: case v2f16: case v1f32: - case v1i32: - case nxv32i1: - case nxv4i8: - case nxv2i16: - case nxv1i32: - case nxv2f16: - case nxv1f32: return 32; + case v1i32: return 32; case x86mmx: case f64 : case i64 : @@ -697,14 +744,7 @@ case v1i64: case v4f16: case v2f32: - case v1f64: - case nxv8i8: - case nxv4i16: - case nxv2i32: - case nxv1i64: - case nxv4f16: - case nxv2f32: - case nxv1f64: return 64; + case v1f64: return 64; case f80 : return 80; case v3i32: case v3f32: return 96; @@ -719,14 +759,7 @@ case v1i128: case v8f16: case v4f32: - case v2f64: - case nxv16i8: - case nxv8i16: - case nxv4i32: - case nxv2i64: - case nxv8f16: - case nxv4f32: - case nxv2f64: return 128; + case v2f64: return 128; case v5i32: case v5f32: return 160; case v32i8: @@ -734,39 +767,25 @@ case v8i32: case v4i64: case v8f32: - case v4f64: - case nxv32i8: - case nxv16i16: - case nxv8i32: - case nxv4i64: - case nxv8f32: - case nxv4f64: return 256; + case v4f64: return 256; case v512i1: case v64i8: case v32i16: case v16i32: case v8i64: case v16f32: - case v8f64: - case nxv32i16: - case nxv16i32: - case nxv8i64: - case nxv16f32: - case nxv8f64: return 512; + case v8f64: return 512; case v1024i1: case v128i8: case v64i16: case v32i32: case v16i64: - case v32f32: - case nxv32i32: - case nxv16i64: return 1024; + case v32f32: return 1024; case v256i8: case v128i16: case v64i32: case v32i64: - case v64f32: - case nxv32i64: return 2048; + case v64f32: return 2048; case v128i32: case v128f32: return 4096; case v256i32: @@ -791,30 +810,62 @@ return (getSizeInBits() + 7) / 8; } + /// Returns the minimum number of bytes overwritten by a store of the + /// specified value type, along with a boolean indicating whether the + /// runtime size written to is exactly that size (if false) or if it's an + /// integer multiple of that minimum (true). + ScalableSize getScalableStoreSize() const { + ScalableSize SizeInBits = getScalableSizeInBits(); + return { (SizeInBits.getMinSize() + 7) / 8, SizeInBits.isScalable() }; + } + + /// Returns the number of bytes overwritten by a store of the specified + /// value type. If the type is scalable, this quantity represents the + /// minimum size. If not scalable, it represents the exact size. + unsigned getMinStoreSize() const { + return getScalableStoreSize().getMinSize(); + } + /// Return the number of bits overwritten by a store of the specified value /// type. unsigned getStoreSizeInBits() const { return getStoreSize() * 8; } + /// Returns the minimum number of bits overwritten by a store of the + /// specified value type, along with a boolean indicating whether the + /// runtime size written to is exactly that size (if false) or if it's an + /// integer multiple of that minimum (true). + ScalableSize getScalableStoreSizeInBits() const { + ScalableSize SizeInBytes = getScalableStoreSize(); + return { SizeInBytes.getMinSize() * 8, SizeInBytes.isScalable() }; + } + + /// Returns the number of bits overwritten by a store of the specified + /// value type. If the type is scalable, this quantity represents the + /// minimum size. If not scalable, it represents the exact size. + unsigned getMinStoreSizeInBits() const { + return getScalableStoreSizeInBits().getMinSize(); + } + /// Return true if this has more bits than VT. bool bitsGT(MVT VT) const { - return getSizeInBits() > VT.getSizeInBits(); + return getScalableSizeInBits() > VT.getScalableSizeInBits(); } /// Return true if this has no less bits than VT. bool bitsGE(MVT VT) const { - return getSizeInBits() >= VT.getSizeInBits(); + return getScalableSizeInBits() >= VT.getScalableSizeInBits(); } /// Return true if this has less bits than VT. bool bitsLT(MVT VT) const { - return getSizeInBits() < VT.getSizeInBits(); + return getScalableSizeInBits() < VT.getScalableSizeInBits(); } /// Return true if this has no more bits than VT. bool bitsLE(MVT VT) const { - return getSizeInBits() <= VT.getSizeInBits(); + return getScalableSizeInBits() <= VT.getScalableSizeInBits(); } static MVT getFloatingPointVT(unsigned BitWidth) { diff --git a/llvm/include/llvm/Support/ScalableSize.h b/llvm/include/llvm/Support/ScalableSize.h --- a/llvm/include/llvm/Support/ScalableSize.h +++ b/llvm/include/llvm/Support/ScalableSize.h @@ -15,6 +15,8 @@ #ifndef LLVM_SUPPORT_SCALABLESIZE_H #define LLVM_SUPPORT_SCALABLESIZE_H +#include + namespace llvm { class ElementCount { @@ -38,6 +40,93 @@ } }; +// This class is used to represent the size of types. If the type is of fixed +// size, it will represent the exact size. If the type is a scalable vector, +// it will represent the known minimum size. +class ScalableSize { + uint64_t MinSize; // The known minimum size. + bool Scalable; // If true, then the runtime size is an integer multiple + // of MinSize. + +public: + constexpr ScalableSize(uint64_t MinSize, bool Scalable) + : MinSize(MinSize), Scalable(Scalable) {} + + // Scalable vector types with the same minimum size as a fixed size type are + // not guaranteed to be the same size at runtime, so they are never + // considered to be equal. + friend bool operator==(const ScalableSize &LHS, const ScalableSize &RHS) { + return std::tie(LHS.MinSize, LHS.Scalable) == + std::tie(RHS.MinSize, RHS.Scalable); + } + + friend bool operator!=(const ScalableSize &LHS, const ScalableSize &RHS) { + return !(LHS == RHS); + } + + // For many cases, size ordering between scalable and fixed size types cannot + // be determined at compile time, so such comparisons aren't allowed. + // + // e.g. could be bigger than <4 x i32> with a runtime + // vscale >= 5, equal sized with a vscale of 4, and smaller with + // a vscale <= 3. + // + // If the scalable flags match, just perform the requested comparison + // between the minimum sizes. + friend bool operator<(const ScalableSize &LHS, const ScalableSize &RHS) { + assert(LHS.Scalable == RHS.Scalable && + "Ordering comparison of scalable and fixed types"); + + return LHS.MinSize < RHS.MinSize; + } + + friend bool operator>(const ScalableSize &LHS, const ScalableSize &RHS) { + return RHS < LHS; + } + + friend bool operator<=(const ScalableSize &LHS, const ScalableSize &RHS) { + return !(RHS < LHS); + } + + friend bool operator>=(const ScalableSize &LHS, const ScalableSize& RHS) { + return !(LHS < RHS); + } + + // Convenience operators to obtain relative sizes independently of + // the scalable flag. + ScalableSize operator*(unsigned RHS) const { + return { MinSize * RHS, Scalable }; + } + + friend ScalableSize operator*(const unsigned LHS, const ScalableSize &RHS) { + return { RHS.MinSize * LHS, RHS.Scalable }; + } + + ScalableSize operator/(unsigned RHS) const { + return { MinSize / RHS, Scalable }; + } + + // Return the minimum size with the assumption that the size is exact. + // Use in places where a scalable size doesn't make sense (e.g. non-vector + // types, or vectors in backends which don't support scalable vectors) + uint64_t getFixedSize() const { + assert(!Scalable && "Request for a fixed size on a scalable object"); + return MinSize; + } + + // Return the minimum size. Use in places where the scalable property doesn't + // matter (e.g. determining alignment) or in conjunction with the + // isScalable method below. + uint64_t getMinSize() const { + return MinSize; + } + + // Return whether or not the size is scalable. + bool isScalable() const { + return Scalable; + } +}; + } // end namespace llvm #endif // LLVM_SUPPORT_SCALABLESIZE_H diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -221,11 +221,13 @@ ForCodeSize = DAG.getMachineFunction().getFunction().hasOptSize(); MaximumLegalStoreInBits = 0; + // We use the minimum store size here, since that's all we can guarantee + // for the scalable vector types. for (MVT VT : MVT::all_valuetypes()) if (EVT(VT).isSimple() && VT != MVT::Other && TLI.isTypeLegal(EVT(VT)) && - VT.getSizeInBits() >= MaximumLegalStoreInBits) - MaximumLegalStoreInBits = VT.getSizeInBits(); + VT.getMinSizeInBits() >= MaximumLegalStoreInBits) + MaximumLegalStoreInBits = VT.getMinSizeInBits(); } void ConsiderForPruning(SDNode *N) { diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -23,6 +23,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Support/ScalableSize.h" using namespace llvm; #define DEBUG_TYPE "legalize-types" @@ -4597,7 +4598,8 @@ unsigned Width, EVT WidenVT, unsigned Align = 0, unsigned WidenEx = 0) { EVT WidenEltVT = WidenVT.getVectorElementType(); - unsigned WidenWidth = WidenVT.getSizeInBits(); + const bool Scalable = WidenVT.isScalableVector(); + unsigned WidenWidth = WidenVT.getMinSizeInBits(); unsigned WidenEltWidth = WidenEltVT.getSizeInBits(); unsigned AlignInBits = Align*8; @@ -4608,23 +4610,27 @@ // See if there is larger legal integer than the element type to load/store. unsigned VT; - for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE; - VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) { - EVT MemVT((MVT::SimpleValueType) VT); - unsigned MemVTWidth = MemVT.getSizeInBits(); - if (MemVT.getSizeInBits() <= WidenEltWidth) - break; - auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT); - if ((Action == TargetLowering::TypeLegal || - Action == TargetLowering::TypePromoteInteger) && - (WidenWidth % MemVTWidth) == 0 && - isPowerOf2_32(WidenWidth / MemVTWidth) && - (MemVTWidth <= Width || - (Align!=0 && MemVTWidth<=AlignInBits && MemVTWidth<=Width+WidenEx))) { - if (MemVTWidth == WidenWidth) - return MemVT; - RetVT = MemVT; - break; + // Don't bother looking for an integer type if the vector is scalable, skip + // to vector types. + if (!Scalable) { + for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE; + VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) { + EVT MemVT((MVT::SimpleValueType) VT); + unsigned MemVTWidth = MemVT.getSizeInBits(); + if (MemVT.getSizeInBits() <= WidenEltWidth) + break; + auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT); + if ((Action == TargetLowering::TypeLegal || + Action == TargetLowering::TypePromoteInteger) && + (WidenWidth % MemVTWidth) == 0 && + isPowerOf2_32(WidenWidth / MemVTWidth) && + (MemVTWidth <= Width || + (Align!=0 && MemVTWidth<=AlignInBits && MemVTWidth<=Width+WidenEx))) { + if (MemVTWidth == WidenWidth) + return MemVT; + RetVT = MemVT; + break; + } } } @@ -4633,7 +4639,10 @@ for (VT = (unsigned)MVT::LAST_VECTOR_VALUETYPE; VT >= (unsigned)MVT::FIRST_VECTOR_VALUETYPE; --VT) { EVT MemVT = (MVT::SimpleValueType) VT; - unsigned MemVTWidth = MemVT.getSizeInBits(); + // Skip vector MVTs which don't match the scalable property of WidenVT. + if (Scalable != MemVT.isScalableVector()) + continue; + unsigned MemVTWidth = MemVT.getMinSizeInBits(); auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT); if ((Action == TargetLowering::TypeLegal || Action == TargetLowering::TypePromoteInteger) && diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -9174,7 +9174,7 @@ // if it isn't first piece, alignment must be 1 ISD::OutputArg MyFlags(Flags, Parts[j].getValueType(), VT, i < CLI.NumFixedArgs, - i, j*Parts[j].getValueType().getStoreSize()); + i, j*Parts[j].getValueType().getMinStoreSize()); if (NumParts > 1 && j == 0) MyFlags.Flags.setSplit(); else if (j != 0) { @@ -9643,8 +9643,11 @@ unsigned NumRegs = TLI->getNumRegistersForCallingConv( *CurDAG->getContext(), F.getCallingConv(), VT); for (unsigned i = 0; i != NumRegs; ++i) { + // For scalable vectors, use the minimum size; individual targets + // are responsible for handling scalable vector arguments and + // return values. ISD::InputArg MyFlags(Flags, RegisterVT, VT, isArgValueUsed, - ArgNo, PartBase+i*RegisterVT.getStoreSize()); + ArgNo, PartBase+i*RegisterVT.getMinStoreSize()); if (NumRegs > 1 && i == 0) MyFlags.Flags.setSplit(); // if it isn't first piece, alignment must be 1 @@ -9657,7 +9660,7 @@ } if (NeedsRegBlock && Value == NumValues - 1) Ins[Ins.size() - 1].Flags.setInConsecutiveRegsLast(); - PartBase += VT.getStoreSize(); + PartBase += VT.getMinStoreSize(); } } diff --git a/llvm/lib/CodeGen/ValueTypes.cpp b/llvm/lib/CodeGen/ValueTypes.cpp --- a/llvm/lib/CodeGen/ValueTypes.cpp +++ b/llvm/lib/CodeGen/ValueTypes.cpp @@ -11,6 +11,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Type.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ScalableSize.h" using namespace llvm; EVT EVT::changeExtendedTypeToInteger() const { @@ -105,11 +106,26 @@ assert(isExtended() && "Type is not extended!"); if (IntegerType *ITy = dyn_cast(LLVMTy)) return ITy->getBitWidth(); - if (VectorType *VTy = dyn_cast(LLVMTy)) + if (VectorType *VTy = dyn_cast(LLVMTy)) { + assert(!VTy->isScalable() && + "Size of scalable type cannot be represented by a scalar."); return VTy->getBitWidth(); + } llvm_unreachable("Unrecognized extended type!"); } +ScalableSize EVT::getScalableExtendedSizeInBits() const { + assert(isExtended() && "Type is not extended!"); + if (VectorType *VTy = dyn_cast(LLVMTy)) + return VTy->getScalableSizeInBits(); + return { getExtendedSizeInBits(), false }; +} + +unsigned EVT::getMinExtendedSizeInBits() const { + assert(isExtended() && "Type is not extended!"); + return getScalableExtendedSizeInBits().getMinSize(); +} + /// getEVTString - This function returns value type as a string, e.g. "i32". std::string EVT::getEVTString() const { switch (V.SimpleTy) { diff --git a/llvm/lib/IR/DataLayout.cpp b/llvm/lib/IR/DataLayout.cpp --- a/llvm/lib/IR/DataLayout.cpp +++ b/llvm/lib/IR/DataLayout.cpp @@ -29,6 +29,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -740,7 +741,9 @@ llvm_unreachable("Bad type for getAlignment!!!"); } - return getAlignmentInfo(AlignType, getTypeSizeInBits(Ty), abi_or_pref, Ty); + // If we're dealing with a scalable vector, we just need the minimum size for + // determining alignment. If not, we'll get the exact size. + return getAlignmentInfo(AlignType, getMinTypeSizeInBits(Ty), abi_or_pref, Ty); } unsigned DataLayout::getABITypeAlignment(Type *Ty) const { diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -38,6 +38,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -1778,7 +1779,7 @@ const Twine &Name, Instruction *InsertBefore) : Instruction(VectorType::get(cast(V1->getType())->getElementType(), - cast(Mask->getType())->getNumElements()), + cast(Mask->getType())->getElementCount()), ShuffleVector, OperandTraits::op_begin(this), OperandTraits::operands(this), @@ -1795,7 +1796,7 @@ const Twine &Name, BasicBlock *InsertAtEnd) : Instruction(VectorType::get(cast(V1->getType())->getElementType(), - cast(Mask->getType())->getNumElements()), + cast(Mask->getType())->getElementCount()), ShuffleVector, OperandTraits::op_begin(this), OperandTraits::operands(this), @@ -2968,8 +2969,8 @@ } // Get the bit sizes, we'll need these - unsigned SrcBits = SrcTy->getPrimitiveSizeInBits(); // 0 for ptr - unsigned DestBits = DestTy->getPrimitiveSizeInBits(); // 0 for ptr + auto SrcBits = SrcTy->getScalableSizeInBits(); // 0 for ptr + auto DestBits = DestTy->getScalableSizeInBits(); // 0 for ptr // Run through the possibilities ... if (DestTy->isIntegerTy()) { // Casting to integral @@ -3030,12 +3031,12 @@ } } - unsigned SrcBits = SrcTy->getPrimitiveSizeInBits(); // 0 for ptr - unsigned DestBits = DestTy->getPrimitiveSizeInBits(); // 0 for ptr + auto SrcBits = SrcTy->getScalableSizeInBits(); // 0 for ptr + auto DestBits = DestTy->getScalableSizeInBits(); // 0 for ptr // Could still have vectors of pointers if the number of elements doesn't // match - if (SrcBits == 0 || DestBits == 0) + if (SrcBits.getMinSize() == 0 || DestBits.getMinSize() == 0) return false; if (SrcBits != DestBits) @@ -3245,7 +3246,7 @@ // For non-pointer cases, the cast is okay if the source and destination bit // widths are identical. if (!SrcPtrTy) - return SrcTy->getPrimitiveSizeInBits() == DstTy->getPrimitiveSizeInBits(); + return SrcTy->getScalableSizeInBits() == DstTy->getScalableSizeInBits(); // If both are pointers then the address spaces must match. if (SrcPtrTy->getAddressSpace() != DstPtrTy->getAddressSpace()) diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -26,6 +26,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Support/ScalableSize.h" #include #include @@ -121,7 +122,11 @@ case Type::PPC_FP128TyID: return 128; case Type::X86_MMXTyID: return 64; case Type::IntegerTyID: return cast(this)->getBitWidth(); - case Type::VectorTyID: return cast(this)->getBitWidth(); + case Type::VectorTyID: { + const VectorType *VTy = cast(this); + assert(!VTy->isScalable() && "Scalable vectors are not a primitive type"); + return VTy->getBitWidth(); + } default: return 0; } } @@ -130,6 +135,23 @@ return getScalarType()->getPrimitiveSizeInBits(); } +ScalableSize Type::getScalableSizeInBits() const { + if (auto *VTy = dyn_cast(this)) + return {VTy->getBitWidth(), VTy->isScalable()}; + + return {getPrimitiveSizeInBits(), false}; +} + +unsigned Type::getMinSizeInBits() const { + return getScalableSizeInBits().getMinSize(); +} + +unsigned Type::getFixedSizeInBits() const { + auto Size = getScalableSizeInBits(); + assert(!Size.isScalable() && "Request for a fixed size on a scalable vector"); + return Size.getMinSize(); +} + int Type::getFPMantissaWidth() const { if (auto *VTy = dyn_cast(this)) return VTy->getElementType()->getFPMantissaWidth(); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -307,6 +307,9 @@ setOperationAction(ISD::ROTL, MVT::i32, Expand); setOperationAction(ISD::ROTL, MVT::i64, Expand); for (MVT VT : MVT::vector_valuetypes()) { + // SVE types handled later + if (VT.isScalableVector()) + continue; setOperationAction(ISD::ROTL, VT, Expand); setOperationAction(ISD::ROTR, VT, Expand); } @@ -321,6 +324,9 @@ setOperationAction(ISD::SDIVREM, MVT::i32, Expand); setOperationAction(ISD::SDIVREM, MVT::i64, Expand); for (MVT VT : MVT::vector_valuetypes()) { + // SVE types handled later + if (VT.isScalableVector()) + continue; setOperationAction(ISD::SDIVREM, VT, Expand); setOperationAction(ISD::UDIVREM, VT, Expand); } @@ -753,6 +759,9 @@ // Likewise, narrowing and extending vector loads/stores aren't handled // directly. for (MVT VT : MVT::vector_valuetypes()) { + // SVE types handled later + if (VT.isScalableVector()) + continue; setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Expand); if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32) { @@ -769,6 +778,9 @@ setOperationAction(ISD::CTTZ, VT, Expand); for (MVT InnerVT : MVT::vector_valuetypes()) { + // Don't mix neon fixed length types with sve scalable types + if (InnerVT.isScalableVector()) + continue; setTruncStoreAction(VT, InnerVT, Expand); setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand); setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand); diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -135,6 +135,9 @@ } for (MVT VT : MVT::integer_vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; setLoadExtAction(ISD::EXTLOAD, VT, MVT::v2i8, Expand); setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i8, Expand); setLoadExtAction(ISD::ZEXTLOAD, VT, MVT::v2i8, Expand); diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -662,7 +662,12 @@ } for (MVT VT : MVT::vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; for (MVT InnerVT : MVT::vector_valuetypes()) { + if (InnerVT.isScalableVector()) + continue; setTruncStoreAction(VT, InnerVT, Expand); addAllExtLoads(VT, InnerVT, Expand); } @@ -868,6 +873,9 @@ for (MVT Ty : {MVT::v8i8, MVT::v4i8, MVT::v2i8, MVT::v4i16, MVT::v2i16, MVT::v2i32}) { for (MVT VT : MVT::integer_vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; setLoadExtAction(ISD::EXTLOAD, VT, Ty, Legal); setLoadExtAction(ISD::ZEXTLOAD, VT, Ty, Legal); setLoadExtAction(ISD::SEXTLOAD, VT, Ty, Legal); @@ -1011,6 +1019,9 @@ // ARM does not have ROTL. setOperationAction(ISD::ROTL, MVT::i32, Expand); for (MVT VT : MVT::vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; setOperationAction(ISD::ROTL, VT, Expand); setOperationAction(ISD::ROTR, VT, Expand); } diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp --- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp +++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp @@ -1435,11 +1435,17 @@ }; for (MVT VT : MVT::vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; + for (unsigned VectExpOp : VectExpOps) setOperationAction(VectExpOp, VT, Expand); // Expand all extending loads and truncating stores: for (MVT TargetVT : MVT::vector_valuetypes()) { + if (TargetVT.isScalableVector()) + continue; if (TargetVT == VT) continue; setLoadExtAction(ISD::EXTLOAD, TargetVT, VT, Expand); @@ -1850,7 +1856,7 @@ TargetLoweringBase::LegalizeTypeAction HexagonTargetLowering::getPreferredVectorAction(MVT VT) const { - if (VT.getVectorNumElements() == 1) + if (VT.getVectorNumElements() == 1 || VT.isScalableVector()) return TargetLoweringBase::TypeScalarizeVector; // Always widen vectors of i1. diff --git a/llvm/lib/Target/Hexagon/HexagonSubtarget.h b/llvm/lib/Target/Hexagon/HexagonSubtarget.h --- a/llvm/lib/Target/Hexagon/HexagonSubtarget.h +++ b/llvm/lib/Target/Hexagon/HexagonSubtarget.h @@ -228,7 +228,7 @@ } bool isHVXVectorType(MVT VecTy, bool IncludeBool = false) const { - if (!VecTy.isVector() || !useHVXOps()) + if (!VecTy.isVector() || !useHVXOps() || VecTy.isScalableVector()) return false; MVT ElemTy = VecTy.getVectorElementType(); if (!IncludeBool && ElemTy == MVT::i1) diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -45,6 +45,8 @@ bool HexagonTTIImpl::isTypeForHVX(Type *VecTy) const { assert(VecTy->isVectorTy()); + if (cast(VecTy)->isScalable()) + return false; // Avoid types like <2 x i32*>. if (!cast(VecTy)->getElementType()->isIntegerTy()) return false; diff --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp --- a/llvm/lib/Target/Mips/MipsISelLowering.cpp +++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp @@ -331,6 +331,9 @@ // Set LoadExtAction for f16 vectors to Expand for (MVT VT : MVT::fp_vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; MVT F16VT = MVT::getVectorVT(MVT::f16, VT.getVectorNumElements()); if (F16VT.isValid()) setLoadExtAction(ISD::EXTLOAD, VT, F16VT, Expand); diff --git a/llvm/lib/Target/Mips/MipsSEISelLowering.cpp b/llvm/lib/Target/Mips/MipsSEISelLowering.cpp --- a/llvm/lib/Target/Mips/MipsSEISelLowering.cpp +++ b/llvm/lib/Target/Mips/MipsSEISelLowering.cpp @@ -72,7 +72,12 @@ if (Subtarget.hasDSP() || Subtarget.hasMSA()) { // Expand all truncating stores and extending loads. for (MVT VT0 : MVT::vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT0.isScalableVector()) + continue; for (MVT VT1 : MVT::vector_valuetypes()) { + if (VT1.isScalableVector()) + continue; setTruncStoreAction(VT0, VT1, Expand); setLoadExtAction(ISD::SEXTLOAD, VT0, VT1, Expand); setLoadExtAction(ISD::ZEXTLOAD, VT0, VT1, Expand); diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp --- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp @@ -554,6 +554,10 @@ // First set operation action for all vector types to expand. Then we // will selectively turn on ones that can be effectively codegen'd. for (MVT VT : MVT::vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; + // add/sub are legal for all supported vector VT's. setOperationAction(ISD::ADD, VT, Legal); setOperationAction(ISD::SUB, VT, Legal); @@ -647,6 +651,9 @@ setOperationAction(ISD::ROTR, VT, Expand); for (MVT InnerVT : MVT::vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; setTruncStoreAction(VT, InnerVT, Expand); setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand); setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand); diff --git a/llvm/lib/Target/Sparc/SparcISelLowering.cpp b/llvm/lib/Target/Sparc/SparcISelLowering.cpp --- a/llvm/lib/Target/Sparc/SparcISelLowering.cpp +++ b/llvm/lib/Target/Sparc/SparcISelLowering.cpp @@ -1439,6 +1439,9 @@ } // Truncating/extending stores/loads are also not supported. for (MVT VT : MVT::integer_vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i32, Expand); setLoadExtAction(ISD::ZEXTLOAD, VT, MVT::v2i32, Expand); setLoadExtAction(ISD::EXTLOAD, VT, MVT::v2i32, Expand); diff --git a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp --- a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp +++ b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp @@ -295,6 +295,10 @@ setOperationAction(ISD::PREFETCH, MVT::Other, Custom); for (MVT VT : MVT::vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; + // Assume by default that all vector operations need to be expanded. for (unsigned Opcode = 0; Opcode < ISD::BUILTIN_OP_END; ++Opcode) if (getOperationAction(Opcode, VT) == Legal) @@ -302,6 +306,8 @@ // Likewise all truncating stores and extending loads. for (MVT InnerVT : MVT::vector_valuetypes()) { + if (InnerVT.isScalableVector()) + continue; setTruncStoreAction(VT, InnerVT, Expand); setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand); setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -205,8 +205,12 @@ for (auto T : {MVT::i8, MVT::i16, MVT::i32}) setOperationAction(ISD::SIGN_EXTEND_INREG, T, Action); } - for (auto T : MVT::integer_vector_valuetypes()) + for (auto T : MVT::integer_vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; setOperationAction(ISD::SIGN_EXTEND_INREG, T, Expand); + } // Dynamic stack allocation: use the default expansion. setOperationAction(ISD::STACKSAVE, MVT::Other, Expand); @@ -238,6 +242,9 @@ for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64, MVT::v4f32, MVT::v2f64}) { for (auto MemT : MVT::vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (MemT.isScalableVector()) + continue; if (MVT(T) != MemT) { setTruncStoreAction(T, MemT, Expand); for (auto Ext : {ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -717,6 +717,10 @@ // (for widening) or expand (for scalarization). Then we will selectively // turn on ones that can be effectively codegen'd. for (MVT VT : MVT::vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; + setOperationAction(ISD::SDIV, VT, Expand); setOperationAction(ISD::UDIV, VT, Expand); setOperationAction(ISD::SREM, VT, Expand); @@ -755,6 +759,9 @@ setOperationAction(ISD::ANY_EXTEND, VT, Expand); setOperationAction(ISD::SELECT_CC, VT, Expand); for (MVT InnerVT : MVT::vector_valuetypes()) { + if (VT.isScalableVector()) + continue; + setTruncStoreAction(InnerVT, VT, Expand); setLoadExtAction(ISD::SEXTLOAD, InnerVT, VT, Expand); @@ -909,6 +916,9 @@ // scalars) and extend in-register to a legal 128-bit vector type. For sext // loads these must work with a single scalar load. for (MVT VT : MVT::integer_vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; setLoadExtAction(ISD::EXTLOAD, VT, MVT::v2i8, Custom); setLoadExtAction(ISD::EXTLOAD, VT, MVT::v2i16, Custom); setLoadExtAction(ISD::EXTLOAD, VT, MVT::v2i32, Custom); @@ -1073,6 +1083,9 @@ // Avoid narrow result types when widening. The legal types are listed // in the next loop. for (MVT VT : MVT::integer_vector_valuetypes()) { + // Scalable vectors aren't supported on this backend. + if (VT.isScalableVector()) + continue; setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i8, Custom); setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i16, Custom); setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i32, Custom); diff --git a/llvm/test/Other/scalable-vectors-core-ir.ll b/llvm/test/Other/scalable-vectors-core-ir.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Other/scalable-vectors-core-ir.ll @@ -0,0 +1,393 @@ +; RUN: opt -S -verify < %s | FileCheck %s +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnu" + +;; Check supported instructions are accepted without dropping 'vscale'. +;; Same order as the LangRef + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Unary Operations +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + + +define @fneg( %val) { +; CHECK-LABEL: @fneg +; CHECK: %r = fneg %val +; CHECK-NEXT: ret %r + %r = fneg %val + ret %r +} + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Binary Operations +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +define @add( %a, %b) { +; CHECK-LABEL: @add +; CHECK: %r = add %a, %b +; CHECK-NEXT: ret %r + %r = add %a, %b + ret %r +} + +define @fadd( %a, %b) { +; CHECK-LABEL: @fadd +; CHECK: %r = fadd %a, %b +; CHECK-NEXT: ret %r + %r = fadd %a, %b + ret %r +} + +define @sub( %a, %b) { +; CHECK-LABEL: @sub +; CHECK: %r = sub %a, %b +; CHECK-NEXT: ret %r + %r = sub %a, %b + ret %r +} + +define @fsub( %a, %b) { +; CHECK-LABEL: @fsub +; CHECK: %r = fsub %a, %b +; CHECK-NEXT: ret %r + %r = fsub %a, %b + ret %r +} + +define @mul( %a, %b) { +; CHECK-LABEL: @mul +; CHECK: %r = mul %a, %b +; CHECK-NEXT: ret %r + %r = mul %a, %b + ret %r +} + +define @fmul( %a, %b) { +; CHECK-LABEL: @fmul +; CHECK: %r = fmul %a, %b +; CHECK-NEXT: ret %r + %r = fmul %a, %b + ret %r +} + +define @udiv( %a, %b) { +; CHECK-LABEL: @udiv +; CHECK: %r = udiv %a, %b +; CHECK-NEXT: ret %r + %r = udiv %a, %b + ret %r +} + +define @sdiv( %a, %b) { +; CHECK-LABEL: @sdiv +; CHECK: %r = sdiv %a, %b +; CHECK-NEXT: ret %r + %r = sdiv %a, %b + ret %r +} + +define @fdiv( %a, %b) { +; CHECK-LABEL: @fdiv +; CHECK: %r = fdiv %a, %b +; CHECK-NEXT: ret %r + %r = fdiv %a, %b + ret %r +} + +define @urem( %a, %b) { +; CHECK-LABEL: @urem +; CHECK: %r = urem %a, %b +; CHECK-NEXT: ret %r + %r = urem %a, %b + ret %r +} + +define @srem( %a, %b) { +; CHECK-LABEL: @srem +; CHECK: %r = srem %a, %b +; CHECK-NEXT: ret %r + %r = srem %a, %b + ret %r +} + +define @frem( %a, %b) { +; CHECK-LABEL: @frem +; CHECK: %r = frem %a, %b +; CHECK-NEXT: ret %r + %r = frem %a, %b + ret %r +} + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Bitwise Binary Operations +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +define @shl( %a, %b) { +; CHECK-LABEL: @shl +; CHECK: %r = shl %a, %b +; CHECK-NEXT: ret %r + %r = shl %a, %b + ret %r +} + +define @lshr( %a, %b) { +; CHECK-LABEL: @lshr +; CHECK: %r = lshr %a, %b +; CHECK-NEXT: ret %r + %r = lshr %a, %b + ret %r +} + +define @ashr( %a, %b) { +; CHECK-LABEL: @ashr +; CHECK: %r = ashr %a, %b +; CHECK-NEXT: ret %r + %r = ashr %a, %b + ret %r +} + +define @and( %a, %b) { +; CHECK-LABEL: @and +; CHECK: %r = and %a, %b +; CHECK-NEXT: ret %r + %r = and %a, %b + ret %r +} + +define @or( %a, %b) { +; CHECK-LABEL: @or +; CHECK: %r = or %a, %b +; CHECK-NEXT: ret %r + %r = or %a, %b + ret %r +} + +define @xor( %a, %b) { +; CHECK-LABEL: @xor +; CHECK: %r = xor %a, %b +; CHECK-NEXT: ret %r + %r = xor %a, %b + ret %r +} + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Vector Operations +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +define i64 @extractelement( %val) { +; CHECK-LABEL: @extractelement +; CHECK: %r = extractelement %val, i32 0 +; CHECK-NEXT: ret i64 %r + %r = extractelement %val, i32 0 + ret i64 %r +} + +define @insertelement( %vec, i8 %ins) { +; CHECK-LABEL: @insertelement +; CHECK: %r = insertelement %vec, i8 %ins, i32 0 +; CHECK-NEXT: ret %r + %r = insertelement %vec, i8 %ins, i32 0 + ret %r +} + +define @shufflevector(half %val) { +; CHECK-LABEL: @shufflevector +; CHECK: %insvec = insertelement undef, half %val, i32 0 +; CHECK-NEXT: %r = shufflevector %insvec, undef, zeroinitializer +; CHECK-NEXT: ret %r + %insvec = insertelement undef, half %val, i32 0 + %r = shufflevector %insvec, undef, zeroinitializer + ret %r +} + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Memory Access and Addressing Operations +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +define void @alloca() { +; CHECK-LABEL: @alloca +; CHECK: %vec = alloca +; CHECK-NEXT: ret void + %vec = alloca + ret void +} + +define @load(* %ptr) { +; CHECK-LABEL: @load +; CHECK: %r = load , * %ptr +; CHECK-NEXT: ret %r + %r = load , * %ptr + ret %r +} + +define void @store( %data, * %ptr) { +; CHECK-LABEL: @store +; CHECK: store %data, * %ptr +; CHECK-NEXT: ret void + store %data, * %ptr + ret void +} + +define * @getelementptr(* %base) { +; CHECK-LABEL: @getelementptr +; CHECK: %r = getelementptr , * %base, i64 0 +; CHECK-NEXT: ret * %r + %r = getelementptr , * %base, i64 0 + ret * %r +} + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Conversion Operations +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +define @truncto( %val) { +; CHECK-LABEL: @truncto +; CHECK: %r = trunc %val to +; CHECK-NEXT: ret %r + %r = trunc %val to + ret %r +} + +define @zextto( %val) { +; CHECK-LABEL: @zextto +; CHECK: %r = zext %val to +; CHECK-NEXT: ret %r + %r = zext %val to + ret %r +} + +define @sextto( %val) { +; CHECK-LABEL: @sextto +; CHECK: %r = sext %val to +; CHECK-NEXT: ret %r + %r = sext %val to + ret %r +} + +define @fptruncto( %val) { +; CHECK-LABEL: @fptruncto +; CHECK: %r = fptrunc %val to +; CHECK-NEXT: ret %r + %r = fptrunc %val to + ret %r +} + +define @fpextto( %val) { +; CHECK-LABEL: @fpextto +; CHECK: %r = fpext %val to +; CHECK-NEXT: ret %r + %r = fpext %val to + ret %r +} + +define @fptouito( %val) { +; CHECK-LABEL: @fptoui +; CHECK: %r = fptoui %val to +; CHECK-NEXT: ret %r + %r = fptoui %val to + ret %r +} + +define @fptosito( %val) { +; CHECK-LABEL: @fptosi +; CHECK: %r = fptosi %val to +; CHECK-NEXT: ret %r + %r = fptosi %val to + ret %r +} + +define @uitofpto( %val) { +; CHECK-LABEL: @uitofp +; CHECK: %r = uitofp %val to +; CHECK-NEXT: ret %r + %r = uitofp %val to + ret %r +} + +define @sitofpto( %val) { +; CHECK-LABEL: @sitofp +; CHECK: %r = sitofp %val to +; CHECK-NEXT: ret %r + %r = sitofp %val to + ret %r +} + +define @ptrtointto( %val) { +; CHECK-LABEL: @ptrtointto +; CHECK: %r = ptrtoint %val to +; CHECK-NEXT: ret %r + %r = ptrtoint %val to + ret %r +} + +define @inttoptrto( %val) { +; CHECK-LABEL: @inttoptrto +; CHECK: %r = inttoptr %val to +; CHECK-NEXT: ret %r + %r = inttoptr %val to + ret %r +} + +define @bitcastto( %a) { +; CHECK-LABEL: @bitcast +; CHECK: %r = bitcast %a to +; CHECK-NEXT: ret %r + %r = bitcast %a to + ret %r +} + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Other Operations +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +define @icmp( %a, %b) { +; CHECK-LABEL: @icmp +; CHECK: %r = icmp eq %a, %b +; CHECK-NEXT: ret %r + %r = icmp eq %a, %b + ret %r +} + +define @fcmp( %a, %b) { +; CHECK-LABEL: @fcmp +; CHECK: %r = fcmp une %a, %b +; CHECK-NEXT: ret %r + %r = fcmp une %a, %b + ret %r +} + +define @phi( %a, i32 %val) { +; CHECK-LABEL: @phi +; CHECK: %r = phi [ %a, %entry ], [ %added, %iszero ] +; CHECK-NEXT: ret %r +entry: + %cmp = icmp eq i32 %val, 0 + br i1 %cmp, label %iszero, label %end + +iszero: + %ins = insertelement undef, i8 1, i32 0 + %splatone = shufflevector %ins, undef, zeroinitializer + %added = add %a, %splatone + br label %end + +end: + %r = phi [ %a, %entry ], [ %added, %iszero ] + ret %r +} + +define @select( %a, %b, %sval) { +; CHECK-LABEL: @select +; CHECK: %r = select %sval, %a, %b +; CHECK-NEXT: ret %r + %r = select %sval, %a, %b + ret %r +} + +declare @callee() +define @call( %val) { +; CHECK-LABEL: @call +; CHECK: %r = call @callee( %val) +; CHECK-NEXT: ret %r + %r = call @callee( %val) + ret %r +} \ No newline at end of file diff --git a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp --- a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp +++ b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp @@ -120,4 +120,63 @@ ScV4Float64Ty->getElementType()); } +TEST(ScalableVectorMVTsTest, SizeQueries) { + LLVMContext Ctx; + + EVT nxv4i32 = EVT::getVectorVT(Ctx, MVT::i32, 4, /*Scalable=*/ true); + EVT nxv2i32 = EVT::getVectorVT(Ctx, MVT::i32, 2, /*Scalable=*/ true); + EVT nxv2i64 = EVT::getVectorVT(Ctx, MVT::i64, 2, /*Scalable=*/ true); + EVT nxv2f64 = EVT::getVectorVT(Ctx, MVT::f64, 2, /*Scalable=*/ true); + + EVT v4i32 = EVT::getVectorVT(Ctx, MVT::i32, 4); + EVT v2i32 = EVT::getVectorVT(Ctx, MVT::i32, 2); + EVT v2i64 = EVT::getVectorVT(Ctx, MVT::i64, 2); + EVT v2f64 = EVT::getVectorVT(Ctx, MVT::f64, 2); + + // Check equivalence and ordering on scalable types. + EXPECT_EQ(nxv4i32.getScalableSizeInBits(), nxv2i64.getScalableSizeInBits()); + EXPECT_EQ(nxv2f64.getScalableSizeInBits(), nxv2i64.getScalableSizeInBits()); + EXPECT_NE(nxv2i32.getScalableSizeInBits(), nxv4i32.getScalableSizeInBits()); + EXPECT_LT(nxv2i32.getScalableSizeInBits(), nxv2i64.getScalableSizeInBits()); + EXPECT_LE(nxv4i32.getScalableSizeInBits(), nxv2i64.getScalableSizeInBits()); + EXPECT_GT(nxv4i32.getScalableSizeInBits(), nxv2i32.getScalableSizeInBits()); + EXPECT_GE(nxv2i64.getScalableSizeInBits(), nxv4i32.getScalableSizeInBits()); + + // Check equivalence and ordering on fixed types. + EXPECT_EQ(v4i32.getScalableSizeInBits(), v2i64.getScalableSizeInBits()); + EXPECT_EQ(v2f64.getScalableSizeInBits(), v2i64.getScalableSizeInBits()); + EXPECT_NE(v2i32.getScalableSizeInBits(), v4i32.getScalableSizeInBits()); + EXPECT_LT(v2i32.getScalableSizeInBits(), v2i64.getScalableSizeInBits()); + EXPECT_LE(v4i32.getScalableSizeInBits(), v2i64.getScalableSizeInBits()); + EXPECT_GT(v4i32.getScalableSizeInBits(), v2i32.getScalableSizeInBits()); + EXPECT_GE(v2i64.getScalableSizeInBits(), v4i32.getScalableSizeInBits()); + + // Check that scalable and non-scalable types with the same minimum size + // are not considered equal. + ASSERT_TRUE(v4i32.getScalableSizeInBits() != nxv4i32.getScalableSizeInBits()); + ASSERT_FALSE(v2i64.getScalableSizeInBits() == + nxv2f64.getScalableSizeInBits()); + + // Check that we can obtain a known-exact size from a non-scalable type. + EXPECT_EQ(v4i32.getSizeInBits(), 128U); + EXPECT_EQ(v2i64.getScalableSizeInBits().getFixedSize(), 128U); + + // Check that we can query the known minimum size for both scalable and + // fixed length types. + EXPECT_EQ(nxv2i32.getMinSizeInBits(), 64U); + EXPECT_EQ(nxv2f64.getScalableSizeInBits().getMinSize(), 128U); + EXPECT_EQ(v2i32.getMinSizeInBits(), nxv2i32.getMinSizeInBits()); + + // Check scalable property. + ASSERT_FALSE(v4i32.getScalableSizeInBits().isScalable()); + ASSERT_TRUE(nxv4i32.getScalableSizeInBits().isScalable()); + + // Check convenience size scaling methods. + EXPECT_EQ(v2i32.getScalableSizeInBits() * 2, v4i32.getScalableSizeInBits()); + EXPECT_EQ(2 * nxv2i32.getScalableSizeInBits(), + nxv4i32.getScalableSizeInBits()); + EXPECT_EQ(nxv2f64.getScalableSizeInBits() / 2, + nxv2i32.getScalableSizeInBits()); +} + } // end anonymous namespace diff --git a/llvm/unittests/IR/VectorTypesTest.cpp b/llvm/unittests/IR/VectorTypesTest.cpp --- a/llvm/unittests/IR/VectorTypesTest.cpp +++ b/llvm/unittests/IR/VectorTypesTest.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/ScalableSize.h" @@ -160,5 +161,117 @@ EXPECT_EQ(EltCnt.Min, 8U); ASSERT_TRUE(EltCnt.Scalable); } +TEST(VectorTypesTest, FixedLenComparisons) { + LLVMContext Ctx; + DataLayout DL(""); + + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *Int64Ty = Type::getInt64Ty(Ctx); + + VectorType *V2Int32Ty = VectorType::get(Int32Ty, 2); + VectorType *V4Int32Ty = VectorType::get(Int32Ty, 4); + + VectorType *V2Int64Ty = VectorType::get(Int64Ty, 2); + + ScalableSize V2I32Len = V2Int32Ty->getScalableSizeInBits(); + EXPECT_EQ(V2I32Len.getMinSize(), 64U); + EXPECT_FALSE(V2I32Len.isScalable()); + + EXPECT_LT(V2Int32Ty->getScalableSizeInBits(), + V4Int32Ty->getScalableSizeInBits()); + EXPECT_GT(V2Int64Ty->getScalableSizeInBits(), + V2Int32Ty->getScalableSizeInBits()); + EXPECT_EQ(V4Int32Ty->getScalableSizeInBits(), + V2Int64Ty->getScalableSizeInBits()); + EXPECT_NE(V2Int32Ty->getScalableSizeInBits(), + V2Int64Ty->getScalableSizeInBits()); + + // Check that a fixed-only comparison works for fixed size vectors. + EXPECT_EQ(V2Int64Ty->getFixedSizeInBits(), + V4Int32Ty->getFixedSizeInBits()); + + // Check the DataLayout interfaces. + EXPECT_EQ(DL.getScalableTypeSizeInBits(V2Int64Ty), + DL.getScalableTypeSizeInBits(V4Int32Ty)); + EXPECT_EQ(DL.getMinTypeSizeInBits(V2Int32Ty), 64U); + EXPECT_EQ(DL.getTypeSizeInBits(V2Int64Ty), 128U); + EXPECT_EQ(DL.getScalableTypeStoreSize(V2Int64Ty), + DL.getScalableTypeStoreSize(V4Int32Ty)); + EXPECT_NE(DL.getScalableTypeStoreSizeInBits(V2Int32Ty), + DL.getScalableTypeStoreSizeInBits(V2Int64Ty)); + EXPECT_EQ(DL.getMinTypeStoreSizeInBits(V2Int32Ty), 64U); + EXPECT_EQ(DL.getMinTypeStoreSize(V2Int64Ty), 16U); + EXPECT_EQ(DL.getScalableTypeAllocSize(V4Int32Ty), + DL.getScalableTypeAllocSize(V2Int64Ty)); + EXPECT_NE(DL.getScalableTypeAllocSizeInBits(V2Int32Ty), + DL.getScalableTypeAllocSizeInBits(V2Int64Ty)); + EXPECT_EQ(DL.getMinTypeAllocSizeInBits(V4Int32Ty), 128U); + EXPECT_EQ(DL.getMinTypeAllocSize(V2Int32Ty), 8U); + ASSERT_TRUE(DL.typeSizeEqualsStoreSize(V4Int32Ty)); +} + +TEST(VectorTypesTest, ScalableComparisons) { + LLVMContext Ctx; + DataLayout DL(""); + + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *Int64Ty = Type::getInt64Ty(Ctx); + + VectorType *ScV2Int32Ty = VectorType::get(Int32Ty, {2, true}); + VectorType *ScV4Int32Ty = VectorType::get(Int32Ty, {4, true}); + + VectorType *ScV2Int64Ty = VectorType::get(Int64Ty, {2, true}); + + ScalableSize ScV2I32Len = ScV2Int32Ty->getScalableSizeInBits(); + EXPECT_EQ(ScV2I32Len.getMinSize(), 64U); + EXPECT_TRUE(ScV2I32Len.isScalable()); + + EXPECT_LT(ScV2Int32Ty->getScalableSizeInBits(), + ScV4Int32Ty->getScalableSizeInBits()); + EXPECT_GT(ScV2Int64Ty->getScalableSizeInBits(), + ScV2Int32Ty->getScalableSizeInBits()); + EXPECT_EQ(ScV4Int32Ty->getScalableSizeInBits(), + ScV2Int64Ty->getScalableSizeInBits()); + EXPECT_NE(ScV2Int32Ty->getScalableSizeInBits(), + ScV2Int64Ty->getScalableSizeInBits()); + + // Check the DataLayout interfaces. + EXPECT_EQ(DL.getScalableTypeSizeInBits(ScV2Int64Ty), + DL.getScalableTypeSizeInBits(ScV4Int32Ty)); + EXPECT_EQ(DL.getMinTypeSizeInBits(ScV2Int32Ty), 64U); + EXPECT_EQ(DL.getScalableTypeStoreSize(ScV2Int64Ty), + DL.getScalableTypeStoreSize(ScV4Int32Ty)); + EXPECT_NE(DL.getScalableTypeStoreSizeInBits(ScV2Int32Ty), + DL.getScalableTypeStoreSizeInBits(ScV2Int64Ty)); + EXPECT_EQ(DL.getMinTypeStoreSizeInBits(ScV2Int32Ty), 64U); + EXPECT_EQ(DL.getMinTypeStoreSize(ScV2Int64Ty), 16U); + EXPECT_EQ(DL.getScalableTypeAllocSize(ScV4Int32Ty), + DL.getScalableTypeAllocSize(ScV2Int64Ty)); + EXPECT_NE(DL.getScalableTypeAllocSizeInBits(ScV2Int32Ty), + DL.getScalableTypeAllocSizeInBits(ScV2Int64Ty)); + EXPECT_EQ(DL.getMinTypeAllocSizeInBits(ScV4Int32Ty), 128U); + EXPECT_EQ(DL.getMinTypeAllocSize(ScV2Int32Ty), 8U); + ASSERT_TRUE(DL.typeSizeEqualsStoreSize(ScV4Int32Ty)); +} + +TEST(VectorTypesTest, CrossComparisons) { + LLVMContext Ctx; + + Type *Int32Ty = Type::getInt32Ty(Ctx); + + VectorType *V4Int32Ty = VectorType::get(Int32Ty, {4, false}); + VectorType *ScV4Int32Ty = VectorType::get(Int32Ty, {4, true}); + + // Even though the minimum size is the same, a scalable vector could be + // larger so we don't consider them to be the same size. + EXPECT_NE(V4Int32Ty->getScalableSizeInBits(), + ScV4Int32Ty->getScalableSizeInBits()); + // If we are only checking the minimum, then they are the same size. + EXPECT_EQ(V4Int32Ty->getMinSizeInBits(), + ScV4Int32Ty->getMinSizeInBits()); + + // We can't use ordering comparisons (<,<=,>,>=) between scalable and + // non-scalable vector sizes. +} } // end anonymous namespace diff --git a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp --- a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp +++ b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ScalableSize.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include @@ -503,9 +504,16 @@ } auto LT = [](MVT A, MVT B) -> bool { + // Always treat non-scalable MVTs as smaller than scalable MVTs for the + // purposes of ordering. + if (A.isScalableVector() && !B.isScalableVector()) + return false; + if (!A.isScalableVector() && B.isScalableVector()) + return true; + return A.getScalarSizeInBits() < B.getScalarSizeInBits() || (A.getScalarSizeInBits() == B.getScalarSizeInBits() && - A.getSizeInBits() < B.getSizeInBits()); + A.getScalableSizeInBits() < B.getScalableSizeInBits()); }; auto LE = [<](MVT A, MVT B) -> bool { // This function is used when removing elements: when a vector is compared @@ -513,8 +521,13 @@ if (A.isVector() != B.isVector()) return false; + // We also don't want to remove elements when they're both vectors with the + // same minimum number of lanes, but one is scalable and the other not. + if (A.isScalableVector() != B.isScalableVector()) + return false; + return LT(A, B) || (A.getScalarSizeInBits() == B.getScalarSizeInBits() && - A.getSizeInBits() == B.getSizeInBits()); + A.getScalableSizeInBits() == B.getScalableSizeInBits()); }; for (unsigned M : Modes) {