Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -956,10 +956,43 @@ int getShuffleCost(ShuffleKind Kind, VectorType *Tp, int Index = 0, VectorType *SubTp = nullptr) const; + /// Represents a hint about the context in which a cast is used. + /// + /// For zext/sext, the context of the cast is the operand, which must be a + /// load of some kind. + /// + /// For trunc, the context is of the cast is the single user of the + /// instruction, which must be a store of some kind. + /// + /// This enum allows the vectorizer to give getCastInstrCost an idea of the + /// type of cast it's dealing with, as not every cast is equal. For instance, + /// the zext of a load may be free, but the zext of a masked load can be + /// (very) expensive! + /// + /// See \c getCastContextHint to compute a CastContextHint from a cast + /// Instruction*. Callers can use it if they don't need to override the + /// context and just want it to be calculated from the instruction. + enum class CastContextHint : uint8_t { + None, ///< The cast is not used with a load/store of any kind. + Normal, ///< The cast is used with a normal load/store. + Masked, ///< The cast is used with a masked load/store. + GatherScatter, ///< The cast is used with a gather/scatter. + Interleave, ///< The cast is used with an interleaved load/store. + Reversed, ///< The cast is used with a reversed load/store. + }; + + /// Calculates a CastContextHint from \p I. + /// This should be used by callers of getCastInstrCost if they wish to + /// determine the context from some instruction. + /// \returns the CastContextHint for ZExt/SExt/Trunc, None if \p I is nullptr, + /// or if it's another type of cast. + static CastContextHint getCastContextHint(const Instruction *I); + /// \return The expected cost of cast instructions, such as bitcast, trunc, /// zext, etc. If there is an existing instruction that holds Opcode, it /// may be passed in the 'I' parameter. int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind = TTI::TCK_SizeAndLatency, const Instruction *I = nullptr) const; @@ -1382,6 +1415,7 @@ virtual int getShuffleCost(ShuffleKind Kind, VectorType *Tp, int Index, VectorType *SubTp) = 0; virtual int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) = 0; virtual int getExtractWithExtendCost(unsigned Opcode, Type *Dst, @@ -1794,9 +1828,9 @@ return Impl.getShuffleCost(Kind, Tp, Index, SubTp); } int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, + CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) override { - return Impl.getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return Impl.getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); } int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, unsigned Index) override { Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -379,6 +379,7 @@ } unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { switch (Opcode) { @@ -857,14 +858,12 @@ case Instruction::PtrToInt: case Instruction::Trunc: case Instruction::BitCast: - if (TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I) == - TTI::TCC_Free) - return TTI::TCC_Free; - break; case Instruction::FPExt: case Instruction::SExt: case Instruction::ZExt: - if (TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I) == TTI::TCC_Free) + if (TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, + TTI::getCastContextHint(I), CostKind, + I) == TTI::TCC_Free) return TTI::TCC_Free; break; } Index: llvm/include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -682,9 +682,10 @@ } unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr) { - if (BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I) == 0) + if (BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I) == 0) return 0; const TargetLoweringBase *TLI = getTLI(); @@ -730,7 +731,7 @@ // If this is a zext/sext of a load, return 0 if the corresponding // extending load exists on target. - if (I && isa(I->getOperand(0))) { + if (I && CCH == TTI::CastContextHint::Normal) { EVT ExtVT = EVT::getEVT(Dst); EVT LoadVT = EVT::getEVT(Src); unsigned LType = @@ -807,7 +808,7 @@ unsigned SplitCost = (!SplitSrc || !SplitDst) ? TTI->getVectorSplitCost() : 0; return SplitCost + - (2 * TTI->getCastInstrCost(Opcode, SplitDstTy, SplitSrcTy, + (2 * TTI->getCastInstrCost(Opcode, SplitDstTy, SplitSrcTy, CCH, CostKind, I)); } @@ -815,8 +816,7 @@ // the operation will get scalarized. unsigned Num = DstVTy->getNumElements(); unsigned Cost = static_cast(this)->getCastInstrCost( - Opcode, Dst->getScalarType(), Src->getScalarType(), - CostKind, I); + Opcode, Dst->getScalarType(), Src->getScalarType(), CCH, CostKind, I); // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. @@ -840,9 +840,9 @@ VectorType *VecTy, unsigned Index) { return static_cast(this)->getVectorInstrCost( Instruction::ExtractElement, VecTy, Index) + - static_cast(this)->getCastInstrCost(Opcode, Dst, - VecTy->getElementType(), - TTI::TCK_RecipThroughput); + static_cast(this)->getCastInstrCost( + Opcode, Dst, VecTy->getElementType(), TTI::CastContextHint::None, + TTI::TCK_RecipThroughput); } unsigned getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind) { @@ -1442,14 +1442,15 @@ unsigned ExtOp = IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt; + TTI::CastContextHint CCH = TTI::CastContextHint::None; unsigned Cost = 0; - Cost += 2 * ConcreteTTI->getCastInstrCost(ExtOp, ExtTy, RetTy, CostKind); + Cost += + 2 * ConcreteTTI->getCastInstrCost(ExtOp, ExtTy, RetTy, CCH, CostKind); Cost += ConcreteTTI->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind); - Cost += - 2 * ConcreteTTI->getCastInstrCost(Instruction::Trunc, RetTy, ExtTy, - CostKind); + Cost += 2 * ConcreteTTI->getCastInstrCost(Instruction::Trunc, RetTy, + ExtTy, CCH, CostKind); Cost += ConcreteTTI->getArithmeticInstrCost(Instruction::LShr, RetTy, CostKind, TTI::OK_AnyValue, @@ -1512,14 +1513,15 @@ unsigned ExtOp = IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt; + TTI::CastContextHint CCH = TTI::CastContextHint::None; unsigned Cost = 0; - Cost += 2 * ConcreteTTI->getCastInstrCost(ExtOp, ExtTy, MulTy, CostKind); + Cost += + 2 * ConcreteTTI->getCastInstrCost(ExtOp, ExtTy, MulTy, CCH, CostKind); Cost += ConcreteTTI->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind); - Cost += - 2 * ConcreteTTI->getCastInstrCost(Instruction::Trunc, MulTy, ExtTy, - CostKind); + Cost += 2 * ConcreteTTI->getCastInstrCost(Instruction::Trunc, MulTy, + ExtTy, CCH, CostKind); Cost += ConcreteTTI->getArithmeticInstrCost(Instruction::LShr, MulTy, CostKind, TTI::OK_AnyValue, Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -686,12 +686,76 @@ return Cost; } +static TTI::CastContextHint computeExtCastContextHint(const Instruction *I) { + const Instruction *CtxInst = dyn_cast(I->getOperand(0)); + if (!CtxInst) + return TTI::CastContextHint::None; + + // TODO: Detect interleave and reverse. + + if (isa(CtxInst)) + return TTI::CastContextHint::Normal; + + if (const IntrinsicInst *CtxIntrinsic = dyn_cast(CtxInst)) { + switch (CtxIntrinsic->getIntrinsicID()) { + case Intrinsic::masked_load: + return TTI::CastContextHint::Masked; + case Intrinsic::masked_gather: + return TTI::CastContextHint::GatherScatter; + } + } + + return TTI::CastContextHint::None; +} + +static TTI::CastContextHint computeTruncCastContextHint(const Instruction *I) { + const Instruction *CtxInst = nullptr; + if (I->hasOneUse()) + CtxInst = dyn_cast(*I->user_begin()); + + if (!CtxInst) + return TTI::CastContextHint::None; + + // TODO: Detect interleave and reverse. + + if (isa(CtxInst)) + return TTI::CastContextHint::Normal; + + if (const IntrinsicInst *Intrinsic = dyn_cast(CtxInst)) { + switch (Intrinsic->getIntrinsicID()) { + case Intrinsic::masked_store: + return TTI::CastContextHint::Masked; + case Intrinsic::masked_scatter: + return TTI::CastContextHint::GatherScatter; + } + } + + return TTI::CastContextHint::None; +} + +TTI::CastContextHint +TargetTransformInfo::getCastContextHint(const Instruction *I) { + if (!I) + return CastContextHint::None; + + switch (I->getOpcode()) { + case Instruction::ZExt: + case Instruction::SExt: + return computeExtCastContextHint(I); + case Instruction::Trunc: + return computeTruncCastContextHint(I); + default: + return CastContextHint::None; + } +} + int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) const { assert((I == nullptr || I->getOpcode() == Opcode) && "Opcode should reflect passed instruction."); - int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, CostKind, I); + int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } @@ -1322,7 +1386,8 @@ case Instruction::BitCast: case Instruction::AddrSpaceCast: { Type *SrcTy = I->getOperand(0)->getType(); - return getCastInstrCost(I->getOpcode(), I->getType(), SrcTy, CostKind, I); + return getCastInstrCost(I->getOpcode(), I->getType(), SrcTy, + TTI::getCastContextHint(I), CostKind, I); } case Instruction::ExtractElement: { const ExtractElementInst *EEI = cast(I); Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -112,7 +112,7 @@ unsigned getMaxInterleaveFactor(unsigned VF); int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr); int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -270,6 +270,7 @@ } int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { int ISD = TLI->InstructionOpcodeToISD(Opcode); @@ -299,7 +300,7 @@ EVT DstTy = TLI->getValueType(DL, Dst); if (!SrcTy.isSimple() || !DstTy.isSimple()) - return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind); static const TypeConversionCostTblEntry ConversionTbl[] = { @@ -403,7 +404,7 @@ SrcTy.getSimpleVT())) return Entry->Cost; - return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); } int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst, @@ -435,12 +436,14 @@ // we may get the extension for free. If not, get the default cost for the // extend. if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT)) - return Cost + getCastInstrCost(Opcode, Dst, Src, CostKind); + return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None, + CostKind); // The destination type should be larger than the element type. If not, get // the default cost for the extend. if (DstVT.getSizeInBits() < SrcVT.getSizeInBits()) - return Cost + getCastInstrCost(Opcode, Dst, Src, CostKind); + return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None, + CostKind); switch (Opcode) { default: @@ -459,7 +462,8 @@ } // If we are unable to perform the extend for free, get the default cost. - return Cost + getCastInstrCost(Opcode, Dst, Src, CostKind); + return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None, + CostKind); } int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, Index: llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp +++ llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp @@ -1014,7 +1014,8 @@ case Instruction::BitCast: case Instruction::AddrSpaceCast: { return getCastInstrCost(I->getOpcode(), I->getType(), - I->getOperand(0)->getType(), CostKind, I); + I->getOperand(0)->getType(), + TTI::getCastContextHint(I), CostKind, I); } case Instruction::Add: case Instruction::FAdd: Index: llvm/lib/Target/ARM/ARMTargetTransformInfo.h =================================================================== --- llvm/lib/Target/ARM/ARMTargetTransformInfo.h +++ llvm/lib/Target/ARM/ARMTargetTransformInfo.h @@ -197,7 +197,7 @@ } int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, Index: llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -168,6 +168,7 @@ } int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { int ISD = TLI->InstructionOpcodeToISD(Opcode); @@ -192,7 +193,7 @@ EVT DstTy = TLI->getValueType(DL, Dst); if (!SrcTy.isSimple() || !DstTy.isSimple()) - return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); // The extend of a load is free if (I && isa(I->getOperand(0))) { @@ -458,7 +459,7 @@ int BaseCost = ST->hasMVEIntegerOps() && Src->isVectorTy() ? ST->getMVEVectorCostFactor() : 1; - return BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); } int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy, Index: llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h =================================================================== --- llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -142,8 +142,9 @@ ArrayRef Args = ArrayRef(), const Instruction *CxtI = nullptr); unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, - const Instruction *I = nullptr); + TTI::CastContextHint CCH, + TTI::TargetCostKind CostKind, + const Instruction *I = nullptr); unsigned getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); unsigned getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind) { Index: llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -256,7 +256,9 @@ } unsigned HexagonTTIImpl::getCastInstrCost(unsigned Opcode, Type *DstTy, - Type *SrcTy, TTI::TargetCostKind CostKind, const Instruction *I) { + Type *SrcTy, TTI::CastContextHint CCH, + TTI::TargetCostKind CostKind, + const Instruction *I) { if (SrcTy->isFPOrFPVectorTy() || DstTy->isFPOrFPVectorTy()) { unsigned SrcN = SrcTy->isFPOrFPVectorTy() ? getTypeNumElements(SrcTy) : 0; unsigned DstN = DstTy->isFPOrFPVectorTy() ? getTypeNumElements(DstTy) : 0; Index: llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h =================================================================== --- llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h +++ llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h @@ -101,7 +101,7 @@ const Instruction *CxtI = nullptr); int getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index, Type *SubTp); int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, Index: llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp +++ llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp @@ -755,11 +755,12 @@ } int PPCTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode"); - int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); return vectorCostAdjustment(Cost, Opcode, Dst, Src); } Index: llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h =================================================================== --- llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h +++ llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h @@ -90,7 +90,7 @@ unsigned getBoolVecToIntConversionCost(unsigned Opcode, Type *Dst, const Instruction *I); int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, Index: llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp +++ llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp @@ -689,6 +689,7 @@ } int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { unsigned DstScalarBits = Dst->getScalarSizeInBits(); @@ -770,8 +771,8 @@ // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. Base implementation does not // realize float->int gets scalarized. - unsigned ScalarCost = getCastInstrCost(Opcode, Dst->getScalarType(), - Src->getScalarType(), CostKind); + unsigned ScalarCost = getCastInstrCost( + Opcode, Dst->getScalarType(), Src->getScalarType(), CCH, CostKind); unsigned TotCost = VF * ScalarCost; bool NeedsInserts = true, NeedsExtracts = true; // FP128 registers do not get inserted or extracted. @@ -812,7 +813,7 @@ } } - return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); } // Scalar i8 / i16 operations will typically be made after first extending Index: llvm/lib/Target/X86/X86TargetTransformInfo.h =================================================================== --- llvm/lib/Target/X86/X86TargetTransformInfo.h +++ llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -129,7 +129,7 @@ int getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, int Index, VectorType *SubTp); int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::TargetCostKind CostKind, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I = nullptr); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, Index: llvm/lib/Target/X86/X86TargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -1363,6 +1363,7 @@ } int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { int ISD = TLI->InstructionOpcodeToISD(Opcode); @@ -1977,7 +1978,7 @@ // The function getSimpleVT only handles simple value types. if (!SrcTy.isSimple() || !DstTy.isSimple()) - return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind); + return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind); MVT SimpleSrcTy = SrcTy.getSimpleVT(); MVT SimpleDstTy = DstTy.getSimpleVT(); @@ -2038,7 +2039,7 @@ return Entry->Cost; } - return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); } int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, Index: llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp =================================================================== --- llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -2028,8 +2028,8 @@ Type *SrcTy = CI->getOperand(0)->getType(); Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy, - TargetTransformInfo::TCK_SizeAndLatency, - CI); + TTI::getCastContextHint(CI), + TargetTransformInfo::TCK_SizeAndLatency, CI); } else if (GetElementPtrInst *GEP = dyn_cast(Instr)) { // Cost of the address calculation Index: llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp =================================================================== --- llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -2176,8 +2176,9 @@ llvm_unreachable("There are no other cast types."); } const SCEV *Op = CastExpr->getOperand(); - BudgetRemaining -= TTI.getCastInstrCost(Opcode, /*Dst=*/S->getType(), - /*Src=*/Op->getType(), CostKind); + BudgetRemaining -= TTI.getCastInstrCost( + Opcode, /*Dst=*/S->getType(), + /*Src=*/Op->getType(), TTI::CastContextHint::None, CostKind); Worklist.emplace_back(Op); return false; // Will answer upon next entry into this function. } Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -6382,13 +6382,53 @@ case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { + // Computes the CastContextHint from a Load/Store instruction. + auto ComputeCCH = [&](Instruction *I) -> TTI::CastContextHint { + assert((isa(I) || isa(I)) && + "Expected a load or a store!"); + + if (VF == 1) + return TTI::CastContextHint::Normal; + + switch (getWideningDecision(I, VF)) { + case LoopVectorizationCostModel::CM_GatherScatter: + return TTI::CastContextHint::GatherScatter; + case LoopVectorizationCostModel::CM_Interleave: + return TTI::CastContextHint::Interleave; + case LoopVectorizationCostModel::CM_Scalarize: + case LoopVectorizationCostModel::CM_Widen: + return Legal->isMaskRequired(I) ? TTI::CastContextHint::Masked + : TTI::CastContextHint::Normal; + case LoopVectorizationCostModel::CM_Widen_Reverse: + return TTI::CastContextHint::Reversed; + case LoopVectorizationCostModel::CM_Unknown: + llvm_unreachable("Instr did not go through cost modelling?"); + } + + llvm_unreachable("Unhandled case!"); + }; + + unsigned Opcode = I->getOpcode(); + TTI::CastContextHint CCH = TTI::CastContextHint::None; + // For Trunc, the context is the only user, which must be a StoreInst. + if (Opcode == Instruction::Trunc) { + if (I->hasOneUse()) + if (StoreInst *Store = dyn_cast(*I->user_begin())) + CCH = ComputeCCH(Store); + } + // For Z/Sext, the context is the operand, which must be a LoadInst. + else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt) { + if (LoadInst *Load = dyn_cast(I->getOperand(0))) + CCH = ComputeCCH(Load); + } + // We optimize the truncation of induction variables having constant // integer steps. The cost of these truncations is the same as the scalar // operation. if (isOptimizableIVTruncate(I, VF)) { auto *Trunc = cast(I); return TTI.getCastInstrCost(Instruction::Trunc, Trunc->getDestTy(), - Trunc->getSrcTy(), CostKind, Trunc); + Trunc->getSrcTy(), CCH, CostKind, Trunc); } Type *SrcScalarTy = I->getOperand(0)->getType(); @@ -6401,12 +6441,11 @@ // // Calculate the modified src and dest types. Type *MinVecTy = VectorTy; - if (I->getOpcode() == Instruction::Trunc) { + if (Opcode == Instruction::Trunc) { SrcVecTy = smallestIntegerVectorType(SrcVecTy, MinVecTy); VectorTy = largestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy); - } else if (I->getOpcode() == Instruction::ZExt || - I->getOpcode() == Instruction::SExt) { + } else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt) { SrcVecTy = largestIntegerVectorType(SrcVecTy, MinVecTy); VectorTy = smallestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy); @@ -6414,8 +6453,8 @@ } unsigned N = isScalarAfterVectorization(I, VF) ? VF : 1; - return N * TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy, - CostKind, I); + return N * + TTI.getCastInstrCost(Opcode, VectorTy, SrcVecTy, CCH, CostKind, I); } case Instruction::Call: { bool NeedToScalarize; Index: llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -3390,8 +3390,8 @@ Ext->getOpcode(), Ext->getType(), VecTy, i); // Add back the cost of s|zext which is subtracted separately. DeadCost += TTI->getCastInstrCost( - Ext->getOpcode(), Ext->getType(), E->getType(), CostKind, - Ext); + Ext->getOpcode(), Ext->getType(), E->getType(), + TTI::getCastContextHint(Ext), CostKind, Ext); continue; } } @@ -3415,8 +3415,8 @@ case Instruction::BitCast: { Type *SrcTy = VL0->getOperand(0)->getType(); int ScalarEltCost = - TTI->getCastInstrCost(E->getOpcode(), ScalarTy, SrcTy, CostKind, - VL0); + TTI->getCastInstrCost(E->getOpcode(), ScalarTy, SrcTy, + TTI::getCastContextHint(VL0), CostKind, VL0); if (NeedToShuffleReuses) { ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; } @@ -3428,9 +3428,10 @@ int VecCost = 0; // Check if the values are candidates to demote. if (!MinBWs.count(VL0) || VecTy != SrcVecTy) { - VecCost = ReuseShuffleCost + - TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy, - CostKind, VL0); + VecCost = + ReuseShuffleCost + + TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy, + TTI::getCastContextHint(VL0), CostKind, VL0); } return VecCost - ScalarCost; } @@ -3635,9 +3636,9 @@ VectorType *Src0Ty = VectorType::get(Src0SclTy, VL.size()); VectorType *Src1Ty = VectorType::get(Src1SclTy, VL.size()); VecCost = TTI->getCastInstrCost(E->getOpcode(), VecTy, Src0Ty, - CostKind); + TTI::CastContextHint::None, CostKind); VecCost += TTI->getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty, - CostKind); + TTI::CastContextHint::None, CostKind); } VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_Select, VecTy, 0); return ReuseShuffleCost + VecCost - ScalarCost;