Index: llvm/include/llvm/Analysis/ConstraintSystem.h =================================================================== --- llvm/include/llvm/Analysis/ConstraintSystem.h +++ llvm/include/llvm/Analysis/ConstraintSystem.h @@ -128,6 +128,22 @@ return R; } + static SmallVector negateOrEqual(SmallVector R) { + // The negated constraint R is obtained by multiplying by -1. + for (auto &C : R) + if (MulOverflow(C, int64_t(-1), C)) + return {}; + return R; + } + + static SmallVector toStrictLessThan(SmallVector R) { + // The strict less than is obtained by subtracting 1 from the constant. + if (SubOverflow(R[0], int64_t(1), R[0])) { + return {}; + } + return R; + } + bool isConditionImplied(SmallVector R) const; SmallVector getLastConstraint() const { Index: llvm/lib/Transforms/Scalar/ConstraintElimination.cpp =================================================================== --- llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -102,6 +102,14 @@ : Pred(Pred), Op0(Op0), Op1(Op1) {} }; +enum class ConditionImpliedTy { + None, // Non-equality constraint not implied to be true + ConditionHoldTrue, // Non-equality constraint implied to be true + EqualityHoldTrue, // Equality constraint implied to be true + EqualityHoldFalse, // Equality constraint implied to be false + EqualityDoesNotHold // Equality constraint does not hold +}; + struct ConstraintTy { SmallVector Coefficients; SmallVector Preconditions; @@ -109,12 +117,11 @@ SmallVector> ExtraInfo; bool IsSigned = false; - bool IsEq = false; ConstraintTy() = default; - ConstraintTy(SmallVector Coefficients, bool IsSigned) - : Coefficients(Coefficients), IsSigned(IsSigned) {} + ConstraintTy(SmallVector Coefficients, bool IsSigned, bool IsEq) + : Coefficients(Coefficients), IsSigned(IsSigned), IsEq(IsEq) {} unsigned size() const { return Coefficients.size(); } @@ -123,6 +130,13 @@ /// Returns true if all preconditions for this list of constraints are /// satisfied given \p CS and the corresponding \p Value2Index mapping. bool isValid(const ConstraintInfo &Info) const; + + bool isEq() const { return IsEq; } + + ConditionImpliedTy isImpliedBy(const ConstraintSystem &CS) const; + +private: + bool IsEq = false; }; /// Wrapper encapsulating separate constraint systems and corresponding value @@ -480,11 +494,10 @@ // subtracting all coefficients from B. ConstraintTy Res( SmallVector(Value2Index.size() + NewVariables.size() + 1, 0), - IsSigned); + IsSigned, IsEq); // Collect variables that are known to be positive in all uses in the // constraint. DenseMap KnownNonNegativeVariables; - Res.IsEq = IsEq; auto &R = Res.Coefficients; for (const auto &KV : VariablesA) { R[GetOrAddIndex(KV.Variable)] += KV.Coefficient; @@ -547,7 +560,7 @@ SmallVector NewVariables; ConstraintTy R = getConstraint(Pred, Op0, Op1, NewVariables); - if (R.IsEq || !NewVariables.empty()) + if (!NewVariables.empty()) return {}; return R; } @@ -559,6 +572,37 @@ }); } +ConditionImpliedTy ConstraintTy::isImpliedBy(const ConstraintSystem &CS) const { + bool IsConditionImplied = CS.isConditionImplied(Coefficients); + + if (!IsEq) + return IsConditionImplied ? ConditionImpliedTy::ConditionHoldTrue + : ConditionImpliedTy::None; + + auto NegatedOrEqual = ConstraintSystem::negateOrEqual(Coefficients); + bool IsNegatedOrEqualImplied = + !NegatedOrEqual.empty() && CS.isConditionImplied(NegatedOrEqual); + + auto Negated = ConstraintSystem::negate(Coefficients); + bool IsNegatedImplied = !Negated.empty() && CS.isConditionImplied(Negated); + + auto StrictLessThan = ConstraintSystem::toStrictLessThan(Coefficients); + bool IsStrictLessThanImplied = + !StrictLessThan.empty() && CS.isConditionImplied(StrictLessThan); + + // In order to check that `%a == %b` is true, we want to check that `%a >= %b` + // and `%a <= %b` must hold. + if (IsConditionImplied && IsNegatedOrEqualImplied) + return ConditionImpliedTy::EqualityHoldTrue; + + // In order to check that `%a == %b` is false, we want to check whether either + // `%a > %b` or `%a < %b` holds. + if (IsNegatedImplied || IsStrictLessThanImplied) + return ConditionImpliedTy::EqualityHoldFalse; + + return ConditionImpliedTy::EqualityDoesNotHold; +} + bool ConstraintInfo::doesHold(CmpInst::Predicate Pred, Value *A, Value *B) const { auto R = getConstraintForSolving(Pred, A, B); @@ -977,23 +1021,27 @@ }; bool Changed = false; - bool IsConditionImplied = CSToUse.isConditionImplied(R.Coefficients); + auto CondImpliedStatus = R.isImpliedBy(CSToUse); - if (IsConditionImplied) { + if (CondImpliedStatus == ConditionImpliedTy::EqualityHoldTrue) { Changed = ReplaceCmpWithConstant(Cmp, true); - if (!Changed) - return false; - } + } else if (CondImpliedStatus == ConditionImpliedTy::EqualityHoldFalse) { + Changed = ReplaceCmpWithConstant(Cmp, false); + } else if (CondImpliedStatus != ConditionImpliedTy::EqualityDoesNotHold) { + if (CondImpliedStatus == ConditionImpliedTy::ConditionHoldTrue) { + Changed = ReplaceCmpWithConstant(Cmp, true); + if (!Changed) + return false; + } - // Compute them separately. - auto Negated = ConstraintSystem::negate(R.Coefficients); - auto IsNegatedImplied = - !Negated.empty() && CSToUse.isConditionImplied(Negated); + // Compute them separately. + auto Negated = ConstraintSystem::negate(R.Coefficients); + auto IsNegatedImplied = + !Negated.empty() && CSToUse.isConditionImplied(Negated); - if (IsNegatedImplied) { - Changed = ReplaceCmpWithConstant(Cmp, false); - if (!Changed) - return false; + if (IsNegatedImplied) { + Changed = ReplaceCmpWithConstant(Cmp, false); + } } return Changed; @@ -1054,7 +1102,7 @@ DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned, std::move(ValuesToRelease)); - if (R.IsEq) { + if (R.isEq()) { // Also add the inverted constraint for equality constraints. for (auto &Coeff : R.Coefficients) Coeff *= -1; Index: llvm/test/Transforms/ConstraintElimination/assumes.ll =================================================================== --- llvm/test/Transforms/ConstraintElimination/assumes.ll +++ llvm/test/Transforms/ConstraintElimination/assumes.ll @@ -622,3 +622,50 @@ tail call void @llvm.assume(i1 %c.2) ret i1 %c.2 } + +define i1 @assume_b_plus_1_ult_a(i64 %a, i64 %b) { +; CHECK-LABEL: @assume_b_plus_1_ult_a( +; CHECK-NEXT: [[TMP1:%.*]] = add nuw i64 [[B:%.*]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i64 [[TMP1]], [[A:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[A]], [[B]] +; CHECK-NEXT: ret i1 false +; + %1 = add nuw i64 %b, 1 + %2 = icmp ult i64 %1, %a + tail call void @llvm.assume(i1 %2) + %3 = icmp eq i64 %a, %b + ret i1 %3 +} + +define i1 @assume_a_plus_1_eq_b(i64 %a, i64 %b) { +; CHECK-LABEL: @assume_a_plus_1_eq_b( +; CHECK-NEXT: [[TMP1:%.*]] = add nuw i64 [[A:%.*]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[TMP1]], [[B:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[A]], [[B]] +; CHECK-NEXT: ret i1 false +; + %1 = add nuw i64 %a, 1 + %2 = icmp eq i64 %1, %b + tail call void @llvm.assume(i1 %2) + %3 = icmp eq i64 %a, %b + ret i1 %3 +} + +define i1 @assume_a_ge_b_and_b_ge_c(i64 %a, i64 %b, i64 %c) { +; CHECK-LABEL: @assume_a_ge_b_and_b_ge_c( +; CHECK-NEXT: [[TMP1:%.*]] = icmp uge i64 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP1]]) +; CHECK-NEXT: [[TMP2:%.*]] = icmp uge i64 [[B]], [[C:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[A]], [[C]] +; CHECK-NEXT: ret i1 [[TMP3]] +; + %1 = icmp uge i64 %a, %b + tail call void @llvm.assume(i1 %1) + %2 = icmp uge i64 %b, %c + tail call void @llvm.assume(i1 %2) + %3 = icmp eq i64 %a, %c + ret i1 %3 +} Index: llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll =================================================================== --- llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll +++ llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll @@ -75,9 +75,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[F_0:%.*]] = icmp eq i8 10, 11 ; CHECK-NEXT: [[T_0:%.*]] = icmp eq i8 10, 10 -; CHECK-NEXT: [[RES_1:%.*]] = xor i1 [[T_0]], [[F_0]] +; CHECK-NEXT: [[RES_1:%.*]] = xor i1 true, false ; CHECK-NEXT: [[F_1:%.*]] = icmp eq i8 10, 9 -; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], [[F_1]] +; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], false ; CHECK-NEXT: ret i1 [[RES_2]] ; entry: Index: llvm/test/Transforms/ConstraintElimination/eq.ll =================================================================== --- llvm/test/Transforms/ConstraintElimination/eq.ll +++ llvm/test/Transforms/ConstraintElimination/eq.ll @@ -11,9 +11,9 @@ ; CHECK-NEXT: [[T_2:%.*]] = icmp ule i8 [[A]], [[B]] ; CHECK-NEXT: [[RES_1:%.*]] = xor i1 true, true ; CHECK-NEXT: [[T_3:%.*]] = icmp eq i8 [[A]], [[B]] -; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], [[T_3]] +; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], true ; CHECK-NEXT: [[T_4:%.*]] = icmp eq i8 [[B]], [[A]] -; CHECK-NEXT: [[RES_3:%.*]] = xor i1 [[RES_2]], [[T_4]] +; CHECK-NEXT: [[RES_3:%.*]] = xor i1 [[RES_2]], true ; CHECK-NEXT: [[F_1:%.*]] = icmp ugt i8 [[B]], [[A]] ; CHECK-NEXT: [[RES_4:%.*]] = xor i1 [[RES_3]], false ; CHECK-NEXT: [[F_2:%.*]] = icmp ult i8 [[B]], [[A]] @@ -376,3 +376,41 @@ ret i1 %xor.12 } + +define i1 @test_transitivity_of_equality_and_plus_1(i64 %a, i64 %b, i64 %c) { +; CHECK-LABEL: @test_transitivity_of_equality_and_plus_1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[PRE_1:%.*]] = icmp eq i64 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: br i1 [[PRE_1]], label [[AB_EQUAL:%.*]], label [[NOT_EQ:%.*]] +; CHECK: ab_equal: +; CHECK-NEXT: [[BC_EQ:%.*]] = icmp eq i64 [[B]], [[C:%.*]] +; CHECK-NEXT: br i1 [[BC_EQ]], label [[BC_EQUAL:%.*]], label [[NOT_EQ]] +; CHECK: bc_equal: +; CHECK-NEXT: [[AC_EQ:%.*]] = icmp eq i64 [[A]], [[C]] +; CHECK-NEXT: [[A_PLUS_1:%.*]] = add nuw i64 [[A]], 1 +; CHECK-NEXT: [[C_PLUS_1:%.*]] = add nuw i64 [[C]], 1 +; CHECK-NEXT: [[AC_PLUS_1_EQ:%.*]] = icmp eq i64 [[A_PLUS_1]], [[C_PLUS_1]] +; CHECK-NEXT: [[RESULT:%.*]] = select i1 true, i1 true, i1 false +; CHECK-NEXT: ret i1 [[RESULT]] +; CHECK: not_eq: +; CHECK-NEXT: ret i1 false +; +entry: + %pre.1 = icmp eq i64 %a, %b + br i1 %pre.1, label %ab_equal, label %not_eq + +ab_equal: + %bc_eq = icmp eq i64 %b, %c + br i1 %bc_eq, label %bc_equal, label %not_eq + +bc_equal: + %ac_eq = icmp eq i64 %a, %c + %a_plus_1 = add nuw i64 %a, 1 + %c_plus_1 = add nuw i64 %c, 1 + %ac_plus_1_eq = icmp eq i64 %a_plus_1, %c_plus_1 + %result = select i1 %ac_eq, i1 %ac_plus_1_eq, i1 false + ret i1 %result + +not_eq: + ret i1 false +}