diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -208,6 +208,18 @@ return isa(V) && !isa(V) && !isa(V); } +/// \returns scalar type of the Value. Typically VL[0] is fed into this function. +static Type *getScalarType(const Value *V) { + Type *ScalarTy = V->getType(); + if (const StoreInst *SI = dyn_cast(V)) + ScalarTy = SI->getValueOperand()->getType(); + else if (const CmpInst *CI = dyn_cast(V)) + ScalarTy = CI->getOperand(0)->getType(); + else if (const InsertElementInst *IE = dyn_cast(V)) + ScalarTy = IE->getOperand(1)->getType(); + return ScalarTy; +} + /// Checks if \p V is one of vector-like instructions, i.e. undef, /// insertelement/extractelement with constant indices for fixed vector type or /// extractvalue instruction. @@ -3491,7 +3503,7 @@ // treats loading/storing it as an i8 struct. If we vectorize loads/stores // from such a struct, we read/write packed bits disagreeing with the // unvectorized version. - Type *ScalarTy = VL0->getType(); + Type *ScalarTy = getScalarType(VL0); if (DL.getTypeSizeInBits(ScalarTy) != DL.getTypeAllocSizeInBits(ScalarTy)) return LoadsState::Gather; @@ -3609,7 +3621,7 @@ Optional BoUpSLP::findPartiallyOrderedLoads(const BoUpSLP::TreeEntry &TE) { assert(TE.State == TreeEntry::NeedToGather && "Expected gather node only."); - Type *ScalarTy = TE.Scalars[0]->getType(); + Type *ScalarTy = getScalarType(TE.Scalars[0]); SmallVector Ptrs; Ptrs.reserve(TE.Scalars.size()); @@ -4906,7 +4918,7 @@ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); #ifndef NDEBUG - Type *ScalarTy = VL0->getType(); + Type *ScalarTy = getScalarType(VL0); if (DL->getTypeSizeInBits(ScalarTy) != DL->getTypeAllocSizeInBits(ScalarTy)) LLVM_DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); @@ -5140,7 +5152,7 @@ } case Instruction::Store: { // Check if the stores are consecutive or if we need to swizzle them. - llvm::Type *ScalarTy = cast(VL0)->getValueOperand()->getType(); + llvm::Type *ScalarTy = getScalarType(VL0); // Avoid types that are padded when being allocated as scalars, while // being packed together in a vector (such as i1). if (DL->getTypeSizeInBits(ScalarTy) != @@ -5660,13 +5672,7 @@ ArrayRef VectorizedVals) { ArrayRef VL = E->Scalars; - Type *ScalarTy = VL[0]->getType(); - if (StoreInst *SI = dyn_cast(VL[0])) - ScalarTy = SI->getValueOperand()->getType(); - else if (CmpInst *CI = dyn_cast(VL[0])) - ScalarTy = CI->getOperand(0)->getType(); - else if (auto *IE = dyn_cast(VL[0])) - ScalarTy = IE->getOperand(1)->getType(); + Type *ScalarTy = getScalarType(VL[0]); auto *VecTy = FixedVectorType::get(ScalarTy, VL.size()); TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; @@ -7187,9 +7193,7 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef VL) const { // Find the type of the operands in VL. - Type *ScalarTy = VL[0]->getType(); - if (StoreInst *SI = dyn_cast(VL[0])) - ScalarTy = SI->getValueOperand()->getType(); + Type *ScalarTy = getScalarType(VL[0]); auto *VecTy = FixedVectorType::get(ScalarTy, VL.size()); bool DuplicateNonConst = false; // Find the cost of inserting/extracting values from the vector. @@ -7622,11 +7626,7 @@ unsigned ShuffleOrOp = E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode(); Instruction *VL0 = E->getMainOp(); - Type *ScalarTy = VL0->getType(); - if (auto *Store = dyn_cast(VL0)) - ScalarTy = Store->getValueOperand()->getType(); - else if (auto *IE = dyn_cast(VL0)) - ScalarTy = IE->getOperand(1)->getType(); + Type *ScalarTy = getScalarType(VL0); auto *VecTy = FixedVectorType::get(ScalarTy, E->Scalars.size()); switch (ShuffleOrOp) { case Instruction::PHI: { @@ -10068,9 +10068,7 @@ bool Changed = false; bool CandidateFound = false; InstructionCost MinCost = SLPCostThreshold.getValue(); - Type *ScalarTy = VL[0]->getType(); - if (auto *IE = dyn_cast(VL[0])) - ScalarTy = IE->getOperand(1)->getType(); + Type *ScalarTy = getScalarType(VL[0]); unsigned NextInst = 0, MaxInst = VL.size(); for (unsigned VF = MaxVF; NextInst + 1 < MaxInst && VF >= MinVF; VF /= 2) { @@ -11144,7 +11142,7 @@ unsigned ReduxWidth, FastMathFlags FMF) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; Value *FirstReducedVal = ReducedVals.front(); - Type *ScalarTy = FirstReducedVal->getType(); + Type *ScalarTy = getScalarType(FirstReducedVal); FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth); InstructionCost VectorCost = 0, ScalarCost; // If all of the reduced values are constant, the vector cost is 0, since