Index: llvm/include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -1149,9 +1149,7 @@ OpPropsBW); // For non-rotates (X != Y) we must add shift-by-zero handling costs. if (X != Y) { - Type *CondTy = Type::getInt1Ty(RetTy->getContext()); - if (RetVF > 1) - CondTy = VectorType::get(CondTy, RetVF); + Type *CondTy = RetTy->getWithNewBitWidth(1); Cost += ConcreteTTI->getCmpSelInstrCost(BinaryOperator::ICmp, RetTy, CondTy, nullptr); Cost += ConcreteTTI->getCmpSelInstrCost(BinaryOperator::Select, RetTy, @@ -1169,7 +1167,6 @@ unsigned getIntrinsicInstrCost( Intrinsic::ID IID, Type *RetTy, ArrayRef Tys, FastMathFlags FMF, unsigned ScalarizationCostPassed = std::numeric_limits::max()) { - unsigned RetVF = (RetTy->isVectorTy() ? RetTy->getVectorNumElements() : 1); auto *ConcreteTTI = static_cast(this); SmallVector ISDs; @@ -1326,9 +1323,7 @@ /*IsUnsigned=*/false); case Intrinsic::sadd_sat: case Intrinsic::ssub_sat: { - Type *CondTy = Type::getInt1Ty(RetTy->getContext()); - if (RetVF > 1) - CondTy = VectorType::get(CondTy, RetVF); + Type *CondTy = RetTy->getWithNewBitWidth(1); Type *OpTy = StructType::create({RetTy, CondTy}); Intrinsic::ID OverflowOp = IID == Intrinsic::sadd_sat @@ -1348,9 +1343,7 @@ } case Intrinsic::uadd_sat: case Intrinsic::usub_sat: { - Type *CondTy = Type::getInt1Ty(RetTy->getContext()); - if (RetVF > 1) - CondTy = VectorType::get(CondTy, RetVF); + Type *CondTy = RetTy->getWithNewBitWidth(1); Type *OpTy = StructType::create({RetTy, CondTy}); Intrinsic::ID OverflowOp = IID == Intrinsic::uadd_sat @@ -1367,9 +1360,7 @@ case Intrinsic::smul_fix: case Intrinsic::umul_fix: { unsigned ExtSize = RetTy->getScalarSizeInBits() * 2; - Type *ExtTy = Type::getIntNTy(RetTy->getContext(), ExtSize); - if (RetVF > 1) - ExtTy = VectorType::get(ExtTy, RetVF); + Type *ExtTy = RetTy->getWithNewBitWidth(ExtSize); unsigned ExtOp = IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt; @@ -1433,9 +1424,7 @@ Type *MulTy = RetTy->getContainedType(0); Type *OverflowTy = RetTy->getContainedType(1); unsigned ExtSize = MulTy->getScalarSizeInBits() * 2; - Type *ExtTy = Type::getIntNTy(RetTy->getContext(), ExtSize); - if (MulTy->isVectorTy()) - ExtTy = VectorType::get(ExtTy, MulTy->getVectorNumElements() ); + Type *ExtTy = MulTy->getWithNewBitWidth(ExtSize); unsigned ExtOp = IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt; Index: llvm/include/llvm/IR/DerivedTypes.h =================================================================== --- llvm/include/llvm/IR/DerivedTypes.h +++ llvm/include/llvm/IR/DerivedTypes.h @@ -571,6 +571,10 @@ return cast(this)->isScalable(); } +ElementCount Type::getVectorElementCount() const { + return cast(this)->getElementCount(); +} + /// Class to represent pointers. class PointerType : public Type { explicit PointerType(Type *ElType, unsigned AddrSpace); @@ -618,6 +622,16 @@ return cast(this)->getExtendedType(); } +Type *Type::getWithNewBitWidth(unsigned NewBitWidth) const { + assert( + isIntOrIntVectorTy() && + "Original type expected to be a vector of integers or a scalar integer."); + Type *NewType = getIntNTy(getContext(), NewBitWidth); + if (isVectorTy()) + NewType = VectorType::get(NewType, getVectorElementCount()); + return NewType; +} + unsigned Type::getPointerAddressSpace() const { return cast(getScalarType())->getAddressSpace(); } Index: llvm/include/llvm/IR/Type.h =================================================================== --- llvm/include/llvm/IR/Type.h +++ llvm/include/llvm/IR/Type.h @@ -372,6 +372,7 @@ inline bool getVectorIsScalable() const; inline unsigned getVectorNumElements() const; + inline ElementCount getVectorElementCount() const; Type *getVectorElementType() const { assert(getTypeID() == VectorTyID); return ContainedTys[0]; @@ -382,6 +383,10 @@ return ContainedTys[0]; } + /// Given an integer or vector type, change the bitwidth to NewBitwidth, + /// whilst keeping the old number of lanes. + inline Type *getWithNewBitWidth(unsigned NewBitWidth) const; + /// Given scalar/vector integer type, returns a type with elements twice as /// wide as in the original type. For vectors, preserves element count. inline Type *getExtendedType() const;