diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1969,24 +1969,35 @@ return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy)); } - if (MulC->isZero() || (!Mul->hasNoSignedWrap() && !Mul->hasNoUnsignedWrap())) + if (MulC->isZero()) return nullptr; - // If the multiply does not wrap, try to divide the compare constant by the - // multiplication factor. + // If the multiply does not wrap or the constant is odd, try to divide the + // compare constant by the multiplication factor. if (Cmp.isEquality()) { - // (mul nsw X, MulC) == C --> X == C /s MulC + // (mul nsw X, MulC) eq/ne C --> X eq/ne C /s MulC if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) { Constant *NewC = ConstantInt::get(MulTy, C.sdiv(*MulC)); return new ICmpInst(Pred, X, NewC); } - // (mul nuw X, MulC) == C --> X == C /u MulC - if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) { - Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC)); - return new ICmpInst(Pred, X, NewC); + + // C % MulC == 0 is weaker than we could use if MulC is odd because it + // correct to transform if MulC * N == C including overflow. I.e with i8 + // (icmp eq (mul X, 5), 101) -> (icmp eq X, 225) but since 101 % 5 != 0, we + // miss that case. + if (C.urem(*MulC).isZero()) { + // (mul nuw X, MulC) eq/ne C --> X eq/ne C /u MulC + // (mul X, OddC) eq/ne N * C --> X eq/ne N + if ((*MulC & 1).isOne() || Mul->hasNoUnsignedWrap()) { + Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC)); + return new ICmpInst(Pred, X, NewC); + } } } + if (!Mul->hasNoSignedWrap() && !Mul->hasNoUnsignedWrap()) + return nullptr; + // With a matching no-overflow guarantee, fold the constants: // (X * MulC) < C --> X < (C / MulC) // (X * MulC) > C --> X > (C / MulC) @@ -4327,16 +4338,58 @@ } { - // Try to remove shared constant multiplier from equality comparison: - // X * C == Y * C (with no overflowing/aliasing) --> X == Y - Value *X, *Y; - const APInt *C; - if (match(Op0, m_Mul(m_Value(X), m_APInt(C))) && *C != 0 && - match(Op1, m_Mul(m_Value(Y), m_SpecificInt(*C))) && I.isEquality()) - if (!C->countTrailingZeros() || - (BO0 && BO1 && BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) || - (BO0 && BO1 && BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) - return new ICmpInst(Pred, X, Y); + // Try to remove shared multiplier from comparison: + // X * Z u{lt/le/gt/ge}/eq/ne Y * Z + Value *X, *Y, *A, *B; + if (Pred == ICmpInst::getUnsignedPredicate(Pred) && + match(Op0, m_Mul(m_Value(X), m_Value(A))) && + match(Op1, m_Mul(m_Value(Y), m_Value(B)))) { + // Find Z, we can't use `m_Specific(...)` in the match for + // Op1 as our 'Z' can be matched as either first or second operand in Op0. + Value *Z = nullptr; + if (X == Y || X == B) { + // 'X' is our 'Z' + // 'A' is our 'X' + Z = X; + X = A; + } else if (A == Y || A == B) { + // 'X' is our 'X' + // 'A' is our 'Z' + Z = A; + } + + if (Z == Y) { + // 'Y' matches 'Z' + // 'B' is our 'Y' + Y = B; + } + // else if (Z == B) -> 'Y' is our 'Y' + + if (Z != nullptr) { + bool NonZero; + if (ICmpInst::isEquality(Pred)) { + KnownBits ZKnown = computeKnownBits(Z, 0, &I); + // if Z % 2 != 0 + // X * Z eq/ne Y * Z -> X eq/ne Y + if (ZKnown.countMaxTrailingZeros() == 0) + return new ICmpInst(Pred, X, Y); + NonZero = !ZKnown.One.isZero() || + isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + // if Z != 0 and nsw(X * Z) and nsw(Y * Z) + // X * Z eq/ne Y * Z -> X eq/ne Y + if (NonZero && BO0 && BO1 && BO0->hasNoSignedWrap() && + BO1->hasNoSignedWrap()) + return new ICmpInst(Pred, X, Y); + } else + NonZero = isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + + // If Z != 0 and nuw(X * Z) and nuw(Y * Z) + // X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y + if (NonZero && BO0 && BO1 && BO0->hasNoUnsignedWrap() && + BO1->hasNoUnsignedWrap()) + return new ICmpInst(Pred, X, Y); + } + } } BinaryOperator *SRem = nullptr; diff --git a/llvm/test/Transforms/InstCombine/icmp-mul.ll b/llvm/test/Transforms/InstCombine/icmp-mul.ll --- a/llvm/test/Transforms/InstCombine/icmp-mul.ll +++ b/llvm/test/Transforms/InstCombine/icmp-mul.ll @@ -424,8 +424,7 @@ define i1 @eq_rem_zero_nonuw(i8 %x) { ; CHECK-LABEL: @eq_rem_zero_nonuw( -; CHECK-NEXT: [[A:%.*]] = mul i8 [[X:%.*]], 5 -; CHECK-NEXT: [[B:%.*]] = icmp eq i8 [[A]], 20 +; CHECK-NEXT: [[B:%.*]] = icmp eq i8 [[X:%.*]], 4 ; CHECK-NEXT: ret i1 [[B]] ; %a = mul i8 %x, 5 @@ -435,8 +434,7 @@ define i1 @ne_rem_zero_nonuw(i8 %x) { ; CHECK-LABEL: @ne_rem_zero_nonuw( -; CHECK-NEXT: [[A:%.*]] = mul i8 [[X:%.*]], 5 -; CHECK-NEXT: [[B:%.*]] = icmp ne i8 [[A]], 30 +; CHECK-NEXT: [[B:%.*]] = icmp ne i8 [[X:%.*]], 6 ; CHECK-NEXT: ret i1 [[B]] ; %a = mul i8 %x, 5 @@ -995,8 +993,7 @@ define <2 x i1> @mul_oddC_ne_vec(<2 x i8> %v) { ; CHECK-LABEL: @mul_oddC_ne_vec( -; CHECK-NEXT: [[MUL:%.*]] = mul <2 x i8> [[V:%.*]], -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[MUL]], +; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[V:%.*]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %mul = mul <2 x i8> %v, @@ -1050,9 +1047,7 @@ ; CHECK-NEXT: [[LB:%.*]] = and i8 [[Z:%.*]], 1 ; CHECK-NEXT: [[NZ:%.*]] = icmp ne i8 [[LB]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[NZ]]) -; CHECK-NEXT: [[MULX:%.*]] = mul i8 [[X:%.*]], [[Z]] -; CHECK-NEXT: [[MULY:%.*]] = mul i8 [[Z]], [[Y:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[MULX]], [[MULY]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i1 [[CMP]] ; %lb = and i8 %z, 1 @@ -1067,9 +1062,8 @@ define <2 x i1> @reused_mul_nsw_xy_z_setnonzero_vec_ne(<2 x i8> %x, <2 x i8> %y, <2 x i8> %zi) { ; CHECK-LABEL: @reused_mul_nsw_xy_z_setnonzero_vec_ne( ; CHECK-NEXT: [[Z:%.*]] = or <2 x i8> [[ZI:%.*]], -; CHECK-NEXT: [[MULX:%.*]] = mul nsw <2 x i8> [[Z]], [[X:%.*]] ; CHECK-NEXT: [[MULY:%.*]] = mul nsw <2 x i8> [[Z]], [[Y:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[MULY]], [[MULX]] +; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[Y]], [[X:%.*]] ; CHECK-NEXT: call void @usev2xi8(<2 x i8> [[MULY]]) ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; @@ -1101,8 +1095,7 @@ ; CHECK-NEXT: [[NZ:%.*]] = icmp ne i8 [[Z:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[NZ]]) ; CHECK-NEXT: [[MULX:%.*]] = mul nuw i8 [[X:%.*]], [[Z]] -; CHECK-NEXT: [[MULY:%.*]] = mul nuw i8 [[Y:%.*]], [[Z]] -; CHECK-NEXT: [[CMP:%.*]] = icmp uge i8 [[MULY]], [[MULX]] +; CHECK-NEXT: [[CMP:%.*]] = icmp uge i8 [[Y:%.*]], [[X]] ; CHECK-NEXT: call void @use(i8 [[MULX]]) ; CHECK-NEXT: ret i1 [[CMP]] ; @@ -1117,10 +1110,7 @@ define <2 x i1> @mul_nuw_xy_z_setnonzero_vec_eq(<2 x i8> %x, <2 x i8> %y, <2 x i8> %zi) { ; CHECK-LABEL: @mul_nuw_xy_z_setnonzero_vec_eq( -; CHECK-NEXT: [[Z:%.*]] = or <2 x i8> [[ZI:%.*]], -; CHECK-NEXT: [[MULX:%.*]] = mul nuw <2 x i8> [[Z]], [[X:%.*]] -; CHECK-NEXT: [[MULY:%.*]] = mul nuw <2 x i8> [[Z]], [[Y:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[MULX]], [[MULY]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %z = or <2 x i8> %zi, @@ -1135,9 +1125,7 @@ ; CHECK-NEXT: [[NZ_NOT:%.*]] = icmp eq i8 [[Z:%.*]], 0 ; CHECK-NEXT: br i1 [[NZ_NOT]], label [[FALSE:%.*]], label [[TRUE:%.*]] ; CHECK: true: -; CHECK-NEXT: [[MULX:%.*]] = mul nuw i8 [[X:%.*]], [[Z]] -; CHECK-NEXT: [[MULY:%.*]] = mul nuw i8 [[Y:%.*]], [[Z]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[MULY]], [[MULX]] +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[Y:%.*]], [[X:%.*]] ; CHECK-NEXT: ret i1 [[CMP]] ; CHECK: false: ; CHECK-NEXT: call void @use(i8 [[Z]]) diff --git a/llvm/test/Transforms/InstCombine/pr38677.ll b/llvm/test/Transforms/InstCombine/pr38677.ll --- a/llvm/test/Transforms/InstCombine/pr38677.ll +++ b/llvm/test/Transforms/InstCombine/pr38677.ll @@ -12,9 +12,7 @@ ; CHECK-NEXT: br label [[FINAL]] ; CHECK: final: ; CHECK-NEXT: [[USE2:%.*]] = phi i32 [ 1, [[ENTRY:%.*]] ], [ select (i1 icmp eq (ptr @A, ptr @B), i32 2, i32 1), [[DELAY]] ] -; CHECK-NEXT: [[B7:%.*]] = mul i32 [[USE2]], 2147483647 -; CHECK-NEXT: [[C3:%.*]] = icmp eq i32 [[B7]], 0 -; CHECK-NEXT: store i1 [[C3]], ptr [[DST:%.*]], align 1 +; CHECK-NEXT: store i1 false, ptr [[DST:%.*]], align 1 ; CHECK-NEXT: ret i32 [[USE2]] ; entry: