Index: llvm/trunk/lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- llvm/trunk/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ llvm/trunk/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1289,6 +1289,63 @@ return nullptr; } +/// Reduce a sequence of min/max with a common operand. +static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, + Value *RHS, + InstCombiner::BuilderTy &Builder) { + assert(SelectPatternResult::isMinOrMax(SPF) && "Expected a min/max"); + // TODO: Allow FP min/max with nnan/nsz. + if (!LHS->getType()->isIntOrIntVectorTy()) + return nullptr; + + // Match 3 of the same min/max ops. Example: umin(umin(), umin()). + Value *A, *B, *C, *D; + SelectPatternResult L = matchSelectPattern(LHS, A, B); + SelectPatternResult R = matchSelectPattern(RHS, C, D); + if (SPF != L.Flavor || L.Flavor != R.Flavor) + return nullptr; + + // Look for a common operand. The use checks are different than usual because + // a min/max pattern typically has 2 uses of each op: 1 by the cmp and 1 by + // the select. + Value *MinMaxOp = nullptr; + Value *ThirdOp = nullptr; + if (LHS->getNumUses() <= 2 && RHS->getNumUses() > 2) { + // If the LHS is only used in this chain and the RHS is used outside of it, + // reuse the RHS min/max because that will eliminate the LHS. + if (D == A || C == A) { + // min(min(a, b), min(c, a)) --> min(min(c, a), b) + // min(min(a, b), min(a, d)) --> min(min(a, d), b) + MinMaxOp = RHS; + ThirdOp = B; + } else if (D == B || C == B) { + // min(min(a, b), min(c, b)) --> min(min(c, b), a) + // min(min(a, b), min(b, d)) --> min(min(b, d), a) + MinMaxOp = RHS; + ThirdOp = A; + } + } else if (RHS->getNumUses() <= 2) { + // Reuse the LHS. This will eliminate the RHS. + if (D == A || D == B) { + // min(min(a, b), min(c, a)) --> min(min(a, b), c) + // min(min(a, b), min(c, b)) --> min(min(a, b), c) + MinMaxOp = LHS; + ThirdOp = C; + } else if (C == A || C == B) { + // min(min(a, b), min(b, d)) --> min(min(a, b), d) + // min(min(a, b), min(c, b)) --> min(min(a, b), d) + MinMaxOp = LHS; + ThirdOp = D; + } + } + if (!MinMaxOp || !ThirdOp) + return nullptr; + + CmpInst::Predicate P = getCmpPredicateForMinMax(SPF); + Value *CmpABC = Builder.CreateICmp(P, MinMaxOp, ThirdOp); + return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -1563,6 +1620,9 @@ Value *NewSel = Builder.CreateSelect(InvertedCmp, A, B); return BinaryOperator::CreateNot(NewSel); } + + if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) + return I; } if (SPF) { Index: llvm/trunk/test/Transforms/InstCombine/max-of-nots.ll =================================================================== --- llvm/trunk/test/Transforms/InstCombine/max-of-nots.ll +++ llvm/trunk/test/Transforms/InstCombine/max-of-nots.ll @@ -84,12 +84,10 @@ define i8 @umin3_not(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @umin3_not( -; CHECK-NEXT: [[CMPYX:%.*]] = icmp ult i8 %y, %x ; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i8 %x, %z ; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i8 %x, i8 %z -; CHECK-NEXT: [[TMP3:%.*]] = icmp ugt i8 %y, %z -; CHECK-NEXT: [[TMP4:%.*]] = select i1 [[TMP3]], i8 %y, i8 %z -; CHECK-NEXT: [[R_V:%.*]] = select i1 [[CMPYX]], i8 [[TMP2]], i8 [[TMP4]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp ugt i8 [[TMP2]], %y +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[TMP3]], i8 [[TMP2]], i8 %y ; CHECK-NEXT: [[R:%.*]] = xor i8 [[R:%.*]].v, -1 ; CHECK-NEXT: ret i8 [[R]] ; Index: llvm/trunk/test/Transforms/InstCombine/minmax-fold.ll =================================================================== --- llvm/trunk/test/Transforms/InstCombine/minmax-fold.ll +++ llvm/trunk/test/Transforms/InstCombine/minmax-fold.ll @@ -754,10 +754,8 @@ ; CHECK-LABEL: @common_factor_smin( ; CHECK-NEXT: [[CMP_AB:%.*]] = icmp slt i32 %a, %b ; CHECK-NEXT: [[MIN_AB:%.*]] = select i1 [[CMP_AB]], i32 %a, i32 %b -; CHECK-NEXT: [[CMP_BC:%.*]] = icmp slt i32 %b, %c -; CHECK-NEXT: [[MIN_BC:%.*]] = select i1 [[CMP_BC]], i32 %b, i32 %c -; CHECK-NEXT: [[CMP_AB_BC:%.*]] = icmp slt i32 [[MIN_AB]], [[MIN_BC]] -; CHECK-NEXT: [[MIN_ABC:%.*]] = select i1 [[CMP_AB_BC]], i32 [[MIN_AB]], i32 [[MIN_BC]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp slt i32 [[MIN_AB]], %c +; CHECK-NEXT: [[MIN_ABC:%.*]] = select i1 [[TMP1]], i32 [[MIN_AB]], i32 %c ; CHECK-NEXT: ret i32 [[MIN_ABC]] ; %cmp_ab = icmp slt i32 %a, %b @@ -775,10 +773,8 @@ ; CHECK-LABEL: @common_factor_smax( ; CHECK-NEXT: [[CMP_AB:%.*]] = icmp sgt <2 x i32> %a, %b ; CHECK-NEXT: [[MAX_AB:%.*]] = select <2 x i1> [[CMP_AB]], <2 x i32> %a, <2 x i32> %b -; CHECK-NEXT: [[CMP_CB:%.*]] = icmp sgt <2 x i32> %c, %b -; CHECK-NEXT: [[MAX_CB:%.*]] = select <2 x i1> [[CMP_CB]], <2 x i32> %c, <2 x i32> %b -; CHECK-NEXT: [[CMP_AB_CB:%.*]] = icmp sgt <2 x i32> [[MAX_AB]], [[MAX_CB]] -; CHECK-NEXT: [[MAX_ABC:%.*]] = select <2 x i1> [[CMP_AB_CB]], <2 x i32> [[MAX_AB]], <2 x i32> [[MAX_CB]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt <2 x i32> [[MAX_AB]], %c +; CHECK-NEXT: [[MAX_ABC:%.*]] = select <2 x i1> [[TMP1]], <2 x i32> [[MAX_AB]], <2 x i32> %c ; CHECK-NEXT: ret <2 x i32> [[MAX_ABC]] ; %cmp_ab = icmp sgt <2 x i32> %a, %b @@ -796,10 +792,8 @@ ; CHECK-LABEL: @common_factor_umin( ; CHECK-NEXT: [[CMP_BC:%.*]] = icmp ult <2 x i32> %b, %c ; CHECK-NEXT: [[MIN_BC:%.*]] = select <2 x i1> [[CMP_BC]], <2 x i32> %b, <2 x i32> %c -; CHECK-NEXT: [[CMP_AB:%.*]] = icmp ult <2 x i32> %a, %b -; CHECK-NEXT: [[MIN_AB:%.*]] = select <2 x i1> [[CMP_AB]], <2 x i32> %a, <2 x i32> %b -; CHECK-NEXT: [[CMP_BC_AB:%.*]] = icmp ult <2 x i32> [[MIN_BC]], [[MIN_AB]] -; CHECK-NEXT: [[MIN_ABC:%.*]] = select <2 x i1> [[CMP_BC_AB]], <2 x i32> [[MIN_BC]], <2 x i32> [[MIN_AB]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp ult <2 x i32> [[MIN_BC]], %a +; CHECK-NEXT: [[MIN_ABC:%.*]] = select <2 x i1> [[TMP1]], <2 x i32> [[MIN_BC]], <2 x i32> %a ; CHECK-NEXT: ret <2 x i32> [[MIN_ABC]] ; %cmp_bc = icmp ult <2 x i32> %b, %c @@ -817,10 +811,8 @@ ; CHECK-LABEL: @common_factor_umax( ; CHECK-NEXT: [[CMP_BC:%.*]] = icmp ugt i32 %b, %c ; CHECK-NEXT: [[MAX_BC:%.*]] = select i1 [[CMP_BC]], i32 %b, i32 %c -; CHECK-NEXT: [[CMP_BA:%.*]] = icmp ugt i32 %b, %a -; CHECK-NEXT: [[MAX_BA:%.*]] = select i1 [[CMP_BA]], i32 %b, i32 %a -; CHECK-NEXT: [[CMP_BC_BA:%.*]] = icmp ugt i32 [[MAX_BC]], [[MAX_BA]] -; CHECK-NEXT: [[MAX_ABC:%.*]] = select i1 [[CMP_BC_BA]], i32 [[MAX_BC]], i32 [[MAX_BA]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[MAX_BC]], %a +; CHECK-NEXT: [[MAX_ABC:%.*]] = select i1 [[TMP1]], i32 [[MAX_BC]], i32 %a ; CHECK-NEXT: ret i32 [[MAX_ABC]] ; %cmp_bc = icmp ugt i32 %b, %c @@ -838,10 +830,8 @@ ; CHECK-LABEL: @common_factor_umax_extra_use_lhs( ; CHECK-NEXT: [[CMP_BC:%.*]] = icmp ugt i32 %b, %c ; CHECK-NEXT: [[MAX_BC:%.*]] = select i1 [[CMP_BC]], i32 %b, i32 %c -; CHECK-NEXT: [[CMP_BA:%.*]] = icmp ugt i32 %b, %a -; CHECK-NEXT: [[MAX_BA:%.*]] = select i1 [[CMP_BA]], i32 %b, i32 %a -; CHECK-NEXT: [[CMP_BC_BA:%.*]] = icmp ugt i32 [[MAX_BC]], [[MAX_BA]] -; CHECK-NEXT: [[MAX_ABC:%.*]] = select i1 [[CMP_BC_BA]], i32 [[MAX_BC]], i32 [[MAX_BA]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[MAX_BC]], %a +; CHECK-NEXT: [[MAX_ABC:%.*]] = select i1 [[TMP1]], i32 [[MAX_BC]], i32 %a ; CHECK-NEXT: call void @extra_use(i32 [[MAX_BC]]) ; CHECK-NEXT: ret i32 [[MAX_ABC]] ; @@ -857,12 +847,10 @@ define i32 @common_factor_umax_extra_use_rhs(i32 %a, i32 %b, i32 %c) { ; CHECK-LABEL: @common_factor_umax_extra_use_rhs( -; CHECK-NEXT: [[CMP_BC:%.*]] = icmp ugt i32 %b, %c -; CHECK-NEXT: [[MAX_BC:%.*]] = select i1 [[CMP_BC]], i32 %b, i32 %c ; CHECK-NEXT: [[CMP_BA:%.*]] = icmp ugt i32 %b, %a ; CHECK-NEXT: [[MAX_BA:%.*]] = select i1 [[CMP_BA]], i32 %b, i32 %a -; CHECK-NEXT: [[CMP_BC_BA:%.*]] = icmp ugt i32 [[MAX_BC]], [[MAX_BA]] -; CHECK-NEXT: [[MAX_ABC:%.*]] = select i1 [[CMP_BC_BA]], i32 [[MAX_BC]], i32 [[MAX_BA]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[MAX_BA]], %c +; CHECK-NEXT: [[MAX_ABC:%.*]] = select i1 [[TMP1]], i32 [[MAX_BA]], i32 %c ; CHECK-NEXT: call void @extra_use(i32 [[MAX_BA]]) ; CHECK-NEXT: ret i32 [[MAX_ABC]] ;