diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1021,10 +1021,47 @@ int getShuffleCost(ShuffleKind Kind, VectorType *Tp, int Index = 0, VectorType *SubTp = nullptr) const; + /// Represents a hint about the context in which a cast is used. + /// + /// For zext/sext, the context of the cast is the operand, which must be a + /// load of some kind. For trunc, the context is of the cast is the single + /// user of the instruction, which must be a store of some kind. + /// + /// This enum allows the vectorizer to give getCastInstrCost an idea of the + /// type of cast it's dealing with, as not every cast is equal. For instance, + /// the zext of a load may be free, but the zext of an interleaving load can + //// be (very) expensive! + /// + /// See \c getCastContextHint to compute a CastContextHint from a cast + /// Instruction*. Callers can use it if they don't need to override the + /// context and just want it to be calculated from the instruction. + /// + /// FIXME: This handles the types of load/store that the vectorizer can + /// produce, which are the cases where the context instruction is most + /// likely to be incorrect. There are other situations where that can happen + /// too, which might be handled here but in the long run a more general + /// solution of costing multiple instructions at the same times may be better. + enum class CastContextHint : uint8_t { + None, ///< The cast is not used with a load/store of any kind. + Normal, ///< The cast is used with a normal load/store. + Masked, ///< The cast is used with a masked load/store. + GatherScatter, ///< The cast is used with a gather/scatter. + Interleave, ///< The cast is used with an interleaved load/store. + Reversed, ///< The cast is used with a reversed load/store. + }; + + /// Calculates a CastContextHint from \p I. + /// This should be used by callers of getCastInstrCost if they wish to + /// determine the context from some instruction. + /// \returns the CastContextHint for ZExt/SExt/Trunc, None if \p I is nullptr, + /// or if it's another type of cast. + static CastContextHint getCastContextHint(const Instruction *I); + /// \return The expected cost of cast instructions, such as bitcast, trunc, /// zext, etc. If there is an existing instruction that holds Opcode, it /// may be passed in the 'I' parameter. int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind = TTI::TCK_SizeAndLatency, const Instruction *I = nullptr) const; @@ -1454,6 +1491,7 @@ virtual int getShuffleCost(ShuffleKind Kind, VectorType *Tp, int Index, VectorType *SubTp) = 0; virtual int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) = 0; virtual int getExtractWithExtendCost(unsigned Opcode, Type *Dst, @@ -1882,9 +1920,9 @@ return Impl.getShuffleCost(Kind, Tp, Index, SubTp); } int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, + CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) override { - return Impl.getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return Impl.getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); } int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, unsigned Index) override { diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -423,6 +423,7 @@ } unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { switch (Opcode) { @@ -915,7 +916,8 @@ case Instruction::SExt: case Instruction::ZExt: case Instruction::AddrSpaceCast: - return TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I); + return TargetTTI->getCastInstrCost( + Opcode, Ty, OpTy, TTI::getCastContextHint(I), CostKind, I); case Instruction::Store: { auto *SI = cast(U); Type *ValTy = U->getOperand(0)->getType(); diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -716,9 +716,10 @@ } unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr) { - if (BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I) == 0) + if (BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I) == 0) return 0; const TargetLoweringBase *TLI = getTLI(); @@ -756,15 +757,12 @@ return 0; LLVM_FALLTHROUGH; case Instruction::SExt: - if (!I) - break; - - if (getTLI()->isExtFree(I)) + if (I && getTLI()->isExtFree(I)) return 0; // If this is a zext/sext of a load, return 0 if the corresponding // extending load exists on target. - if (I && isa(I->getOperand(0))) { + if (CCH == TTI::CastContextHint::Normal) { EVT ExtVT = EVT::getEVT(Dst); EVT LoadVT = EVT::getEVT(Src); unsigned LType = @@ -839,7 +837,7 @@ unsigned SplitCost = (!SplitSrc || !SplitDst) ? TTI->getVectorSplitCost() : 0; return SplitCost + - (2 * TTI->getCastInstrCost(Opcode, SplitDstTy, SplitSrcTy, + (2 * TTI->getCastInstrCost(Opcode, SplitDstTy, SplitSrcTy, CCH, CostKind, I)); } @@ -847,7 +845,7 @@ // the operation will get scalarized. unsigned Num = cast(DstVTy)->getNumElements(); unsigned Cost = thisT()->getCastInstrCost( - Opcode, Dst->getScalarType(), Src->getScalarType(), CostKind, I); + Opcode, Dst->getScalarType(), Src->getScalarType(), CCH, CostKind, I); // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. @@ -872,7 +870,7 @@ return thisT()->getVectorInstrCost(Instruction::ExtractElement, VecTy, Index) + thisT()->getCastInstrCost(Opcode, Dst, VecTy->getElementType(), - TTI::TCK_RecipThroughput); + TTI::CastContextHint::None, TTI::TCK_RecipThroughput); } unsigned getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind) { @@ -1522,13 +1520,14 @@ unsigned ExtOp = IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt; + TTI::CastContextHint CCH = TTI::CastContextHint::None; unsigned Cost = 0; - Cost += 2 * thisT()->getCastInstrCost(ExtOp, ExtTy, RetTy, CostKind); + Cost += 2 * thisT()->getCastInstrCost(ExtOp, ExtTy, RetTy, CCH, CostKind); Cost += thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind); Cost += 2 * thisT()->getCastInstrCost(Instruction::Trunc, RetTy, ExtTy, - CostKind); + CCH, CostKind); Cost += thisT()->getArithmeticInstrCost(Instruction::LShr, RetTy, CostKind, TTI::OK_AnyValue, TTI::OK_UniformConstantValue); @@ -1587,13 +1586,14 @@ unsigned ExtOp = IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt; + TTI::CastContextHint CCH = TTI::CastContextHint::None; unsigned Cost = 0; - Cost += 2 * thisT()->getCastInstrCost(ExtOp, ExtTy, MulTy, CostKind); + Cost += 2 * thisT()->getCastInstrCost(ExtOp, ExtTy, MulTy, CCH, CostKind); Cost += thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind); Cost += 2 * thisT()->getCastInstrCost(Instruction::Trunc, MulTy, ExtTy, - CostKind); + CCH, CostKind); Cost += thisT()->getArithmeticInstrCost(Instruction::LShr, MulTy, CostKind, TTI::OK_AnyValue, TTI::OK_UniformConstantValue); diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -730,12 +730,57 @@ return Cost; } +TTI::CastContextHint +TargetTransformInfo::getCastContextHint(const Instruction *I) { + if (!I) + return CastContextHint::None; + + auto getLoadStoreKind = [](const Value *V, unsigned LdStOp, unsigned MaskedOp, + unsigned GatScatOp) { + const Instruction *I = dyn_cast(V); + if (!I) + return CastContextHint::None; + + if (I->getOpcode() == LdStOp) + return CastContextHint::Normal; + + if (const IntrinsicInst *II = dyn_cast(I)) { + if (II->getIntrinsicID() == MaskedOp) + return TTI::CastContextHint::Masked; + if (II->getIntrinsicID() == GatScatOp) + return TTI::CastContextHint::GatherScatter; + } + + return TTI::CastContextHint::None; + }; + + switch (I->getOpcode()) { + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPExt: + return getLoadStoreKind(I->getOperand(0), Instruction::Load, + Intrinsic::masked_load, Intrinsic::masked_gather); + case Instruction::Trunc: + case Instruction::FPTrunc: + if (I->hasOneUse()) + return getLoadStoreKind(*I->user_begin(), Instruction::Store, + Intrinsic::masked_store, + Intrinsic::masked_scatter); + break; + default: + return CastContextHint::None; + } + + return TTI::CastContextHint::None; +} + int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) const { assert((I == nullptr || I->getOpcode() == Opcode) && "Opcode should reflect passed instruction."); - int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, CostKind, I); + int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -114,7 +114,7 @@ unsigned getMaxInterleaveFactor(unsigned VF); int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr); int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -270,6 +270,7 @@ } int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { int ISD = TLI->InstructionOpcodeToISD(Opcode); @@ -306,7 +307,8 @@ EVT DstTy = TLI->getValueType(DL, Dst); if (!SrcTy.isSimple() || !DstTy.isSimple()) - return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I)); + return AdjustCost( + BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); static const TypeConversionCostTblEntry ConversionTbl[] = { @@ -410,7 +412,8 @@ SrcTy.getSimpleVT())) return AdjustCost(Entry->Cost); - return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I)); + return AdjustCost( + BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); } int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst, @@ -442,12 +445,14 @@ // we may get the extension for free. If not, get the default cost for the // extend. if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT)) - return Cost + getCastInstrCost(Opcode, Dst, Src, CostKind); + return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None, + CostKind); // The destination type should be larger than the element type. If not, get // the default cost for the extend. if (DstVT.getSizeInBits() < SrcVT.getSizeInBits()) - return Cost + getCastInstrCost(Opcode, Dst, Src, CostKind); + return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None, + CostKind); switch (Opcode) { default: @@ -466,7 +471,8 @@ } // If we are unable to perform the extend for free, get the default cost. - return Cost + getCastInstrCost(Opcode, Dst, Src, CostKind); + return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None, + CostKind); } unsigned AArch64TTIImpl::getCFInstrCost(unsigned Opcode, diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h @@ -210,7 +210,7 @@ } int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -301,6 +301,7 @@ } int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { int ISD = TLI->InstructionOpcodeToISD(Opcode); @@ -317,7 +318,8 @@ EVT DstTy = TLI->getValueType(DL, Dst); if (!SrcTy.isSimple() || !DstTy.isSimple()) - return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I)); + return AdjustCost( + BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); // The extend of a load is free if (I && isa(I->getOperand(0))) { @@ -388,8 +390,8 @@ }; if (SrcTy.isVector() && ST->hasMVEIntegerOps()) { if (const auto *Entry = - ConvertCostTableLookup(MVELoadConversionTbl, ISD, SrcTy.getSimpleVT(), - DstTy.getSimpleVT())) + ConvertCostTableLookup(MVELoadConversionTbl, ISD, + SrcTy.getSimpleVT(), DstTy.getSimpleVT())) return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor()); } @@ -399,8 +401,8 @@ }; if (SrcTy.isVector() && ST->hasMVEFloatOps()) { if (const auto *Entry = - ConvertCostTableLookup(MVEFLoadConversionTbl, ISD, SrcTy.getSimpleVT(), - DstTy.getSimpleVT())) + ConvertCostTableLookup(MVEFLoadConversionTbl, ISD, + SrcTy.getSimpleVT(), DstTy.getSimpleVT())) return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor()); } } @@ -672,7 +674,7 @@ ? ST->getMVEVectorCostFactor() : 1; return AdjustCost( - BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I)); + BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); } int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy, diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -146,8 +146,9 @@ ArrayRef Args = ArrayRef(), const Instruction *CxtI = nullptr); unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, - const Instruction *I = nullptr); + TTI::CastContextHint CCH, + TTI::TargetCostKind CostKind, + const Instruction *I = nullptr); unsigned getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); unsigned getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind) { diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -270,7 +270,9 @@ } unsigned HexagonTTIImpl::getCastInstrCost(unsigned Opcode, Type *DstTy, - Type *SrcTy, TTI::TargetCostKind CostKind, const Instruction *I) { + Type *SrcTy, TTI::CastContextHint CCH, + TTI::TargetCostKind CostKind, + const Instruction *I) { if (SrcTy->isFPOrFPVectorTy() || DstTy->isFPOrFPVectorTy()) { unsigned SrcN = SrcTy->isFPOrFPVectorTy() ? getTypeNumElements(SrcTy) : 0; unsigned DstN = DstTy->isFPOrFPVectorTy() ? getTypeNumElements(DstTy) : 0; diff --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h --- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h +++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h @@ -106,7 +106,7 @@ const Instruction *CxtI = nullptr); int getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index, Type *SubTp); int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr); int getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, diff --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp --- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp +++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp @@ -879,11 +879,12 @@ } int PPCTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode"); - int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); Cost = vectorCostAdjustment(Cost, Opcode, Dst, Src); // TODO: Allow non-throughput costs that aren't binary. if (CostKind != TTI::TCK_RecipThroughput) diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h --- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h +++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h @@ -93,7 +93,7 @@ unsigned getBoolVecToIntConversionCost(unsigned Opcode, Type *Dst, const Instruction *I); int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp --- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp +++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp @@ -699,11 +699,12 @@ } int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { // FIXME: Can the logic below also be used for these cost kinds? if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency) { - int BaseCost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + int BaseCost = BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); return BaseCost == 0 ? BaseCost : 1; } @@ -786,8 +787,8 @@ // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. Base implementation does not // realize float->int gets scalarized. - unsigned ScalarCost = getCastInstrCost(Opcode, Dst->getScalarType(), - Src->getScalarType(), CostKind); + unsigned ScalarCost = getCastInstrCost( + Opcode, Dst->getScalarType(), Src->getScalarType(), CCH, CostKind); unsigned TotCost = VF * ScalarCost; bool NeedsInserts = true, NeedsExtracts = true; // FP128 registers do not get inserted or extracted. @@ -828,7 +829,7 @@ } } - return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); } // Scalar i8 / i16 operations will typically be made after first extending diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h --- a/llvm/lib/Target/X86/X86TargetTransformInfo.h +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -130,7 +130,7 @@ int getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, int Index, VectorType *SubTp); int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -1367,6 +1367,7 @@ } int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { int ISD = TLI->InstructionOpcodeToISD(Opcode); @@ -1988,7 +1989,7 @@ // The function getSimpleVT only handles simple value types. if (!SrcTy.isSimple() || !DstTy.isSimple()) - return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind)); + return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind)); MVT SimpleSrcTy = SrcTy.getSimpleVT(); MVT SimpleDstTy = DstTy.getSimpleVT(); @@ -2049,7 +2050,8 @@ return AdjustCost(Entry->Cost); } - return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I)); + return AdjustCost( + BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); } int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp --- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -2025,8 +2025,8 @@ Type *SrcTy = CI->getOperand(0)->getType(); Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy, - TargetTransformInfo::TCK_SizeAndLatency, - CI); + TTI::getCastContextHint(CI), + TargetTransformInfo::TCK_SizeAndLatency, CI); } else if (GetElementPtrInst *GEP = dyn_cast(Instr)) { // Cost of the address calculation diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -2150,8 +2150,9 @@ llvm_unreachable("There are no other cast types."); } const SCEV *Op = CastExpr->getOperand(); - BudgetRemaining -= TTI.getCastInstrCost(Opcode, /*Dst=*/S->getType(), - /*Src=*/Op->getType(), CostKind); + BudgetRemaining -= TTI.getCastInstrCost( + Opcode, /*Dst=*/S->getType(), + /*Src=*/Op->getType(), TTI::CastContextHint::None, CostKind); Worklist.emplace_back(Op); return false; // Will answer upon next entry into this function. } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -6458,13 +6458,54 @@ case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { + // Computes the CastContextHint from a Load/Store instruction. + auto ComputeCCH = [&](Instruction *I) -> TTI::CastContextHint { + assert((isa(I) || isa(I)) && + "Expected a load or a store!"); + + if (VF == 1) + return TTI::CastContextHint::Normal; + + switch (getWideningDecision(I, VF)) { + case LoopVectorizationCostModel::CM_GatherScatter: + return TTI::CastContextHint::GatherScatter; + case LoopVectorizationCostModel::CM_Interleave: + return TTI::CastContextHint::Interleave; + case LoopVectorizationCostModel::CM_Scalarize: + case LoopVectorizationCostModel::CM_Widen: + return Legal->isMaskRequired(I) ? TTI::CastContextHint::Masked + : TTI::CastContextHint::Normal; + case LoopVectorizationCostModel::CM_Widen_Reverse: + return TTI::CastContextHint::Reversed; + case LoopVectorizationCostModel::CM_Unknown: + llvm_unreachable("Instr did not go through cost modelling?"); + } + + llvm_unreachable("Unhandled case!"); + }; + + unsigned Opcode = I->getOpcode(); + TTI::CastContextHint CCH = TTI::CastContextHint::None; + // For Trunc, the context is the only user, which must be a StoreInst. + if (Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) { + if (I->hasOneUse()) + if (StoreInst *Store = dyn_cast(*I->user_begin())) + CCH = ComputeCCH(Store); + } + // For Z/Sext, the context is the operand, which must be a LoadInst. + else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt || + Opcode == Instruction::FPExt) { + if (LoadInst *Load = dyn_cast(I->getOperand(0))) + CCH = ComputeCCH(Load); + } + // We optimize the truncation of induction variables having constant // integer steps. The cost of these truncations is the same as the scalar // operation. if (isOptimizableIVTruncate(I, VF)) { auto *Trunc = cast(I); return TTI.getCastInstrCost(Instruction::Trunc, Trunc->getDestTy(), - Trunc->getSrcTy(), CostKind, Trunc); + Trunc->getSrcTy(), CCH, CostKind, Trunc); } Type *SrcScalarTy = I->getOperand(0)->getType(); @@ -6477,12 +6518,11 @@ // // Calculate the modified src and dest types. Type *MinVecTy = VectorTy; - if (I->getOpcode() == Instruction::Trunc) { + if (Opcode == Instruction::Trunc) { SrcVecTy = smallestIntegerVectorType(SrcVecTy, MinVecTy); VectorTy = largestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy); - } else if (I->getOpcode() == Instruction::ZExt || - I->getOpcode() == Instruction::SExt) { + } else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt) { SrcVecTy = largestIntegerVectorType(SrcVecTy, MinVecTy); VectorTy = smallestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy); @@ -6490,8 +6530,8 @@ } unsigned N = isScalarAfterVectorization(I, VF) ? VF : 1; - return N * TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy, - CostKind, I); + return N * + TTI.getCastInstrCost(Opcode, VectorTy, SrcVecTy, CCH, CostKind, I); } case Instruction::Call: { bool NeedToScalarize; diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -3399,8 +3399,8 @@ Ext->getOpcode(), Ext->getType(), VecTy, i); // Add back the cost of s|zext which is subtracted separately. DeadCost += TTI->getCastInstrCost( - Ext->getOpcode(), Ext->getType(), E->getType(), CostKind, - Ext); + Ext->getOpcode(), Ext->getType(), E->getType(), + TTI::getCastContextHint(Ext), CostKind, Ext); continue; } } @@ -3424,8 +3424,8 @@ case Instruction::BitCast: { Type *SrcTy = VL0->getOperand(0)->getType(); int ScalarEltCost = - TTI->getCastInstrCost(E->getOpcode(), ScalarTy, SrcTy, CostKind, - VL0); + TTI->getCastInstrCost(E->getOpcode(), ScalarTy, SrcTy, + TTI::getCastContextHint(VL0), CostKind, VL0); if (NeedToShuffleReuses) { ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; } @@ -3437,9 +3437,10 @@ int VecCost = 0; // Check if the values are candidates to demote. if (!MinBWs.count(VL0) || VecTy != SrcVecTy) { - VecCost = ReuseShuffleCost + - TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy, - CostKind, VL0); + VecCost = + ReuseShuffleCost + + TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy, + TTI::getCastContextHint(VL0), CostKind, VL0); } return VecCost - ScalarCost; } @@ -3644,9 +3645,9 @@ auto *Src0Ty = FixedVectorType::get(Src0SclTy, VL.size()); auto *Src1Ty = FixedVectorType::get(Src1SclTy, VL.size()); VecCost = TTI->getCastInstrCost(E->getOpcode(), VecTy, Src0Ty, - CostKind); + TTI::CastContextHint::None, CostKind); VecCost += TTI->getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty, - CostKind); + TTI::CastContextHint::None, CostKind); } VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_Select, VecTy, 0); return ReuseShuffleCost + VecCost - ScalarCost;