diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1165,15 +1165,32 @@ /// /// We can't replace %sel with %add unless we strip away the flags. /// TODO: Wrapping flags could be preserved in some cases with better analysis. -static Value *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp, - const SimplifyQuery &Q) { +static Instruction *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp, + const SimplifyQuery &Q, + InstCombiner &IC) { if (!Cmp.isEquality()) return nullptr; // Canonicalize the pattern to ICMP_EQ by swapping the select operands. Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue(); - if (Cmp.getPredicate() == ICmpInst::ICMP_NE) + bool Swapped = false; + if (Cmp.getPredicate() == ICmpInst::ICMP_NE) { std::swap(TrueVal, FalseVal); + Swapped = true; + } + + // In X == Y ? f(X) : Z, try to evaluate f(X) and replace the operand. + // Take care to avoid replacing X == Y ? X : Z with X == Y ? Y : Z, as that + // would lead to an infinite replacement cycle. + Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); + if (TrueVal != CmpLHS) + if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, + /* AllowRefinement */ true)) + return IC.replaceOperand(Sel, Swapped ? 2 : 1, V); + if (TrueVal != CmpRHS) + if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, + /* AllowRefinement */ true)) + return IC.replaceOperand(Sel, Swapped ? 2 : 1, V); auto *FalseInst = dyn_cast(FalseVal); if (!FalseInst) @@ -1198,12 +1215,11 @@ // We have an 'EQ' comparison, so the select's false value will propagate. // Example: // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 - Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, /* AllowRefinement */ false) == TrueVal || SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, /* AllowRefinement */ false) == TrueVal) { - return FalseVal; + return IC.replaceInstUsesWith(Sel, FalseVal); } // Restore poison-generating flags if the transform did not apply. @@ -1439,8 +1455,8 @@ /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { - if (Value *V = foldSelectValueEquivalence(SI, *ICI, SQ)) - return replaceInstUsesWith(SI, V); + if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI, SQ, *this)) + return NewSel; if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, *this)) return NewSel; diff --git a/llvm/test/Transforms/InstCombine/rem.ll b/llvm/test/Transforms/InstCombine/rem.ll --- a/llvm/test/Transforms/InstCombine/rem.ll +++ b/llvm/test/Transforms/InstCombine/rem.ll @@ -50,8 +50,7 @@ define i5 @biggest_divisor(i5 %x) { ; CHECK-LABEL: @biggest_divisor( ; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i5 [[X:%.*]], -1 -; CHECK-NEXT: [[TMP1:%.*]] = zext i1 [[DOTNOT]] to i5 -; CHECK-NEXT: [[REM:%.*]] = add i5 [[TMP1]], [[X]] +; CHECK-NEXT: [[REM:%.*]] = select i1 [[DOTNOT]], i5 0, i5 [[X]] ; CHECK-NEXT: ret i5 [[REM]] ; %rem = urem i5 %x, -1 diff --git a/llvm/test/Transforms/InstCombine/select-binop-cmp.ll b/llvm/test/Transforms/InstCombine/select-binop-cmp.ll --- a/llvm/test/Transforms/InstCombine/select-binop-cmp.ll +++ b/llvm/test/Transforms/InstCombine/select-binop-cmp.ll @@ -564,12 +564,10 @@ ret <2 x i8> %C } -; TODO: support for undefs, check for an identity constant does not handle them yet -define <2 x i8> @select_xor_icmp_vec_bad_2(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z) { -; CHECK-LABEL: @select_xor_icmp_vec_bad_2( +define <2 x i8> @select_xor_icmp_vec_undef(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z) { +; CHECK-LABEL: @select_xor_icmp_vec_undef( ; CHECK-NEXT: [[A:%.*]] = icmp eq <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[B:%.*]] = xor <2 x i8> [[X]], [[Z:%.*]] -; CHECK-NEXT: [[C:%.*]] = select <2 x i1> [[A]], <2 x i8> [[B]], <2 x i8> [[Y:%.*]] +; CHECK-NEXT: [[C:%.*]] = select <2 x i1> [[A]], <2 x i8> [[Z:%.*]], <2 x i8> [[Y:%.*]] ; CHECK-NEXT: ret <2 x i8> [[C]] ; %A = icmp eq <2 x i8> %x, @@ -604,11 +602,10 @@ ret i32 %C } -define i32 @select_and_icmp_bad(i32 %x, i32 %y, i32 %z) { -; CHECK-LABEL: @select_and_icmp_bad( +define i32 @select_and_icmp_zero(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: @select_and_icmp_zero( ; CHECK-NEXT: [[A:%.*]] = icmp eq i32 [[X:%.*]], 0 -; CHECK-NEXT: [[B:%.*]] = and i32 [[X]], [[Z:%.*]] -; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[B]], i32 [[Y:%.*]] +; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 0, i32 [[Y:%.*]] ; CHECK-NEXT: ret i32 [[C]] ; %A = icmp eq i32 %x, 0 diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll --- a/llvm/test/Transforms/InstCombine/select.ll +++ b/llvm/test/Transforms/InstCombine/select.ll @@ -2606,8 +2606,7 @@ define i8 @select_replacement_add_eq(i8 %x, i8 %y) { ; CHECK-LABEL: @select_replacement_add_eq( ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], 1 -; CHECK-NEXT: [[ADD:%.*]] = add i8 [[X]], 1 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[ADD]], i8 [[Y:%.*]] +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 2, i8 [[Y:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] ; %cmp = icmp eq i8 %x, 1 @@ -2620,8 +2619,7 @@ ; CHECK-LABEL: @select_replacement_add_ne( ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[X:%.*]], 1 ; CHECK-NEXT: call void @use(i1 [[CMP]]) -; CHECK-NEXT: [[ADD:%.*]] = add i8 [[X]], 1 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[Y:%.*]], i8 [[ADD]] +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[Y:%.*]], i8 2 ; CHECK-NEXT: ret i8 [[SEL]] ; %cmp = icmp ne i8 %x, 1 @@ -2634,8 +2632,7 @@ define i8 @select_replacement_add_nuw(i8 %x, i8 %y) { ; CHECK-LABEL: @select_replacement_add_nuw( ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], 1 -; CHECK-NEXT: [[ADD:%.*]] = add nuw i8 [[X]], 1 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[ADD]], i8 [[Y:%.*]] +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 2, i8 [[Y:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] ; %cmp = icmp eq i8 %x, 1 @@ -2647,8 +2644,7 @@ define i8 @select_replacement_sub(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @select_replacement_sub( ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[SUB:%.*]] = sub i8 [[X]], [[Y]] -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[SUB]], i8 [[Z:%.*]] +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 0, i8 [[Z:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] ; %cmp = icmp eq i8 %x, %y @@ -2661,8 +2657,7 @@ ; CHECK-LABEL: @select_replacement_shift( ; CHECK-NEXT: [[SHR:%.*]] = lshr exact i8 [[X:%.*]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[SHR]], [[Y:%.*]] -; CHECK-NEXT: [[SHL:%.*]] = shl i8 [[Y]], 1 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[SHL]], i8 [[Z:%.*]] +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[X]], i8 [[Z:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] ; %shr = lshr exact i8 %x, 1