diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -620,7 +620,7 @@ /// Estimate the overhead of scalarizing an instruction. Insert and Extract /// are set if the demanded result elements need to be inserted and/or /// extracted from vectors. - unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract) const; /// Estimate the overhead of scalarizing an instructions unique @@ -1261,7 +1261,8 @@ virtual bool shouldBuildLookupTables() = 0; virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0; virtual bool useColdCCForColdCall(Function &F) = 0; - virtual unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + virtual unsigned getScalarizationOverhead(VectorType *Ty, + const APInt &DemandedElts, bool Insert, bool Extract) = 0; virtual unsigned getOperandsScalarizationOverhead(ArrayRef Args, @@ -1609,7 +1610,7 @@ return Impl.useColdCCForColdCall(F); } - unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract) override { return Impl.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -240,7 +240,7 @@ bool useColdCCForColdCall(Function &F) { return false; } - unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract) { return 0; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -552,32 +552,30 @@ /// Estimate the overhead of scalarizing an instruction. Insert and Extract /// are set if the demanded result elements need to be inserted and/or /// extracted from vectors. - unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract) { - auto *VTy = cast(Ty); - assert(DemandedElts.getBitWidth() == VTy->getNumElements() && + assert(DemandedElts.getBitWidth() == Ty->getNumElements() && "Vector size mismatch"); unsigned Cost = 0; - for (int i = 0, e = VTy->getNumElements(); i < e; ++i) { + for (int i = 0, e = Ty->getNumElements(); i < e; ++i) { if (!DemandedElts[i]) continue; if (Insert) Cost += static_cast(this)->getVectorInstrCost( - Instruction::InsertElement, VTy, i); + Instruction::InsertElement, Ty, i); if (Extract) Cost += static_cast(this)->getVectorInstrCost( - Instruction::ExtractElement, VTy, i); + Instruction::ExtractElement, Ty, i); } return Cost; } /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead. - unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) { - auto *VTy = cast(Ty); - APInt DemandedElts = APInt::getAllOnesValue(VTy->getNumElements()); + unsigned getScalarizationOverhead(VectorType *Ty, bool Insert, bool Extract) { + APInt DemandedElts = APInt::getAllOnesValue(Ty->getNumElements()); return static_cast(this)->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract); } @@ -591,11 +589,10 @@ SmallPtrSet UniqueOperands; for (const Value *A : Args) { if (!isa(A) && UniqueOperands.insert(A).second) { - Type *VecTy = nullptr; - if (A->getType()->isVectorTy()) { - VecTy = A->getType(); + auto *VecTy = dyn_cast(A->getType()); + if (VecTy) { // If A is a vector operand, VF should be 1 or correspond to A. - assert((VF == 1 || VF == cast(VecTy)->getNumElements()) && + assert((VF == 1 || VF == VecTy->getNumElements()) && "Vector argument does not match VF"); } else @@ -608,17 +605,16 @@ return Cost; } - unsigned getScalarizationOverhead(Type *VecTy, ArrayRef Args) { + unsigned getScalarizationOverhead(VectorType *Ty, ArrayRef Args) { unsigned Cost = 0; - auto *VecVTy = cast(VecTy); - Cost += getScalarizationOverhead(VecVTy, true, false); + Cost += getScalarizationOverhead(Ty, true, false); if (!Args.empty()) - Cost += getOperandsScalarizationOverhead(Args, VecVTy->getNumElements()); + Cost += getOperandsScalarizationOverhead(Args, Ty->getNumElements()); else // When no information on arguments is provided, we add the cost // associated with one argument as a heuristic. - Cost += getScalarizationOverhead(VecVTy, false, true); + Cost += getScalarizationOverhead(Ty, false, true); return Cost; } @@ -742,13 +738,16 @@ break; } + auto *SrcVTy = dyn_cast(Src); + auto *DstVTy = dyn_cast(Dst); + // If the cast is marked as legal (or promote) then assume low cost. if (SrcLT.first == DstLT.first && TLI->isOperationLegalOrPromote(ISD, DstLT.second)) return SrcLT.first; // Handle scalar conversions. - if (!Src->isVectorTy() && !Dst->isVectorTy()) { + if (!SrcVTy && !DstVTy) { // Scalar bitcasts are usually free. if (Opcode == Instruction::BitCast) return 0; @@ -763,9 +762,7 @@ } // Check vector-to-vector casts. - if (Dst->isVectorTy() && Src->isVectorTy()) { - auto *SrcVTy = cast(Src); - auto *DstVTy = cast(Dst); + if (DstVTy && SrcVTy) { // If the cast is between same-sized registers, then the check is simple. if (SrcLT.first == DstLT.first && SrcLT.second.getSizeInBits() == DstLT.second.getSizeInBits()) { @@ -819,19 +816,18 @@ // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. - return getScalarizationOverhead(Dst, true, true) + Num * Cost; + return getScalarizationOverhead(DstVTy, true, true) + Num * Cost; } // We already handled vector-to-vector and scalar-to-scalar conversions. // This // is where we handle bitcast between vectors and scalars. We need to assume // that the conversion is scalarized in one way or another. - if (Opcode == Instruction::BitCast) + if (Opcode == Instruction::BitCast) { // Illegal bitcasts are done by storing and loading from a stack slot. - return (Src->isVectorTy() ? getScalarizationOverhead(Src, false, true) - : 0) + - (Dst->isVectorTy() ? getScalarizationOverhead(Dst, true, false) - : 0); + return (SrcVTy ? getScalarizationOverhead(SrcVTy, false, true) : 0) + + (DstVTy ? getScalarizationOverhead(DstVTy, true, false) : 0); + } llvm_unreachable("Unhandled cast"); } @@ -923,7 +919,8 @@ if (LA != TargetLowering::Legal && LA != TargetLowering::Custom) { // This is a vector load/store for some illegal type that is scalarized. // We must account for the cost of building or decomposing the vector. - Cost += getScalarizationOverhead(Src, Opcode != Instruction::Store, + Cost += getScalarizationOverhead(cast(Src), + Opcode != Instruction::Store, Opcode == Instruction::Store); } } @@ -1118,7 +1115,8 @@ if (RetVF > 1 || VF > 1) { ScalarizationCost = 0; if (!RetTy->isVoidTy()) - ScalarizationCost += getScalarizationOverhead(RetTy, true, false); + ScalarizationCost += + getScalarizationOverhead(cast(RetTy), true, false); ScalarizationCost += getOperandsScalarizationOverhead(Args, VF); } @@ -1224,21 +1222,19 @@ unsigned ScalarizationCost = ScalarizationCostPassed; unsigned ScalarCalls = 1; Type *ScalarRetTy = RetTy; - if (RetTy->isVectorTy()) { + if (auto *RetVTy = dyn_cast(RetTy)) { if (ScalarizationCostPassed == std::numeric_limits::max()) - ScalarizationCost = getScalarizationOverhead(RetTy, true, false); - ScalarCalls = - std::max(ScalarCalls, cast(RetTy)->getNumElements()); + ScalarizationCost = getScalarizationOverhead(RetVTy, true, false); + ScalarCalls = std::max(ScalarCalls, RetVTy->getNumElements()); ScalarRetTy = RetTy->getScalarType(); } SmallVector ScalarTys; for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) { Type *Ty = Tys[i]; - if (Ty->isVectorTy()) { + if (auto *VTy = dyn_cast(Ty)) { if (ScalarizationCostPassed == std::numeric_limits::max()) - ScalarizationCost += getScalarizationOverhead(Ty, false, true); - ScalarCalls = - std::max(ScalarCalls, cast(Ty)->getNumElements()); + ScalarizationCost += getScalarizationOverhead(VTy, false, true); + ScalarCalls = std::max(ScalarCalls, VTy->getNumElements()); Ty = Ty->getScalarType(); } ScalarTys.push_back(Ty); @@ -1588,12 +1584,12 @@ // Else, assume that we need to scalarize this intrinsic. For math builtins // this will emit a costly libcall, adding call overhead and spills. Make it // very expensive. - if (RetTy->isVectorTy()) { + if (auto *RetVTy = dyn_cast(RetTy)) { unsigned ScalarizationCost = ((ScalarizationCostPassed != std::numeric_limits::max()) ? ScalarizationCostPassed - : getScalarizationOverhead(RetTy, true, false)); - unsigned ScalarCalls = cast(RetTy)->getNumElements(); + : getScalarizationOverhead(RetVTy, true, false)); + unsigned ScalarCalls = RetVTy->getNumElements(); SmallVector ScalarTys; for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) { Type *Ty = Tys[i]; @@ -1604,14 +1600,12 @@ unsigned ScalarCost = ConcreteTTI->getIntrinsicInstrCost( IID, RetTy->getScalarType(), ScalarTys, FMF, CostKind); for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) { - if (Tys[i]->isVectorTy()) { + if (auto *VTy = dyn_cast(Tys[i])) { if (ScalarizationCostPassed == std::numeric_limits::max()) - ScalarizationCost += getScalarizationOverhead(Tys[i], false, true); - ScalarCalls = - std::max(ScalarCalls, cast(Tys[i])->getNumElements()); + ScalarizationCost += getScalarizationOverhead(VTy, false, true); + ScalarCalls = std::max(ScalarCalls, VTy->getNumElements()); } } - return ScalarCalls * ScalarCost + ScalarizationCost; } diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -370,8 +370,10 @@ return TTIImpl->useColdCCForColdCall(F); } -unsigned TargetTransformInfo::getScalarizationOverhead( - Type *Ty, const APInt &DemandedElts, bool Insert, bool Extract) const { +unsigned +TargetTransformInfo::getScalarizationOverhead(VectorType *Ty, + const APInt &DemandedElts, + bool Insert, bool Extract) const { return TTIImpl->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract); } diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -807,7 +807,7 @@ CostKind); // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. - return BaseT::getScalarizationOverhead(Ty, Args) + Num * Cost; + return BaseT::getScalarizationOverhead(VTy, Args) + Num * Cost; } return BaseCost; @@ -899,7 +899,7 @@ // The scalarization cost should be a lot higher. We use the number of vector // elements plus the scalarization overhead. unsigned ScalarCost = - NumElems * LT.first + BaseT::getScalarizationOverhead(DataTy, {}); + NumElems * LT.first + BaseT::getScalarizationOverhead(VTy, {}); if (Alignment < EltSize / 8) return ScalarCost; diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -101,7 +101,7 @@ return true; } - unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract); unsigned getOperandsScalarizationOverhead(ArrayRef Args, unsigned VF); diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -115,7 +115,7 @@ return (8 * ST.getVectorLength()) / ElemWidth; } -unsigned HexagonTTIImpl::getScalarizationOverhead(Type *Ty, +unsigned HexagonTTIImpl::getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract) { return BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract); diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp --- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp +++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp @@ -464,7 +464,8 @@ return DivInstrCost; } else if (ST->hasVector()) { - unsigned VF = cast(Ty)->getNumElements(); + auto *VTy = cast(Ty); + unsigned VF = VTy->getNumElements(); unsigned NumVectors = getNumVectorRegs(Ty); // These vector operations are custom handled, but are still supported @@ -477,7 +478,7 @@ if (DivRemConstPow2) return (NumVectors * (SignedDivRem ? SDivPow2Cost : 1)); if (DivRemConst) - return VF * DivMulSeqCost + getScalarizationOverhead(Ty, Args); + return VF * DivMulSeqCost + getScalarizationOverhead(VTy, Args); if ((SignedDivRem || UnsignedDivRem) && VF > 4) // Temporary hack: disable high vectorization factors with integer // division/remainder, which will get scalarized and handled with @@ -500,7 +501,7 @@ // inserting and extracting the values. unsigned ScalarCost = getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind); - unsigned Cost = (VF * ScalarCost) + getScalarizationOverhead(Ty, Args); + unsigned Cost = (VF * ScalarCost) + getScalarizationOverhead(VTy, Args); // FIXME: VF 2 for these FP operations are currently just as // expensive as for VF 4. if (VF == 2) @@ -517,7 +518,7 @@ // There is no native support for FRem. if (Opcode == Instruction::FRem) { - unsigned Cost = (VF * LIBCALL_COST) + getScalarizationOverhead(Ty, Args); + unsigned Cost = (VF * LIBCALL_COST) + getScalarizationOverhead(VTy, Args); // FIXME: VF 2 for float is currently just as expensive as for VF 4. if (VF == 2 && ScalarBits == 32) Cost *= 2; @@ -724,8 +725,9 @@ } } else if (ST->hasVector()) { - assert (Dst->isVectorTy()); - unsigned VF = cast(Src)->getNumElements(); + auto *SrcVecTy = cast(Src); + auto *DstVecTy = cast(Dst); + unsigned VF = SrcVecTy->getNumElements(); unsigned NumDstVectors = getNumVectorRegs(Dst); unsigned NumSrcVectors = getNumVectorRegs(Src); @@ -781,8 +783,8 @@ (Opcode == Instruction::FPToSI || Opcode == Instruction::FPToUI)) NeedsExtracts = false; - TotCost += getScalarizationOverhead(Src, false, NeedsExtracts); - TotCost += getScalarizationOverhead(Dst, NeedsInserts, false); + TotCost += getScalarizationOverhead(SrcVecTy, false, NeedsExtracts); + TotCost += getScalarizationOverhead(DstVecTy, NeedsInserts, false); // FIXME: VF 2 for float<->i32 is currently just as expensive as for VF 4. if (VF == 2 && SrcScalarBits == 32 && DstScalarBits == 32) @@ -793,7 +795,8 @@ if (Opcode == Instruction::FPTrunc) { if (SrcScalarBits == 128) // fp128 -> double/float + inserts of elements. - return VF /*ldxbr/lexbr*/ + getScalarizationOverhead(Dst, true, false); + return VF /*ldxbr/lexbr*/ + + getScalarizationOverhead(DstVecTy, true, false); else // double -> float return VF / 2 /*vledb*/ + std::max(1U, VF / 4 /*vperm*/); } @@ -806,7 +809,7 @@ return VF * 2; } // -> fp128. VF * lxdb/lxeb + extraction of elements. - return VF + getScalarizationOverhead(Src, false, true); + return VF + getScalarizationOverhead(SrcVecTy, false, true); } } diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h --- a/llvm/lib/Target/X86/X86TargetTransformInfo.h +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -135,7 +135,7 @@ TTI::TargetCostKind CostKind, const Instruction *I = nullptr); int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); - unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts, + unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract); int getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment, unsigned AddressSpace, diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -2888,10 +2888,9 @@ return BaseT::getVectorInstrCost(Opcode, Val, Index) + RegisterFileMoveCost; } -unsigned X86TTIImpl::getScalarizationOverhead(Type *Ty, +unsigned X86TTIImpl::getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract) { - auto* VecTy = cast(Ty); unsigned Cost = 0; // For insertions, a ISD::BUILD_VECTOR style vector initialization can be much @@ -2917,7 +2916,7 @@ // 128-bit vector is free. // NOTE: This assumes legalization widens vXf32 vectors. if (MScalarTy == MVT::f32) - for (unsigned i = 0, e = VecTy->getNumElements(); i < e; i += 4) + for (unsigned i = 0, e = Ty->getNumElements(); i < e; i += 4) if (DemandedElts[i]) Cost--; } @@ -2933,7 +2932,7 @@ // vector elements, which represents the number of unpacks we'll end up // performing. unsigned NumElts = LT.second.getVectorNumElements(); - unsigned Pow2Elts = PowerOf2Ceil(VecTy->getNumElements()); + unsigned Pow2Elts = PowerOf2Ceil(Ty->getNumElements()); Cost += (std::min(NumElts, Pow2Elts) - 1) * LT.first; } } @@ -2970,7 +2969,7 @@ APInt DemandedElts = APInt::getAllOnesValue(NumElem); int Cost = BaseT::getMemoryOpCost(Opcode, VTy->getScalarType(), Alignment, AddressSpace, CostKind); - int SplitCost = getScalarizationOverhead(Src, DemandedElts, + int SplitCost = getScalarizationOverhead(VTy, DemandedElts, Opcode == Instruction::Load, Opcode == Instruction::Store); return NumElem * Cost + SplitCost; diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -5702,9 +5702,9 @@ // Compute the scalarization overhead of needed insertelement instructions // and phi nodes. if (isScalarWithPredication(I) && !I->getType()->isVoidTy()) { - ScalarCost += - TTI.getScalarizationOverhead(ToVectorTy(I->getType(), VF), - APInt::getAllOnesValue(VF), true, false); + ScalarCost += TTI.getScalarizationOverhead( + cast(ToVectorTy(I->getType(), VF)), + APInt::getAllOnesValue(VF), true, false); ScalarCost += VF * TTI.getCFInstrCost(Instruction::PHI); } @@ -5720,8 +5720,8 @@ Worklist.push_back(J); else if (needsExtract(J, VF)) ScalarCost += TTI.getScalarizationOverhead( - ToVectorTy(J->getType(), VF), APInt::getAllOnesValue(VF), false, - true); + cast(ToVectorTy(J->getType(), VF)), + APInt::getAllOnesValue(VF), false, true); } // Scale the total scalar cost by block probability. @@ -6016,8 +6016,8 @@ Type *RetTy = ToVectorTy(I->getType(), VF); if (!RetTy->isVoidTy() && (!isa(I) || !TTI.supportsEfficientVectorElementLoadStore())) - Cost += TTI.getScalarizationOverhead(RetTy, APInt::getAllOnesValue(VF), - true, false); + Cost += TTI.getScalarizationOverhead( + cast(RetTy), APInt::getAllOnesValue(VF), true, false); // Some targets keep addresses scalar. if (isa(I) && !TTI.prefersVectorizedAddressing()) @@ -6222,7 +6222,7 @@ if (ScalarPredicatedBB) { // Return cost for branches around scalarized and predicated blocks. - Type *Vec_i1Ty = + VectorType *Vec_i1Ty = VectorType::get(IntegerType::getInt1Ty(RetTy->getContext()), VF); return (TTI.getScalarizationOverhead(Vec_i1Ty, APInt::getAllOnesValue(VF), false, true) +