diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h --- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h @@ -1038,7 +1038,8 @@ }, [=](const LegalityQuery &Query) { LLT VecTy = Query.Types[TypeIdx]; - LLT NewTy = LLT::scalarOrVector(MaxElements, VecTy.getElementType()); + LLT NewTy = LLT::scalarOrVector(ElementCount::getFixed(MaxElements), + VecTy.getElementType()); return std::make_pair(TypeIdx, NewTy); }); } diff --git a/llvm/include/llvm/Support/LowLevelTypeImpl.h b/llvm/include/llvm/Support/LowLevelTypeImpl.h --- a/llvm/include/llvm/Support/LowLevelTypeImpl.h +++ b/llvm/include/llvm/Support/LowLevelTypeImpl.h @@ -96,14 +96,12 @@ return vector(ElementCount::getScalable(MinNumElements), ScalarTy); } - static LLT scalarOrVector(uint16_t NumElements, LLT ScalarTy) { - // FIXME: Migrate interface to use ElementCount - return NumElements == 1 ? ScalarTy - : LLT::fixed_vector(NumElements, ScalarTy); + static LLT scalarOrVector(ElementCount EC, LLT ScalarTy) { + return EC.isScalar() ? ScalarTy : LLT::vector(EC, ScalarTy); } - static LLT scalarOrVector(uint16_t NumElements, unsigned ScalarSize) { - return scalarOrVector(NumElements, LLT::scalar(ScalarSize)); + static LLT scalarOrVector(ElementCount EC, unsigned ScalarSize) { + return scalarOrVector(EC, LLT::scalar(ScalarSize)); } explicit LLT(bool isPointer, bool isVector, ElementCount EC, @@ -189,7 +187,8 @@ LLT changeNumElements(unsigned NewNumElts) const { assert((!isVector() || !isScalable()) && "Cannot use changeNumElements on a scalable vector"); - return LLT::scalarOrVector(NewNumElts, getScalarType()); + return LLT::scalarOrVector(ElementCount::getFixed(NewNumElts), + getScalarType()); } /// Return a type that is \p Factor times smaller. Reduces the number of @@ -199,7 +198,8 @@ assert(Factor != 1); if (isVector()) { assert(getNumElements() % Factor == 0); - return scalarOrVector(getNumElements() / Factor, getElementType()); + return scalarOrVector(getElementCount().divideCoefficientBy(Factor), + getElementType()); } assert(getSizeInBits() % Factor == 0); diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp --- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp @@ -60,7 +60,8 @@ unsigned EltSize = OrigTy.getScalarSizeInBits(); if (LeftoverSize % EltSize != 0) return {-1, -1}; - LeftoverTy = LLT::scalarOrVector(LeftoverSize / EltSize, EltSize); + LeftoverTy = LLT::scalarOrVector( + ElementCount::getFixed(LeftoverSize / EltSize), EltSize); } else { LeftoverTy = LLT::scalar(LeftoverSize); } @@ -178,7 +179,8 @@ unsigned EltSize = MainTy.getScalarSizeInBits(); if (LeftoverSize % EltSize != 0) return false; - LeftoverTy = LLT::scalarOrVector(LeftoverSize / EltSize, EltSize); + LeftoverTy = LLT::scalarOrVector( + ElementCount::getFixed(LeftoverSize / EltSize), EltSize); } else { LeftoverTy = LLT::scalar(LeftoverSize); } @@ -2572,7 +2574,8 @@ // Type of the intermediate result vector. const unsigned NewEltsPerOldElt = NewNumElts / OldNumElts; - LLT MidTy = LLT::scalarOrVector(NewEltsPerOldElt, NewEltTy); + LLT MidTy = + LLT::scalarOrVector(ElementCount::getFixed(NewEltsPerOldElt), NewEltTy); auto NewEltsPerOldEltK = MIRBuilder.buildConstant(IdxTy, NewEltsPerOldElt); @@ -3300,9 +3303,6 @@ return UnableToLegalize; const LLT NarrowTy0 = NarrowTyArg; - const unsigned NewNumElts = - NarrowTy0.isVector() ? NarrowTy0.getNumElements() : 1; - const Register DstReg = MI.getOperand(0).getReg(); LLT DstTy = MRI.getType(DstReg); LLT LeftoverTy0; @@ -3322,7 +3322,9 @@ for (unsigned I = 1, E = MI.getNumOperands(); I != E; ++I) { Register SrcReg = MI.getOperand(I).getReg(); LLT SrcTyI = MRI.getType(SrcReg); - LLT NarrowTyI = LLT::scalarOrVector(NewNumElts, SrcTyI.getScalarType()); + const auto NewEC = NarrowTy0.isVector() ? NarrowTy0.getElementCount() + : ElementCount::getFixed(1); + LLT NarrowTyI = LLT::scalarOrVector(NewEC, SrcTyI.getScalarType()); LLT LeftoverTyI; // Split this operand into the requested typed registers, and any leftover @@ -3685,7 +3687,7 @@ LLT ElementType = SrcTy.getElementType(); LLT OverflowElementTy = MRI.getType(Overflow).getElementType(); - const int NumResult = SrcTy.getNumElements(); + const ElementCount NumResult = SrcTy.getElementCount(); LLT GCDTy = getGCDType(SrcTy, NarrowTy); // Unmerge the operands to smaller parts of GCD type. @@ -3693,7 +3695,7 @@ auto UnmergeRHS = MIRBuilder.buildUnmerge(GCDTy, RHS); const int NumOps = UnmergeLHS->getNumOperands() - 1; - const int PartsPerUnmerge = NumResult / NumOps; + const ElementCount PartsPerUnmerge = NumResult.divideCoefficientBy(NumOps); LLT OverflowTy = LLT::scalarOrVector(PartsPerUnmerge, OverflowElementTy); LLT ResultTy = LLT::scalarOrVector(PartsPerUnmerge, ElementType); @@ -3711,7 +3713,7 @@ LLT ResultLCMTy = buildLCMMergePieces(SrcTy, NarrowTy, GCDTy, ResultParts); LLT OverflowLCMTy = - LLT::scalarOrVector(ResultLCMTy.getNumElements(), OverflowElementTy); + LLT::scalarOrVector(ResultLCMTy.getElementCount(), OverflowElementTy); // Recombine the pieces to the original result and overflow registers. buildWidenedRemergeToDst(Result, ResultLCMTy, ResultParts); @@ -3957,8 +3959,6 @@ SmallVector ExtractedRegs[3]; SmallVector Parts; - unsigned NarrowElts = NarrowTy.isVector() ? NarrowTy.getNumElements() : 1; - // Break down all the sources into NarrowTy pieces we can operate on. This may // involve creating merges to a wider type, padded with undef. for (int I = 0; I != NumSrcOps; ++I) { @@ -3979,7 +3979,9 @@ SrcReg = MIRBuilder.buildBitcast(SrcTy, SrcReg).getReg(0); } } else { - OpNarrowTy = LLT::scalarOrVector(NarrowElts, SrcTy.getScalarType()); + auto NarrowEC = NarrowTy.isVector() ? NarrowTy.getElementCount() + : ElementCount::getFixed(1); + OpNarrowTy = LLT::scalarOrVector(NarrowEC, SrcTy.getScalarType()); } LLT GCDTy = extractGCDType(ExtractedRegs[I], SrcTy, OpNarrowTy, SrcReg); diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -816,7 +816,7 @@ if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) { int GCD = greatestCommonDivisor(OrigTy.getNumElements(), TargetTy.getNumElements()); - return LLT::scalarOrVector(GCD, OrigElt); + return LLT::scalarOrVector(ElementCount::getFixed(GCD), OrigElt); } } else { // If the source is a vector of pointers, return a pointer element. diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp @@ -107,7 +107,9 @@ unsigned Size = Ty.getSizeInBits(); unsigned Pieces = (Size + 63) / 64; unsigned NewNumElts = (Ty.getNumElements() + 1) / Pieces; - return std::make_pair(TypeIdx, LLT::scalarOrVector(NewNumElts, EltTy)); + return std::make_pair( + TypeIdx, + LLT::scalarOrVector(ElementCount::getFixed(NewNumElts), EltTy)); }; } @@ -139,7 +141,7 @@ return LLT::scalar(Size); } - return LLT::scalarOrVector(Size / 32, 32); + return LLT::scalarOrVector(ElementCount::getFixed(Size / 32), 32); } static LegalizeMutation bitcastToRegisterType(unsigned TypeIdx) { @@ -154,7 +156,8 @@ const LLT Ty = Query.Types[TypeIdx]; unsigned Size = Ty.getSizeInBits(); assert(Size % 32 == 0); - return std::make_pair(TypeIdx, LLT::scalarOrVector(Size / 32, 32)); + return std::make_pair( + TypeIdx, LLT::scalarOrVector(ElementCount::getFixed(Size / 32), 32)); }; } @@ -1214,7 +1217,8 @@ if (MaxSize % EltSize == 0) { return std::make_pair( - 0, LLT::scalarOrVector(MaxSize / EltSize, EltTy)); + 0, LLT::scalarOrVector( + ElementCount::getFixed(MaxSize / EltSize), EltTy)); } unsigned NumPieces = Query.MMODescrs[0].SizeInBits / MaxSize; @@ -1242,7 +1246,8 @@ // should be OK, since the new parts will be further legalized. unsigned FloorSize = PowerOf2Floor(DstSize); return std::make_pair( - 0, LLT::scalarOrVector(FloorSize / EltSize, EltTy)); + 0, LLT::scalarOrVector( + ElementCount::getFixed(FloorSize / EltSize), EltTy)); } // Need to split because of alignment. @@ -4448,14 +4453,16 @@ LLT RegTy; if (IsD16 && ST.hasUnpackedD16VMem()) { - RoundedTy = LLT::scalarOrVector(AdjustedNumElts, 32); + RoundedTy = + LLT::scalarOrVector(ElementCount::getFixed(AdjustedNumElts), 32); TFETy = LLT::fixed_vector(AdjustedNumElts + 1, 32); RegTy = S32; } else { unsigned EltSize = EltTy.getSizeInBits(); unsigned RoundedElts = (AdjustedTy.getSizeInBits() + 31) / 32; unsigned RoundedSize = 32 * RoundedElts; - RoundedTy = LLT::scalarOrVector(RoundedSize / EltSize, EltSize); + RoundedTy = LLT::scalarOrVector( + ElementCount::getFixed(RoundedSize / EltSize), EltSize); TFETy = LLT::fixed_vector(RoundedSize / 32 + 1, S32); RegTy = !IsTFE && EltSize == 16 ? V2S16 : S32; } diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp @@ -675,12 +675,13 @@ static LLT getHalfSizedType(LLT Ty) { if (Ty.isVector()) { - assert(Ty.getNumElements() % 2 == 0); - return LLT::scalarOrVector(Ty.getNumElements() / 2, Ty.getElementType()); + assert(Ty.getElementCount().isKnownMultipleOf(2)); + return LLT::scalarOrVector(Ty.getElementCount().divideCoefficientBy(2), + Ty.getElementType()); } - assert(Ty.getSizeInBits() % 2 == 0); - return LLT::scalar(Ty.getSizeInBits() / 2); + assert(Ty.getScalarSizeInBits() % 2 == 0); + return LLT::scalar(Ty.getScalarSizeInBits() / 2); } /// Legalize instruction \p MI where operands in \p OpIndices must be SGPRs. If @@ -1123,8 +1124,8 @@ unsigned FirstPartNumElts = FirstSize / EltSize; unsigned RemainderElts = (TotalSize - FirstSize) / EltSize; - return {LLT::scalarOrVector(FirstPartNumElts, EltTy), - LLT::scalarOrVector(RemainderElts, EltTy)}; + return {LLT::scalarOrVector(ElementCount::getFixed(FirstPartNumElts), EltTy), + LLT::scalarOrVector(ElementCount::getFixed(RemainderElts), EltTy)}; } static LLT widen96To128(LLT Ty) { diff --git a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp --- a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp +++ b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp @@ -100,17 +100,25 @@ TEST(LowLevelTypeTest, ScalarOrVector) { // Test version with number of bits for scalar type. - EXPECT_EQ(LLT::scalar(32), LLT::scalarOrVector(1, 32)); - EXPECT_EQ(LLT::fixed_vector(2, 32), LLT::scalarOrVector(2, 32)); + EXPECT_EQ(LLT::scalar(32), + LLT::scalarOrVector(ElementCount::getFixed(1), 32)); + EXPECT_EQ(LLT::fixed_vector(2, 32), + LLT::scalarOrVector(ElementCount::getFixed(2), 32)); + EXPECT_EQ(LLT::scalable_vector(1, 32), + LLT::scalarOrVector(ElementCount::getScalable(1), 32)); // Test version with LLT for scalar type. - EXPECT_EQ(LLT::scalar(32), LLT::scalarOrVector(1, LLT::scalar(32))); - EXPECT_EQ(LLT::fixed_vector(2, 32), LLT::scalarOrVector(2, LLT::scalar(32))); + EXPECT_EQ(LLT::scalar(32), + LLT::scalarOrVector(ElementCount::getFixed(1), LLT::scalar(32))); + EXPECT_EQ(LLT::fixed_vector(2, 32), + LLT::scalarOrVector(ElementCount::getFixed(2), LLT::scalar(32))); // Test with pointer elements. - EXPECT_EQ(LLT::pointer(1, 32), LLT::scalarOrVector(1, LLT::pointer(1, 32))); - EXPECT_EQ(LLT::fixed_vector(2, LLT::pointer(1, 32)), - LLT::scalarOrVector(2, LLT::pointer(1, 32))); + EXPECT_EQ(LLT::pointer(1, 32), LLT::scalarOrVector(ElementCount::getFixed(1), + LLT::pointer(1, 32))); + EXPECT_EQ( + LLT::fixed_vector(2, LLT::pointer(1, 32)), + LLT::scalarOrVector(ElementCount::getFixed(2), LLT::pointer(1, 32))); } TEST(LowLevelTypeTest, ChangeElementType) {