Index: llvm/include/llvm/Analysis/ConstraintSystem.h =================================================================== --- llvm/include/llvm/Analysis/ConstraintSystem.h +++ llvm/include/llvm/Analysis/ConstraintSystem.h @@ -128,6 +128,22 @@ return R; } + static SmallVector negateOrEqual(SmallVector R) { + // The negated constraint R is obtained by multiplying by -1. + for (auto &C : R) + if (MulOverflow(C, int64_t(-1), C)) + return {}; + return R; + } + + static SmallVector toStrictLessThan(SmallVector R) { + // The strict less than is obtained by subtracting 1 from the constant. + if (SubOverflow(R[0], int64_t(1), R[0])) { + return {}; + } + return R; + } + bool isConditionImplied(SmallVector R) const; SmallVector getLastConstraint() const { Index: llvm/lib/Transforms/Scalar/ConstraintElimination.cpp =================================================================== --- llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -547,7 +547,7 @@ SmallVector NewVariables; ConstraintTy R = getConstraint(Pred, Op0, Op1, NewVariables); - if (R.IsEq || !NewVariables.empty()) + if (!NewVariables.empty()) return {}; return R; } @@ -947,43 +947,80 @@ CSToUse.popLastConstraint(); }); - bool Changed = false; - if (CSToUse.isConditionImplied(R.Coefficients)) { + auto ReplaceCmpWithConstant = [&](CmpInst *Cmp, bool IsTrue) { if (!DebugCounter::shouldExecute(EliminatedCounter)) return false; LLVM_DEBUG({ - dbgs() << "Condition " << *Cmp << " implied by dominating constraints\n"; + dbgs() << "Condition " << (IsTrue ? "" : "!") << *Cmp + << " implied by dominating constraints\n"; CSToUse.dump(); }); generateReproducer(Cmp, ReproducerModule, ReproducerCondStack, Info, DT); - Constant *TrueC = - ConstantInt::getTrue(CmpInst::makeCmpResultType(Cmp->getType())); - Cmp->replaceUsesWithIf(TrueC, [](Use &U) { + Constant *ConstantC = + IsTrue + ? ConstantInt::getTrue(CmpInst::makeCmpResultType(Cmp->getType())) + : ConstantInt::getFalse(CmpInst::makeCmpResultType(Cmp->getType())); + Cmp->replaceUsesWithIf(ConstantC, [](Use &U) { // Conditions in an assume trivially simplify to true. Skip uses // in assume calls to not destroy the available information. auto *II = dyn_cast(U.getUser()); return !II || II->getIntrinsicID() != Intrinsic::assume; }); NumCondsRemoved++; - Changed = true; - } - auto Negated = ConstraintSystem::negate(R.Coefficients); - if (!Negated.empty() && CSToUse.isConditionImplied(Negated)) { - if (!DebugCounter::shouldExecute(EliminatedCounter)) - return false; + return true; + }; - LLVM_DEBUG({ - dbgs() << "Condition !" << *Cmp << " implied by dominating constraints\n"; - CSToUse.dump(); - }); - generateReproducer(Cmp, ReproducerModule, ReproducerCondStack, Info, DT); - Constant *FalseC = - ConstantInt::getFalse(CmpInst::makeCmpResultType(Cmp->getType())); - Cmp->replaceAllUsesWith(FalseC); - NumCondsRemoved++; - Changed = true; + bool Changed = false; + bool IsConditionImplied = CSToUse.isConditionImplied(R.Coefficients); + + if (R.IsEq) { + auto NegatedOrEqual = ConstraintSystem::negateOrEqual(R.Coefficients); + bool IsNegatedOrEqualImplied = + !NegatedOrEqual.empty() && CSToUse.isConditionImplied(NegatedOrEqual); + + auto Negated = ConstraintSystem::negate(R.Coefficients); + bool IsNegatedImplied = + !Negated.empty() && CSToUse.isConditionImplied(Negated); + + auto StrictLessThan = ConstraintSystem::toStrictLessThan(R.Coefficients); + bool IsStrictLessThanImplied = + !StrictLessThan.empty() && CSToUse.isConditionImplied(StrictLessThan); + + // In order to check that `%a == %b` is true, we want to check that + // that `%a >= %b` and `%a <= %b` must hold. + if (IsConditionImplied && IsNegatedOrEqualImplied) { + Changed = ReplaceCmpWithConstant(Cmp, true); + if (!Changed) + return false; + } + + // In order to check that `%a == %b` is false, we want to check whether + // either `%a > %b` or `%a < %b` holds. + if (IsNegatedImplied || IsStrictLessThanImplied) { + Changed = ReplaceCmpWithConstant(Cmp, false); + if (!Changed) + return false; + } + } else { + if (IsConditionImplied) { + Changed = ReplaceCmpWithConstant(Cmp, true); + if (!Changed) + return false; + } + + // Compute them separately. + auto Negated = ConstraintSystem::negate(R.Coefficients); + auto IsNegatedImplied = + !Negated.empty() && CSToUse.isConditionImplied(Negated); + + if (IsNegatedImplied) { + Changed = ReplaceCmpWithConstant(Cmp, false); + if (!Changed) + return false; + } } + return Changed; } Index: llvm/test/Transforms/ConstraintElimination/assumes.ll =================================================================== --- llvm/test/Transforms/ConstraintElimination/assumes.ll +++ llvm/test/Transforms/ConstraintElimination/assumes.ll @@ -622,3 +622,18 @@ tail call void @llvm.assume(i1 %c.2) ret i1 %c.2 } + +define i1 @assume_x_ult_y_plus_1(i64 %x, i64 %y) { +; CHECK-LABEL: @assume_x_ult_y_plus_1( +; CHECK-NEXT: [[TMP1:%.*]] = add nuw i64 [[Y:%.*]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i64 [[TMP1]], [[X:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[X]], [[Y]] +; CHECK-NEXT: ret i1 false +; + %1 = add nuw i64 %y, 1 + %2 = icmp ult i64 %1, %x + tail call void @llvm.assume(i1 %2) + %3 = icmp eq i64 %x, %y + ret i1 %3 +} Index: llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll =================================================================== --- llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll +++ llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll @@ -75,9 +75,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[F_0:%.*]] = icmp eq i8 10, 11 ; CHECK-NEXT: [[T_0:%.*]] = icmp eq i8 10, 10 -; CHECK-NEXT: [[RES_1:%.*]] = xor i1 [[T_0]], [[F_0]] +; CHECK-NEXT: [[RES_1:%.*]] = xor i1 true, false ; CHECK-NEXT: [[F_1:%.*]] = icmp eq i8 10, 9 -; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], [[F_1]] +; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], false ; CHECK-NEXT: ret i1 [[RES_2]] ; entry: Index: llvm/test/Transforms/ConstraintElimination/eq.ll =================================================================== --- llvm/test/Transforms/ConstraintElimination/eq.ll +++ llvm/test/Transforms/ConstraintElimination/eq.ll @@ -11,9 +11,9 @@ ; CHECK-NEXT: [[T_2:%.*]] = icmp ule i8 [[A]], [[B]] ; CHECK-NEXT: [[RES_1:%.*]] = xor i1 true, true ; CHECK-NEXT: [[T_3:%.*]] = icmp eq i8 [[A]], [[B]] -; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], [[T_3]] +; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], true ; CHECK-NEXT: [[T_4:%.*]] = icmp eq i8 [[B]], [[A]] -; CHECK-NEXT: [[RES_3:%.*]] = xor i1 [[RES_2]], [[T_4]] +; CHECK-NEXT: [[RES_3:%.*]] = xor i1 [[RES_2]], true ; CHECK-NEXT: [[F_1:%.*]] = icmp ugt i8 [[B]], [[A]] ; CHECK-NEXT: [[RES_4:%.*]] = xor i1 [[RES_3]], false ; CHECK-NEXT: [[F_2:%.*]] = icmp ult i8 [[B]], [[A]] @@ -376,3 +376,33 @@ ret i1 %xor.12 } + +define i1 @test_transitivity_of_equality(i64 %a, i64 %b, i64 %c) { +; CHECK-LABEL: @test_transitivity_of_equality( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[PRE_1:%.*]] = icmp eq i64 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: br i1 [[PRE_1]], label [[AB_EQUAL:%.*]], label [[NOT_EQ:%.*]] +; CHECK: ab_equal: +; CHECK-NEXT: [[BC_EQ:%.*]] = icmp eq i64 [[B]], [[C:%.*]] +; CHECK-NEXT: br i1 [[BC_EQ]], label [[BC_EQUAL:%.*]], label [[NOT_EQ]] +; CHECK: bc_equal: +; CHECK-NEXT: [[AC_EQ:%.*]] = icmp eq i64 [[A]], [[C]] +; CHECK-NEXT: ret i1 true +; CHECK: not_eq: +; CHECK-NEXT: ret i1 false +; +entry: + %pre.1 = icmp eq i64 %a, %b + br i1 %pre.1, label %ab_equal, label %not_eq + +ab_equal: + %bc_eq = icmp eq i64 %b, %c + br i1 %bc_eq, label %bc_equal, label %not_eq + +bc_equal: + %ac_eq = icmp eq i64 %a, %c + ret i1 %ac_eq + +not_eq: + ret i1 false +}