diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -614,7 +614,7 @@ /// Estimate the overhead of scalarizing an instruction. Insert and Extract /// are set if the demanded result elements need to be inserted and/or /// extracted from vectors. - unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract) const; /// Estimate the overhead of scalarizing an instructions unique @@ -1238,7 +1238,8 @@ virtual bool shouldBuildLookupTables() = 0; virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0; virtual bool useColdCCForColdCall(Function &F) = 0; - virtual unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + virtual unsigned getScalarizationOverhead(VectorType *Ty, + const APInt &DemandedElts, bool Insert, bool Extract) = 0; virtual unsigned getOperandsScalarizationOverhead(ArrayRef Args, @@ -1563,7 +1564,7 @@ return Impl.useColdCCForColdCall(F); } - unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract) override { return Impl.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -239,7 +239,7 @@ bool useColdCCForColdCall(Function &F) { return false; } - unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract) { return 0; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -550,32 +550,30 @@ /// Estimate the overhead of scalarizing an instruction. Insert and Extract /// are set if the demanded result elements need to be inserted and/or /// extracted from vectors. - unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract) { - auto *VTy = cast(Ty); - assert(DemandedElts.getBitWidth() == VTy->getNumElements() && + assert(DemandedElts.getBitWidth() == Ty->getNumElements() && "Vector size mismatch"); unsigned Cost = 0; - for (int i = 0, e = VTy->getNumElements(); i < e; ++i) { + for (int i = 0, e = Ty->getNumElements(); i < e; ++i) { if (!DemandedElts[i]) continue; if (Insert) Cost += static_cast(this)->getVectorInstrCost( - Instruction::InsertElement, VTy, i); + Instruction::InsertElement, Ty, i); if (Extract) Cost += static_cast(this)->getVectorInstrCost( - Instruction::ExtractElement, VTy, i); + Instruction::ExtractElement, Ty, i); } return Cost; } /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead. - unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) { - auto *VTy = cast(Ty); - APInt DemandedElts = APInt::getAllOnesValue(VTy->getNumElements()); + unsigned getScalarizationOverhead(VectorType *Ty, bool Insert, bool Extract) { + APInt DemandedElts = APInt::getAllOnesValue(Ty->getNumElements()); return static_cast(this)->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract); } @@ -589,11 +587,11 @@ SmallPtrSet UniqueOperands; for (const Value *A : Args) { if (!isa(A) && UniqueOperands.insert(A).second) { - Type *VecTy = nullptr; + VectorType *VecTy = nullptr; if (A->getType()->isVectorTy()) { - VecTy = A->getType(); + VecTy = cast(A->getType()); // If A is a vector operand, VF should be 1 or correspond to A. - assert((VF == 1 || VF == cast(VecTy)->getNumElements()) && + assert((VF == 1 || VF == VecTy->getNumElements()) && "Vector argument does not match VF"); } else @@ -606,17 +604,16 @@ return Cost; } - unsigned getScalarizationOverhead(Type *VecTy, ArrayRef Args) { + unsigned getScalarizationOverhead(VectorType *Ty, ArrayRef Args) { unsigned Cost = 0; - auto *VecVTy = cast(VecTy); - Cost += getScalarizationOverhead(VecVTy, true, false); + Cost += getScalarizationOverhead(Ty, true, false); if (!Args.empty()) - Cost += getOperandsScalarizationOverhead(Args, VecVTy->getNumElements()); + Cost += getOperandsScalarizationOverhead(Args, Ty->getNumElements()); else // When no information on arguments is provided, we add the cost // associated with one argument as a heuristic. - Cost += getScalarizationOverhead(VecVTy, false, true); + Cost += getScalarizationOverhead(Ty, false, true); return Cost; } @@ -813,19 +810,20 @@ // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. - return getScalarizationOverhead(Dst, true, true) + Num * Cost; + return getScalarizationOverhead(DstVTy, true, true) + Num * Cost; } // We already handled vector-to-vector and scalar-to-scalar conversions. // This // is where we handle bitcast between vectors and scalars. We need to assume // that the conversion is scalarized in one way or another. - if (Opcode == Instruction::BitCast) + if (Opcode == Instruction::BitCast) { // Illegal bitcasts are done by storing and loading from a stack slot. - return (Src->isVectorTy() ? getScalarizationOverhead(Src, false, true) - : 0) + - (Dst->isVectorTy() ? getScalarizationOverhead(Dst, true, false) - : 0); + auto *SrcVTy = dyn_cast(Src); + auto *DstVTy = dyn_cast(Dst); + return (SrcVTy ? getScalarizationOverhead(SrcVTy, false, true) : 0) + + (DstVTy ? getScalarizationOverhead(DstVTy, true, false) : 0); + } llvm_unreachable("Unhandled cast"); } @@ -914,7 +912,8 @@ if (LA != TargetLowering::Legal && LA != TargetLowering::Custom) { // This is a vector load/store for some illegal type that is scalarized. // We must account for the cost of building or decomposing the vector. - Cost += getScalarizationOverhead(Src, Opcode != Instruction::Store, + Cost += getScalarizationOverhead(cast(Src), + Opcode != Instruction::Store, Opcode == Instruction::Store); } } @@ -1106,7 +1105,8 @@ if (RetVF > 1 || VF > 1) { ScalarizationCost = 0; if (!RetTy->isVoidTy()) - ScalarizationCost += getScalarizationOverhead(RetTy, true, false); + ScalarizationCost += + getScalarizationOverhead(cast(RetTy), true, false); ScalarizationCost += getOperandsScalarizationOverhead(Args, VF); } @@ -1204,20 +1204,20 @@ unsigned ScalarCalls = 1; Type *ScalarRetTy = RetTy; if (RetTy->isVectorTy()) { + auto *RetVTy = cast(RetTy); if (ScalarizationCostPassed == std::numeric_limits::max()) - ScalarizationCost = getScalarizationOverhead(RetTy, true, false); - ScalarCalls = - std::max(ScalarCalls, cast(RetTy)->getNumElements()); + ScalarizationCost = getScalarizationOverhead(RetVTy, true, false); + ScalarCalls = std::max(ScalarCalls, RetVTy->getNumElements()); ScalarRetTy = RetTy->getScalarType(); } SmallVector ScalarTys; for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) { Type *Ty = Tys[i]; if (Ty->isVectorTy()) { + auto *VTy = cast(Ty); if (ScalarizationCostPassed == std::numeric_limits::max()) - ScalarizationCost += getScalarizationOverhead(Ty, false, true); - ScalarCalls = - std::max(ScalarCalls, cast(Ty)->getNumElements()); + ScalarizationCost += getScalarizationOverhead(VTy, false, true); + ScalarCalls = std::max(ScalarCalls, VTy->getNumElements()); Ty = Ty->getScalarType(); } ScalarTys.push_back(Ty); @@ -1547,11 +1547,12 @@ // this will emit a costly libcall, adding call overhead and spills. Make it // very expensive. if (RetTy->isVectorTy()) { + auto *RetVTy = cast(RetTy); unsigned ScalarizationCost = ((ScalarizationCostPassed != std::numeric_limits::max()) ? ScalarizationCostPassed - : getScalarizationOverhead(RetTy, true, false)); - unsigned ScalarCalls = cast(RetTy)->getNumElements(); + : getScalarizationOverhead(RetVTy, true, false)); + unsigned ScalarCalls = RetVTy->getNumElements(); SmallVector ScalarTys; for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) { Type *Ty = Tys[i]; @@ -1563,10 +1564,10 @@ IID, RetTy->getScalarType(), ScalarTys, FMF); for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) { if (Tys[i]->isVectorTy()) { + auto *VTy = cast(Tys[i]); if (ScalarizationCostPassed == std::numeric_limits::max()) - ScalarizationCost += getScalarizationOverhead(Tys[i], false, true); - ScalarCalls = - std::max(ScalarCalls, cast(Tys[i])->getNumElements()); + ScalarizationCost += getScalarizationOverhead(VTy, false, true); + ScalarCalls = std::max(ScalarCalls, VTy->getNumElements()); } } diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -368,8 +368,10 @@ return TTIImpl->useColdCCForColdCall(F); } -unsigned TargetTransformInfo::getScalarizationOverhead( - Type *Ty, const APInt &DemandedElts, bool Insert, bool Extract) const { +unsigned +TargetTransformInfo::getScalarizationOverhead(VectorType *Ty, + const APInt &DemandedElts, + bool Insert, bool Extract) const { return TTIImpl->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract); } diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -798,7 +798,7 @@ unsigned Cost = getArithmeticInstrCost(Opcode, Ty->getScalarType()); // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. - return BaseT::getScalarizationOverhead(Ty, Args) + Num * Cost; + return BaseT::getScalarizationOverhead(VTy, Args) + Num * Cost; } return BaseCost; @@ -887,7 +887,7 @@ // The scalarization cost should be a lot higher. We use the number of vector // elements plus the scalarization overhead. unsigned ScalarCost = - NumElems * LT.first + BaseT::getScalarizationOverhead(DataTy, {}); + NumElems * LT.first + BaseT::getScalarizationOverhead(VTy, {}); if (Alignment < EltSize / 8) return ScalarCost; diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -101,7 +101,7 @@ return true; } - unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract); unsigned getOperandsScalarizationOverhead(ArrayRef Args, unsigned VF); diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -115,7 +115,7 @@ return (8 * ST.getVectorLength()) / ElemWidth; } -unsigned HexagonTTIImpl::getScalarizationOverhead(Type *Ty, +unsigned HexagonTTIImpl::getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract) { return BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract); diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp --- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp +++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp @@ -459,6 +459,7 @@ return DivInstrCost; } else if (ST->hasVector()) { + auto VTy = cast(Ty); unsigned VF = cast(Ty)->getNumElements(); unsigned NumVectors = getNumVectorRegs(Ty); @@ -472,7 +473,7 @@ if (DivRemConstPow2) return (NumVectors * (SignedDivRem ? SDivPow2Cost : 1)); if (DivRemConst) - return VF * DivMulSeqCost + getScalarizationOverhead(Ty, Args); + return VF * DivMulSeqCost + getScalarizationOverhead(VTy, Args); if ((SignedDivRem || UnsignedDivRem) && VF > 4) // Temporary hack: disable high vectorization factors with integer // division/remainder, which will get scalarized and handled with @@ -495,7 +496,7 @@ // inserting and extracting the values. unsigned ScalarCost = getArithmeticInstrCost(Opcode, Ty->getScalarType()); - unsigned Cost = (VF * ScalarCost) + getScalarizationOverhead(Ty, Args); + unsigned Cost = (VF * ScalarCost) + getScalarizationOverhead(VTy, Args); // FIXME: VF 2 for these FP operations are currently just as // expensive as for VF 4. if (VF == 2) @@ -512,7 +513,7 @@ // There is no native support for FRem. if (Opcode == Instruction::FRem) { - unsigned Cost = (VF * LIBCALL_COST) + getScalarizationOverhead(Ty, Args); + unsigned Cost = (VF * LIBCALL_COST) + getScalarizationOverhead(VTy, Args); // FIXME: VF 2 for float is currently just as expensive as for VF 4. if (VF == 2 && ScalarBits == 32) Cost *= 2; @@ -718,8 +719,9 @@ } } else if (ST->hasVector()) { - assert (Dst->isVectorTy()); - unsigned VF = cast(Src)->getNumElements(); + auto SrcVecTy = cast(Src); + auto DstVecTy = cast(Dst); + unsigned VF = SrcVecTy->getNumElements(); unsigned NumDstVectors = getNumVectorRegs(Dst); unsigned NumSrcVectors = getNumVectorRegs(Src); @@ -775,8 +777,8 @@ (Opcode == Instruction::FPToSI || Opcode == Instruction::FPToUI)) NeedsExtracts = false; - TotCost += getScalarizationOverhead(Src, false, NeedsExtracts); - TotCost += getScalarizationOverhead(Dst, NeedsInserts, false); + TotCost += getScalarizationOverhead(SrcVecTy, false, NeedsExtracts); + TotCost += getScalarizationOverhead(DstVecTy, NeedsInserts, false); // FIXME: VF 2 for float<->i32 is currently just as expensive as for VF 4. if (VF == 2 && SrcScalarBits == 32 && DstScalarBits == 32) @@ -787,7 +789,8 @@ if (Opcode == Instruction::FPTrunc) { if (SrcScalarBits == 128) // fp128 -> double/float + inserts of elements. - return VF /*ldxbr/lexbr*/ + getScalarizationOverhead(Dst, true, false); + return VF /*ldxbr/lexbr*/ + + getScalarizationOverhead(DstVecTy, true, false); else // double -> float return VF / 2 /*vledb*/ + std::max(1U, VF / 4 /*vperm*/); } @@ -800,7 +803,7 @@ return VF * 2; } // -> fp128. VF * lxdb/lxeb + extraction of elements. - return VF + getScalarizationOverhead(Src, false, true); + return VF + getScalarizationOverhead(SrcVecTy, false, true); } } diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h --- a/llvm/lib/Target/X86/X86TargetTransformInfo.h +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -132,7 +132,7 @@ int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, const Instruction *I = nullptr); int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); - unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract); int getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment, unsigned AddressSpace, const Instruction *I = nullptr); diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -2873,10 +2873,9 @@ return BaseT::getVectorInstrCost(Opcode, Val, Index) + RegisterFileMoveCost; } -unsigned X86TTIImpl::getScalarizationOverhead(Type *Ty, +unsigned X86TTIImpl::getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract) { - auto* VecTy = cast(Ty); unsigned Cost = 0; // For insertions, a ISD::BUILD_VECTOR style vector initialization can be much @@ -2902,7 +2901,7 @@ // 128-bit vector is free. // NOTE: This assumes legalization widens vXf32 vectors. if (MScalarTy == MVT::f32) - for (unsigned i = 0, e = VecTy->getNumElements(); i < e; i += 4) + for (unsigned i = 0, e = Ty->getNumElements(); i < e; i += 4) if (DemandedElts[i]) Cost--; } @@ -2918,7 +2917,7 @@ // vector elements, which represents the number of unpacks we'll end up // performing. unsigned NumElts = LT.second.getVectorNumElements(); - unsigned Pow2Elts = PowerOf2Ceil(VecTy->getNumElements()); + unsigned Pow2Elts = PowerOf2Ceil(Ty->getNumElements()); Cost += (std::min(NumElts, Pow2Elts) - 1) * LT.first; } } @@ -2954,7 +2953,7 @@ APInt DemandedElts = APInt::getAllOnesValue(NumElem); int Cost = BaseT::getMemoryOpCost(Opcode, VTy->getScalarType(), Alignment, AddressSpace); - int SplitCost = getScalarizationOverhead(Src, DemandedElts, + int SplitCost = getScalarizationOverhead(VTy, DemandedElts, Opcode == Instruction::Load, Opcode == Instruction::Store); return NumElem * Cost + SplitCost; diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -5698,9 +5698,9 @@ // Compute the scalarization overhead of needed insertelement instructions // and phi nodes. if (isScalarWithPredication(I) && !I->getType()->isVoidTy()) { - ScalarCost += - TTI.getScalarizationOverhead(ToVectorTy(I->getType(), VF), - APInt::getAllOnesValue(VF), true, false); + ScalarCost += TTI.getScalarizationOverhead( + cast(ToVectorTy(I->getType(), VF)), + APInt::getAllOnesValue(VF), true, false); ScalarCost += VF * TTI.getCFInstrCost(Instruction::PHI); } @@ -5716,8 +5716,8 @@ Worklist.push_back(J); else if (needsExtract(J, VF)) ScalarCost += TTI.getScalarizationOverhead( - ToVectorTy(J->getType(), VF), APInt::getAllOnesValue(VF), false, - true); + cast(ToVectorTy(J->getType(), VF)), + APInt::getAllOnesValue(VF), false, true); } // Scale the total scalar cost by block probability. @@ -6001,8 +6001,8 @@ Type *RetTy = ToVectorTy(I->getType(), VF); if (!RetTy->isVoidTy() && (!isa(I) || !TTI.supportsEfficientVectorElementLoadStore())) - Cost += TTI.getScalarizationOverhead(RetTy, APInt::getAllOnesValue(VF), - true, false); + Cost += TTI.getScalarizationOverhead( + cast(RetTy), APInt::getAllOnesValue(VF), true, false); // Some targets keep addresses scalar. if (isa(I) && !TTI.prefersVectorizedAddressing()) @@ -6206,7 +6206,7 @@ if (ScalarPredicatedBB) { // Return cost for branches around scalarized and predicated blocks. - Type *Vec_i1Ty = + VectorType *Vec_i1Ty = VectorType::get(IntegerType::getInt1Ty(RetTy->getContext()), VF); return (TTI.getScalarizationOverhead(Vec_i1Ty, APInt::getAllOnesValue(VF), false, true) +