diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -865,6 +865,17 @@ LI->getPointerAddressSpace(), CostKind, I); } + case Instruction::Select: { + Type *CondTy = U->getOperand(0)->getType(); + return TargetTTI->getCmpSelInstrCost(Opcode, U->getType(), CondTy, + CostKind, I); + } + case Instruction::ICmp: + case Instruction::FCmp: { + Type *ValTy = U->getOperand(0)->getType(); + return TargetTTI->getCmpSelInstrCost(Opcode, ValTy, U->getType(), + CostKind, I); + } } // By default, just classify everything as 'basic'. return TTI::TCC_Basic; diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -838,6 +838,10 @@ int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); + // Selects on vectors are actually vector selects. if (ISD == ISD::SELECT) { assert(CondTy && "CondTy must exist"); diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1296,18 +1296,10 @@ Op1VK, Op2VK, Op1VP, Op2VP, Operands, I); } - case Instruction::Select: { - const SelectInst *SI = cast(I); - Type *CondTy = SI->getCondition()->getType(); - return getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy, - CostKind, I); - } + case Instruction::Select: case Instruction::ICmp: - case Instruction::FCmp: { - Type *ValTy = I->getOperand(0)->getType(); - return getCmpSelInstrCost(I->getOpcode(), ValTy, I->getType(), - CostKind, I); - } + case Instruction::FCmp: + return getUserCost(I, CostKind); case Instruction::Store: case Instruction::Load: return getUserCost(I, CostKind); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -620,6 +620,9 @@ Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); int ISD = TLI->InstructionOpcodeToISD(Opcode); // We don't lower some vector selects well that are wider than the register diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -508,6 +508,10 @@ int ARMTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); + int ISD = TLI->InstructionOpcodeToISD(Opcode); // On NEON a vector select gets lowered to vbsl. if (ST->hasNEON() && ValTy->isVectorTy() && ISD == ISD::SELECT) { diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -236,7 +236,7 @@ unsigned HexagonTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { - if (ValTy->isVectorTy()) { + if (ValTy->isVectorTy() && CostKind == TTI::TCK_RecipThroughput) { std::pair LT = TLI.getTypeLegalizationCost(DL, ValTy); if (Opcode == Instruction::FCmp) return LT.first + FloatFactor * getTypeNumElements(ValTy); diff --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp --- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp +++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp @@ -779,6 +779,9 @@ TTI::TargetCostKind CostKind, const Instruction *I) { int Cost = BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return Cost; return vectorCostAdjustment(Cost, Opcode, ValTy, nullptr); } diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp --- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp +++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp @@ -837,6 +837,9 @@ Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind); + if (!ValTy->isVectorTy()) { switch (Opcode) { case Instruction::ICmp: { diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -2051,6 +2051,10 @@ int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); + // Legalize the type. std::pair LT = TLI->getTypeLegalizationCost(DL, ValTy);