diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -1518,6 +1518,27 @@ return CV->getSplatValue(); if (const ConstantVector *CV = dyn_cast(this)) return CV->getSplatValue(AllowUndefs); + + // Check if this is a constant expression splat of the form returned by + // ConstantVector::getSplat() + auto *Shuf = dyn_cast(this); + if (Shuf && Shuf->getOpcode() == Instruction::ShuffleVector && + isa(Shuf->getOperand(1))) { + + auto *IElt = dyn_cast(Shuf->getOperand(0)); + if (IElt && IElt->getOpcode() == Instruction::InsertElement && + isa(IElt->getOperand(0))) { + + ArrayRef Mask = Shuf->getShuffleMask(); + Constant *SplatVal = IElt->getOperand(1); + ConstantInt *Index = cast(IElt->getOperand(2)); + + if (Index->getValue() == 0 && + std::all_of(Mask.begin(), Mask.end(), [](int I) { return I == 0; })) + return SplatVal; + } + } + return nullptr; } @@ -1551,9 +1572,8 @@ const APInt &Constant::getUniqueInteger() const { if (const ConstantInt *CI = dyn_cast(this)) return CI->getValue(); - assert(this->getSplatValue() && "Doesn't contain a unique integer!"); - const Constant *C = this->getAggregateElement(0U); - assert(C && isa(C) && "Not a vector of numbers!"); + Constant *C = this->getSplatValue(); + assert(C && isa(C) && "Doesn't contain a unique integer!"); return cast(C)->getValue(); } diff --git a/llvm/lib/IR/DataLayout.cpp b/llvm/lib/IR/DataLayout.cpp --- a/llvm/lib/IR/DataLayout.cpp +++ b/llvm/lib/IR/DataLayout.cpp @@ -814,7 +814,7 @@ unsigned NumBits = getIndexTypeSizeInBits(Ty); IntegerType *IntTy = IntegerType::get(Ty->getContext(), NumBits); if (VectorType *VecTy = dyn_cast(Ty)) - return FixedVectorType::get(IntTy, VecTy->getNumElements()); + return VectorType::get(IntTy, VecTy); return IntTy; } diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -3220,7 +3220,7 @@ // FIXME: Check address space sizes here if (VectorType *SrcVecTy = dyn_cast(SrcTy)) if (VectorType *DestVecTy = dyn_cast(DestTy)) - if (SrcVecTy->getNumElements() == DestVecTy->getNumElements()) { + if (SrcVecTy->getElementCount() == DestVecTy->getElementCount()) { // An element by element cast. Find the appropriate opcode based on the // element types. SrcTy = SrcVecTy->getElementType(); diff --git a/llvm/test/Transforms/InstSimplify/gep.ll b/llvm/test/Transforms/InstSimplify/gep.ll --- a/llvm/test/Transforms/InstSimplify/gep.ll +++ b/llvm/test/Transforms/InstSimplify/gep.ll @@ -168,4 +168,12 @@ ret %gep } +define @scalable_vector_idx_mix_scalar_vector() { +; CHECK-LABEL: @scalable_vector_idx_mix_scalar_vector( +; CHECK-NEXT: ret getelementptr (%struct, zeroinitializer, zeroinitializer, shufflevector ( insertelement ( undef, i32 1, i32 0), undef, zeroinitializer)) +; + %gep = getelementptr %struct, zeroinitializer, i32 0, shufflevector ( insertelement ( undef, i32 1, i32 0), undef, zeroinitializer) + ret %gep +} + ; Check ConstantExpr::getGetElementPtr() using ElementCount for size queries - end.