diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1147,18 +1147,6 @@ /// Construct a vectorizable tree that starts at \p Roots. void buildTree(ArrayRef Roots); - /// Checks if the very first tree node is going to be vectorized. - bool isVectorizedFirstNode() const { - return !VectorizableTree.empty() && - VectorizableTree.front()->State == TreeEntry::Vectorize; - } - - /// Returns the main instruction for the very first node. - Instruction *getFirstNodeMainOp() const { - assert(!VectorizableTree.empty() && "No tree to get the first node from"); - return VectorizableTree.front()->getMainOp(); - } - /// Returns whether the root node has in-tree uses. bool doesRootHaveInTreeUses() const { return !VectorizableTree.empty() && @@ -13313,22 +13301,7 @@ // Estimate cost. InstructionCost TreeCost = V.getTreeCost(VL); InstructionCost ReductionCost = - getReductionCost(TTI, VL, ReduxWidth, RdxFMF); - if (V.isVectorizedFirstNode() && isa(VL.front())) { - Instruction *MainOp = V.getFirstNodeMainOp(); - for (Value *V : VL) { - auto *VI = dyn_cast(V); - // Add the costs of scalar GEP pointers, to be removed from the - // code. - if (!VI || VI == MainOp) - continue; - auto *Ptr = dyn_cast(VI->getPointerOperand()); - if (!Ptr || !Ptr->hasOneUse() || Ptr->hasAllConstantIndices()) - continue; - TreeCost -= TTI->getArithmeticInstrCost( - Instruction::Add, Ptr->getType(), TTI::TCK_RecipThroughput); - } - } + getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF); InstructionCost Cost = TreeCost + ReductionCost; LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for reduction\n"); if (!Cost.isValid()) @@ -13564,7 +13537,8 @@ /// Calculate the cost of a reduction. InstructionCost getReductionCost(TargetTransformInfo *TTI, ArrayRef ReducedVals, - unsigned ReduxWidth, FastMathFlags FMF) { + bool IsCmpSelMinMax, unsigned ReduxWidth, + FastMathFlags FMF) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; Value *FirstReducedVal = ReducedVals.front(); Type *ScalarTy = FirstReducedVal->getType(); @@ -13573,6 +13547,35 @@ // If all of the reduced values are constant, the vector cost is 0, since // the reduction value can be calculated at the compile time. bool AllConsts = allConstant(ReducedVals); + auto EvaluateScalarCost = [&](function_ref GenCostFn) { + InstructionCost Cost = 0; + // Scalar cost is repeated for N-1 elements. + int Cnt = ReducedVals.size(); + for (Value *RdxVal : ReducedVals) { + if (Cnt == 1) + break; + --Cnt; + if (RdxVal->hasNUsesOrMore(IsCmpSelMinMax ? 3 : 2)) { + Cost += GenCostFn(); + continue; + } + InstructionCost ScalarCost = 0; + for (User *U : RdxVal->users()) { + auto *RdxOp = cast(U); + if (hasRequiredNumberOfUses(IsCmpSelMinMax, RdxOp)) { + ScalarCost += TTI->getInstructionCost(RdxOp, CostKind); + continue; + } + ScalarCost = InstructionCost::getInvalid(); + break; + } + if (ScalarCost.isValid()) + Cost += ScalarCost; + else + Cost += GenCostFn(); + } + return Cost; + }; switch (RdxKind) { case RecurKind::Add: case RecurKind::Mul: @@ -13585,7 +13588,9 @@ if (!AllConsts) VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF, CostKind); - ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind); + ScalarCost = EvaluateScalarCost([&]() { + return TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind); + }); break; } case RecurKind::FMax: @@ -13599,10 +13604,12 @@ /*IsUnsigned=*/false, CostKind); } CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind); - ScalarCost = TTI->getCmpSelInstrCost(Instruction::FCmp, ScalarTy, - SclCondTy, RdxPred, CostKind) + - TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, - SclCondTy, RdxPred, CostKind); + ScalarCost = EvaluateScalarCost([&]() { + return TTI->getCmpSelInstrCost(Instruction::FCmp, ScalarTy, SclCondTy, + RdxPred, CostKind) + + TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, SclCondTy, + RdxPred, CostKind); + }); break; } case RecurKind::SMax: @@ -13619,18 +13626,18 @@ IsUnsigned, CostKind); } CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind); - ScalarCost = TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy, - SclCondTy, RdxPred, CostKind) + - TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, - SclCondTy, RdxPred, CostKind); + ScalarCost = EvaluateScalarCost([&]() { + return TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy, SclCondTy, + RdxPred, CostKind) + + TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, SclCondTy, + RdxPred, CostKind); + }); break; } default: llvm_unreachable("Expected arithmetic or min/max reduction operation"); } - // Scalar cost is repeated for N-1 elements. - ScalarCost *= (ReduxWidth - 1); LLVM_DEBUG(dbgs() << "SLP: Adding cost " << VectorCost - ScalarCost << " for reduction that starts with " << *FirstReducedVal << " (It is a splitting reduction)\n"); diff --git a/llvm/test/Transforms/SLPVectorizer/X86/horizontal-smax.ll b/llvm/test/Transforms/SLPVectorizer/X86/horizontal-smax.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/horizontal-smax.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/horizontal-smax.ll @@ -88,11 +88,46 @@ ret i32 %16 } +; FIXME: looks like the cost of @llvm.smax.i32 is not correct, lowered as select+cmp define i32 @smax_v16i32(i32) { -; CHECK-LABEL: @smax_v16i32( -; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i32>, ptr @arr, align 16 -; CHECK-NEXT: [[TMP3:%.*]] = call i32 @llvm.vector.reduce.smax.v16i32(<16 x i32> [[TMP2]]) -; CHECK-NEXT: ret i32 [[TMP3]] +; SSE-LABEL: @smax_v16i32( +; SSE-NEXT: [[TMP2:%.*]] = load i32, ptr @arr, align 16 +; SSE-NEXT: [[TMP3:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 1), align 4 +; SSE-NEXT: [[TMP4:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 2), align 8 +; SSE-NEXT: [[TMP5:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 3), align 4 +; SSE-NEXT: [[TMP6:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 4), align 16 +; SSE-NEXT: [[TMP7:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 5), align 4 +; SSE-NEXT: [[TMP8:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 6), align 8 +; SSE-NEXT: [[TMP9:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 7), align 4 +; SSE-NEXT: [[TMP10:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 8), align 16 +; SSE-NEXT: [[TMP11:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 9), align 4 +; SSE-NEXT: [[TMP12:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 10), align 8 +; SSE-NEXT: [[TMP13:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 11), align 4 +; SSE-NEXT: [[TMP14:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 12), align 16 +; SSE-NEXT: [[TMP15:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 13), align 4 +; SSE-NEXT: [[TMP16:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 14), align 8 +; SSE-NEXT: [[TMP17:%.*]] = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 15), align 4 +; SSE-NEXT: [[TMP18:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP2]], i32 [[TMP3]]) +; SSE-NEXT: [[TMP19:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP18]], i32 [[TMP4]]) +; SSE-NEXT: [[TMP20:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP19]], i32 [[TMP5]]) +; SSE-NEXT: [[TMP21:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP20]], i32 [[TMP6]]) +; SSE-NEXT: [[TMP22:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP21]], i32 [[TMP7]]) +; SSE-NEXT: [[TMP23:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP22]], i32 [[TMP8]]) +; SSE-NEXT: [[TMP24:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP23]], i32 [[TMP9]]) +; SSE-NEXT: [[TMP25:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP24]], i32 [[TMP10]]) +; SSE-NEXT: [[TMP26:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP25]], i32 [[TMP11]]) +; SSE-NEXT: [[TMP27:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP26]], i32 [[TMP12]]) +; SSE-NEXT: [[TMP28:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP27]], i32 [[TMP13]]) +; SSE-NEXT: [[TMP29:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP28]], i32 [[TMP14]]) +; SSE-NEXT: [[TMP30:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP29]], i32 [[TMP15]]) +; SSE-NEXT: [[TMP31:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP30]], i32 [[TMP16]]) +; SSE-NEXT: [[TMP32:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP31]], i32 [[TMP17]]) +; SSE-NEXT: ret i32 [[TMP32]] +; +; AVX-LABEL: @smax_v16i32( +; AVX-NEXT: [[TMP2:%.*]] = load <16 x i32>, ptr @arr, align 16 +; AVX-NEXT: [[TMP3:%.*]] = call i32 @llvm.vector.reduce.smax.v16i32(<16 x i32> [[TMP2]]) +; AVX-NEXT: ret i32 [[TMP3]] ; %2 = load i32, ptr @arr, align 16 %3 = load i32, ptr getelementptr inbounds ([32 x i32], ptr @arr, i64 0, i64 1), align 4