diff --git a/llvm/include/llvm/CodeGen/GlobalISel/InstructionSelectorImpl.h b/llvm/include/llvm/CodeGen/GlobalISel/InstructionSelectorImpl.h --- a/llvm/include/llvm/CodeGen/GlobalISel/InstructionSelectorImpl.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/InstructionSelectorImpl.h @@ -595,7 +595,7 @@ case GIM_CheckPointerToAny: { int64_t InsnID = MatchTable[CurrentIdx++]; int64_t OpIdx = MatchTable[CurrentIdx++]; - int64_t SizeInBits = MatchTable[CurrentIdx++]; + uint64_t SizeInBits = MatchTable[CurrentIdx++]; DEBUG_WITH_TYPE(TgtInstructionSelector::getName(), dbgs() << CurrentIdx << ": GIM_CheckPointerToAny(MIs[" diff --git a/llvm/include/llvm/Support/LowLevelTypeImpl.h b/llvm/include/llvm/Support/LowLevelTypeImpl.h --- a/llvm/include/llvm/Support/LowLevelTypeImpl.h +++ b/llvm/include/llvm/Support/LowLevelTypeImpl.h @@ -67,7 +67,7 @@ assert(!EC.isScalar() && "invalid number of vector elements"); assert(!ScalarTy.isVector() && "invalid vector element type"); return LLT{ScalarTy.isPointer(), /*isVector=*/true, EC, - ScalarTy.getSizeInBits(), + ScalarTy.getSizeInBits().getFixedSize(), ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0}; } @@ -100,12 +100,14 @@ return EC.isScalar() ? ScalarTy : LLT::vector(EC, ScalarTy); } - static LLT scalarOrVector(ElementCount EC, unsigned ScalarSize) { - return scalarOrVector(EC, LLT::scalar(ScalarSize)); + static LLT scalarOrVector(ElementCount EC, uint64_t ScalarSize) { + assert(ScalarSize <= std::numeric_limits::max() && + "Not enough bits in LLT to represent size"); + return scalarOrVector(EC, LLT::scalar(static_cast(ScalarSize))); } explicit LLT(bool isPointer, bool isVector, ElementCount EC, - unsigned SizeInBits, unsigned AddressSpace) { + uint64_t SizeInBits, unsigned AddressSpace) { init(isPointer, isVector, EC, SizeInBits, AddressSpace); } explicit LLT() : IsPointer(false), IsVector(false), RawData(0) {} @@ -148,18 +150,19 @@ } /// Returns the total size of the type. Must only be called on sized types. - unsigned getSizeInBits() const { + TypeSize getSizeInBits() const { if (isPointer() || isScalar()) - return getScalarSizeInBits(); - // FIXME: This should return a TypeSize in order to work for scalable - // vectors. - return getScalarSizeInBits() * getElementCount().getKnownMinValue(); + return TypeSize::Fixed(getScalarSizeInBits()); + auto EC = getElementCount(); + return TypeSize(getScalarSizeInBits() * EC.getKnownMinValue(), + EC.isScalable()); } /// Returns the total size of the type in bytes, i.e. number of whole bytes /// needed to represent the size in bits. Must only be called on sized types. - unsigned getSizeInBytes() const { - return (getSizeInBits() + 7) / 8; + TypeSize getSizeInBytes() const { + TypeSize BaseSize = getSizeInBits(); + return {(BaseSize.getKnownMinSize() + 7) / 8, BaseSize.isScalable()}; } LLT getScalarType() const { @@ -199,11 +202,11 @@ getElementType()); } - assert(getSizeInBits() % Factor == 0); - return scalar(getSizeInBits() / Factor); + assert(getScalarSizeInBits() % Factor == 0); + return scalar(getScalarSizeInBits() / Factor); } - bool isByteSized() const { return (getSizeInBits() & 7) == 0; } + bool isByteSized() const { return getSizeInBits().isKnownMultipleOf(8); } unsigned getScalarSizeInBits() const { assert(RawData != 0 && "Invalid Type"); @@ -333,8 +336,10 @@ return getMask(FieldInfo) & (RawData >> FieldInfo[1]); } - void init(bool IsPointer, bool IsVector, ElementCount EC, unsigned SizeInBits, + void init(bool IsPointer, bool IsVector, ElementCount EC, uint64_t SizeInBits, unsigned AddressSpace) { + assert(SizeInBits <= std::numeric_limits::max() && + "Not enough bits in LLT to represent size"); this->IsPointer = IsPointer; this->IsVector = IsVector; if (!IsVector) { diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -1565,7 +1565,7 @@ Register SrcReg = getOrCreateVReg(**AI); LLT SrcTy = MRI->getType(SrcReg); if (SrcTy.isPointer()) - MinPtrSize = std::min(SrcTy.getSizeInBits(), MinPtrSize); + MinPtrSize = std::min(SrcTy.getSizeInBits(), MinPtrSize); SrcRegs.push_back(SrcReg); } diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp --- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp @@ -258,7 +258,8 @@ m_OneNonDBGUse(m_any_of(m_GAShr(m_Reg(ShiftSrc), m_ICst(ShiftImm)), m_GLShr(m_Reg(ShiftSrc), m_ICst(ShiftImm)))))) return false; - if (ShiftImm < 0 || ShiftImm + Width > Ty.getSizeInBits()) + if (ShiftImm < 0 || + static_cast(ShiftImm + Width) > Ty.getSizeInBits()) return false; MatchInfo = [=](MachineIRBuilder &B) { auto Cst1 = B.buildConstant(Ty, ShiftImm); diff --git a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp --- a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp +++ b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp @@ -81,6 +81,9 @@ EXPECT_EQ(EC, VTy.getElementCount()); if (!EC.isScalable()) EXPECT_EQ(S * EC.getFixedValue(), VTy.getSizeInBits()); + else + EXPECT_EQ(TypeSize::Scalable(S * EC.getKnownMinValue()), + VTy.getSizeInBits()); // Test equality operators. EXPECT_TRUE(VTy == VTy); diff --git a/llvm/utils/TableGen/GlobalISelEmitter.cpp b/llvm/utils/TableGen/GlobalISelEmitter.cpp --- a/llvm/utils/TableGen/GlobalISelEmitter.cpp +++ b/llvm/utils/TableGen/GlobalISelEmitter.cpp @@ -182,7 +182,13 @@ assert((!Ty.isVector() || Ty.isScalable() == Other.Ty.isScalable()) && "Unexpected mismatch of scalable property"); - return Ty.getSizeInBits() < Other.Ty.getSizeInBits(); + return Ty.isVector() + ? std::make_tuple(Ty.isScalable(), + Ty.getSizeInBits().getKnownMinSize()) < + std::make_tuple(Other.Ty.isScalable(), + Other.Ty.getSizeInBits().getKnownMinSize()) + : Ty.getSizeInBits().getFixedSize() < + Other.Ty.getSizeInBits().getFixedSize(); } bool operator==(const LLTCodeGen &B) const { return Ty == B.Ty; } @@ -3788,7 +3794,8 @@ return None; // Align so unusual types like i1 don't get rounded down. - return llvm::alignTo(MemTyOrNone->get().getSizeInBits(), 8); + return llvm::alignTo( + static_cast(MemTyOrNone->get().getSizeInBits()), 8); } Expected GlobalISelEmitter::addBuiltinPredicates(