Index: llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6838,9 +6838,37 @@ for (ReductionOpsType &RdxOp : ReductionOps) IgnoreList.append(RdxOp.begin(), RdxOp.end()); + unsigned ReduxWidth = PowerOf2Floor(NumReducedVals); + if (NumReducedVals > ReduxWidth) { + // In the loop below, we are building a tree based on the first + // 'ReduxWidth' values. + // If the operands of those values have common traits (compare predicate, + // constant operand, etc), then we want to group those together to + // minimize the cost of the reduction. + + // TODO: This should be extended to count common operands for + // compares and binops. + + // Step 1: Count the number of times each compare predicate occurs. + SmallDenseMap PredCountMap; + for (Value *RdxVal : ReducedVals) { + CmpInst::Predicate Pred; + if (match(RdxVal, m_Cmp(Pred, m_Value(), m_Value()))) + ++PredCountMap[Pred]; + } + // Step 2: Sort the values so the most common predicates come first. + sort(ReducedVals, [&PredCountMap](Value *A, Value *B) { + CmpInst::Predicate PredA, PredB; + if (match(A, m_Cmp(PredA, m_Value(), m_Value())) && + match(B, m_Cmp(PredB, m_Value(), m_Value()))) { + return PredCountMap[PredA] > PredCountMap[PredB]; + } + return false; + }); + } + Value *VectorizedTree = nullptr; unsigned i = 0; - unsigned ReduxWidth = PowerOf2Floor(NumReducedVals); while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) { ArrayRef VL = makeArrayRef(&ReducedVals[i], ReduxWidth); V.buildTree(VL, ExternallyUsedValues, IgnoreList); Index: llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll =================================================================== --- llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll +++ llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll @@ -81,20 +81,12 @@ define float @merge_anyof_v4f32_wrong_first(<4 x float> %x) { ; CHECK-LABEL: @merge_anyof_v4f32_wrong_first( -; CHECK-NEXT: [[X0:%.*]] = extractelement <4 x float> [[X:%.*]], i32 0 -; CHECK-NEXT: [[X1:%.*]] = extractelement <4 x float> [[X]], i32 1 -; CHECK-NEXT: [[X2:%.*]] = extractelement <4 x float> [[X]], i32 2 -; CHECK-NEXT: [[X3:%.*]] = extractelement <4 x float> [[X]], i32 3 -; CHECK-NEXT: [[CMP3WRONG:%.*]] = fcmp olt float [[X3]], 4.200000e+01 -; CHECK-NEXT: [[CMP0:%.*]] = fcmp ogt float [[X0]], 1.000000e+00 -; CHECK-NEXT: [[CMP1:%.*]] = fcmp ogt float [[X1]], 1.000000e+00 -; CHECK-NEXT: [[CMP2:%.*]] = fcmp ogt float [[X2]], 1.000000e+00 -; CHECK-NEXT: [[CMP3:%.*]] = fcmp ogt float [[X3]], 1.000000e+00 -; CHECK-NEXT: [[OR03:%.*]] = or i1 [[CMP0]], [[CMP3WRONG]] -; CHECK-NEXT: [[OR031:%.*]] = or i1 [[OR03]], [[CMP1]] -; CHECK-NEXT: [[OR0312:%.*]] = or i1 [[OR031]], [[CMP2]] -; CHECK-NEXT: [[OR03123:%.*]] = or i1 [[OR0312]], [[CMP3]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[OR03123]], float -1.000000e+00, float 1.000000e+00 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[X:%.*]], i32 3 +; CHECK-NEXT: [[CMP3WRONG:%.*]] = fcmp olt float [[TMP1]], 4.200000e+01 +; CHECK-NEXT: [[TMP2:%.*]] = fcmp ogt <4 x float> [[X]], +; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP4]], float -1.000000e+00, float 1.000000e+00 ; CHECK-NEXT: ret float [[R]] ; %x0 = extractelement <4 x float> %x, i32 0 @@ -143,20 +135,12 @@ define i32 @merge_anyof_v4i32_wrong_middle(<4 x i32> %x) { ; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle( -; CHECK-NEXT: [[X0:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 0 -; CHECK-NEXT: [[X1:%.*]] = extractelement <4 x i32> [[X]], i32 1 -; CHECK-NEXT: [[X2:%.*]] = extractelement <4 x i32> [[X]], i32 2 -; CHECK-NEXT: [[X3:%.*]] = extractelement <4 x i32> [[X]], i32 3 -; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[X3]], 42 -; CHECK-NEXT: [[CMP0:%.*]] = icmp sgt i32 [[X0]], 1 -; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[X1]], 1 -; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[X2]], 1 -; CHECK-NEXT: [[CMP3:%.*]] = icmp sgt i32 [[X3]], 1 -; CHECK-NEXT: [[OR03:%.*]] = or i1 [[CMP0]], [[CMP3]] -; CHECK-NEXT: [[OR033:%.*]] = or i1 [[OR03]], [[CMP3WRONG]] -; CHECK-NEXT: [[OR0332:%.*]] = or i1 [[OR033]], [[CMP2]] -; CHECK-NEXT: [[OR03321:%.*]] = or i1 [[OR0332]], [[CMP1]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[OR03321]], i32 -1, i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3 +; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP1]], 42 +; CHECK-NEXT: [[TMP2:%.*]] = icmp sgt <4 x i32> [[X]], +; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP4]], i32 -1, i32 1 ; CHECK-NEXT: ret i32 [[R]] ; %x0 = extractelement <4 x i32> %x, i32 0 @@ -176,29 +160,18 @@ ret i32 %r } +; Operand/predicate swapping allows forming a reduction, but the +; ideal reduction groups all of the original 'sgt' ops together. + define i32 @merge_anyof_v4i32_wrong_middle_better_rdx(<4 x i32> %x, <4 x i32> %y) { ; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle_better_rdx( -; CHECK-NEXT: [[X0:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 0 -; CHECK-NEXT: [[X1:%.*]] = extractelement <4 x i32> [[X]], i32 1 -; CHECK-NEXT: [[X2:%.*]] = extractelement <4 x i32> [[X]], i32 2 -; CHECK-NEXT: [[X3:%.*]] = extractelement <4 x i32> [[X]], i32 3 -; CHECK-NEXT: [[Y0:%.*]] = extractelement <4 x i32> [[Y:%.*]], i32 0 -; CHECK-NEXT: [[Y1:%.*]] = extractelement <4 x i32> [[Y]], i32 1 -; CHECK-NEXT: [[Y2:%.*]] = extractelement <4 x i32> [[Y]], i32 2 -; CHECK-NEXT: [[Y3:%.*]] = extractelement <4 x i32> [[Y]], i32 3 -; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[X1]], [[Y1]] -; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32> undef, i32 [[X0]], i32 0 -; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x i32> [[TMP1]], i32 [[X3]], i32 1 -; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i32> [[TMP2]], i32 [[Y3]], i32 2 -; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x i32> [[TMP3]], i32 [[X2]], i32 3 -; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x i32> undef, i32 [[Y0]], i32 0 -; CHECK-NEXT: [[TMP6:%.*]] = insertelement <4 x i32> [[TMP5]], i32 [[Y3]], i32 1 -; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x i32> [[TMP6]], i32 [[X3]], i32 2 -; CHECK-NEXT: [[TMP8:%.*]] = insertelement <4 x i32> [[TMP7]], i32 [[Y2]], i32 3 -; CHECK-NEXT: [[TMP9:%.*]] = icmp sgt <4 x i32> [[TMP4]], [[TMP8]] -; CHECK-NEXT: [[TMP10:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP9]]) -; CHECK-NEXT: [[TMP11:%.*]] = or i1 [[TMP10]], [[CMP1]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP11]], i32 -1, i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i32> [[Y:%.*]], i32 3 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3 +; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP2]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp sgt <4 x i32> [[X]], [[Y]] +; CHECK-NEXT: [[TMP4:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP3]]) +; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP4]], [[CMP3WRONG]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP5]], i32 -1, i32 1 ; CHECK-NEXT: ret i32 [[R]] ; %x0 = extractelement <4 x i32> %x, i32 0