Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -105,6 +105,27 @@ bool canAnalyze(LoopInfo &LI); }; +/// Represents a hint about the context in which a cast is used. +/// Not every possible scenario is covered here - just the extraordinary cases +/// targets care about. +/// +/// If something is added to this enum, please update getCastInstrCost in +/// TargetTransformInfo.cpp, and, optionally, getInstructionCost in +/// LoopVectorize.cpp. These are the two places where this value is computed. +/// +/// TargetTransformInfo is where the value is computed using the IR instruction +/// (this is only done if the context is unknown). +/// +/// LoopVectorize is where the value is computed based on the vectorizer's +/// decisions and knowledge. +enum class CastContextHint : uint8_t { + /// We don't know the context of this cast, or we don't care about it. + Unknown, + /// This is a sext/zext of a masked load, or a trunc whose only user is a + /// masked store. + MaskedExtOrTrunc, +}; + /// This pass provides access to the codegen interfaces that are needed /// for IR-level transformations. class TargetTransformInfo { @@ -895,7 +916,8 @@ /// zext, etc. If there is an existing instruction that holds Opcode, it /// may be passed in the 'I' parameter. int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I = nullptr) const; + const Instruction *I = nullptr, + CastContextHint CCH = CastContextHint::Unknown) const; /// \return The expected cost of a sign- or zero-extended vector extract. Use /// -1 to indicate that there is no information about the index value. @@ -1313,7 +1335,7 @@ virtual int getShuffleCost(ShuffleKind Kind, VectorType *Tp, int Index, VectorType *SubTp) = 0; virtual int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I) = 0; + const Instruction *I, CastContextHint CCH) = 0; virtual int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, unsigned Index) = 0; virtual int getCFInstrCost(unsigned Opcode) = 0; @@ -1712,8 +1734,8 @@ return Impl.getShuffleCost(Kind, Tp, Index, SubTp); } int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I) override { - return Impl.getCastInstrCost(Opcode, Dst, Src, I); + const Instruction *I, CastContextHint CCH) override { + return Impl.getCastInstrCost(Opcode, Dst, Src, I, CCH); } int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, unsigned Index) override { Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -381,7 +381,8 @@ } unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I) { + const Instruction *I, + CastContextHint CCH = CastContextHint::Unknown) { switch (Opcode) { default: break; Index: llvm/include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -691,7 +691,8 @@ } unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I = nullptr) { + const Instruction *I = nullptr, + CastContextHint CCH = CastContextHint::Unknown) { const TargetLoweringBase *TLI = getTLI(); int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); @@ -721,11 +722,12 @@ case Instruction::SExt: { // If this is a zext/sext of a load, return 0 if the corresponding // extending load exists on target. - if (I && isa(I->getOperand(0))) { + if ((CCH == CastContextHint::Unknown) && I && + isa(I->getOperand(0))) { EVT ExtVT = EVT::getEVT(Dst); EVT LoadVT = EVT::getEVT(Src); unsigned LType = - ((Opcode == Instruction::ZExt) ? ISD::ZEXTLOAD : ISD::SEXTLOAD); + ((Opcode == Instruction::ZExt) ? ISD::ZEXTLOAD : ISD::SEXTLOAD); if (TLI->isLoadExtLegal(LType, ExtVT, LoadVT)) return 0; } @@ -796,14 +798,14 @@ SrcVTy->getNumElements() / 2); T *TTI = static_cast(this); return TTI->getVectorSplitCost() + - (2 * TTI->getCastInstrCost(Opcode, SplitDst, SplitSrc, I)); + (2 * TTI->getCastInstrCost(Opcode, SplitDst, SplitSrc, I, CCH)); } // In other cases where the source or destination are illegal, assume // the operation will get scalarized. unsigned Num = DstVTy->getNumElements(); unsigned Cost = static_cast(this)->getCastInstrCost( - Opcode, Dst->getScalarType(), Src->getScalarType(), I); + Opcode, Dst->getScalarType(), Src->getScalarType(), I, CCH); // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -599,11 +599,47 @@ return Cost; } +// Computes a CastContextHint for a (non-FP) Trunc instruction. +static CastContextHint getCCHForTrunc(const Instruction *I) { + if (I->hasOneUse()) { + const Value *Use = *I->user_begin(); + if (const IntrinsicInst *Intrinsic = dyn_cast(Use)) + if (Intrinsic->getIntrinsicID() == Intrinsic::masked_store) + return CastContextHint::MaskedExtOrTrunc; + } + return CastContextHint::Unknown; +} + +// Computes a CastContextHint for a SExt/ZExt instruction. +static CastContextHint getCCHForExt(const Instruction *I) { + Value *Operand = I->getOperand(0); + if (const IntrinsicInst *Intrinsic = dyn_cast(Operand)) + if (Intrinsic->getIntrinsicID() == Intrinsic::masked_load) + return CastContextHint::MaskedExtOrTrunc; + return CastContextHint::Unknown; +} + int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I) const { + const Instruction *I, + CastContextHint CCH) const { assert((I == nullptr || I->getOpcode() == Opcode) && "Opcode should reflect passed instruction."); - int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, I); + // If we have an instruction and the CCH is unknown, try to guess it based on + // the instruction. + if (I && CCH == CastContextHint::Unknown) { + switch (Opcode) { + case Instruction::ZExt: + case Instruction::SExt: + CCH = getCCHForExt(I); + break; + case Instruction::Trunc: + CCH = getCCHForTrunc(I); + break; + default: + break; + } + } + int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, I, CCH); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -112,7 +112,8 @@ unsigned getMaxInterleaveFactor(unsigned VF); int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I = nullptr); + const Instruction *I = nullptr, + CastContextHint CCH = CastContextHint::Unknown); int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, unsigned Index); Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -267,7 +267,8 @@ } int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I) { + const Instruction *I, + CastContextHint CCH) { int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); Index: llvm/lib/Target/ARM/ARMTargetTransformInfo.h =================================================================== --- llvm/lib/Target/ARM/ARMTargetTransformInfo.h +++ llvm/lib/Target/ARM/ARMTargetTransformInfo.h @@ -194,7 +194,8 @@ } int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I = nullptr); + const Instruction *I = nullptr, + CastContextHint CCH = CastContextHint::Unknown); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, const Instruction *I = nullptr); Index: llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -164,7 +164,7 @@ } int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I) { + const Instruction *I, CastContextHint CCH) { int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); @@ -187,7 +187,7 @@ EVT DstTy = TLI->getValueType(DL, Dst); if (!SrcTy.isSimple() || !DstTy.isSimple()) - return BaseT::getCastInstrCost(Opcode, Dst, Src); + return BaseT::getCastInstrCost(Opcode, Dst, Src, nullptr, CCH); // The extend of a load is free if (I && isa(I->getOperand(0))) { @@ -418,7 +418,7 @@ int BaseCost = ST->hasMVEIntegerOps() && Src->isVectorTy() ? ST->getMVEVectorCostFactor() : 1; - return BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src); + return BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, nullptr, CCH); } int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy, Index: llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h =================================================================== --- llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -140,7 +140,8 @@ ArrayRef Args = ArrayRef(), const Instruction *CxtI = nullptr); unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I = nullptr); + const Instruction *I = nullptr, + CastContextHint CCH = CastContextHint::Unknown); unsigned getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); unsigned getCFInstrCost(unsigned Opcode) { Index: llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -258,7 +258,8 @@ } unsigned HexagonTTIImpl::getCastInstrCost(unsigned Opcode, Type *DstTy, - Type *SrcTy, const Instruction *I) { + Type *SrcTy, const Instruction *I, + CastContextHint CCH) { if (SrcTy->isFPOrFPVectorTy() || DstTy->isFPOrFPVectorTy()) { unsigned SrcN = SrcTy->isFPOrFPVectorTy() ? getTypeNumElements(SrcTy) : 0; unsigned DstN = DstTy->isFPOrFPVectorTy() ? getTypeNumElements(DstTy) : 0; Index: llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h =================================================================== --- llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h +++ llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h @@ -99,7 +99,8 @@ const Instruction *CxtI = nullptr); int getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index, Type *SubTp); int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I = nullptr); + const Instruction *I = nullptr, + CastContextHint CCH = CastContextHint::Unknown); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, const Instruction *I = nullptr); int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); Index: llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp +++ llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp @@ -749,7 +749,7 @@ } int PPCTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I) { + const Instruction *I, CastContextHint CCH) { assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode"); int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src); Index: llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h =================================================================== --- llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h +++ llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h @@ -88,7 +88,8 @@ unsigned getBoolVecToIntConversionCost(unsigned Opcode, Type *Dst, const Instruction *I); int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I = nullptr); + const Instruction *I = nullptr, + CastContextHint CCH = CastContextHint::Unknown); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, const Instruction *I = nullptr); int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); Index: llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp +++ llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp @@ -684,7 +684,8 @@ } int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I) { + const Instruction *I, + CastContextHint CCH) { unsigned DstScalarBits = Dst->getScalarSizeInBits(); unsigned SrcScalarBits = Src->getScalarSizeInBits(); Index: llvm/lib/Target/X86/X86TargetTransformInfo.h =================================================================== --- llvm/lib/Target/X86/X86TargetTransformInfo.h +++ llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -128,7 +128,8 @@ int getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, int Index, VectorType *SubTp); int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I = nullptr); + const Instruction *I = nullptr, + CastContextHint CCH = CastContextHint::Unknown); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, const Instruction *I = nullptr); int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); Index: llvm/lib/Target/X86/X86TargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -1357,7 +1357,7 @@ } int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I) { + const Instruction *I, CastContextHint CCH) { int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -6361,13 +6361,26 @@ case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { + CastContextHint CCH = CastContextHint::Unknown; + unsigned Opcode = I->getOpcode(); + if (Opcode == Instruction::Trunc) { + if (I->hasOneUse() && isa(*I->user_begin()) && + Legal->isMaskRequired(cast(*I->user_begin()))) + CCH = CastContextHint::MaskedExtOrTrunc; + } else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt) { + Value *Operand = I->getOperand(0); + if (isa(Operand) && + Legal->isMaskRequired(cast(Operand))) + CCH = CastContextHint::MaskedExtOrTrunc; + } + // We optimize the truncation of induction variables having constant // integer steps. The cost of these truncations is the same as the scalar // operation. if (isOptimizableIVTruncate(I, VF)) { auto *Trunc = cast(I); return TTI.getCastInstrCost(Instruction::Trunc, Trunc->getDestTy(), - Trunc->getSrcTy(), Trunc); + Trunc->getSrcTy(), Trunc, CCH); } Type *SrcScalarTy = I->getOperand(0)->getType(); @@ -6393,7 +6406,7 @@ } unsigned N = isScalarAfterVectorization(I, VF) ? VF : 1; - return N * TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy, I); + return N * TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy, I, CCH); } case Instruction::Call: { bool NeedToScalarize;