Index: llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6838,9 +6838,37 @@ for (ReductionOpsType &RdxOp : ReductionOps) IgnoreList.append(RdxOp.begin(), RdxOp.end()); + unsigned ReduxWidth = PowerOf2Floor(NumReducedVals); + if (NumReducedVals > ReduxWidth) { + // In the loop below, we are building a tree based on the first + // 'ReduxWidth' values. + // If the operands of those values have common traits (compare predicate, + // constant operand, etc), then we want to group those together to + // minimize the cost of the reduction. + + // TODO: This should be extended to count common operands for + // compares and binops. + + // Step 1: Count the number of times each compare predicate occurs. + SmallDenseMap<unsigned, unsigned> PredCountMap; + for (Value *RdxVal : ReducedVals) { + CmpInst::Predicate Pred; + if (match(RdxVal, m_Cmp(Pred, m_Value(), m_Value()))) + ++PredCountMap[Pred]; + } + // Step 2: Sort the values so the most common predicates come first. + sort(ReducedVals, [&PredCountMap](Value *A, Value *B) { + CmpInst::Predicate PredA, PredB; + if (match(A, m_Cmp(PredA, m_Value(), m_Value())) && + match(B, m_Cmp(PredB, m_Value(), m_Value()))) { + return PredCountMap[PredA] > PredCountMap[PredB]; + } + return false; + }); + } + Value *VectorizedTree = nullptr; unsigned i = 0; - unsigned ReduxWidth = PowerOf2Floor(NumReducedVals); while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) { ArrayRef<Value *> VL = makeArrayRef(&ReducedVals[i], ReduxWidth); V.buildTree(VL, ExternallyUsedValues, IgnoreList); Index: llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll =================================================================== --- llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll +++ llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll @@ -81,20 +81,12 @@ define float @merge_anyof_v4f32_wrong_first(<4 x float> %x) { ; CHECK-LABEL: @merge_anyof_v4f32_wrong_first( -; CHECK-NEXT: [[X0:%.*]] = extractelement <4 x float> [[X:%.*]], i32 0 -; CHECK-NEXT: [[X1:%.*]] = extractelement <4 x float> [[X]], i32 1 -; CHECK-NEXT: [[X2:%.*]] = extractelement <4 x float> [[X]], i32 2 -; CHECK-NEXT: [[X3:%.*]] = extractelement <4 x float> [[X]], i32 3 -; CHECK-NEXT: [[CMP3WRONG:%.*]] = fcmp olt float [[X3]], 4.200000e+01 -; CHECK-NEXT: [[CMP0:%.*]] = fcmp ogt float [[X0]], 1.000000e+00 -; CHECK-NEXT: [[CMP1:%.*]] = fcmp ogt float [[X1]], 1.000000e+00 -; CHECK-NEXT: [[CMP2:%.*]] = fcmp ogt float [[X2]], 1.000000e+00 -; CHECK-NEXT: [[CMP3:%.*]] = fcmp ogt float [[X3]], 1.000000e+00 -; CHECK-NEXT: [[OR03:%.*]] = or i1 [[CMP0]], [[CMP3WRONG]] -; CHECK-NEXT: [[OR031:%.*]] = or i1 [[OR03]], [[CMP1]] -; CHECK-NEXT: [[OR0312:%.*]] = or i1 [[OR031]], [[CMP2]] -; CHECK-NEXT: [[OR03123:%.*]] = or i1 [[OR0312]], [[CMP3]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[OR03123]], float -1.000000e+00, float 1.000000e+00 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[X:%.*]], i32 3 +; CHECK-NEXT: [[CMP3WRONG:%.*]] = fcmp olt float [[TMP1]], 4.200000e+01 +; CHECK-NEXT: [[TMP2:%.*]] = fcmp ogt <4 x float> [[X]], <float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00> +; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP4]], float -1.000000e+00, float 1.000000e+00 ; CHECK-NEXT: ret float [[R]] ; %x0 = extractelement <4 x float> %x, i32 0 @@ -143,20 +135,12 @@ define i32 @merge_anyof_v4i32_wrong_middle(<4 x i32> %x) { ; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle( -; CHECK-NEXT: [[X0:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 0 -; CHECK-NEXT: [[X1:%.*]] = extractelement <4 x i32> [[X]], i32 1 -; CHECK-NEXT: [[X2:%.*]] = extractelement <4 x i32> [[X]], i32 2 -; CHECK-NEXT: [[X3:%.*]] = extractelement <4 x i32> [[X]], i32 3 -; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[X3]], 42 -; CHECK-NEXT: [[CMP0:%.*]] = icmp sgt i32 [[X0]], 1 -; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[X1]], 1 -; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[X2]], 1 -; CHECK-NEXT: [[CMP3:%.*]] = icmp sgt i32 [[X3]], 1 -; CHECK-NEXT: [[OR03:%.*]] = or i1 [[CMP0]], [[CMP3]] -; CHECK-NEXT: [[OR033:%.*]] = or i1 [[OR03]], [[CMP3WRONG]] -; CHECK-NEXT: [[OR0332:%.*]] = or i1 [[OR033]], [[CMP2]] -; CHECK-NEXT: [[OR03321:%.*]] = or i1 [[OR0332]], [[CMP1]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[OR03321]], i32 -1, i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3 +; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP1]], 42 +; CHECK-NEXT: [[TMP2:%.*]] = icmp sgt <4 x i32> [[X]], <i32 1, i32 1, i32 1, i32 1> +; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP4]], i32 -1, i32 1 ; CHECK-NEXT: ret i32 [[R]] ; %x0 = extractelement <4 x i32> %x, i32 0