diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -42,6 +42,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MachineValueType.h" +#include "llvm/Support/TypeSize.h" #include #include #include @@ -170,11 +171,15 @@ } /// Returns the size of the value in bits. - unsigned getValueSizeInBits() const { + /// + /// If the value type is a scalable vector type, the scalable property will + /// be set and the runtime size will be a positive integer multiple of the + /// base size. + const TypeSize getValueSizeInBits() const { return getValueType().getSizeInBits(); } - unsigned getScalarValueSizeInBits() const { + const TypeSize getScalarValueSizeInBits() const { return getValueType().getScalarType().getSizeInBits(); } @@ -1022,7 +1027,11 @@ } /// Returns MVT::getSizeInBits(getValueType(ResNo)). - unsigned getValueSizeInBits(unsigned ResNo) const { + /// + /// If the value type is a scalable vector type, the scalable property will + /// be set and the runtime size will be a positive integer multiple of the + /// base size. + const TypeSize getValueSizeInBits(unsigned ResNo) const { return getValueType(ResNo).getSizeInBits(); } diff --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h --- a/llvm/include/llvm/CodeGen/ValueTypes.h +++ b/llvm/include/llvm/CodeGen/ValueTypes.h @@ -18,6 +18,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/MachineValueType.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/TypeSize.h" #include #include #include @@ -209,11 +210,13 @@ /// Return true if the bit size is a multiple of 8. bool isByteSized() const { - return (getSizeInBits() & 7) == 0; + return getSizeInBits().isByteSized(); } /// Return true if the size is a power-of-two number of bytes. bool isRound() const { + if (isScalableVector()) + return false; unsigned BitSize = getSizeInBits(); return BitSize >= 8 && !(BitSize & (BitSize - 1)); } @@ -288,25 +291,38 @@ } /// Return the size of the specified value type in bits. - unsigned getSizeInBits() const { + /// + /// If the value type is a scalable vector type, the scalable property will + /// be set and the runtime size will be a positive integer multiple of the + /// base size. + const TypeSize getSizeInBits() const { if (isSimple()) return V.getSizeInBits(); return getExtendedSizeInBits(); } - unsigned getScalarSizeInBits() const { + const TypeSize getScalarSizeInBits() const { return getScalarType().getSizeInBits(); } /// Return the number of bytes overwritten by a store of the specified value /// type. - unsigned getStoreSize() const { - return (getSizeInBits() + 7) / 8; + /// + /// If the value type is a scalable vector type, the scalable property will + /// be set and the runtime size will be a positive integer multiple of the + /// base size. + const TypeSize getStoreSize() const { + TypeSize BaseSize = getSizeInBits(); + return {(BaseSize.getKnownMinSize() + 7) / 8, BaseSize.isScalable()}; } /// Return the number of bits overwritten by a store of the specified value /// type. - unsigned getStoreSizeInBits() const { + /// + /// If the value type is a scalable vector type, the scalable property will + /// be set and the runtime size will be a positive integer multiple of the + /// base size. + const TypeSize getStoreSizeInBits() const { return getStoreSize() * 8; } @@ -428,7 +444,7 @@ bool isExtended2048BitVector() const LLVM_READONLY; EVT getExtendedVectorElementType() const; unsigned getExtendedVectorNumElements() const LLVM_READONLY; - unsigned getExtendedSizeInBits() const LLVM_READONLY; + const TypeSize getExtendedSizeInBits() const LLVM_READONLY; }; } // end namespace llvm diff --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h --- a/llvm/include/llvm/Support/MachineValueType.h +++ b/llvm/include/llvm/Support/MachineValueType.h @@ -671,7 +671,12 @@ return { getVectorNumElements(), isScalableVector() }; } - unsigned getSizeInBits() const { + /// Returns the size of the specified MVT in bits. + /// + /// If the value type is a scalable vector type, the scalable property will + /// be set and the runtime size will be a positive integer multiple of the + /// base size. + const TypeSize getSizeInBits() const { switch (SimpleTy) { default: llvm_unreachable("getSizeInBits called on extended MVT."); @@ -691,25 +696,25 @@ case Metadata: llvm_unreachable("Value type is metadata."); case i1: - case v1i1: - case nxv1i1: return 1; - case v2i1: - case nxv2i1: return 2; - case v4i1: - case nxv4i1: return 4; + case v1i1: return TypeSize::Fixed(1); + case nxv1i1: return TypeSize::Scalable(1); + case v2i1: return TypeSize::Fixed(2); + case nxv2i1: return TypeSize::Scalable(2); + case v4i1: return TypeSize::Fixed(4); + case nxv4i1: return TypeSize::Scalable(4); case i8 : case v1i8: - case v8i1: + case v8i1: return TypeSize::Fixed(8); case nxv1i8: - case nxv8i1: return 8; + case nxv8i1: return TypeSize::Scalable(8); case i16 : case f16: case v16i1: case v2i8: - case v1i16: + case v1i16: return TypeSize::Fixed(16); case nxv16i1: case nxv2i8: - case nxv1i16: return 16; + case nxv1i16: return TypeSize::Scalable(16); case f32 : case i32 : case v32i1: @@ -717,15 +722,15 @@ case v2i16: case v2f16: case v1f32: - case v1i32: + case v1i32: return TypeSize::Fixed(32); case nxv32i1: case nxv4i8: case nxv2i16: case nxv1i32: case nxv2f16: - case nxv1f32: return 32; + case nxv1f32: return TypeSize::Scalable(32); case v3i16: - case v3f16: return 48; + case v3f16: return TypeSize::Fixed(48); case x86mmx: case f64 : case i64 : @@ -736,17 +741,17 @@ case v1i64: case v4f16: case v2f32: - case v1f64: + case v1f64: return TypeSize::Fixed(64); case nxv8i8: case nxv4i16: case nxv2i32: case nxv1i64: case nxv4f16: case nxv2f32: - case nxv1f64: return 64; - case f80 : return 80; + case nxv1f64: return TypeSize::Scalable(64); + case f80 : return TypeSize::Fixed(80); case v3i32: - case v3f32: return 96; + case v3f32: return TypeSize::Fixed(96); case f128: case ppcf128: case i128: @@ -758,16 +763,16 @@ case v1i128: case v8f16: case v4f32: - case v2f64: + case v2f64: return TypeSize::Fixed(128); case nxv16i8: case nxv8i16: case nxv4i32: case nxv2i64: case nxv8f16: case nxv4f32: - case nxv2f64: return 128; + case nxv2f64: return TypeSize::Scalable(128); case v5i32: - case v5f32: return 160; + case v5f32: return TypeSize::Fixed(160); case v256i1: case v32i8: case v16i16: @@ -775,13 +780,13 @@ case v4i64: case v16f16: case v8f32: - case v4f64: + case v4f64: return TypeSize::Fixed(256); case nxv32i8: case nxv16i16: case nxv8i32: case nxv4i64: case nxv8f32: - case nxv4f64: return 256; + case nxv4f64: return TypeSize::Scalable(256); case v512i1: case v64i8: case v32i16: @@ -789,56 +794,71 @@ case v8i64: case v32f16: case v16f32: - case v8f64: + case v8f64: return TypeSize::Fixed(512); case nxv32i16: case nxv16i32: case nxv8i64: case nxv16f32: - case nxv8f64: return 512; + case nxv8f64: return TypeSize::Scalable(512); case v1024i1: case v128i8: case v64i16: case v32i32: case v16i64: - case v32f32: + case v32f32: return TypeSize::Fixed(1024); case nxv32i32: - case nxv16i64: return 1024; + case nxv16i64: return TypeSize::Scalable(1024); case v256i8: case v128i16: case v64i32: case v32i64: - case v64f32: - case nxv32i64: return 2048; + case v64f32: return TypeSize::Fixed(2048); + case nxv32i64: return TypeSize::Scalable(2048); case v128i32: - case v128f32: return 4096; + case v128f32: return TypeSize::Fixed(4096); case v256i32: - case v256f32: return 8192; + case v256f32: return TypeSize::Fixed(8192); case v512i32: - case v512f32: return 16384; + case v512f32: return TypeSize::Fixed(16384); case v1024i32: - case v1024f32: return 32768; + case v1024f32: return TypeSize::Fixed(32768); case v2048i32: - case v2048f32: return 65536; - case exnref: return 0; // opaque type + case v2048f32: return TypeSize::Fixed(65536); + case exnref: return TypeSize::Fixed(0); // opaque type } } - unsigned getScalarSizeInBits() const { + const TypeSize getScalarSizeInBits() const { return getScalarType().getSizeInBits(); } /// Return the number of bytes overwritten by a store of the specified value /// type. - unsigned getStoreSize() const { - return (getSizeInBits() + 7) / 8; + /// + /// If the value type is a scalable vector type, the scalable property will + /// be set and the runtime size will be a positive integer multiple of the + /// base size. + const TypeSize getStoreSize() const { + TypeSize BaseSize = getSizeInBits(); + return {(BaseSize.getKnownMinSize() + 7) / 8, BaseSize.isScalable()}; } /// Return the number of bits overwritten by a store of the specified value /// type. - unsigned getStoreSizeInBits() const { + /// + /// If the value type is a scalable vector type, the scalable property will + /// be set and the runtime size will be a positive integer multiple of the + /// base size. + const TypeSize getStoreSizeInBits() const { return getStoreSize() * 8; } + /// Returns true if the number of bits for the type is a multiple of an + /// 8-bit byte. + bool isByteSized() const { + return getSizeInBits().isByteSized(); + } + /// Return true if this has more bits than VT. bool bitsGT(MVT VT) const { return getSizeInBits() > VT.getSizeInBits(); diff --git a/llvm/include/llvm/Support/TypeSize.h b/llvm/include/llvm/Support/TypeSize.h --- a/llvm/include/llvm/Support/TypeSize.h +++ b/llvm/include/llvm/Support/TypeSize.h @@ -138,6 +138,11 @@ return IsScalable; } + // Returns true if the number of bits is a multiple of an 8-bit byte. + bool isByteSized() const { + return (MinSize & 7) == 0; + } + // Casts to a uint64_t if this is a fixed-width size. // // NOTE: This interface is obsolete and will be removed in a future version diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -220,11 +220,13 @@ ForCodeSize = DAG.getMachineFunction().getFunction().hasOptSize(); MaximumLegalStoreInBits = 0; + // We use the minimum store size here, since that's all we can guarantee + // for the scalable vector types. for (MVT VT : MVT::all_valuetypes()) if (EVT(VT).isSimple() && VT != MVT::Other && TLI.isTypeLegal(EVT(VT)) && - VT.getSizeInBits() >= MaximumLegalStoreInBits) - MaximumLegalStoreInBits = VT.getSizeInBits(); + VT.getSizeInBits().getKnownMinSize() >= MaximumLegalStoreInBits) + MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinSize(); } void ConsiderForPruning(SDNode *N) { @@ -13973,8 +13975,8 @@ // the stored value). With Offset=n (for n > 0) the loaded value starts at the // n:th least significant byte of the stored value. if (DAG.getDataLayout().isBigEndian()) - Offset = (STMemType.getStoreSizeInBits() - - LDMemType.getStoreSizeInBits()) / 8 - Offset; + Offset = ((int64_t)STMemType.getStoreSizeInBits() - + (int64_t)LDMemType.getStoreSizeInBits()) / 8 - Offset; // Check that the stored value cover all bits that are loaded. bool STCoversLD = @@ -15131,7 +15133,7 @@ // The latest Node in the DAG. SDLoc DL(StoreNodes[0].MemNode); - int64_t ElementSizeBits = MemVT.getStoreSizeInBits(); + TypeSize ElementSizeBits = MemVT.getStoreSizeInBits(); unsigned SizeInBits = NumStores * ElementSizeBits; unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1; @@ -15516,7 +15518,7 @@ Attribute::NoImplicitFloat); // This function cannot currently deal with non-byte-sized memory sizes. - if (ElementSizeBytes * 8 != MemVT.getSizeInBits()) + if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits()) return false; if (!MemVT.isSimple()) diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -23,6 +23,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Support/TypeSize.h" using namespace llvm; #define DEBUG_TYPE "legalize-types" @@ -4680,7 +4681,8 @@ unsigned Width, EVT WidenVT, unsigned Align = 0, unsigned WidenEx = 0) { EVT WidenEltVT = WidenVT.getVectorElementType(); - unsigned WidenWidth = WidenVT.getSizeInBits(); + const bool Scalable = WidenVT.isScalableVector(); + unsigned WidenWidth = WidenVT.getSizeInBits().getKnownMinSize(); unsigned WidenEltWidth = WidenEltVT.getSizeInBits(); unsigned AlignInBits = Align*8; @@ -4691,23 +4693,27 @@ // See if there is larger legal integer than the element type to load/store. unsigned VT; - for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE; - VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) { - EVT MemVT((MVT::SimpleValueType) VT); - unsigned MemVTWidth = MemVT.getSizeInBits(); - if (MemVT.getSizeInBits() <= WidenEltWidth) - break; - auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT); - if ((Action == TargetLowering::TypeLegal || - Action == TargetLowering::TypePromoteInteger) && - (WidenWidth % MemVTWidth) == 0 && - isPowerOf2_32(WidenWidth / MemVTWidth) && - (MemVTWidth <= Width || - (Align!=0 && MemVTWidth<=AlignInBits && MemVTWidth<=Width+WidenEx))) { - if (MemVTWidth == WidenWidth) - return MemVT; - RetVT = MemVT; - break; + // Don't bother looking for an integer type if the vector is scalable, skip + // to vector types. + if (!Scalable) { + for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE; + VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) { + EVT MemVT((MVT::SimpleValueType) VT); + unsigned MemVTWidth = MemVT.getSizeInBits(); + if (MemVT.getSizeInBits() <= WidenEltWidth) + break; + auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT); + if ((Action == TargetLowering::TypeLegal || + Action == TargetLowering::TypePromoteInteger) && + (WidenWidth % MemVTWidth) == 0 && + isPowerOf2_32(WidenWidth / MemVTWidth) && + (MemVTWidth <= Width || + (Align!=0 && MemVTWidth<=AlignInBits && MemVTWidth<=Width+WidenEx))) { + if (MemVTWidth == WidenWidth) + return MemVT; + RetVT = MemVT; + break; + } } } @@ -4716,7 +4722,10 @@ for (VT = (unsigned)MVT::LAST_VECTOR_VALUETYPE; VT >= (unsigned)MVT::FIRST_VECTOR_VALUETYPE; --VT) { EVT MemVT = (MVT::SimpleValueType) VT; - unsigned MemVTWidth = MemVT.getSizeInBits(); + // Skip vector MVTs which don't match the scalable property of WidenVT. + if (Scalable != MemVT.isScalableVector()) + continue; + unsigned MemVTWidth = MemVT.getSizeInBits().getKnownMinSize(); auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT); if ((Action == TargetLowering::TypeLegal || Action == TargetLowering::TypePromoteInteger) && diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -8842,7 +8842,9 @@ // We check here that the size of the memory operand fits within the size of // the MMO. This is because the MMO might indicate only a possible address // range instead of specifying the affected memory addresses precisely. - assert(memvt.getStoreSize() <= MMO->getSize() && "Size mismatch!"); + // TODO: Make MachineMemOperands aware of scalable vectors. + assert(memvt.getStoreSize().getKnownMinSize() <= MMO->getSize() && + "Size mismatch!"); } /// Profile - Gather unique data for the node. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4304,7 +4304,10 @@ MachineMemOperand *MMO = DAG.getMachineFunction(). getMachineMemOperand(MachinePointerInfo(PtrOperand), - MachineMemOperand::MOStore, VT.getStoreSize(), + MachineMemOperand::MOStore, + // TODO: Make MachineMemOperands aware of scalable + // vectors. + VT.getStoreSize().getKnownMinSize(), Alignment, AAInfo); SDValue StoreNode = DAG.getMaskedStore(getRoot(), sdl, Src0, Ptr, Mask, VT, MMO, false /* Truncating */, @@ -4408,7 +4411,10 @@ const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr; MachineMemOperand *MMO = DAG.getMachineFunction(). getMachineMemOperand(MachinePointerInfo(MemOpBasePtr), - MachineMemOperand::MOStore, VT.getStoreSize(), + MachineMemOperand::MOStore, + // TODO: Make MachineMemOperands aware of scalable + // vectors. + VT.getStoreSize().getKnownMinSize(), Alignment, AAInfo); if (!UniformBase) { Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); @@ -4477,7 +4483,10 @@ MachineMemOperand *MMO = DAG.getMachineFunction(). getMachineMemOperand(MachinePointerInfo(PtrOperand), - MachineMemOperand::MOLoad, VT.getStoreSize(), + MachineMemOperand::MOLoad, + // TODO: Make MachineMemOperands aware of scalable + // vectors. + VT.getStoreSize().getKnownMinSize(), Alignment, AAInfo, Ranges); SDValue Load = DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Mask, Src0, VT, MMO, @@ -4528,7 +4537,10 @@ MachineMemOperand *MMO = DAG.getMachineFunction(). getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : nullptr), - MachineMemOperand::MOLoad, VT.getStoreSize(), + MachineMemOperand::MOLoad, + // TODO: Make MachineMemOperands aware of scalable + // vectors. + VT.getStoreSize().getKnownMinSize(), Alignment, AAInfo, Ranges); if (!UniformBase) { @@ -9249,9 +9261,11 @@ for (unsigned j = 0; j != NumParts; ++j) { // if it isn't first piece, alignment must be 1 + // For scalable vectors the scalable part is currently handled + // by individual targets, so we just use the known minimum size here. ISD::OutputArg MyFlags(Flags, Parts[j].getValueType(), VT, - i < CLI.NumFixedArgs, - i, j*Parts[j].getValueType().getStoreSize()); + i < CLI.NumFixedArgs, i, + j*Parts[j].getValueType().getStoreSize().getKnownMinSize()); if (NumParts > 1 && j == 0) MyFlags.Flags.setSplit(); else if (j != 0) { @@ -9720,8 +9734,11 @@ unsigned NumRegs = TLI->getNumRegistersForCallingConv( *CurDAG->getContext(), F.getCallingConv(), VT); for (unsigned i = 0; i != NumRegs; ++i) { + // For scalable vectors, use the minimum size; individual targets + // are responsible for handling scalable vector arguments and + // return values. ISD::InputArg MyFlags(Flags, RegisterVT, VT, isArgValueUsed, - ArgNo, PartBase+i*RegisterVT.getStoreSize()); + ArgNo, PartBase+i*RegisterVT.getStoreSize().getKnownMinSize()); if (NumRegs > 1 && i == 0) MyFlags.Flags.setSplit(); // if it isn't first piece, alignment must be 1 @@ -9734,7 +9751,7 @@ } if (NeedsRegBlock && Value == NumValues - 1) Ins[Ins.size() - 1].Flags.setInConsecutiveRegsLast(); - PartBase += VT.getStoreSize(); + PartBase += VT.getStoreSize().getKnownMinSize(); } } diff --git a/llvm/lib/CodeGen/SelectionDAG/StatepointLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/StatepointLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/StatepointLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/StatepointLowering.cpp @@ -384,7 +384,8 @@ // can consider allowing spills of smaller values to larger slots // (i.e. change the '==' in the assert below to a '>='). MachineFrameInfo &MFI = Builder.DAG.getMachineFunction().getFrameInfo(); - assert((MFI.getObjectSize(Index) * 8) == Incoming.getValueSizeInBits() && + assert((MFI.getObjectSize(Index) * 8) == + (int64_t)Incoming.getValueSizeInBits() && "Bad spill: stack slot does not match!"); // Note: Using the alignment of the spill slot (rather than the abi or diff --git a/llvm/lib/CodeGen/ValueTypes.cpp b/llvm/lib/CodeGen/ValueTypes.cpp --- a/llvm/lib/CodeGen/ValueTypes.cpp +++ b/llvm/lib/CodeGen/ValueTypes.cpp @@ -11,6 +11,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Type.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/TypeSize.h" using namespace llvm; EVT EVT::changeExtendedTypeToInteger() const { @@ -101,12 +102,12 @@ return cast(LLVMTy)->getNumElements(); } -unsigned EVT::getExtendedSizeInBits() const { +const TypeSize EVT::getExtendedSizeInBits() const { assert(isExtended() && "Type is not extended!"); if (IntegerType *ITy = dyn_cast(LLVMTy)) - return ITy->getBitWidth(); + return TypeSize::Fixed(ITy->getBitWidth()); if (VectorType *VTy = dyn_cast(LLVMTy)) - return VTy->getBitWidth(); + return VTy->getPrimitiveSizeInBits(); llvm_unreachable("Unrecognized extended type!"); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -9937,7 +9937,7 @@ // Only interested in 64-bit vectors as the ultimate result. EVT VT = N->getValueType(0); - if (!VT.isVector()) + if (!VT.isVector() || VT.isScalableVector()) return SDValue(); if (VT.getSimpleVT().getSizeInBits() != 64) return SDValue(); diff --git a/llvm/lib/Target/AArch64/AArch64StackOffset.h b/llvm/lib/Target/AArch64/AArch64StackOffset.h --- a/llvm/lib/Target/AArch64/AArch64StackOffset.h +++ b/llvm/lib/Target/AArch64/AArch64StackOffset.h @@ -15,6 +15,7 @@ #define LLVM_LIB_TARGET_AARCH64_AARCH64STACKOFFSET_H #include "llvm/Support/MachineValueType.h" +#include "llvm/Support/TypeSize.h" namespace llvm { @@ -45,8 +46,7 @@ StackOffset() : Bytes(0), ScalableBytes(0) {} StackOffset(int64_t Offset, MVT::SimpleValueType T) : StackOffset() { - assert(MVT(T).getSizeInBits() % 8 == 0 && - "Offset type is not a multiple of bytes"); + assert(MVT(T).isByteSized() && "Offset type is not a multiple of bytes"); *this += Part(Offset, T); } @@ -56,11 +56,11 @@ StackOffset &operator=(const StackOffset &) = default; StackOffset &operator+=(const StackOffset::Part &Other) { - int64_t OffsetInBytes = Other.first * (Other.second.getSizeInBits() / 8); - if (Other.second.isScalableVector()) - ScalableBytes += OffsetInBytes; + const TypeSize Size = Other.second.getSizeInBits(); + if (Size.isScalable()) + ScalableBytes += Other.first * ((int64_t)Size.getKnownMinSize() / 8); else - Bytes += OffsetInBytes; + Bytes += Other.first * ((int64_t)Size.getFixedSize() / 8); return *this; } diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -14879,7 +14879,7 @@ V = -V; } - unsigned NumBytes = std::max(VT.getSizeInBits() / 8, 1U); + unsigned NumBytes = std::max((unsigned)VT.getSizeInBits() / 8, 1U); // MVE: size * imm7 if (VT.isVector() && Subtarget->hasMVEIntegerOps()) { diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp --- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp +++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp @@ -475,7 +475,7 @@ MemAddr = DAG.getNode(ISD::ADD, dl, MVT::i32, StackPtr, MemAddr); if (ArgAlign) LargestAlignSeen = std::max(LargestAlignSeen, - VA.getLocVT().getStoreSizeInBits() >> 3); + (unsigned)VA.getLocVT().getStoreSizeInBits() >> 3); if (Flags.isByVal()) { // The argument is a struct passed by value. According to LLVM, "Arg" // is a pointer. diff --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp --- a/llvm/lib/Target/Mips/MipsISelLowering.cpp +++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp @@ -125,7 +125,8 @@ CallingConv::ID CC, EVT VT) const { if (VT.isVector()) - return std::max((VT.getSizeInBits() / (Subtarget.isABI_O32() ? 32 : 64)), + return std::max(((unsigned)VT.getSizeInBits() / + (Subtarget.isABI_O32() ? 32 : 64)), 1U); return MipsTargetLowering::getNumRegisters(Context, VT); } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -885,7 +885,7 @@ MVT SimpleVT = LoadedVT.getSimpleVT(); MVT ScalarVT = SimpleVT.getScalarType(); // Read at least 8 bits (predicates are stored as 8-bit values) - unsigned fromTypeWidth = std::max(8U, ScalarVT.getSizeInBits()); + unsigned fromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits()); unsigned int fromType; // Vector Setting @@ -1030,7 +1030,7 @@ // Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float MVT ScalarVT = SimpleVT.getScalarType(); // Read at least 8 bits (predicates are stored as 8-bit values) - unsigned FromTypeWidth = std::max(8U, ScalarVT.getSizeInBits()); + unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits()); unsigned int FromType; // The last operand holds the original LoadSDNode::getExtensionType() value unsigned ExtensionType = cast( diff --git a/llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp b/llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp --- a/llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp @@ -1044,7 +1044,7 @@ if (Use->isMachineOpcode()) return 0; MaxTruncation = - std::max(MaxTruncation, Use->getValueType(0).getSizeInBits()); + std::max(MaxTruncation, (unsigned)Use->getValueType(0).getSizeInBits()); continue; case ISD::STORE: { if (Use->isMachineOpcode()) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -5834,7 +5834,7 @@ "Expected VTs to be the same size!"); unsigned Scale = VT.getScalarSizeInBits() / InVT.getScalarSizeInBits(); In = extractSubVector(In, 0, DAG, DL, - std::max(128U, VT.getSizeInBits() / Scale)); + std::max(128U, (unsigned)VT.getSizeInBits() / Scale)); InVT = In.getValueType(); } @@ -8625,7 +8625,7 @@ ImmH = DAG.getBitcast(MVT::v32i1, ImmH); DstVec = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, ImmL, ImmH); } else { - MVT ImmVT = MVT::getIntegerVT(std::max(VT.getSizeInBits(), 8U)); + MVT ImmVT = MVT::getIntegerVT(std::max((unsigned)VT.getSizeInBits(), 8U)); SDValue Imm = DAG.getConstant(Immediate, dl, ImmVT); MVT VecVT = VT.getSizeInBits() >= 8 ? VT : MVT::v8i1; DstVec = DAG.getBitcast(VecVT, Imm); @@ -32842,7 +32842,8 @@ Offset += Src.getConstantOperandVal(1); Src = Src.getOperand(0); } - WideSizeInBits = std::max(WideSizeInBits, Src.getValueSizeInBits()); + WideSizeInBits = std::max(WideSizeInBits, + (unsigned)Src.getValueSizeInBits()); assert((Offset % BaseVT.getVectorNumElements()) == 0 && "Unexpected subvector extraction"); Offset /= BaseVT.getVectorNumElements(); @@ -35779,7 +35780,7 @@ const X86Subtarget &Subtarget) { // Find the appropriate width for the PSADBW. EVT InVT = Zext0.getOperand(0).getValueType(); - unsigned RegSize = std::max(128u, InVT.getSizeInBits()); + unsigned RegSize = std::max(128u, (unsigned)InVT.getSizeInBits()); // "Zero-extend" the i8 vectors. This is not a per-element zext, rather we // fill in the missing vector elements with 0. diff --git a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp --- a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp +++ b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp @@ -120,4 +120,61 @@ ScV4Float64Ty->getElementType()); } +TEST(ScalableVectorMVTsTest, SizeQueries) { + LLVMContext Ctx; + + EVT nxv4i32 = EVT::getVectorVT(Ctx, MVT::i32, 4, /*Scalable=*/ true); + EVT nxv2i32 = EVT::getVectorVT(Ctx, MVT::i32, 2, /*Scalable=*/ true); + EVT nxv2i64 = EVT::getVectorVT(Ctx, MVT::i64, 2, /*Scalable=*/ true); + EVT nxv2f64 = EVT::getVectorVT(Ctx, MVT::f64, 2, /*Scalable=*/ true); + + EVT v4i32 = EVT::getVectorVT(Ctx, MVT::i32, 4); + EVT v2i32 = EVT::getVectorVT(Ctx, MVT::i32, 2); + EVT v2i64 = EVT::getVectorVT(Ctx, MVT::i64, 2); + EVT v2f64 = EVT::getVectorVT(Ctx, MVT::f64, 2); + + // Check equivalence and ordering on scalable types. + EXPECT_EQ(nxv4i32.getSizeInBits(), nxv2i64.getSizeInBits()); + EXPECT_EQ(nxv2f64.getSizeInBits(), nxv2i64.getSizeInBits()); + EXPECT_NE(nxv2i32.getSizeInBits(), nxv4i32.getSizeInBits()); + EXPECT_LT(nxv2i32.getSizeInBits(), nxv2i64.getSizeInBits()); + EXPECT_LE(nxv4i32.getSizeInBits(), nxv2i64.getSizeInBits()); + EXPECT_GT(nxv4i32.getSizeInBits(), nxv2i32.getSizeInBits()); + EXPECT_GE(nxv2i64.getSizeInBits(), nxv4i32.getSizeInBits()); + + // Check equivalence and ordering on fixed types. + EXPECT_EQ(v4i32.getSizeInBits(), v2i64.getSizeInBits()); + EXPECT_EQ(v2f64.getSizeInBits(), v2i64.getSizeInBits()); + EXPECT_NE(v2i32.getSizeInBits(), v4i32.getSizeInBits()); + EXPECT_LT(v2i32.getSizeInBits(), v2i64.getSizeInBits()); + EXPECT_LE(v4i32.getSizeInBits(), v2i64.getSizeInBits()); + EXPECT_GT(v4i32.getSizeInBits(), v2i32.getSizeInBits()); + EXPECT_GE(v2i64.getSizeInBits(), v4i32.getSizeInBits()); + + // Check that scalable and non-scalable types with the same minimum size + // are not considered equal. + ASSERT_TRUE(v4i32.getSizeInBits() != nxv4i32.getSizeInBits()); + ASSERT_FALSE(v2i64.getSizeInBits() == nxv2f64.getSizeInBits()); + + // Check that we can obtain a known-exact size from a non-scalable type. + EXPECT_EQ(v4i32.getSizeInBits(), 128U); + EXPECT_EQ(v2i64.getSizeInBits().getFixedSize(), 128U); + + // Check that we can query the known minimum size for both scalable and + // fixed length types. + EXPECT_EQ(nxv2i32.getSizeInBits().getKnownMinSize(), 64U); + EXPECT_EQ(nxv2f64.getSizeInBits().getKnownMinSize(), 128U); + EXPECT_EQ(v2i32.getSizeInBits().getKnownMinSize(), + nxv2i32.getSizeInBits().getKnownMinSize()); + + // Check scalable property. + ASSERT_FALSE(v4i32.getSizeInBits().isScalable()); + ASSERT_TRUE(nxv4i32.getSizeInBits().isScalable()); + + // Check convenience size scaling methods. + EXPECT_EQ(v2i32.getSizeInBits() * 2, v4i32.getSizeInBits()); + EXPECT_EQ(2 * nxv2i32.getSizeInBits(), nxv4i32.getSizeInBits()); + EXPECT_EQ(nxv2f64.getSizeInBits() / 2, nxv2i32.getSizeInBits()); +} + } // end anonymous namespace diff --git a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp --- a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp +++ b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/TypeSize.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include @@ -503,9 +504,15 @@ } auto LT = [](MVT A, MVT B) -> bool { - return A.getScalarSizeInBits() < B.getScalarSizeInBits() || - (A.getScalarSizeInBits() == B.getScalarSizeInBits() && - A.getSizeInBits() < B.getSizeInBits()); + // Always treat non-scalable MVTs as smaller than scalable MVTs for the + // purposes of ordering. + if (A.isScalableVector() && !B.isScalableVector()) + return false; + if (!A.isScalableVector() && B.isScalableVector()) + return true; + + return std::tie(A.getScalarSizeInBits(), A.getSizeInBits()) < + std::tie(B.getScalarSizeInBits(), B.getSizeInBits()); }; auto LE = [<](MVT A, MVT B) -> bool { // This function is used when removing elements: when a vector is compared @@ -513,8 +520,13 @@ if (A.isVector() != B.isVector()) return false; - return LT(A, B) || (A.getScalarSizeInBits() == B.getScalarSizeInBits() && - A.getSizeInBits() == B.getSizeInBits()); + // We also don't want to remove elements when they're both vectors with the + // same minimum number of lanes, but one is scalable and the other not. + if (A.isScalableVector() != B.isScalableVector()) + return false; + + return std::tie(A.getScalarSizeInBits(), A.getSizeInBits()) <= + std::tie(B.getScalarSizeInBits(), B.getSizeInBits()); }; for (unsigned M : Modes) {