diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h --- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h @@ -384,7 +384,8 @@ unsigned UnmergeNumElts = DestTy.isVector() ? CastSrcTy.getNumElements() / NumDefs : 1; - LLT UnmergeTy = CastSrcTy.changeNumElements(UnmergeNumElts); + LLT UnmergeTy = CastSrcTy.changeElementCount( + ElementCount::getFixed(UnmergeNumElts)); if (isInstUnsupported( {TargetOpcode::G_UNMERGE_VALUES, {UnmergeTy, CastSrcTy}})) diff --git a/llvm/include/llvm/Support/LowLevelTypeImpl.h b/llvm/include/llvm/Support/LowLevelTypeImpl.h --- a/llvm/include/llvm/Support/LowLevelTypeImpl.h +++ b/llvm/include/llvm/Support/LowLevelTypeImpl.h @@ -182,13 +182,10 @@ : LLT::scalar(NewEltSize); } - /// Return a vector or scalar with the same element type and the new number of - /// elements. - LLT changeNumElements(unsigned NewNumElts) const { - assert((!isVector() || !isScalable()) && - "Cannot use changeNumElements on a scalable vector"); - return LLT::scalarOrVector(ElementCount::getFixed(NewNumElts), - getScalarType()); + /// Return a vector or scalar with the same element type and the new element + /// count. + LLT changeElementCount(ElementCount EC) const { + return LLT::scalarOrVector(EC, getScalarType()); } /// Return a type that is \p Factor times smaller. Reduces the number of @@ -197,7 +194,7 @@ LLT divide(int Factor) const { assert(Factor != 1); if (isVector()) { - assert(getNumElements() % Factor == 0); + assert(getElementCount().isKnownMultipleOf(Factor)); return scalarOrVector(getElementCount().divideCoefficientBy(Factor), getElementType()); } diff --git a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp --- a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp @@ -377,7 +377,7 @@ PartLLT.getScalarSizeInBits() == LLTy.getScalarSizeInBits() * 2 && Regs.size() == 1) { LLT NewTy = PartLLT.changeElementType(LLTy.getElementType()) - .changeNumElements(PartLLT.getNumElements() * 2); + .changeElementCount(PartLLT.getElementCount() * 2); CastRegs[0] = B.buildBitcast(NewTy, Regs[0]).getReg(0); PartLLT = NewTy; } diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp --- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp @@ -4237,7 +4237,8 @@ // We only support splitting a shuffle into 2, so adjust NarrowTy accordingly. // Further legalization attempts will be needed to do split further. - NarrowTy = DstTy.changeNumElements(DstTy.getNumElements() / 2); + NarrowTy = + DstTy.changeElementCount(DstTy.getElementCount().divideCoefficientBy(2)); unsigned NewElts = NarrowTy.getNumElements(); SmallVector SplitSrc1Regs, SplitSrc2Regs; diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp --- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp @@ -850,7 +850,8 @@ isPowerOf2_32(SrcTy.getSizeInBits())); // Split input type. - LLT SplitSrcTy = SrcTy.changeNumElements(SrcTy.getNumElements() / 2); + LLT SplitSrcTy = + SrcTy.changeElementCount(SrcTy.getElementCount().divideCoefficientBy(2)); // First, split the source into two smaller vectors. SmallVector SplitSrcs; extractParts(SrcReg, MRI, MIRBuilder, SplitSrcTy, 2, SplitSrcs); diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp --- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp @@ -709,8 +709,9 @@ assert(MRI.getType(MI.getOperand(0).getReg()).getNumElements() == 2 && "Unexpected dest elements"); auto Undef = B.buildUndef(SrcTy); - DupSrc = B.buildConcatVectors(SrcTy.changeNumElements(4), - {Src1Reg, Undef.getReg(0)}) + DupSrc = B.buildConcatVectors( + SrcTy.changeElementCount(ElementCount::getFixed(4)), + {Src1Reg, Undef.getReg(0)}) .getReg(0); } B.buildInstr(MatchInfo.first, {MI.getOperand(0).getReg()}, {DupSrc, Lane}); diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp @@ -49,7 +49,7 @@ static LLT getPow2VectorType(LLT Ty) { unsigned NElts = Ty.getNumElements(); unsigned Pow2NElts = 1 << Log2_32_Ceil(NElts); - return Ty.changeNumElements(Pow2NElts); + return Ty.changeElementCount(ElementCount::getFixed(Pow2NElts)); } // Round the number of bits to the next power of two bits @@ -2445,7 +2445,8 @@ static LLT widenToNextPowerOf2(LLT Ty) { if (Ty.isVector()) - return Ty.changeNumElements(PowerOf2Ceil(Ty.getNumElements())); + return Ty.changeElementCount( + ElementCount::getFixed(PowerOf2Ceil(Ty.getNumElements()))); return LLT::scalar(PowerOf2Ceil(Ty.getSizeInBits())); } @@ -4439,7 +4440,8 @@ return false; const unsigned AdjustedNumElts = DMaskLanes == 0 ? 1 : DMaskLanes; - const LLT AdjustedTy = Ty.changeNumElements(AdjustedNumElts); + const LLT AdjustedTy = + Ty.changeElementCount(ElementCount::getFixed(AdjustedNumElts)); // The raw dword aligned data component of the load. The only legal cases // where this matters should be when using the packed D16 format, for diff --git a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp --- a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp +++ b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp @@ -179,17 +179,35 @@ const LLT V3S64 = LLT::fixed_vector(3, 64); // Vector to scalar - EXPECT_EQ(S64, V2S64.changeNumElements(1)); + EXPECT_EQ(S64, V2S64.changeElementCount(ElementCount::getFixed(1))); // Vector to vector - EXPECT_EQ(V3S64, V2S64.changeNumElements(3)); + EXPECT_EQ(V3S64, V2S64.changeElementCount(ElementCount::getFixed(3))); // Scalar to vector - EXPECT_EQ(V2S64, S64.changeNumElements(2)); + EXPECT_EQ(V2S64, S64.changeElementCount(ElementCount::getFixed(2))); - EXPECT_EQ(P0, V2P0.changeNumElements(1)); - EXPECT_EQ(V3P0, V2P0.changeNumElements(3)); - EXPECT_EQ(V2P0, P0.changeNumElements(2)); + EXPECT_EQ(P0, V2P0.changeElementCount(ElementCount::getFixed(1))); + EXPECT_EQ(V3P0, V2P0.changeElementCount(ElementCount::getFixed(3))); + EXPECT_EQ(V2P0, P0.changeElementCount(ElementCount::getFixed(2))); + + const LLT NXV2S64 = LLT::scalable_vector(2, 64); + const LLT NXV3S64 = LLT::scalable_vector(3, 64); + const LLT NXV2P0 = LLT::scalable_vector(2, P0); + + // Scalable vector to scalar + EXPECT_EQ(S64, NXV2S64.changeElementCount(ElementCount::getFixed(1))); + EXPECT_EQ(P0, NXV2P0.changeElementCount(ElementCount::getFixed(1))); + + // Fixed-width vector to scalable vector + EXPECT_EQ(NXV3S64, V2S64.changeElementCount(ElementCount::getScalable(3))); + + // Scalable vector to fixed-width vector + EXPECT_EQ(V3P0, NXV2P0.changeElementCount(ElementCount::getFixed(3))); + + // Scalar to scalable vector + EXPECT_EQ(NXV2S64, S64.changeElementCount(ElementCount::getScalable(2))); + EXPECT_EQ(NXV2P0, P0.changeElementCount(ElementCount::getScalable(2))); } #ifdef GTEST_HAS_DEATH_TEST