diff --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h --- a/llvm/include/llvm/CodeGen/ValueTypes.h +++ b/llvm/include/llvm/CodeGen/ValueTypes.h @@ -18,6 +18,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/MachineValueType.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/ScalableSize.h" #include #include #include @@ -209,7 +210,7 @@ /// Return true if the bit size is a multiple of 8. bool isByteSized() const { - return (getSizeInBits() & 7) == 0; + return (getKnownMinSizeInBits() & 7) == 0; } /// Return true if the size is a power-of-two number of bytes. @@ -221,31 +222,31 @@ /// Return true if this has the same number of bits as VT. bool bitsEq(EVT VT) const { if (EVT::operator==(VT)) return true; - return getSizeInBits() == VT.getSizeInBits(); + return getScalableSizeInBits() == VT.getScalableSizeInBits(); } /// Return true if this has more bits than VT. bool bitsGT(EVT VT) const { if (EVT::operator==(VT)) return false; - return getSizeInBits() > VT.getSizeInBits(); + return getScalableSizeInBits() > VT.getScalableSizeInBits(); } /// Return true if this has no less bits than VT. bool bitsGE(EVT VT) const { if (EVT::operator==(VT)) return true; - return getSizeInBits() >= VT.getSizeInBits(); + return getScalableSizeInBits() >= VT.getScalableSizeInBits(); } /// Return true if this has less bits than VT. bool bitsLT(EVT VT) const { if (EVT::operator==(VT)) return false; - return getSizeInBits() < VT.getSizeInBits(); + return getScalableSizeInBits() < VT.getScalableSizeInBits(); } /// Return true if this has no more bits than VT. bool bitsLE(EVT VT) const { if (EVT::operator==(VT)) return true; - return getSizeInBits() <= VT.getSizeInBits(); + return getScalableSizeInBits() <= VT.getScalableSizeInBits(); } /// Return the SimpleValueType held in the specified simple EVT. @@ -287,29 +288,79 @@ return {getExtendedVectorNumElements(), false}; } - /// Return the size of the specified value type in bits. + /// Return the size of the specified value type in bits. An assert will + /// occur if this is called on a scalable vector type. unsigned getSizeInBits() const { if (isSimple()) return V.getSizeInBits(); return getExtendedSizeInBits(); } + /// Returns the size of the specified value type as a minimum number of + /// bits and a boolean indicating whether the runtime size is exactly that + /// size (if false) or if it's an integer multiple of that minimum (true). + ScalableSize getScalableSizeInBits() const { + if (isSimple()) + return V.getScalableSizeInBits(); + return getScalableExtendedSizeInBits(); + } + + /// Returns the size of the type in bits. If the type is scalable, this + /// quantity represents the minimum size. If the type is not scalable, + /// it represents the exact size. + unsigned getKnownMinSizeInBits() const { + if (isSimple()) + return V.getKnownMinSizeInBits(); + return getKnownMinExtendedSizeInBits(); + } + unsigned getScalarSizeInBits() const { return getScalarType().getSizeInBits(); } /// Return the number of bytes overwritten by a store of the specified value - /// type. + /// type. An assert will occur if this is called on a scalable vector type. unsigned getStoreSize() const { return (getSizeInBits() + 7) / 8; } - /// Return the number of bits overwritten by a store of the specified value - /// type. + /// Returns the minimum number of bytes overwritten by a store of the + /// specified value type, along with a boolean indicating whether the + /// runtime size written to is exactly that size (if false) or if it's an + /// integer multiple of that minimum (true). + ScalableSize getScalableStoreSize() const { + ScalableSize SizeInBits = getScalableSizeInBits(); + return {(SizeInBits.getKnownMinSize() + 7) / 8, SizeInBits.isScalable()}; + } + + /// Returns the number of bytes overwritten by a store of the specified + /// value type. If the type is scalable, this quantity represents the + /// minimum size. If not scalable, it represents the exact size. + unsigned getKnownMinStoreSize() const { + return (getKnownMinSizeInBits() + 7) / 8; + } + + /// Returns the number of bits overwritten by a store of the specified value + /// type. An assert will occur if this is called on a scalable vector type. unsigned getStoreSizeInBits() const { return getStoreSize() * 8; } + /// Returns the minimum number of bits overwritten by a store of the + /// specified value type, along with a boolean indicating whether the + /// runtime size written to is exactly that size (if false) or if it's an + /// integer multiple of that minimum (true). + ScalableSize getScalableStoreSizeInBits() const { + return getScalableStoreSize() * 8; + } + + /// Returns the number of bits overwritten by a store of the specified value + /// type. If the type is scalable, this quantity represents the minimum + /// size. If not scalable, it represents the exact size. + unsigned getKnownMinStoreSizeInBits() const { + return getKnownMinStoreSize() * 8; + } + /// Rounds the bit-width of the given integer EVT up to the nearest power of /// two (and at least to eight), and returns the integer EVT with that /// number of bits. @@ -429,6 +480,8 @@ EVT getExtendedVectorElementType() const; unsigned getExtendedVectorNumElements() const LLVM_READONLY; unsigned getExtendedSizeInBits() const LLVM_READONLY; + ScalableSize getScalableExtendedSizeInBits() const LLVM_READONLY; + unsigned getKnownMinExtendedSizeInBits() const LLVM_READONLY; }; } // end namespace llvm diff --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h --- a/llvm/include/llvm/Support/MachineValueType.h +++ b/llvm/include/llvm/Support/MachineValueType.h @@ -656,7 +656,68 @@ return { getVectorNumElements(), isScalableVector() }; } + /// Returns the size of the specified MVT as a minimum number of bits and a + /// boolean indicating whether the runtime size is exactly that + /// size (if false) or if it's an integer multiple of that minimum (true). + ScalableSize getScalableSizeInBits() const { + switch (SimpleTy) { + default: return { getSizeInBits(), false }; + case nxv1i1: return { 1U, true }; + case nxv2i1: return { 2U, true }; + case nxv4i1: return { 4U, true }; + case nxv1i8: + case nxv8i1: return { 8U, true }; + case nxv16i1: + case nxv2i8: + case nxv1i16: return { 16U, true }; + case nxv32i1: + case nxv4i8: + case nxv2i16: + case nxv1i32: + case nxv2f16: + case nxv1f32: return { 32U, true }; + case nxv8i8: + case nxv4i16: + case nxv2i32: + case nxv1i64: + case nxv4f16: + case nxv2f32: + case nxv1f64: return { 64U, true }; + case nxv16i8: + case nxv8i16: + case nxv4i32: + case nxv2i64: + case nxv8f16: + case nxv4f32: + case nxv2f64: return { 128U, true }; + case nxv32i8: + case nxv16i16: + case nxv8i32: + case nxv4i64: + case nxv8f32: + case nxv4f64: return { 256U, true }; + case nxv32i16: + case nxv16i32: + case nxv8i64: + case nxv16f32: + case nxv8f64: return { 512U, true }; + case nxv32i32: + case nxv16i64: return { 1024U, true }; + case nxv32i64: return { 2048U, true }; + } + } + + /// Returns the size of the MVT in bits. If the type is scalable, this + /// quantity represents the minimum size. If the type is not scalable, + /// it represents the exact size. + unsigned getKnownMinSizeInBits() const { + return getScalableSizeInBits().getKnownMinSize(); + } + + /// Returns the size of the specified MVT in bits. + /// An assert will occur if this is called on a scalable vector type. unsigned getSizeInBits() const { + assert(!isScalableVector() && "getSizeInBits called on scalable vector"); switch (SimpleTy) { default: llvm_unreachable("getSizeInBits called on extended MVT."); @@ -676,25 +737,17 @@ case Metadata: llvm_unreachable("Value type is metadata."); case i1: - case v1i1: - case nxv1i1: return 1; - case v2i1: - case nxv2i1: return 2; - case v4i1: - case nxv4i1: return 4; + case v1i1: return 1; + case v2i1: return 2; + case v4i1: return 4; case i8 : case v1i8: - case v8i1: - case nxv1i8: - case nxv8i1: return 8; + case v8i1: return 8; case i16 : case f16: case v16i1: case v2i8: - case v1i16: - case nxv16i1: - case nxv2i8: - case nxv1i16: return 16; + case v1i16: return 16; case f32 : case i32 : case v32i1: @@ -702,13 +755,7 @@ case v2i16: case v2f16: case v1f32: - case v1i32: - case nxv32i1: - case nxv4i8: - case nxv2i16: - case nxv1i32: - case nxv2f16: - case nxv1f32: return 32; + case v1i32: return 32; case v3i16: case v3f16: return 48; case x86mmx: @@ -721,14 +768,7 @@ case v1i64: case v4f16: case v2f32: - case v1f64: - case nxv8i8: - case nxv4i16: - case nxv2i32: - case nxv1i64: - case nxv4f16: - case nxv2f32: - case nxv1f64: return 64; + case v1f64: return 64; case f80 : return 80; case v3i32: case v3f32: return 96; @@ -743,14 +783,7 @@ case v1i128: case v8f16: case v4f32: - case v2f64: - case nxv16i8: - case nxv8i16: - case nxv4i32: - case nxv2i64: - case nxv8f16: - case nxv4f32: - case nxv2f64: return 128; + case v2f64: return 128; case v5i32: case v5f32: return 160; case v32i8: @@ -759,13 +792,7 @@ case v4i64: case v16f16: case v8f32: - case v4f64: - case nxv32i8: - case nxv16i16: - case nxv8i32: - case nxv4i64: - case nxv8f32: - case nxv4f64: return 256; + case v4f64: return 256; case v512i1: case v64i8: case v32i16: @@ -773,26 +800,18 @@ case v8i64: case v32f16: case v16f32: - case v8f64: - case nxv32i16: - case nxv16i32: - case nxv8i64: - case nxv16f32: - case nxv8f64: return 512; + case v8f64: return 512; case v1024i1: case v128i8: case v64i16: case v32i32: case v16i64: - case v32f32: - case nxv32i32: - case nxv16i64: return 1024; + case v32f32: return 1024; case v256i8: case v128i16: case v64i32: case v32i64: - case v64f32: - case nxv32i64: return 2048; + case v64f32: return 2048; case v128i32: case v128f32: return 4096; case v256i32: @@ -817,30 +836,62 @@ return (getSizeInBits() + 7) / 8; } + /// Returns the minimum number of bytes overwritten by a store of the + /// specified value type, along with a boolean indicating whether the + /// runtime size written to is exactly that size (if false) or if it's an + /// integer multiple of that minimum (true). + ScalableSize getScalableStoreSize() const { + ScalableSize SizeInBits = getScalableSizeInBits(); + return {(SizeInBits.getKnownMinSize() + 7) / 8, SizeInBits.isScalable()}; + } + + /// Returns the number of bytes overwritten by a store of the specified + /// value type. If the type is scalable, this quantity represents the + /// minimum size. If not scalable, it represents the exact size. + unsigned getKnownMinStoreSize() const { + return getScalableStoreSize().getKnownMinSize(); + } + /// Return the number of bits overwritten by a store of the specified value /// type. unsigned getStoreSizeInBits() const { return getStoreSize() * 8; } + /// Returns the minimum number of bits overwritten by a store of the + /// specified value type, along with a boolean indicating whether the + /// runtime size written to is exactly that size (if false) or if it's an + /// integer multiple of that minimum (true). + ScalableSize getScalableStoreSizeInBits() const { + ScalableSize SizeInBytes = getScalableStoreSize(); + return { SizeInBytes.getKnownMinSize() * 8, SizeInBytes.isScalable() }; + } + + /// Returns the number of bits overwritten by a store of the specified + /// value type. If the type is scalable, this quantity represents the + /// minimum size. If not scalable, it represents the exact size. + unsigned getKnownMinStoreSizeInBits() const { + return getScalableStoreSizeInBits().getKnownMinSize(); + } + /// Return true if this has more bits than VT. bool bitsGT(MVT VT) const { - return getSizeInBits() > VT.getSizeInBits(); + return getScalableSizeInBits() > VT.getScalableSizeInBits(); } /// Return true if this has no less bits than VT. bool bitsGE(MVT VT) const { - return getSizeInBits() >= VT.getSizeInBits(); + return getScalableSizeInBits() >= VT.getScalableSizeInBits(); } /// Return true if this has less bits than VT. bool bitsLT(MVT VT) const { - return getSizeInBits() < VT.getSizeInBits(); + return getScalableSizeInBits() < VT.getScalableSizeInBits(); } /// Return true if this has no more bits than VT. bool bitsLE(MVT VT) const { - return getSizeInBits() <= VT.getSizeInBits(); + return getScalableSizeInBits() <= VT.getScalableSizeInBits(); } static MVT getFloatingPointVT(unsigned BitWidth) { diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -220,11 +220,13 @@ ForCodeSize = DAG.getMachineFunction().getFunction().hasOptSize(); MaximumLegalStoreInBits = 0; + // We use the minimum store size here, since that's all we can guarantee + // for the scalable vector types. for (MVT VT : MVT::all_valuetypes()) if (EVT(VT).isSimple() && VT != MVT::Other && TLI.isTypeLegal(EVT(VT)) && - VT.getSizeInBits() >= MaximumLegalStoreInBits) - MaximumLegalStoreInBits = VT.getSizeInBits(); + VT.getKnownMinSizeInBits() >= MaximumLegalStoreInBits) + MaximumLegalStoreInBits = VT.getKnownMinSizeInBits(); } void ConsiderForPruning(SDNode *N) { diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -23,6 +23,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Support/ScalableSize.h" using namespace llvm; #define DEBUG_TYPE "legalize-types" @@ -4632,7 +4633,8 @@ unsigned Width, EVT WidenVT, unsigned Align = 0, unsigned WidenEx = 0) { EVT WidenEltVT = WidenVT.getVectorElementType(); - unsigned WidenWidth = WidenVT.getSizeInBits(); + const bool Scalable = WidenVT.isScalableVector(); + unsigned WidenWidth = WidenVT.getKnownMinSizeInBits(); unsigned WidenEltWidth = WidenEltVT.getSizeInBits(); unsigned AlignInBits = Align*8; @@ -4643,23 +4645,27 @@ // See if there is larger legal integer than the element type to load/store. unsigned VT; - for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE; - VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) { - EVT MemVT((MVT::SimpleValueType) VT); - unsigned MemVTWidth = MemVT.getSizeInBits(); - if (MemVT.getSizeInBits() <= WidenEltWidth) - break; - auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT); - if ((Action == TargetLowering::TypeLegal || - Action == TargetLowering::TypePromoteInteger) && - (WidenWidth % MemVTWidth) == 0 && - isPowerOf2_32(WidenWidth / MemVTWidth) && - (MemVTWidth <= Width || - (Align!=0 && MemVTWidth<=AlignInBits && MemVTWidth<=Width+WidenEx))) { - if (MemVTWidth == WidenWidth) - return MemVT; - RetVT = MemVT; - break; + // Don't bother looking for an integer type if the vector is scalable, skip + // to vector types. + if (!Scalable) { + for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE; + VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) { + EVT MemVT((MVT::SimpleValueType) VT); + unsigned MemVTWidth = MemVT.getSizeInBits(); + if (MemVT.getSizeInBits() <= WidenEltWidth) + break; + auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT); + if ((Action == TargetLowering::TypeLegal || + Action == TargetLowering::TypePromoteInteger) && + (WidenWidth % MemVTWidth) == 0 && + isPowerOf2_32(WidenWidth / MemVTWidth) && + (MemVTWidth <= Width || + (Align!=0 && MemVTWidth<=AlignInBits && MemVTWidth<=Width+WidenEx))) { + if (MemVTWidth == WidenWidth) + return MemVT; + RetVT = MemVT; + break; + } } } @@ -4668,7 +4674,10 @@ for (VT = (unsigned)MVT::LAST_VECTOR_VALUETYPE; VT >= (unsigned)MVT::FIRST_VECTOR_VALUETYPE; --VT) { EVT MemVT = (MVT::SimpleValueType) VT; - unsigned MemVTWidth = MemVT.getSizeInBits(); + // Skip vector MVTs which don't match the scalable property of WidenVT. + if (Scalable != MemVT.isScalableVector()) + continue; + unsigned MemVTWidth = MemVT.getKnownMinSizeInBits(); auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT); if ((Action == TargetLowering::TypeLegal || Action == TargetLowering::TypePromoteInteger) && diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -9199,8 +9199,8 @@ for (unsigned j = 0; j != NumParts; ++j) { // if it isn't first piece, alignment must be 1 ISD::OutputArg MyFlags(Flags, Parts[j].getValueType(), VT, - i < CLI.NumFixedArgs, - i, j*Parts[j].getValueType().getStoreSize()); + i < CLI.NumFixedArgs, + i, j*Parts[j].getValueType().getKnownMinStoreSize()); if (NumParts > 1 && j == 0) MyFlags.Flags.setSplit(); else if (j != 0) { @@ -9669,8 +9669,11 @@ unsigned NumRegs = TLI->getNumRegistersForCallingConv( *CurDAG->getContext(), F.getCallingConv(), VT); for (unsigned i = 0; i != NumRegs; ++i) { + // For scalable vectors, use the minimum size; individual targets + // are responsible for handling scalable vector arguments and + // return values. ISD::InputArg MyFlags(Flags, RegisterVT, VT, isArgValueUsed, - ArgNo, PartBase+i*RegisterVT.getStoreSize()); + ArgNo, PartBase+i*RegisterVT.getKnownMinStoreSize()); if (NumRegs > 1 && i == 0) MyFlags.Flags.setSplit(); // if it isn't first piece, alignment must be 1 @@ -9683,7 +9686,7 @@ } if (NeedsRegBlock && Value == NumValues - 1) Ins[Ins.size() - 1].Flags.setInConsecutiveRegsLast(); - PartBase += VT.getStoreSize(); + PartBase += VT.getKnownMinStoreSize(); } } diff --git a/llvm/lib/CodeGen/ValueTypes.cpp b/llvm/lib/CodeGen/ValueTypes.cpp --- a/llvm/lib/CodeGen/ValueTypes.cpp +++ b/llvm/lib/CodeGen/ValueTypes.cpp @@ -11,6 +11,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Type.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ScalableSize.h" using namespace llvm; EVT EVT::changeExtendedTypeToInteger() const { @@ -105,11 +106,26 @@ assert(isExtended() && "Type is not extended!"); if (IntegerType *ITy = dyn_cast(LLVMTy)) return ITy->getBitWidth(); - if (VectorType *VTy = dyn_cast(LLVMTy)) + if (VectorType *VTy = dyn_cast(LLVMTy)) { + assert(!VTy->isScalable() && + "Size of scalable type cannot be represented by a scalar."); return VTy->getBitWidth(); + } llvm_unreachable("Unrecognized extended type!"); } +ScalableSize EVT::getScalableExtendedSizeInBits() const { + assert(isExtended() && "Type is not extended!"); + if (VectorType *VTy = dyn_cast(LLVMTy)) + return VTy->getScalableSizeInBits(); + return { getExtendedSizeInBits(), false }; +} + +unsigned EVT::getKnownMinExtendedSizeInBits() const { + assert(isExtended() && "Type is not extended!"); + return getScalableExtendedSizeInBits().getKnownMinSize(); +} + /// getEVTString - This function returns value type as a string, e.g. "i32". std::string EVT::getEVTString() const { switch (V.SimpleTy) { diff --git a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp --- a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp +++ b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp @@ -120,4 +120,63 @@ ScV4Float64Ty->getElementType()); } +TEST(ScalableVectorMVTsTest, SizeQueries) { + LLVMContext Ctx; + + EVT nxv4i32 = EVT::getVectorVT(Ctx, MVT::i32, 4, /*Scalable=*/ true); + EVT nxv2i32 = EVT::getVectorVT(Ctx, MVT::i32, 2, /*Scalable=*/ true); + EVT nxv2i64 = EVT::getVectorVT(Ctx, MVT::i64, 2, /*Scalable=*/ true); + EVT nxv2f64 = EVT::getVectorVT(Ctx, MVT::f64, 2, /*Scalable=*/ true); + + EVT v4i32 = EVT::getVectorVT(Ctx, MVT::i32, 4); + EVT v2i32 = EVT::getVectorVT(Ctx, MVT::i32, 2); + EVT v2i64 = EVT::getVectorVT(Ctx, MVT::i64, 2); + EVT v2f64 = EVT::getVectorVT(Ctx, MVT::f64, 2); + + // Check equivalence and ordering on scalable types. + EXPECT_EQ(nxv4i32.getScalableSizeInBits(), nxv2i64.getScalableSizeInBits()); + EXPECT_EQ(nxv2f64.getScalableSizeInBits(), nxv2i64.getScalableSizeInBits()); + EXPECT_NE(nxv2i32.getScalableSizeInBits(), nxv4i32.getScalableSizeInBits()); + EXPECT_LT(nxv2i32.getScalableSizeInBits(), nxv2i64.getScalableSizeInBits()); + EXPECT_LE(nxv4i32.getScalableSizeInBits(), nxv2i64.getScalableSizeInBits()); + EXPECT_GT(nxv4i32.getScalableSizeInBits(), nxv2i32.getScalableSizeInBits()); + EXPECT_GE(nxv2i64.getScalableSizeInBits(), nxv4i32.getScalableSizeInBits()); + + // Check equivalence and ordering on fixed types. + EXPECT_EQ(v4i32.getScalableSizeInBits(), v2i64.getScalableSizeInBits()); + EXPECT_EQ(v2f64.getScalableSizeInBits(), v2i64.getScalableSizeInBits()); + EXPECT_NE(v2i32.getScalableSizeInBits(), v4i32.getScalableSizeInBits()); + EXPECT_LT(v2i32.getScalableSizeInBits(), v2i64.getScalableSizeInBits()); + EXPECT_LE(v4i32.getScalableSizeInBits(), v2i64.getScalableSizeInBits()); + EXPECT_GT(v4i32.getScalableSizeInBits(), v2i32.getScalableSizeInBits()); + EXPECT_GE(v2i64.getScalableSizeInBits(), v4i32.getScalableSizeInBits()); + + // Check that scalable and non-scalable types with the same minimum size + // are not considered equal. + ASSERT_TRUE(v4i32.getScalableSizeInBits() != nxv4i32.getScalableSizeInBits()); + ASSERT_FALSE(v2i64.getScalableSizeInBits() == + nxv2f64.getScalableSizeInBits()); + + // Check that we can obtain a known-exact size from a non-scalable type. + EXPECT_EQ(v4i32.getSizeInBits(), 128U); + EXPECT_EQ(v2i64.getScalableSizeInBits().getFixedSize(), 128U); + + // Check that we can query the known minimum size for both scalable and + // fixed length types. + EXPECT_EQ(nxv2i32.getKnownMinSizeInBits(), 64U); + EXPECT_EQ(nxv2f64.getScalableSizeInBits().getKnownMinSize(), 128U); + EXPECT_EQ(v2i32.getKnownMinSizeInBits(), nxv2i32.getKnownMinSizeInBits()); + + // Check scalable property. + ASSERT_FALSE(v4i32.getScalableSizeInBits().isScalable()); + ASSERT_TRUE(nxv4i32.getScalableSizeInBits().isScalable()); + + // Check convenience size scaling methods. + EXPECT_EQ(v2i32.getScalableSizeInBits() * 2, v4i32.getScalableSizeInBits()); + EXPECT_EQ(2 * nxv2i32.getScalableSizeInBits(), + nxv4i32.getScalableSizeInBits()); + EXPECT_EQ(nxv2f64.getScalableSizeInBits() / 2, + nxv2i32.getScalableSizeInBits()); +} + } // end anonymous namespace diff --git a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp --- a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp +++ b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ScalableSize.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include @@ -503,9 +504,16 @@ } auto LT = [](MVT A, MVT B) -> bool { + // Always treat non-scalable MVTs as smaller than scalable MVTs for the + // purposes of ordering. + if (A.isScalableVector() && !B.isScalableVector()) + return false; + if (!A.isScalableVector() && B.isScalableVector()) + return true; + return A.getScalarSizeInBits() < B.getScalarSizeInBits() || (A.getScalarSizeInBits() == B.getScalarSizeInBits() && - A.getSizeInBits() < B.getSizeInBits()); + A.getScalableSizeInBits() < B.getScalableSizeInBits()); }; auto LE = [<](MVT A, MVT B) -> bool { // This function is used when removing elements: when a vector is compared @@ -513,8 +521,13 @@ if (A.isVector() != B.isVector()) return false; + // We also don't want to remove elements when they're both vectors with the + // same minimum number of lanes, but one is scalable and the other not. + if (A.isScalableVector() != B.isScalableVector()) + return false; + return LT(A, B) || (A.getScalarSizeInBits() == B.getScalarSizeInBits() && - A.getSizeInBits() == B.getSizeInBits()); + A.getScalableSizeInBits() == B.getScalableSizeInBits()); }; for (unsigned M : Modes) {