diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -4494,8 +4494,8 @@ } Value *CodeGenFunction::EmitNeonSplat(Value *V, Constant *C) { - unsigned nElts = V->getType()->getVectorNumElements(); - Value* SV = llvm::ConstantVector::getSplat(nElts, C); + auto EC = cast(V->getType())->getElementCount(); + Value *SV = llvm::ConstantVector::getSplat(EC, C); return Builder.CreateShuffleVector(V, V, SV, "lane"); } @@ -8633,7 +8633,7 @@ llvm::VectorType::get(VTy->getElementType(), VTy->getNumElements() / 2) : VTy; llvm::Constant *cst = cast(Ops[3]); - Value *SV = llvm::ConstantVector::getSplat(VTy->getNumElements(), cst); + Value *SV = llvm::ConstantVector::getSplat(VTy->getElementCount(), cst); Ops[1] = Builder.CreateBitCast(Ops[1], SourceTy); Ops[1] = Builder.CreateShuffleVector(Ops[1], Ops[1], SV, "lane"); @@ -8662,7 +8662,7 @@ llvm::Type *STy = llvm::VectorType::get(VTy->getElementType(), VTy->getNumElements() * 2); Ops[2] = Builder.CreateBitCast(Ops[2], STy); - Value* SV = llvm::ConstantVector::getSplat(VTy->getNumElements(), + Value *SV = llvm::ConstantVector::getSplat(VTy->getElementCount(), cast(Ops[3])); Ops[2] = Builder.CreateShuffleVector(Ops[2], Ops[2], SV, "lane"); diff --git a/llvm/include/llvm/Analysis/Utils/Local.h b/llvm/include/llvm/Analysis/Utils/Local.h --- a/llvm/include/llvm/Analysis/Utils/Local.h +++ b/llvm/include/llvm/Analysis/Utils/Local.h @@ -63,7 +63,8 @@ // Splat the constant if needed. if (IntIdxTy->isVectorTy() && !OpC->getType()->isVectorTy()) - OpC = ConstantVector::getSplat(IntIdxTy->getVectorNumElements(), OpC); + OpC = ConstantVector::getSplat( + cast(IntIdxTy)->getElementCount(), OpC); Constant *Scale = ConstantInt::get(IntIdxTy, Size); Constant *OC = ConstantExpr::getIntegerCast(OpC, IntIdxTy, true /*SExt*/); diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -518,7 +518,7 @@ public: /// Return a ConstantVector with the specified constant in each element. - static Constant *getSplat(unsigned NumElts, Constant *Elt); + static Constant *getSplat(ElementCount EC, Constant *Elt); /// Specialize the getType() method to always return a VectorType, /// which reduces the amount of casting needed in parts of the compiler. diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -707,9 +707,8 @@ Offset = Offset.sextOrTrunc(IntIdxTy->getIntegerBitWidth()); Constant *OffsetIntPtr = ConstantInt::get(IntIdxTy, Offset); - if (V->getType()->isVectorTy()) - return ConstantVector::getSplat(V->getType()->getVectorNumElements(), - OffsetIntPtr); + if (VectorType *VecTy = dyn_cast(V->getType())) + return ConstantVector::getSplat(VecTy->getElementCount(), OffsetIntPtr); return OffsetIntPtr; } diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -6539,13 +6539,13 @@ UseSplat = true; } - unsigned End = getTransitionType()->getVectorNumElements(); + auto EC = cast(getTransitionType())->getElementCount(); if (UseSplat) - return ConstantVector::getSplat(End, Val); + return ConstantVector::getSplat(EC, Val); SmallVector ConstVec; UndefValue *UndefVal = UndefValue::get(Val->getType()); - for (unsigned Idx = 0; Idx != End; ++Idx) { + for (unsigned Idx = 0; Idx != EC.Min; ++Idx) { if (Idx == ExtractIdx) ConstVec.push_back(Val); else diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -2226,7 +2226,7 @@ if (Idxs.size() == 1 && (Idx0->isNullValue() || isa(Idx0))) return GEPTy->isVectorTy() && !C->getType()->isVectorTy() ? ConstantVector::getSplat( - cast(GEPTy)->getNumElements(), C) + cast(GEPTy)->getElementCount(), C) : C; if (C->isNullValue()) { diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -370,7 +370,7 @@ // Broadcast a scalar to a vector, if necessary. if (VectorType *VTy = dyn_cast(Ty)) - C = ConstantVector::getSplat(VTy->getNumElements(), C); + C = ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -387,7 +387,7 @@ } VectorType *VTy = cast(Ty); - return ConstantVector::getSplat(VTy->getNumElements(), + return ConstantVector::getSplat(VTy->getElementCount(), getAllOnesValue(VTy->getElementType())); } @@ -681,7 +681,7 @@ assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1."); ConstantInt *TrueC = ConstantInt::getTrue(Ty->getContext()); if (auto *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), TrueC); + return ConstantVector::getSplat(VTy->getElementCount(), TrueC); return TrueC; } @@ -689,7 +689,7 @@ assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1."); ConstantInt *FalseC = ConstantInt::getFalse(Ty->getContext()); if (auto *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), FalseC); + return ConstantVector::getSplat(VTy->getElementCount(), FalseC); return FalseC; } @@ -712,7 +712,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -736,7 +736,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -781,7 +781,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -793,7 +793,7 @@ // For vectors, broadcast the value. if (auto *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -806,7 +806,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -817,7 +817,7 @@ Constant *C = get(Ty->getContext(), NaN); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -828,8 +828,8 @@ Constant *C = get(Ty->getContext(), NaN); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); - + return ConstantVector::getSplat(VTy->getElementCount(), C); + return C; } @@ -839,8 +839,8 @@ Constant *C = get(Ty->getContext(), NaN); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); - + return ConstantVector::getSplat(VTy->getElementCount(), C); + return C; } @@ -850,7 +850,7 @@ Constant *C = get(Ty->getContext(), NegZero); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -898,7 +898,7 @@ Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative)); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -1204,15 +1204,35 @@ return nullptr; } -Constant *ConstantVector::getSplat(unsigned NumElts, Constant *V) { - // If this splat is compatible with ConstantDataVector, use it instead of - // ConstantVector. - if ((isa(V) || isa(V)) && - ConstantDataSequential::isElementTypeCompatible(V->getType())) - return ConstantDataVector::getSplat(NumElts, V); +Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) { + if (!EC.Scalable) { + // If this splat is compatible with ConstantDataVector, use it instead of + // ConstantVector. + if ((isa(V) || isa(V)) && + ConstantDataSequential::isElementTypeCompatible(V->getType())) + return ConstantDataVector::getSplat(EC.Min, V); + + SmallVector Elts(EC.Min, V); + return get(Elts); + } + + Type *VTy = VectorType::get(V->getType(), EC); + + if (V->isNullValue()) + return ConstantAggregateZero::get(VTy); + else if (isa(V)) + return UndefValue::get(VTy); + + Type *I32Ty = Type::getInt32Ty(VTy->getContext()); - SmallVector Elts(NumElts, V); - return get(Elts); + // Move scalar into vector. + Constant *UndefV = UndefValue::get(VTy); + V = ConstantExpr::getInsertElement(UndefV, V, ConstantInt::get(I32Ty, 0)); + // Build shuffle mask to perform the splat. + Type *MaskTy = VectorType::get(I32Ty, EC); + Constant *Zeros = ConstantAggregateZero::get(MaskTy); + // Splat. + return ConstantExpr::getShuffleVector(V, UndefV, Zeros); } ConstantTokenNone *ConstantTokenNone::get(LLVMContext &Context) { @@ -2098,15 +2118,15 @@ unsigned AS = C->getType()->getPointerAddressSpace(); Type *ReqTy = DestTy->getPointerTo(AS); - unsigned NumVecElts = 0; - if (C->getType()->isVectorTy()) - NumVecElts = C->getType()->getVectorNumElements(); + ElementCount EltCount = {0, false}; + if (VectorType *VecTy = dyn_cast(C->getType())) + EltCount = VecTy->getElementCount(); else for (auto Idx : Idxs) - if (Idx->getType()->isVectorTy()) - NumVecElts = Idx->getType()->getVectorNumElements(); + if (VectorType *VecTy = dyn_cast(Idx->getType())) + EltCount = VecTy->getElementCount(); - if (NumVecElts) - ReqTy = VectorType::get(ReqTy, NumVecElts); + if (EltCount.Min != 0) + ReqTy = VectorType::get(ReqTy, EltCount); if (OnlyIfReducedTy == ReqTy) return nullptr; @@ -2116,13 +2136,14 @@ ArgVec.reserve(1 + Idxs.size()); ArgVec.push_back(C); for (unsigned i = 0, e = Idxs.size(); i != e; ++i) { - assert((!Idxs[i]->getType()->isVectorTy() || - Idxs[i]->getType()->getVectorNumElements() == NumVecElts) && - "getelementptr index type missmatch"); + assert( + (!Idxs[i]->getType()->isVectorTy() || + cast(Idxs[i]->getType())->getElementCount() == EltCount) && + "getelementptr index type missmatch"); Constant *Idx = cast(Idxs[i]); - if (NumVecElts && !Idxs[i]->getType()->isVectorTy()) - Idx = ConstantVector::getSplat(NumVecElts, Idx); + if (EltCount.Min != 0 && !Idxs[i]->getType()->isVectorTy()) + Idx = ConstantVector::getSplat(EltCount, Idx); ArgVec.push_back(Idx); } @@ -2759,7 +2780,7 @@ return getFP(V->getContext(), Elts); } } - return ConstantVector::getSplat(NumElts, V); + return ConstantVector::getSplat({NumElts, false}, V); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5365,8 +5365,10 @@ if (ScalarC && ScalarM) { // We allow undefs in matching, but this transform removes those for safety. // Demanded elements analysis should be able to recover some/all of that. - C = ConstantVector::getSplat(V1Ty->getVectorNumElements(), ScalarC); - M = ConstantVector::getSplat(M->getType()->getVectorNumElements(), ScalarM); + C = ConstantVector::getSplat(cast(V1Ty)->getElementCount(), + ScalarC); + M = ConstantVector::getSplat( + cast(M->getType())->getElementCount(), ScalarM); Value *NewCmp = IsFP ? Builder.CreateFCmp(Pred, V1, C) : Builder.CreateICmp(Pred, V1, C); return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), M); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -761,7 +761,7 @@ APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); if (VectorType *VT = dyn_cast(X->getType())) - Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); + Mask = ConstantVector::getSplat(VT->getElementCount(), Mask); return BinaryOperator::CreateAnd(X, Mask); } @@ -796,7 +796,7 @@ APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); if (VectorType *VT = dyn_cast(X->getType())) - Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); + Mask = ConstantVector::getSplat(VT->getElementCount(), Mask); return BinaryOperator::CreateAnd(X, Mask); } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1726,9 +1726,10 @@ // FIXME: If the step is non-constant, we create the vector splat with // IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't // handle a constant vector splat. - Value *SplatVF = isa(Mul) - ? ConstantVector::getSplat(VF, cast(Mul)) - : Builder.CreateVectorSplat(VF, Mul); + Value *SplatVF = + isa(Mul) + ? ConstantVector::getSplat({VF, false}, cast(Mul)) + : Builder.CreateVectorSplat(VF, Mul); Builder.restoreIP(CurrIP); // We may need to add the step a number of times, depending on the unroll @@ -3738,7 +3739,7 @@ // incoming scalar reduction. VectorStart = ReductionStartValue; } else { - Identity = ConstantVector::getSplat(VF, Iden); + Identity = ConstantVector::getSplat({VF, false}, Iden); // This vector is the Identity vector where the first element is the // incoming scalar reduction. diff --git a/llvm/unittests/FuzzMutate/OperationsTest.cpp b/llvm/unittests/FuzzMutate/OperationsTest.cpp --- a/llvm/unittests/FuzzMutate/OperationsTest.cpp +++ b/llvm/unittests/FuzzMutate/OperationsTest.cpp @@ -92,8 +92,8 @@ ConstantStruct::get(StructType::create(Ctx, "OpaqueStruct")); Constant *a = ConstantArray::get(ArrayType::get(i32->getType(), 2), {i32, i32}); - Constant *v8i8 = ConstantVector::getSplat(8, i8); - Constant *v4f16 = ConstantVector::getSplat(4, f16); + Constant *v8i8 = ConstantVector::getSplat({8, false}, i8); + Constant *v4f16 = ConstantVector::getSplat({4, false}, f16); Constant *p0i32 = ConstantPointerNull::get(PointerType::get(i32->getType(), 0)); diff --git a/llvm/unittests/IR/VerifierTest.cpp b/llvm/unittests/IR/VerifierTest.cpp --- a/llvm/unittests/IR/VerifierTest.cpp +++ b/llvm/unittests/IR/VerifierTest.cpp @@ -57,7 +57,7 @@ ConstantInt *CI = ConstantInt::get(ITy, 0); // Valid type : freeze(<2 x i32>) - Constant *CV = ConstantVector::getSplat(2, CI); + Constant *CV = ConstantVector::getSplat({2, false}, CI); FreezeInst *FI_vec = new FreezeInst(CV); FI_vec->insertBefore(RI);