Index: llvm/include/llvm/Analysis/ConstraintSystem.h =================================================================== --- llvm/include/llvm/Analysis/ConstraintSystem.h +++ llvm/include/llvm/Analysis/ConstraintSystem.h @@ -128,6 +128,22 @@ return R; } + static SmallVector negateOrEqual(SmallVector R) { + // The negated constraint R is obtained by multiplying by -1. + for (auto &C : R) + if (MulOverflow(C, int64_t(-1), C)) + return {}; + return R; + } + + static SmallVector toStrictLessThan(SmallVector R) { + // The strict less than is obtained by subtracting 1 from the constant. + if (SubOverflow(R[0], int64_t(1), R[0])) { + return {}; + } + return R; + } + bool isConditionImplied(SmallVector R) const; SmallVector getLastConstraint() const { Index: llvm/lib/Transforms/Scalar/ConstraintElimination.cpp =================================================================== --- llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -102,6 +102,8 @@ : Pred(Pred), Op0(Op0), Op1(Op1) {} }; +enum class EqualityImpliedTy { None, HoldTrue, HoldFalse, DoesNotHold }; + struct ConstraintTy { SmallVector Coefficients; SmallVector Preconditions; @@ -109,12 +111,11 @@ SmallVector> ExtraInfo; bool IsSigned = false; - bool IsEq = false; ConstraintTy() = default; - ConstraintTy(SmallVector Coefficients, bool IsSigned) - : Coefficients(Coefficients), IsSigned(IsSigned) {} + ConstraintTy(SmallVector Coefficients, bool IsSigned, bool IsEq) + : Coefficients(Coefficients), IsSigned(IsSigned), IsEq(IsEq) {} unsigned size() const { return Coefficients.size(); } @@ -123,6 +124,13 @@ /// Returns true if all preconditions for this list of constraints are /// satisfied given \p CS and the corresponding \p Value2Index mapping. bool isValid(const ConstraintInfo &Info) const; + + bool isEq() const { return IsEq; } + + EqualityImpliedTy isEqImplied(const ConstraintSystem &CS) const; + +private: + bool IsEq = false; }; /// Wrapper encapsulating separate constraint systems and corresponding value @@ -480,11 +488,10 @@ // subtracting all coefficients from B. ConstraintTy Res( SmallVector(Value2Index.size() + NewVariables.size() + 1, 0), - IsSigned); + IsSigned, IsEq); // Collect variables that are known to be positive in all uses in the // constraint. DenseMap KnownNonNegativeVariables; - Res.IsEq = IsEq; auto &R = Res.Coefficients; for (const auto &KV : VariablesA) { R[GetOrAddIndex(KV.Variable)] += KV.Coefficient; @@ -547,7 +554,7 @@ SmallVector NewVariables; ConstraintTy R = getConstraint(Pred, Op0, Op1, NewVariables); - if (R.IsEq || !NewVariables.empty()) + if (!NewVariables.empty()) return {}; return R; } @@ -559,6 +566,36 @@ }); } +EqualityImpliedTy ConstraintTy::isEqImplied(const ConstraintSystem &CS) const { + if (!IsEq) + return EqualityImpliedTy::None; + + bool IsConditionImplied = CS.isConditionImplied(Coefficients); + + auto NegatedOrEqual = ConstraintSystem::negateOrEqual(Coefficients); + bool IsNegatedOrEqualImplied = + !NegatedOrEqual.empty() && CS.isConditionImplied(NegatedOrEqual); + + auto Negated = ConstraintSystem::negate(Coefficients); + bool IsNegatedImplied = !Negated.empty() && CS.isConditionImplied(Negated); + + auto StrictLessThan = ConstraintSystem::toStrictLessThan(Coefficients); + bool IsStrictLessThanImplied = + !StrictLessThan.empty() && CS.isConditionImplied(StrictLessThan); + + // In order to check that `%a == %b` is true, we want to check that `%a >= %b` + // and `%a <= %b` must hold. + if (IsConditionImplied && IsNegatedOrEqualImplied) + return EqualityImpliedTy::HoldTrue; + + // In order to check that `%a == %b` is false, we want to check whether either + // `%a > %b` or `%a < %b` holds. + if (IsNegatedImplied || IsStrictLessThanImplied) + return EqualityImpliedTy::HoldFalse; + + return EqualityImpliedTy::DoesNotHold; +} + bool ConstraintInfo::doesHold(CmpInst::Predicate Pred, Value *A, Value *B) const { auto R = getConstraintForSolving(Pred, A, B); @@ -947,43 +984,66 @@ CSToUse.popLastConstraint(); }); - bool Changed = false; - if (CSToUse.isConditionImplied(R.Coefficients)) { + auto ReplaceCmpWithConstant = [&](CmpInst *Cmp, bool IsTrue) { if (!DebugCounter::shouldExecute(EliminatedCounter)) return false; LLVM_DEBUG({ - dbgs() << "Condition " << *Cmp << " implied by dominating constraints\n"; + dbgs() << "Condition " << (IsTrue ? "" : "!") << *Cmp + << " implied by dominating constraints\n"; CSToUse.dump(); }); generateReproducer(Cmp, ReproducerModule, ReproducerCondStack, Info, DT); - Constant *TrueC = - ConstantInt::getTrue(CmpInst::makeCmpResultType(Cmp->getType())); - Cmp->replaceUsesWithIf(TrueC, [](Use &U) { + Constant *ConstantC = + IsTrue + ? ConstantInt::getTrue(CmpInst::makeCmpResultType(Cmp->getType())) + : ConstantInt::getFalse(CmpInst::makeCmpResultType(Cmp->getType())); + Cmp->replaceUsesWithIf(ConstantC, [](Use &U) { // Conditions in an assume trivially simplify to true. Skip uses // in assume calls to not destroy the available information. auto *II = dyn_cast(U.getUser()); return !II || II->getIntrinsicID() != Intrinsic::assume; }); NumCondsRemoved++; - Changed = true; - } - auto Negated = ConstraintSystem::negate(R.Coefficients); - if (!Negated.empty() && CSToUse.isConditionImplied(Negated)) { - if (!DebugCounter::shouldExecute(EliminatedCounter)) - return false; + return true; + }; - LLVM_DEBUG({ - dbgs() << "Condition !" << *Cmp << " implied by dominating constraints\n"; - CSToUse.dump(); - }); - generateReproducer(Cmp, ReproducerModule, ReproducerCondStack, Info, DT); - Constant *FalseC = - ConstantInt::getFalse(CmpInst::makeCmpResultType(Cmp->getType())); - Cmp->replaceAllUsesWith(FalseC); - NumCondsRemoved++; - Changed = true; + bool Changed = false; + bool IsConditionImplied = CSToUse.isConditionImplied(R.Coefficients); + + auto EqImpliedStatus = R.isEqImplied(CSToUse); + + if (EqImpliedStatus != EqualityImpliedTy::None) { + if (EqImpliedStatus == EqualityImpliedTy::HoldTrue) { + Changed = ReplaceCmpWithConstant(Cmp, true); + if (!Changed) + return false; + } + + if (EqImpliedStatus == EqualityImpliedTy::HoldFalse) { + Changed = ReplaceCmpWithConstant(Cmp, false); + if (!Changed) + return false; + } + } else { + if (IsConditionImplied) { + Changed = ReplaceCmpWithConstant(Cmp, true); + if (!Changed) + return false; + } + + // Compute them separately. + auto Negated = ConstraintSystem::negate(R.Coefficients); + auto IsNegatedImplied = + !Negated.empty() && CSToUse.isConditionImplied(Negated); + + if (IsNegatedImplied) { + Changed = ReplaceCmpWithConstant(Cmp, false); + if (!Changed) + return false; + } } + return Changed; } @@ -1042,7 +1102,7 @@ DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned, std::move(ValuesToRelease)); - if (R.IsEq) { + if (R.isEq()) { // Also add the inverted constraint for equality constraints. for (auto &Coeff : R.Coefficients) Coeff *= -1; Index: llvm/test/Transforms/ConstraintElimination/assumes.ll =================================================================== --- llvm/test/Transforms/ConstraintElimination/assumes.ll +++ llvm/test/Transforms/ConstraintElimination/assumes.ll @@ -622,3 +622,50 @@ tail call void @llvm.assume(i1 %c.2) ret i1 %c.2 } + +define i1 @assume_b_plus_1_ult_a(i64 %a, i64 %b) { +; CHECK-LABEL: @assume_b_plus_1_ult_a( +; CHECK-NEXT: [[TMP1:%.*]] = add nuw i64 [[B:%.*]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i64 [[TMP1]], [[A:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[A]], [[B]] +; CHECK-NEXT: ret i1 false +; + %1 = add nuw i64 %b, 1 + %2 = icmp ult i64 %1, %a + tail call void @llvm.assume(i1 %2) + %3 = icmp eq i64 %a, %b + ret i1 %3 +} + +define i1 @assume_a_plus_1_eq_b(i64 %a, i64 %b) { +; CHECK-LABEL: @assume_a_plus_1_eq_b( +; CHECK-NEXT: [[TMP1:%.*]] = add nuw i64 [[A:%.*]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[TMP1]], [[B:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[A]], [[B]] +; CHECK-NEXT: ret i1 false +; + %1 = add nuw i64 %a, 1 + %2 = icmp eq i64 %1, %b + tail call void @llvm.assume(i1 %2) + %3 = icmp eq i64 %a, %b + ret i1 %3 +} + +define i1 @assume_a_ge_b_and_b_ge_c(i64 %a, i64 %b, i64 %c) { +; CHECK-LABEL: @assume_a_ge_b_and_b_ge_c( +; CHECK-NEXT: [[TMP1:%.*]] = icmp uge i64 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP1]]) +; CHECK-NEXT: [[TMP2:%.*]] = icmp uge i64 [[B]], [[C:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[A]], [[C]] +; CHECK-NEXT: ret i1 [[TMP3]] +; + %1 = icmp uge i64 %a, %b + tail call void @llvm.assume(i1 %1) + %2 = icmp uge i64 %b, %c + tail call void @llvm.assume(i1 %2) + %3 = icmp eq i64 %a, %c + ret i1 %3 +} Index: llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll =================================================================== --- llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll +++ llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll @@ -75,9 +75,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[F_0:%.*]] = icmp eq i8 10, 11 ; CHECK-NEXT: [[T_0:%.*]] = icmp eq i8 10, 10 -; CHECK-NEXT: [[RES_1:%.*]] = xor i1 [[T_0]], [[F_0]] +; CHECK-NEXT: [[RES_1:%.*]] = xor i1 true, false ; CHECK-NEXT: [[F_1:%.*]] = icmp eq i8 10, 9 -; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], [[F_1]] +; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], false ; CHECK-NEXT: ret i1 [[RES_2]] ; entry: Index: llvm/test/Transforms/ConstraintElimination/eq.ll =================================================================== --- llvm/test/Transforms/ConstraintElimination/eq.ll +++ llvm/test/Transforms/ConstraintElimination/eq.ll @@ -11,9 +11,9 @@ ; CHECK-NEXT: [[T_2:%.*]] = icmp ule i8 [[A]], [[B]] ; CHECK-NEXT: [[RES_1:%.*]] = xor i1 true, true ; CHECK-NEXT: [[T_3:%.*]] = icmp eq i8 [[A]], [[B]] -; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], [[T_3]] +; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], true ; CHECK-NEXT: [[T_4:%.*]] = icmp eq i8 [[B]], [[A]] -; CHECK-NEXT: [[RES_3:%.*]] = xor i1 [[RES_2]], [[T_4]] +; CHECK-NEXT: [[RES_3:%.*]] = xor i1 [[RES_2]], true ; CHECK-NEXT: [[F_1:%.*]] = icmp ugt i8 [[B]], [[A]] ; CHECK-NEXT: [[RES_4:%.*]] = xor i1 [[RES_3]], false ; CHECK-NEXT: [[F_2:%.*]] = icmp ult i8 [[B]], [[A]] @@ -376,3 +376,41 @@ ret i1 %xor.12 } + +define i1 @test_transitivity_of_equality_and_plus_1(i64 %a, i64 %b, i64 %c) { +; CHECK-LABEL: @test_transitivity_of_equality_and_plus_1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[PRE_1:%.*]] = icmp eq i64 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: br i1 [[PRE_1]], label [[AB_EQUAL:%.*]], label [[NOT_EQ:%.*]] +; CHECK: ab_equal: +; CHECK-NEXT: [[BC_EQ:%.*]] = icmp eq i64 [[B]], [[C:%.*]] +; CHECK-NEXT: br i1 [[BC_EQ]], label [[BC_EQUAL:%.*]], label [[NOT_EQ]] +; CHECK: bc_equal: +; CHECK-NEXT: [[AC_EQ:%.*]] = icmp eq i64 [[A]], [[C]] +; CHECK-NEXT: [[A_PLUS_1:%.*]] = add nuw i64 [[A]], 1 +; CHECK-NEXT: [[C_PLUS_1:%.*]] = add nuw i64 [[C]], 1 +; CHECK-NEXT: [[AC_PLUS_1_EQ:%.*]] = icmp eq i64 [[A_PLUS_1]], [[C_PLUS_1]] +; CHECK-NEXT: [[RESULT:%.*]] = select i1 true, i1 true, i1 false +; CHECK-NEXT: ret i1 [[RESULT]] +; CHECK: not_eq: +; CHECK-NEXT: ret i1 false +; +entry: + %pre.1 = icmp eq i64 %a, %b + br i1 %pre.1, label %ab_equal, label %not_eq + +ab_equal: + %bc_eq = icmp eq i64 %b, %c + br i1 %bc_eq, label %bc_equal, label %not_eq + +bc_equal: + %ac_eq = icmp eq i64 %a, %c + %a_plus_1 = add nuw i64 %a, 1 + %c_plus_1 = add nuw i64 %c, 1 + %ac_plus_1_eq = icmp eq i64 %a_plus_1, %c_plus_1 + %result = select i1 %ac_eq, i1 %ac_plus_1_eq, i1 false + ret i1 %result + +not_eq: + ret i1 false +}