Index: lib/Transforms/InstCombine/InstCombineCompares.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCompares.cpp +++ lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2412,6 +2412,20 @@ // however such select could have been recognized as a bit test of the // highest big by other combine transforms. Here we try to recognize this // select in some particular cases. + Value *MaybeShiftWithCast; + Value *ShiftValue = ConstantInt::get( + LHS->getType(), LHS->getType()->getIntegerBitWidth() - 1); + auto IsAShrByWidthMaybeCasted = [&](Value *V) { + // If it is a cast, check that it's a cast to the type we expect and + // take its argument. + if (auto *Cast = dyn_cast(V)) { + if (Cast->getDestTy() != Equal->getType()) + return false; + V = Cast->getOperand(0); + } + // Check that the value is a signed right shift of LHS by ShiftValue. + return match(V, m_AShr(m_Specific(LHS), m_Specific(ShiftValue))); + }; // Try to recognize: // select i1 (a < 0), i32 -1, i32 Greater // simplified in form: @@ -2419,11 +2433,9 @@ // If %a was negative, then the result of "ashr" is -1 and the result of // "or" is also -1. If %a was non-negative, then the result of "ashr" is 0 // and the result of "or" is "Greater". - Value *ShiftValue = ConstantInt::get( - LHS->getType(), LHS->getType()->getIntegerBitWidth() - 1); if (match(SI->getFalseValue(), - m_Or(m_AShr(m_Specific(LHS), m_Specific(ShiftValue)), - m_ConstantInt(Greater)))) { + m_Or(m_Value(MaybeShiftWithCast), m_ConstantInt(Greater))) && + IsAShrByWidthMaybeCasted(MaybeShiftWithCast)) { Less = ConstantInt::get(Greater->getType(), -1, true); return true; } @@ -2434,9 +2446,10 @@ // (a s>> 31) & (Less - Greater) + Greater ConstantInt *LessMinusGreater; if (match(SI->getFalseValue(), - m_Add(m_And(m_AShr(m_Specific(LHS), m_Specific(ShiftValue)), + m_Add(m_And(m_Value(MaybeShiftWithCast), m_ConstantInt(LessMinusGreater)), - m_ConstantInt(Greater)))) { + m_ConstantInt(Greater))) && + IsAShrByWidthMaybeCasted(MaybeShiftWithCast)) { APInt LessI = LessMinusGreater->getValue() + Greater->getValue(); Less = ConstantInt::get(Greater->getContext(), LessI); return true; @@ -2449,9 +2462,10 @@ // It is used instead of the pattern above in case if we can prove that // ((Less - Greater) & Greater) = 0. if (match(SI->getFalseValue(), - m_Or(m_And(m_AShr(m_Specific(LHS), m_Specific(ShiftValue)), + m_Or(m_And(m_Value(MaybeShiftWithCast), m_ConstantInt(LessMinusGreater)), m_ConstantInt(Greater))) && + IsAShrByWidthMaybeCasted(MaybeShiftWithCast) && (LessMinusGreater->getValue() & Greater->getValue()).isNullValue()) { APInt LessI = LessMinusGreater->getValue() + Greater->getValue(); Less = ConstantInt::get(Greater->getContext(), LessI); Index: test/Transforms/InstCombine/three-way-comparison.ll =================================================================== --- test/Transforms/InstCombine/three-way-comparison.ll +++ test/Transforms/InstCombine/three-way-comparison.ll @@ -372,3 +372,149 @@ exit: ret i32 42 } + +define i32 @compare_against_arbitrary_value_type_mismatch(i64 %x, i64 %c) { +; TODO: We can prove that if %x s> %c then %x != c, so there should be no actual +; calculations in callfoo block. @foo can be invoked with 1. We only do it +; for constants that are not 0 currently while it could be generalized. +; CHECK-LABEL: @compare_against_arbitrary_value_type_mismatch( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i64 [[X:%.*]], [[C:%.*]] +; CHECK-NEXT: br i1 [[TMP0]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] +; CHECK: callfoo: +; CHECK-NEXT: [[CMP1:%.*]] = icmp ne i64 [[X]], [[C]] +; CHECK-NEXT: [[SELECT2:%.*]] = zext i1 [[CMP1]] to i32 +; CHECK-NEXT: call void @foo(i32 [[SELECT2]]) +; CHECK-NEXT: br label [[EXIT]] +; CHECK: exit: +; CHECK-NEXT: ret i32 42 +; + +entry: + %cmp1 = icmp eq i64 %x, %c + %cmp2 = icmp slt i64 %x, %c + %select1 = select i1 %cmp2, i32 -1, i32 1 + %select2 = select i1 %cmp1, i32 0, i32 %select1 + %cond = icmp sgt i32 %select2, 0 + br i1 %cond, label %callfoo, label %exit + +callfoo: + call void @foo(i32 %select2) + br label %exit + +exit: + ret i32 42 +} + +define i32 @compare_against_zero_type_mismatch_idiomatic(i64 %x) { +; TODO: We can prove that if %x s> %c then %x != c, so there should be no actual +; calculations in callfoo block. @foo can be invoked with 1. For some +; reasons it does not happen for constant zero while it does happen for +; other constants. +; CHECK-LABEL: @compare_against_zero_type_mismatch_idiomatic( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[X:%.*]], 0 +; CHECK-NEXT: [[TMP0:%.*]] = ashr i64 [[X]], 63 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], 1 +; CHECK-NEXT: [[SELECT2:%.*]] = select i1 [[CMP1]], i32 0, i32 [[TMP2]] +; CHECK-NEXT: [[COND:%.*]] = icmp sgt i32 [[SELECT2]], 0 +; CHECK-NEXT: br i1 [[COND]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] +; CHECK: callfoo: +; CHECK-NEXT: call void @foo(i32 [[SELECT2]]) +; CHECK-NEXT: br label [[EXIT]] +; CHECK: exit: +; CHECK-NEXT: ret i32 42 +; + +entry: + %cmp1 = icmp eq i64 %x, 0 + %cmp2 = icmp slt i64 %x, 0 + %select1 = select i1 %cmp2, i32 -1, i32 1 + %select2 = select i1 %cmp1, i32 0, i32 %select1 + %cond = icmp sgt i32 %select2, 0 + br i1 %cond, label %callfoo, label %exit + +callfoo: + call void @foo(i32 %select2) + br label %exit + +exit: + ret i32 42 +} + +define i32 @compare_against_zero_type_mismatch_non_idiomatic_1(i64 %x) { +; TODO: We can prove that if %x s> %c then %x != c, so there should be no actual +; calculations in callfoo block. @foo can be invoked with 1. For some +; reasons it does not happen for constant zero while it does happen for +; other constants. +; CHECK-LABEL: @compare_against_zero_type_mismatch_non_idiomatic_1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[X:%.*]], 0 +; CHECK-NEXT: [[TMP0:%.*]] = ashr i64 [[X]], 63 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], -8 +; CHECK-NEXT: [[TMP3:%.*]] = or i32 [[TMP2]], 1 +; CHECK-NEXT: [[SELECT2:%.*]] = select i1 [[CMP1]], i32 0, i32 [[TMP3]] +; CHECK-NEXT: [[COND:%.*]] = icmp sgt i32 [[SELECT2]], 0 +; CHECK-NEXT: br i1 [[COND]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] +; CHECK: callfoo: +; CHECK-NEXT: call void @foo(i32 [[SELECT2]]) +; CHECK-NEXT: br label [[EXIT]] +; CHECK: exit: +; CHECK-NEXT: ret i32 42 +; + +entry: + %cmp1 = icmp eq i64 %x, 0 + %cmp2 = icmp slt i64 %x, 0 + %select1 = select i1 %cmp2, i32 -7, i32 1 + %select2 = select i1 %cmp1, i32 0, i32 %select1 + %cond = icmp sgt i32 %select2, 0 + br i1 %cond, label %callfoo, label %exit + +callfoo: + call void @foo(i32 %select2) + br label %exit + +exit: + ret i32 42 +} + +define i32 @compare_against_zero_type_mismatch_non_idiomatic_2(i64 %x) { +; TODO: We can prove that if %x s> %c then %x != c, so there should be no actual +; calculations in callfoo block. @foo can be invoked with 1. For some +; reasons it does not happen for constant zero while it does happen for +; other constants. +; CHECK-LABEL: @compare_against_zero_type_mismatch_non_idiomatic_2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[X:%.*]], 0 +; CHECK-NEXT: [[TMP0:%.*]] = ashr i64 [[X]], 63 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], -7 +; CHECK-NEXT: [[TMP3:%.*]] = add nsw i32 [[TMP2]], 1 +; CHECK-NEXT: [[SELECT2:%.*]] = select i1 [[CMP1]], i32 0, i32 [[TMP3]] +; CHECK-NEXT: [[COND:%.*]] = icmp sgt i32 [[SELECT2]], 0 +; CHECK-NEXT: br i1 [[COND]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] +; CHECK: callfoo: +; CHECK-NEXT: call void @foo(i32 [[SELECT2]]) +; CHECK-NEXT: br label [[EXIT]] +; CHECK: exit: +; CHECK-NEXT: ret i32 42 +; + +entry: + %cmp1 = icmp eq i64 %x, 0 + %cmp2 = icmp slt i64 %x, 0 + %select1 = select i1 %cmp2, i32 -6, i32 1 + %select2 = select i1 %cmp1, i32 0, i32 %select1 + %cond = icmp sgt i32 %select2, 0 + br i1 %cond, label %callfoo, label %exit + +callfoo: + call void @foo(i32 %select2) + br label %exit + +exit: + ret i32 42 +}