diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1103,8 +1103,8 @@ /// is using a compare with the specified predicate as condition. When vector /// types are passed, \p VecPred must be used for all lanes. InstructionCost - getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy = nullptr, - CmpInst::Predicate VecPred = CmpInst::BAD_ICMP_PREDICATE, + getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, + CmpInst::Predicate VecPred, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput, const Instruction *I = nullptr) const; diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h --- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -348,6 +348,9 @@ SinkAndHoistLICMFlags *LICMFlags = nullptr, OptimizationRemarkEmitter *ORE = nullptr); +/// Returns the comparison predicate used when expanding a min/max reduction. +CmpInst::Predicate getMinMaxReductionPredicate(RecurKind RK); + /// Returns a Min/Max operation corresponding to MinMaxRecurrenceKind. /// The Builder's fast-math-flags must be set to propagate the expected values. Value *createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left, diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -889,32 +889,28 @@ return true; } -Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left, - Value *Right) { - CmpInst::Predicate Pred; +CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) { switch (RK) { default: llvm_unreachable("Unknown min/max recurrence kind"); case RecurKind::UMin: - Pred = CmpInst::ICMP_ULT; - break; + return CmpInst::ICMP_ULT; case RecurKind::UMax: - Pred = CmpInst::ICMP_UGT; - break; + return CmpInst::ICMP_UGT; case RecurKind::SMin: - Pred = CmpInst::ICMP_SLT; - break; + return CmpInst::ICMP_SLT; case RecurKind::SMax: - Pred = CmpInst::ICMP_SGT; - break; + return CmpInst::ICMP_SGT; case RecurKind::FMin: - Pred = CmpInst::FCMP_OLT; - break; + return CmpInst::FCMP_OLT; case RecurKind::FMax: - Pred = CmpInst::FCMP_OGT; - break; + return CmpInst::FCMP_OGT; } +} +Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left, + Value *Right) { + CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK); Value *Cmp = Builder.CreateCmp(Pred, Left, Right, "rdx.minmax.cmp"); Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select"); return Select; diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -8557,28 +8557,32 @@ } case RecurKind::FMax: case RecurKind::FMin: { + auto *SclCondTy = CmpInst::makeCmpResultType(ScalarTy); auto *VecCondTy = cast(CmpInst::makeCmpResultType(VectorTy)); VectorCost = TTI->getMinMaxReductionCost(VectorTy, VecCondTy, /*unsigned=*/false, CostKind); - ScalarCost = - TTI->getCmpSelInstrCost(Instruction::FCmp, ScalarTy) + - TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, - CmpInst::makeCmpResultType(ScalarTy)); + CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind); + ScalarCost = TTI->getCmpSelInstrCost(Instruction::FCmp, ScalarTy, + SclCondTy, RdxPred) + + TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, + SclCondTy, RdxPred); break; } case RecurKind::SMax: case RecurKind::SMin: case RecurKind::UMax: case RecurKind::UMin: { + auto *SclCondTy = CmpInst::makeCmpResultType(ScalarTy); auto *VecCondTy = cast(CmpInst::makeCmpResultType(VectorTy)); bool IsUnsigned = RdxKind == RecurKind::UMax || RdxKind == RecurKind::UMin; VectorCost = TTI->getMinMaxReductionCost(VectorTy, VecCondTy, IsUnsigned, CostKind); - ScalarCost = - TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy) + - TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, - CmpInst::makeCmpResultType(ScalarTy)); + CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind); + ScalarCost = TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy, + SclCondTy, RdxPred, CostKind) + + TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, + SclCondTy, RdxPred, CostKind); break; } default: diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -82,7 +82,7 @@ ExtractElementInst *Ext1, unsigned PreferredExtractIndex) const; bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1, - unsigned Opcode, + const Instruction &I, ExtractElementInst *&ConvertToShuffle, unsigned PreferredExtractIndex); void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1, @@ -299,12 +299,13 @@ /// \p ConvertToShuffle to that extract instruction. bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1, - unsigned Opcode, + const Instruction &I, ExtractElementInst *&ConvertToShuffle, unsigned PreferredExtractIndex) { assert(isa(Ext0->getOperand(1)) && isa(Ext1->getOperand(1)) && "Expected constant extract indexes"); + unsigned Opcode = I.getOpcode(); Type *ScalarTy = Ext0->getType(); auto *VecTy = cast(Ext0->getOperand(0)->getType()); InstructionCost ScalarOpCost, VectorOpCost; @@ -317,10 +318,11 @@ } else { assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && "Expected a compare"); - ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, - CmpInst::makeCmpResultType(ScalarTy)); - VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, - CmpInst::makeCmpResultType(VecTy)); + CmpInst::Predicate Pred = cast(I).getPredicate(); + ScalarOpCost = TTI.getCmpSelInstrCost( + Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred); + VectorOpCost = TTI.getCmpSelInstrCost( + Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred); } // Get cost estimates for the extract elements. These costs will factor into @@ -495,8 +497,7 @@ m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex))); ExtractElementInst *ExtractToChange; - if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), ExtractToChange, - InsertIndex)) + if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex)) return false; if (ExtractToChange) { @@ -640,8 +641,11 @@ unsigned Opcode = I.getOpcode(); InstructionCost ScalarOpCost, VectorOpCost; if (IsCmp) { - ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy); - VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy); + CmpInst::Predicate Pred = cast(I).getPredicate(); + ScalarOpCost = TTI.getCmpSelInstrCost( + Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred); + VectorOpCost = TTI.getCmpSelInstrCost( + Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred); } else { ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); @@ -741,7 +745,10 @@ InstructionCost OldCost = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0); OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1); - OldCost += TTI.getCmpSelInstrCost(CmpOpcode, I0->getType()) * 2; + OldCost += + TTI.getCmpSelInstrCost(CmpOpcode, I0->getType(), + CmpInst::makeCmpResultType(I0->getType()), Pred) * + 2; OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType()); // The proposed vector pattern is: @@ -750,7 +757,8 @@ int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0; int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1; auto *CmpTy = cast(CmpInst::makeCmpResultType(X->getType())); - InstructionCost NewCost = TTI.getCmpSelInstrCost(CmpOpcode, X->getType()); + InstructionCost NewCost = TTI.getCmpSelInstrCost( + CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred); SmallVector ShufMask(VecTy->getNumElements(), UndefMaskElem); ShufMask[CheapIndex] = ExpensiveIndex; NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy,