diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -3237,57 +3237,58 @@ SrcTy->isAggregateType() || DstTy->isAggregateType()) return false; - // Get the size of the types in bits, we'll need this later - unsigned SrcBitSize = SrcTy->getScalarSizeInBits(); - unsigned DstBitSize = DstTy->getScalarSizeInBits(); + // Get the size of the types in bits, and whether we are dealing + // with vector types, we'll need this later. + bool SrcIsVec = isa(SrcTy); + bool DstIsVec = isa(DstTy); + unsigned SrcScalarBitSize = SrcTy->getScalarSizeInBits(); + unsigned DstScalarBitSize = DstTy->getScalarSizeInBits(); // If these are vector types, get the lengths of the vectors (using zero for // scalar types means that checking that vector lengths match also checks that // scalars are not being converted to vectors or vectors to scalars). - unsigned SrcLength = SrcTy->isVectorTy() ? - cast(SrcTy)->getNumElements() : 0; - unsigned DstLength = DstTy->isVectorTy() ? - cast(DstTy)->getNumElements() : 0; + ElementCount SrcEC = SrcIsVec ? cast(SrcTy)->getElementCount() + : ElementCount(0, false); + ElementCount DstEC = DstIsVec ? cast(DstTy)->getElementCount() + : ElementCount(0, false); // Switch on the opcode provided switch (op) { default: return false; // This is an input error case Instruction::Trunc: return SrcTy->isIntOrIntVectorTy() && DstTy->isIntOrIntVectorTy() && - SrcLength == DstLength && SrcBitSize > DstBitSize; + SrcEC == DstEC && SrcScalarBitSize > DstScalarBitSize; case Instruction::ZExt: return SrcTy->isIntOrIntVectorTy() && DstTy->isIntOrIntVectorTy() && - SrcLength == DstLength && SrcBitSize < DstBitSize; + SrcEC == DstEC && SrcScalarBitSize < DstScalarBitSize; case Instruction::SExt: return SrcTy->isIntOrIntVectorTy() && DstTy->isIntOrIntVectorTy() && - SrcLength == DstLength && SrcBitSize < DstBitSize; + SrcEC == DstEC && SrcScalarBitSize < DstScalarBitSize; case Instruction::FPTrunc: return SrcTy->isFPOrFPVectorTy() && DstTy->isFPOrFPVectorTy() && - SrcLength == DstLength && SrcBitSize > DstBitSize; + SrcEC == DstEC && SrcScalarBitSize > DstScalarBitSize; case Instruction::FPExt: return SrcTy->isFPOrFPVectorTy() && DstTy->isFPOrFPVectorTy() && - SrcLength == DstLength && SrcBitSize < DstBitSize; + SrcEC == DstEC && SrcScalarBitSize < DstScalarBitSize; case Instruction::UIToFP: case Instruction::SIToFP: return SrcTy->isIntOrIntVectorTy() && DstTy->isFPOrFPVectorTy() && - SrcLength == DstLength; + SrcEC == DstEC; case Instruction::FPToUI: case Instruction::FPToSI: return SrcTy->isFPOrFPVectorTy() && DstTy->isIntOrIntVectorTy() && - SrcLength == DstLength; + SrcEC == DstEC; case Instruction::PtrToInt: - if (isa(SrcTy) != isa(DstTy)) + if (SrcIsVec != DstIsVec) + return false; + if (SrcEC != DstEC) return false; - if (VectorType *VT = dyn_cast(SrcTy)) - if (VT->getNumElements() != cast(DstTy)->getNumElements()) - return false; return SrcTy->isPtrOrPtrVectorTy() && DstTy->isIntOrIntVectorTy(); case Instruction::IntToPtr: - if (isa(SrcTy) != isa(DstTy)) + if (SrcIsVec != DstIsVec) + return false; + if (SrcEC != DstEC) return false; - if (VectorType *VT = dyn_cast(SrcTy)) - if (VT->getNumElements() != cast(DstTy)->getNumElements()) - return false; return SrcTy->isIntOrIntVectorTy() && DstTy->isPtrOrPtrVectorTy(); case Instruction::BitCast: { PointerType *SrcPtrTy = dyn_cast(SrcTy->getScalarType()); @@ -3308,14 +3309,12 @@ return false; // A vector of pointers must have the same number of elements. - VectorType *SrcVecTy = dyn_cast(SrcTy); - VectorType *DstVecTy = dyn_cast(DstTy); - if (SrcVecTy && DstVecTy) - return (SrcVecTy->getNumElements() == DstVecTy->getNumElements()); - if (SrcVecTy) - return SrcVecTy->getNumElements() == 1; - if (DstVecTy) - return DstVecTy->getNumElements() == 1; + if (SrcIsVec && DstIsVec) + return SrcEC == DstEC; + if (SrcIsVec) + return SrcEC == ElementCount(1, false); + if (DstIsVec) + return DstEC == ElementCount(1, false); return true; } @@ -3331,9 +3330,9 @@ if (SrcPtrTy->getAddressSpace() == DstPtrTy->getAddressSpace()) return false; - if (VectorType *SrcVecTy = dyn_cast(SrcTy)) { - if (VectorType *DstVecTy = dyn_cast(DstTy)) - return (SrcVecTy->getNumElements() == DstVecTy->getNumElements()); + if (SrcIsVec) { + if (DstIsVec) + return SrcEC == DstEC; return false; } diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp --- a/llvm/unittests/IR/InstructionsTest.cpp +++ b/llvm/unittests/IR/InstructionsTest.cpp @@ -197,6 +197,12 @@ Type *V2Int32Ty = VectorType::get(Int32Ty, 2); Type *V2Int64Ty = VectorType::get(Int64Ty, 2); Type *V4Int16Ty = VectorType::get(Int16Ty, 4); + Type *V1Int16Ty = VectorType::get(Int16Ty, 1); + + Type *VScaleV2Int32Ty = VectorType::get(Int32Ty, 2, true); + Type *VScaleV2Int64Ty = VectorType::get(Int64Ty, 2, true); + Type *VScaleV4Int16Ty = VectorType::get(Int16Ty, 4, true); + Type *VScaleV1Int16Ty = VectorType::get(Int16Ty, 1, true); Type *Int32PtrTy = PointerType::get(Int32Ty, 0); Type *Int64PtrTy = PointerType::get(Int64Ty, 0); @@ -207,11 +213,15 @@ Type *V2Int32PtrAS1Ty = VectorType::get(Int32PtrAS1Ty, 2); Type *V2Int64PtrAS1Ty = VectorType::get(Int64PtrAS1Ty, 2); Type *V4Int32PtrAS1Ty = VectorType::get(Int32PtrAS1Ty, 4); + Type *VScaleV4Int32PtrAS1Ty = VectorType::get(Int32PtrAS1Ty, 4, true); Type *V4Int64PtrAS1Ty = VectorType::get(Int64PtrAS1Ty, 4); Type *V2Int64PtrTy = VectorType::get(Int64PtrTy, 2); Type *V2Int32PtrTy = VectorType::get(Int32PtrTy, 2); + Type *VScaleV2Int32PtrTy = VectorType::get(Int32PtrTy, 2, true); Type *V4Int32PtrTy = VectorType::get(Int32PtrTy, 4); + Type *VScaleV4Int32PtrTy = VectorType::get(Int32PtrTy, 4, true); + Type *VScaleV4Int64PtrTy = VectorType::get(Int64PtrTy, 4, true); const Constant* c8 = Constant::getNullValue(V8x8Ty); const Constant* c64 = Constant::getNullValue(V8x64Ty); @@ -286,6 +296,75 @@ Constant::getNullValue(V2Int32PtrTy), V4Int32PtrAS1Ty)); + // Address space cast of fixed/scalable vectors of pointers to scalable/fixed + // vector of pointers. + EXPECT_FALSE(CastInst::castIsValid( + Instruction::AddrSpaceCast, Constant::getNullValue(VScaleV4Int32PtrAS1Ty), + V4Int32PtrTy)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::AddrSpaceCast, + Constant::getNullValue(V4Int32PtrTy), + VScaleV4Int32PtrAS1Ty)); + // Address space cast of scalable vectors of pointers to scalable vector of + // pointers. + EXPECT_FALSE(CastInst::castIsValid( + Instruction::AddrSpaceCast, Constant::getNullValue(VScaleV4Int32PtrAS1Ty), + VScaleV2Int32PtrTy)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::AddrSpaceCast, + Constant::getNullValue(VScaleV2Int32PtrTy), + VScaleV4Int32PtrAS1Ty)); + EXPECT_TRUE(CastInst::castIsValid(Instruction::AddrSpaceCast, + Constant::getNullValue(VScaleV4Int64PtrTy), + VScaleV4Int32PtrAS1Ty)); + // Same number of lanes, different address space. + EXPECT_TRUE(CastInst::castIsValid( + Instruction::AddrSpaceCast, Constant::getNullValue(VScaleV4Int32PtrAS1Ty), + VScaleV4Int32PtrTy)); + // Same number of lanes, same address space. + EXPECT_FALSE(CastInst::castIsValid(Instruction::AddrSpaceCast, + Constant::getNullValue(VScaleV4Int64PtrTy), + VScaleV4Int32PtrTy)); + + // Bit casting fixed/scalable vector to scalable/fixed vectors. + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(V2Int32Ty), + VScaleV2Int32Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(V2Int64Ty), + VScaleV2Int64Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(V4Int16Ty), + VScaleV4Int16Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV2Int32Ty), + V2Int32Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV2Int64Ty), + V2Int64Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV4Int16Ty), + V4Int16Ty)); + + // Bit casting scalable vectors to scalable vectors. + EXPECT_TRUE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV4Int16Ty), + VScaleV2Int32Ty)); + EXPECT_TRUE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV2Int32Ty), + VScaleV4Int16Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV2Int64Ty), + VScaleV2Int32Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV2Int32Ty), + VScaleV2Int64Ty)); + + // Bitcasting to/from + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(VScaleV1Int16Ty), + V1Int16Ty)); + EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast, + Constant::getNullValue(V1Int16Ty), + VScaleV1Int16Ty)); // Check that assertion is not hit when creating a cast with a vector of // pointers