diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -78,9 +78,26 @@ using BaseT = TargetTransformInfoImplCRTPBase; using TTI = TargetTransformInfo; + /// Theoretically, scalable vectors should be supported for many (if not all) + /// of these functions that take VectorType parameters. For instance, we + /// can do a shufflevector on a scalable vector, so we should also be able + /// to reason about the cost of this operation. Since this is the case, we + /// have decided to expose a base VectorType interface. + /// + /// This function automates the business of writing "FIXME: support scalable + /// vectors" everywhere, asserting, and casting to FixedVectorType in order + /// to be able to call getNumElements(). + FixedVectorType *FIXME_ScalableVectorNotSupported(VectorType *VTy) { + assert(isa(VTy) && + "Scalable vectors not yet supported in this context"); + return cast(VTy); + } + /// Estimate a cost of Broadcast as an extract and sequence of insert /// operations. - unsigned getBroadcastShuffleOverhead(VectorType *VTy) { + unsigned getBroadcastShuffleOverhead(VectorType *InVTy) { + auto *VTy = FIXME_ScalableVectorNotSupported(InVTy); + unsigned Cost = 0; // Broadcast cost is equal to the cost of extracting the zero'th element // plus the cost of inserting it into every element of the result vector. @@ -96,7 +113,9 @@ /// Estimate a cost of shuffle as a sequence of extract and insert /// operations. - unsigned getPermuteShuffleOverhead(VectorType *VTy) { + unsigned getPermuteShuffleOverhead(VectorType *InVTy) { + auto *VTy = FIXME_ScalableVectorNotSupported(InVTy); + unsigned Cost = 0; // Shuffle cost is equal to the cost of extracting element from its argument // plus the cost of inserting them onto the result vector. @@ -116,8 +135,11 @@ /// Estimate a cost of subvector extraction as a sequence of extract and /// insert operations. - unsigned getExtractSubvectorOverhead(VectorType *VTy, int Index, - VectorType *SubVTy) { + unsigned getExtractSubvectorOverhead(VectorType *InVTy, int Index, + VectorType *InSubVTy) { + auto *VTy = FIXME_ScalableVectorNotSupported(InVTy); + auto *SubVTy = FIXME_ScalableVectorNotSupported(InSubVTy); + assert(VTy && SubVTy && "Can only extract subvectors from vectors"); int NumSubElts = SubVTy->getNumElements(); @@ -139,8 +161,11 @@ /// Estimate a cost of subvector insertion as a sequence of extract and /// insert operations. - unsigned getInsertSubvectorOverhead(VectorType *VTy, int Index, - VectorType *SubVTy) { + unsigned getInsertSubvectorOverhead(VectorType *InVTy, int Index, + VectorType *InSubVTy) { + auto *VTy = FIXME_ScalableVectorNotSupported(InVTy); + auto *SubVTY = FIXME_ScalableVectorNotSupported(InSubVTy); + assert(VTy && SubVTy && "Can only insert subvectors into vectors"); int NumSubElts = SubVTy->getNumElements(); @@ -525,8 +550,12 @@ /// Estimate the overhead of scalarizing an instruction. Insert and Extract /// are set if the demanded result elements need to be inserted and/or /// extracted from vectors. - unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *InTy, const APInt &DemandedElts, bool Insert, bool Extract) { + /// FIXME: a bitfield is not a reasonable abstraction for talking about + /// which elements are needed from a scalable vector + auto *Ty = FIXME_ScalableVectorNotSupported(InTy); + assert(DemandedElts.getBitWidth() == Ty->getNumElements() && "Vector size mismatch"); @@ -547,7 +576,10 @@ } /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead. - unsigned getScalarizationOverhead(VectorType *Ty, bool Insert, bool Extract) { + unsigned getScalarizationOverhead(VectorType *InTy, bool Insert, + bool Extract) { + auto *Ty = FIXME_ScalableVectorNotSupported(InTy); + APInt DemandedElts = APInt::getAllOnesValue(Ty->getNumElements()); return static_cast(this)->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract); @@ -565,11 +597,12 @@ auto *VecTy = dyn_cast(A->getType()); if (VecTy) { // If A is a vector operand, VF should be 1 or correspond to A. - assert((VF == 1 || VF == VecTy->getNumElements()) && + assert((VF == 1 || VF == FIXME_ScalableVectorNotSupported(VecTy) + ->getNumElements()) && "Vector argument does not match VF"); } else - VecTy = VectorType::get(A->getType(), VF); + VecTy = FixedVectorType::get(A->getType(), VF); Cost += getScalarizationOverhead(VecTy, false, true); } @@ -578,7 +611,10 @@ return Cost; } - unsigned getScalarizationOverhead(VectorType *Ty, ArrayRef Args) { + unsigned getScalarizationOverhead(VectorType *InTy, + ArrayRef Args) { + auto *Ty = FIXME_ScalableVectorNotSupported(InTy); + unsigned Cost = 0; Cost += getScalarizationOverhead(Ty, true, false); @@ -631,7 +667,7 @@ // TODO: If one of the types get legalized by splitting, handle this // similarly to what getCastInstrCost() does. if (auto *VTy = dyn_cast(Ty)) { - unsigned Num = VTy->getNumElements(); + unsigned Num = FIXME_ScalableVectorNotSupported(VTy)->getNumElements(); unsigned Cost = static_cast(this)->getArithmeticInstrCost( Opcode, VTy->getScalarType(), CostKind); // Return the cost of multiple scalar invocation plus the cost of @@ -643,8 +679,11 @@ return OpCost; } - unsigned getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, int Index, - VectorType *SubTp) { + unsigned getShuffleCost(TTI::ShuffleKind Kind, VectorType *InTp, int Index, + VectorType *InSubTp) { + auto *Tp = FIXME_ScalableVectorNotSupported(InTp); + auto *SubTp = FIXME_ScalableVectorNotSupported(InSubTp); + switch (Kind) { case TTI::SK_Broadcast: return getBroadcastShuffleOverhead(Tp); @@ -777,12 +816,11 @@ bool SplitDst = TLI->getTypeAction(Dst->getContext(), TLI->getValueType(DL, Dst)) == TargetLowering::TypeSplitVector; - if ((SplitSrc || SplitDst) && SrcVTy->getNumElements() > 1 && - DstVTy->getNumElements() > 1) { - Type *SplitDstTy = VectorType::get(DstVTy->getElementType(), - DstVTy->getNumElements() / 2); - Type *SplitSrcTy = VectorType::get(SrcVTy->getElementType(), - SrcVTy->getNumElements() / 2); + if ((SplitSrc || SplitDst) && + FIXME_ScalableVectorNotSupported(SrcVTy)->getNumElements() > 1 && + FIXME_ScalableVectorNotSupported(DstVTy)->getNumElements() > 1) { + Type *SplitDstTy = VectorType::getHalfElementsVectorType(DstVTy); + Type *SplitSrcTy = VectorType::getHalfElementsVectorType(SrcVTy); T *TTI = static_cast(this); // If both types need to be split then the split is free. unsigned SplitCost = @@ -794,7 +832,7 @@ // In other cases where the source or destination are illegal, assume // the operation will get scalarized. - unsigned Num = DstVTy->getNumElements(); + unsigned Num = FIXME_ScalableVectorNotSupported(DstVTy)->getNumElements(); unsigned Cost = static_cast(this)->getCastInstrCost( Opcode, Dst->getScalarType(), Src->getScalarType(), CostKind, I); @@ -861,7 +899,7 @@ // TODO: If one of the types get legalized by splitting, handle this // similarly to what getCastInstrCost() does. if (auto *ValVTy = dyn_cast(ValTy)) { - unsigned Num = ValVTy->getNumElements(); + unsigned Num = FIXME_ScalableVectorNotSupported(ValVTy)->getNumElements(); if (CondTy) CondTy = CondTy->getScalarType(); unsigned Cost = static_cast(this)->getCmpSelInstrCost( @@ -929,7 +967,7 @@ TTI::TargetCostKind CostKind, bool UseMaskForCond = false, bool UseMaskForGaps = false) { - auto *VT = cast(VecTy); + auto *VT = FIXME_ScalableVectorNotSupported(cast(VecTy)); unsigned NumElts = VT->getNumElements(); assert(Factor > 1 && NumElts % Factor == 0 && "Invalid interleave factor"); @@ -1113,7 +1151,10 @@ Type *RetTy = ICA.getReturnType(); unsigned VF = ICA.getVectorFactor(); unsigned RetVF = - (RetTy->isVectorTy() ? cast(RetTy)->getNumElements() : 1); + (RetTy->isVectorTy() + ? FIXME_ScalableVectorNotSupported(cast(RetTy)) + ->getNumElements() + : 1); assert((RetVF == 1 || VF == 1) && "VF > 1 and RetVF is a vector type"); const IntrinsicInst *I = ICA.getInst(); const SmallVectorImpl &Args = ICA.getArgs(); @@ -1256,7 +1297,9 @@ if (auto *RetVTy = dyn_cast(RetTy)) { if (!SkipScalarizationCost) ScalarizationCost = getScalarizationOverhead(RetVTy, true, false); - ScalarCalls = std::max(ScalarCalls, RetVTy->getNumElements()); + ScalarCalls = std::max( + ScalarCalls, + FIXME_ScalableVectorNotSupported(RetVTy)->getNumElements()); ScalarRetTy = RetTy->getScalarType(); } SmallVector ScalarTys; @@ -1265,7 +1308,9 @@ if (auto *VTy = dyn_cast(Ty)) { if (!SkipScalarizationCost) ScalarizationCost += getScalarizationOverhead(VTy, false, true); - ScalarCalls = std::max(ScalarCalls, VTy->getNumElements()); + ScalarCalls = + std::max(ScalarCalls, + FIXME_ScalableVectorNotSupported(VTy)->getNumElements()); Ty = Ty->getScalarType(); } ScalarTys.push_back(Ty); @@ -1623,7 +1668,8 @@ unsigned ScalarizationCost = SkipScalarizationCost ? ScalarizationCostPassed : getScalarizationOverhead(RetVTy, true, false); - unsigned ScalarCalls = RetVTy->getNumElements(); + unsigned ScalarCalls = + FIXME_ScalableVectorNotSupported(RetVTy)->getNumElements(); SmallVector ScalarTys; for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) { Type *Ty = Tys[i]; @@ -1637,7 +1683,9 @@ if (auto *VTy = dyn_cast(Tys[i])) { if (!ICA.skipScalarizationCost()) ScalarizationCost += getScalarizationOverhead(VTy, false, true); - ScalarCalls = std::max(ScalarCalls, VTy->getNumElements()); + ScalarCalls = + std::max(ScalarCalls, + FIXME_ScalableVectorNotSupported(VTy)->getNumElements()); } } return ScalarCalls * ScalarCost + ScalarizationCost; @@ -1712,7 +1760,8 @@ bool IsPairwise, TTI::TargetCostKind CostKind) { Type *ScalarTy = Ty->getElementType(); - unsigned NumVecElts = Ty->getNumElements(); + unsigned NumVecElts = + FIXME_ScalableVectorNotSupported(Ty)->getNumElements(); unsigned NumReduxLevels = Log2_32(NumVecElts); unsigned ArithCost = 0; unsigned ShuffleCost = 0; @@ -1763,7 +1812,8 @@ TTI::TargetCostKind CostKind) { Type *ScalarTy = Ty->getElementType(); Type *ScalarCondTy = CondTy->getElementType(); - unsigned NumVecElts = Ty->getNumElements(); + unsigned NumVecElts = + FIXME_ScalableVectorNotSupported(Ty)->getNumElements(); unsigned NumReduxLevels = Log2_32(NumVecElts); unsigned CmpOpcode; if (Ty->isFPOrFPVectorTy()) {