diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -78,9 +78,26 @@ using BaseT = TargetTransformInfoImplCRTPBase; using TTI = TargetTransformInfo; + /// Theoretically, scalable vectors should be supported for many (if not all) + /// of these functions that take VectorType parameters. For instance, we + /// can do a shufflevector on a scalable vector, so we should also be able + /// to reason about the cost of this operation. Since this is the case, we + /// have decided to expose a base VectorType interface. + /// + /// This function automates the business of writing "FIXME: support scalable + /// vectors" everywhere, asserting, and casting to FixedVectorType in order + /// to be able to call getNumElements(). + FixedVectorType *FIXME_ScalableVectorNotSupported(VectorType *VTy) { + assert(isa(VTy) && + "Scalable vectors not yet supported in this context"); + return cast(VTy); + } + /// Estimate a cost of Broadcast as an extract and sequence of insert /// operations. - unsigned getBroadcastShuffleOverhead(VectorType *VTy) { + unsigned getBroadcastShuffleOverhead(VectorType *InVTy) { + auto *VTy = cast(InVTy); + unsigned Cost = 0; // Broadcast cost is equal to the cost of extracting the zero'th element // plus the cost of inserting it into every element of the result vector. @@ -96,7 +113,9 @@ /// Estimate a cost of shuffle as a sequence of extract and insert /// operations. - unsigned getPermuteShuffleOverhead(VectorType *VTy) { + unsigned getPermuteShuffleOverhead(VectorType *InVTy) { + auto *VTy = cast(InVTy); + unsigned Cost = 0; // Shuffle cost is equal to the cost of extracting element from its argument // plus the cost of inserting them onto the result vector. @@ -116,8 +135,11 @@ /// Estimate a cost of subvector extraction as a sequence of extract and /// insert operations. - unsigned getExtractSubvectorOverhead(VectorType *VTy, int Index, - VectorType *SubVTy) { + unsigned getExtractSubvectorOverhead(VectorType *InVTy, int Index, + VectorType *InSubVTy) { + auto *VTy = cast(InVTy); + auto *SubVTy = cast(InSubVTy); + assert(VTy && SubVTy && "Can only extract subvectors from vectors"); int NumSubElts = SubVTy->getNumElements(); @@ -139,8 +161,11 @@ /// Estimate a cost of subvector insertion as a sequence of extract and /// insert operations. - unsigned getInsertSubvectorOverhead(VectorType *VTy, int Index, - VectorType *SubVTy) { + unsigned getInsertSubvectorOverhead(VectorType *InVTy, int Index, + VectorType *InSubVTy) { + auto *VTy = cast(InVTy); + auto *SubVTy = cast(InSubVTy); + assert(VTy && SubVTy && "Can only insert subvectors into vectors"); int NumSubElts = SubVTy->getNumElements(); @@ -525,8 +550,12 @@ /// Estimate the overhead of scalarizing an instruction. Insert and Extract /// are set if the demanded result elements need to be inserted and/or /// extracted from vectors. - unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *InTy, const APInt &DemandedElts, bool Insert, bool Extract) { + /// FIXME: a bitfield is not a reasonable abstraction for talking about + /// which elements are needed from a scalable vector + auto *Ty = cast(InTy); + assert(DemandedElts.getBitWidth() == Ty->getNumElements() && "Vector size mismatch"); @@ -547,7 +576,10 @@ } /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead. - unsigned getScalarizationOverhead(VectorType *Ty, bool Insert, bool Extract) { + unsigned getScalarizationOverhead(VectorType *InTy, bool Insert, + bool Extract) { + auto *Ty = cast(InTy); + APInt DemandedElts = APInt::getAllOnesValue(Ty->getNumElements()); return static_cast(this)->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract); @@ -565,11 +597,12 @@ auto *VecTy = dyn_cast(A->getType()); if (VecTy) { // If A is a vector operand, VF should be 1 or correspond to A. - assert((VF == 1 || VF == VecTy->getNumElements()) && + assert((VF == 1 || + VF == cast(VecTy)->getNumElements()) && "Vector argument does not match VF"); } else - VecTy = VectorType::get(A->getType(), VF); + VecTy = FixedVectorType::get(A->getType(), VF); Cost += getScalarizationOverhead(VecTy, false, true); } @@ -578,7 +611,10 @@ return Cost; } - unsigned getScalarizationOverhead(VectorType *Ty, ArrayRef Args) { + unsigned getScalarizationOverhead(VectorType *InTy, + ArrayRef Args) { + auto *Ty = cast(InTy); + unsigned Cost = 0; Cost += getScalarizationOverhead(Ty, true, false); @@ -631,7 +667,7 @@ // TODO: If one of the types get legalized by splitting, handle this // similarly to what getCastInstrCost() does. if (auto *VTy = dyn_cast(Ty)) { - unsigned Num = VTy->getNumElements(); + unsigned Num = cast(VTy)->getNumElements(); unsigned Cost = static_cast(this)->getArithmeticInstrCost( Opcode, VTy->getScalarType(), CostKind); // Return the cost of multiple scalar invocation plus the cost of @@ -643,8 +679,11 @@ return OpCost; } - unsigned getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, int Index, - VectorType *SubTp) { + unsigned getShuffleCost(TTI::ShuffleKind Kind, VectorType *InTp, int Index, + VectorType *InSubTp) { + auto *Tp = cast(InTp); + auto *SubTp = cast(InSubTp); + switch (Kind) { case TTI::SK_Broadcast: return getBroadcastShuffleOverhead(Tp); @@ -777,12 +816,11 @@ bool SplitDst = TLI->getTypeAction(Dst->getContext(), TLI->getValueType(DL, Dst)) == TargetLowering::TypeSplitVector; - if ((SplitSrc || SplitDst) && SrcVTy->getNumElements() > 1 && - DstVTy->getNumElements() > 1) { - Type *SplitDstTy = VectorType::get(DstVTy->getElementType(), - DstVTy->getNumElements() / 2); - Type *SplitSrcTy = VectorType::get(SrcVTy->getElementType(), - SrcVTy->getNumElements() / 2); + if ((SplitSrc || SplitDst) && + cast(SrcVTy)->getNumElements() > 1 && + cast(DstVTy)->getNumElements() > 1) { + Type *SplitDstTy = VectorType::getHalfElementsVectorType(DstVTy); + Type *SplitSrcTy = VectorType::getHalfElementsVectorType(SrcVTy); T *TTI = static_cast(this); // If both types need to be split then the split is free. unsigned SplitCost = @@ -794,7 +832,7 @@ // In other cases where the source or destination are illegal, assume // the operation will get scalarized. - unsigned Num = DstVTy->getNumElements(); + unsigned Num = cast(DstVTy)->getNumElements(); unsigned Cost = static_cast(this)->getCastInstrCost( Opcode, Dst->getScalarType(), Src->getScalarType(), CostKind, I); @@ -861,7 +899,7 @@ // TODO: If one of the types get legalized by splitting, handle this // similarly to what getCastInstrCost() does. if (auto *ValVTy = dyn_cast(ValTy)) { - unsigned Num = ValVTy->getNumElements(); + unsigned Num = cast(ValVTy)->getNumElements(); if (CondTy) CondTy = CondTy->getScalarType(); unsigned Cost = static_cast(this)->getCmpSelInstrCost( @@ -929,13 +967,13 @@ TTI::TargetCostKind CostKind, bool UseMaskForCond = false, bool UseMaskForGaps = false) { - auto *VT = cast(VecTy); + auto *VT = cast(VecTy); unsigned NumElts = VT->getNumElements(); assert(Factor > 1 && NumElts % Factor == 0 && "Invalid interleave factor"); unsigned NumSubElts = NumElts / Factor; - VectorType *SubVT = VectorType::get(VT->getElementType(), NumSubElts); + auto *SubVT = FixedVectorType::get(VT->getElementType(), NumSubElts); // Firstly, the cost of load/store operation. unsigned Cost; @@ -1044,8 +1082,8 @@ return Cost; Type *I8Type = Type::getInt8Ty(VT->getContext()); - VectorType *MaskVT = VectorType::get(I8Type, NumElts); - SubVT = VectorType::get(I8Type, NumSubElts); + auto *MaskVT = FixedVectorType::get(I8Type, NumElts); + SubVT = FixedVectorType::get(I8Type, NumSubElts); // The Mask shuffling cost is extract all the elements of the Mask // and insert each of them Factor times into the wide vector: @@ -1113,7 +1151,8 @@ Type *RetTy = ICA.getReturnType(); unsigned VF = ICA.getVectorFactor(); unsigned RetVF = - (RetTy->isVectorTy() ? cast(RetTy)->getNumElements() : 1); + (RetTy->isVectorTy() ? cast(RetTy)->getNumElements() + : 1); assert((RetVF == 1 || VF == 1) && "VF > 1 and RetVF is a vector type"); const IntrinsicInst *I = ICA.getInst(); const SmallVectorImpl &Args = ICA.getArgs(); @@ -1126,11 +1165,11 @@ for (Value *Op : Args) { Type *OpTy = Op->getType(); assert(VF == 1 || !OpTy->isVectorTy()); - Types.push_back(VF == 1 ? OpTy : VectorType::get(OpTy, VF)); + Types.push_back(VF == 1 ? OpTy : FixedVectorType::get(OpTy, VF)); } if (VF > 1 && !RetTy->isVoidTy()) - RetTy = VectorType::get(RetTy, VF); + RetTy = FixedVectorType::get(RetTy, VF); // Compute the scalarization overhead based on Args for a vector // intrinsic. A vectorizer will pass a scalar RetTy and VF > 1, while @@ -1256,7 +1295,8 @@ if (auto *RetVTy = dyn_cast(RetTy)) { if (!SkipScalarizationCost) ScalarizationCost = getScalarizationOverhead(RetVTy, true, false); - ScalarCalls = std::max(ScalarCalls, RetVTy->getNumElements()); + ScalarCalls = std::max(ScalarCalls, + cast(RetVTy)->getNumElements()); ScalarRetTy = RetTy->getScalarType(); } SmallVector ScalarTys; @@ -1265,7 +1305,8 @@ if (auto *VTy = dyn_cast(Ty)) { if (!SkipScalarizationCost) ScalarizationCost += getScalarizationOverhead(VTy, false, true); - ScalarCalls = std::max(ScalarCalls, VTy->getNumElements()); + ScalarCalls = std::max(ScalarCalls, + cast(VTy)->getNumElements()); Ty = Ty->getScalarType(); } ScalarTys.push_back(Ty); @@ -1623,7 +1664,7 @@ unsigned ScalarizationCost = SkipScalarizationCost ? ScalarizationCostPassed : getScalarizationOverhead(RetVTy, true, false); - unsigned ScalarCalls = RetVTy->getNumElements(); + unsigned ScalarCalls = cast(RetVTy)->getNumElements(); SmallVector ScalarTys; for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) { Type *Ty = Tys[i]; @@ -1637,7 +1678,8 @@ if (auto *VTy = dyn_cast(Tys[i])) { if (!ICA.skipScalarizationCost()) ScalarizationCost += getScalarizationOverhead(VTy, false, true); - ScalarCalls = std::max(ScalarCalls, VTy->getNumElements()); + ScalarCalls = std::max(ScalarCalls, + cast(VTy)->getNumElements()); } } return ScalarCalls * ScalarCost + ScalarizationCost; @@ -1712,7 +1754,7 @@ bool IsPairwise, TTI::TargetCostKind CostKind) { Type *ScalarTy = Ty->getElementType(); - unsigned NumVecElts = Ty->getNumElements(); + unsigned NumVecElts = cast(Ty)->getNumElements(); unsigned NumReduxLevels = Log2_32(NumVecElts); unsigned ArithCost = 0; unsigned ShuffleCost = 0; @@ -1724,7 +1766,7 @@ LT.second.isVector() ? LT.second.getVectorNumElements() : 1; while (NumVecElts > MVTLen) { NumVecElts /= 2; - VectorType *SubTy = VectorType::get(ScalarTy, NumVecElts); + VectorType *SubTy = FixedVectorType::get(ScalarTy, NumVecElts); // Assume the pairwise shuffles add a cost. ShuffleCost += (IsPairwise + 1) * ConcreteTTI->getShuffleCost(TTI::SK_ExtractSubvector, Ty, @@ -1763,7 +1805,7 @@ TTI::TargetCostKind CostKind) { Type *ScalarTy = Ty->getElementType(); Type *ScalarCondTy = CondTy->getElementType(); - unsigned NumVecElts = Ty->getNumElements(); + unsigned NumVecElts = cast(Ty)->getNumElements(); unsigned NumReduxLevels = Log2_32(NumVecElts); unsigned CmpOpcode; if (Ty->isFPOrFPVectorTy()) { @@ -1783,8 +1825,8 @@ LT.second.isVector() ? LT.second.getVectorNumElements() : 1; while (NumVecElts > MVTLen) { NumVecElts /= 2; - VectorType *SubTy = VectorType::get(ScalarTy, NumVecElts); - CondTy = VectorType::get(ScalarCondTy, NumVecElts); + auto *SubTy = FixedVectorType::get(ScalarTy, NumVecElts); + CondTy = FixedVectorType::get(ScalarCondTy, NumVecElts); // Assume the pairwise shuffles add a cost. ShuffleCost += (IsPairwise + 1) *