diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -3238,16 +3238,18 @@ return false; // Get the size of the types in bits, we'll need this later - unsigned SrcBitSize = SrcTy->getScalarSizeInBits(); - unsigned DstBitSize = DstTy->getScalarSizeInBits(); + TypeSize SrcBitSize = SrcTy->getScalarType()->getPrimitiveSizeInBits(); + TypeSize DstBitSize = DstTy->getScalarType()->getPrimitiveSizeInBits(); // If these are vector types, get the lengths of the vectors (using zero for // scalar types means that checking that vector lengths match also checks that // scalars are not being converted to vectors or vectors to scalars). - unsigned SrcLength = SrcTy->isVectorTy() ? - cast(SrcTy)->getNumElements() : 0; - unsigned DstLength = DstTy->isVectorTy() ? - cast(DstTy)->getNumElements() : 0; + ElementCount SrcLength = SrcTy->isVectorTy() + ? cast(SrcTy)->getElementCount() + : ElementCount(0, false); + ElementCount DstLength = DstTy->isVectorTy() + ? cast(DstTy)->getElementCount() + : ElementCount(0, false); // Switch on the opcode provided switch (op) { @@ -3279,14 +3281,14 @@ if (isa(SrcTy) != isa(DstTy)) return false; if (VectorType *VT = dyn_cast(SrcTy)) - if (VT->getNumElements() != cast(DstTy)->getNumElements()) + if (VT->getElementCount() != cast(DstTy)->getElementCount()) return false; return SrcTy->isPtrOrPtrVectorTy() && DstTy->isIntOrIntVectorTy(); case Instruction::IntToPtr: if (isa(SrcTy) != isa(DstTy)) return false; if (VectorType *VT = dyn_cast(SrcTy)) - if (VT->getNumElements() != cast(DstTy)->getNumElements()) + if (VT->getElementCount() != cast(DstTy)->getElementCount()) return false; return SrcTy->isIntOrIntVectorTy() && DstTy->isPtrOrPtrVectorTy(); case Instruction::BitCast: { @@ -3311,11 +3313,11 @@ VectorType *SrcVecTy = dyn_cast(SrcTy); VectorType *DstVecTy = dyn_cast(DstTy); if (SrcVecTy && DstVecTy) - return (SrcVecTy->getNumElements() == DstVecTy->getNumElements()); + return (SrcVecTy->getElementCount() == DstVecTy->getElementCount()); if (SrcVecTy) - return SrcVecTy->getNumElements() == 1; + return SrcVecTy->getElementCount() == ElementCount(1, false); if (DstVecTy) - return DstVecTy->getNumElements() == 1; + return DstVecTy->getElementCount() == ElementCount(1, false); return true; } @@ -3333,7 +3335,7 @@ if (VectorType *SrcVecTy = dyn_cast(SrcTy)) { if (VectorType *DstVecTy = dyn_cast(DstTy)) - return (SrcVecTy->getNumElements() == DstVecTy->getNumElements()); + return (SrcVecTy->getElementCount() == DstVecTy->getElementCount()); return false; } diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp --- a/llvm/unittests/IR/InstructionsTest.cpp +++ b/llvm/unittests/IR/InstructionsTest.cpp @@ -197,6 +197,12 @@ Type *V2Int32Ty = VectorType::get(Int32Ty, 2); Type *V2Int64Ty = VectorType::get(Int64Ty, 2); Type *V4Int16Ty = VectorType::get(Int16Ty, 4); + Type *V1Int16Ty = VectorType::get(Int16Ty, 1); + + Type *VScaleV2Int32Ty = VectorType::get(Int32Ty, 2, true); + Type *VScaleV2Int64Ty = VectorType::get(Int64Ty, 2, true); + Type *VScaleV4Int16Ty = VectorType::get(Int16Ty, 4, true); + Type *VScaleV1Int16Ty = VectorType::get(Int16Ty, 1, true); Type *Int32PtrTy = PointerType::get(Int32Ty, 0); Type *Int64PtrTy = PointerType::get(Int64Ty, 0); @@ -207,11 +213,15 @@ Type *V2Int32PtrAS1Ty = VectorType::get(Int32PtrAS1Ty, 2); Type *V2Int64PtrAS1Ty = VectorType::get(Int64PtrAS1Ty, 2); Type *V4Int32PtrAS1Ty = VectorType::get(Int32PtrAS1Ty, 4); + Type *VScaleV4Int32PtrAS1Ty = VectorType::get(Int32PtrAS1Ty, 4, true); Type *V4Int64PtrAS1Ty = VectorType::get(Int64PtrAS1Ty, 4); Type *V2Int64PtrTy = VectorType::get(Int64PtrTy, 2); Type *V2Int32PtrTy = VectorType::get(Int32PtrTy, 2); + Type *VScaleV2Int32PtrTy = VectorType::get(Int32PtrTy, 2, true); Type *V4Int32PtrTy = VectorType::get(Int32PtrTy, 4); + Type *VScaleV4Int32PtrTy = VectorType::get(Int32PtrTy, 4, true); + Type *VScaleV4Int64PtrTy = VectorType::get(Int64PtrTy, 4, true); const Constant* c8 = Constant::getNullValue(V8x8Ty); const Constant* c64 = Constant::getNullValue(V8x64Ty); @@ -286,6 +296,72 @@ Constant::getNullValue(V2Int32PtrTy), V4Int32PtrAS1Ty)); + // Address space cast of fixed/scalable vectors of pointers to scalable/fixed + // vector of pointers. + EXPECT_FALSE(CastInst::castIsValid( + Instruction::AddrSpaceCast, Constant::getNullValue(VScaleV4Int32PtrAS1Ty), + VScaleV2Int32PtrTy)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::AddrSpaceCast, + Constant::getNullValue(VScaleV2Int32PtrTy), + VScaleV4Int32PtrAS1Ty)); + // Address space cast of scalable vectors of pointers to scalable vector of + // pointers. + EXPECT_FALSE(CastInst::castIsValid( + Instruction::AddrSpaceCast, Constant::getNullValue(VScaleV4Int32PtrAS1Ty), + VScaleV2Int32PtrTy)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::AddrSpaceCast, + Constant::getNullValue(VScaleV2Int32PtrTy), + VScaleV4Int32PtrAS1Ty)); + // Same number of lanes, different address space. + EXPECT_TRUE(CastInst::castIsValid( + Instruction::AddrSpaceCast, Constant::getNullValue(VScaleV4Int32PtrAS1Ty), + VScaleV4Int32PtrTy)); + // Same number of lanes, same address space. + EXPECT_FALSE(CastInst::castIsValid(Instruction::AddrSpaceCast, + Constant::getNullValue(VScaleV4Int64PtrTy), + VScaleV4Int32PtrTy)); + + // Bit casting fixed/scalable vector to scalable/fixed vectors. + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(V2Int32Ty), + VScaleV2Int32Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(V2Int64Ty), + VScaleV2Int64Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(V4Int16Ty), + VScaleV4Int16Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV2Int32Ty), + V2Int32Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV2Int64Ty), + V2Int64Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV4Int16Ty), + V4Int16Ty)); + + // Bit casting scalable vectors to scalable vectors. + EXPECT_TRUE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV4Int16Ty), + VScaleV2Int32Ty)); + EXPECT_TRUE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV2Int32Ty), + VScaleV4Int16Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV2Int64Ty), + VScaleV2Int32Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV2Int32Ty), + VScaleV2Int64Ty)); + + // Bitcasting to/from + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV1Int16Ty), + V1Int16Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(V1Int16Ty), + VScaleV1Int16Ty)); // Check that assertion is not hit when creating a cast with a vector of // pointers