diff --git a/llvm/include/llvm/Analysis/ConstraintSystem.h b/llvm/include/llvm/Analysis/ConstraintSystem.h --- a/llvm/include/llvm/Analysis/ConstraintSystem.h +++ b/llvm/include/llvm/Analysis/ConstraintSystem.h @@ -122,12 +122,33 @@ // The negated constraint R is obtained by multiplying by -1 and adding 1 to // the constant. R[0] += 1; + return negateOrEqual(R); + } + + /// Multiplies each coefficient in the given vector by -1. Does not modify the + /// original vector. + /// + /// \param R The vector of coefficients to be negated. + static SmallVector negateOrEqual(SmallVector R) { + // The negated constraint R is obtained by multiplying by -1. for (auto &C : R) if (MulOverflow(C, int64_t(-1), C)) return {}; return R; } + /// Converts the given vector to form a strict less than inequality. Does not + /// modify the original vector. + /// + /// \param R The vector of coefficients to be converted. + static SmallVector toStrictLessThan(SmallVector R) { + // The strict less than is obtained by subtracting 1 from the constant. + if (SubOverflow(R[0], int64_t(1), R[0])) { + return {}; + } + return R; + } + bool isConditionImplied(SmallVector R) const; SmallVector getLastConstraint() const { diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp --- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -38,6 +38,7 @@ #include "llvm/Transforms/Utils/ValueMapper.h" #include +#include #include using namespace llvm; @@ -109,12 +110,11 @@ SmallVector> ExtraInfo; bool IsSigned = false; - bool IsEq = false; ConstraintTy() = default; - ConstraintTy(SmallVector Coefficients, bool IsSigned) - : Coefficients(Coefficients), IsSigned(IsSigned) {} + ConstraintTy(SmallVector Coefficients, bool IsSigned, bool IsEq) + : Coefficients(Coefficients), IsSigned(IsSigned), IsEq(IsEq) {} unsigned size() const { return Coefficients.size(); } @@ -123,6 +123,18 @@ /// Returns true if all preconditions for this list of constraints are /// satisfied given \p CS and the corresponding \p Value2Index mapping. bool isValid(const ConstraintInfo &Info) const; + + bool isEq() const { return IsEq; } + + /// Check if the current constraint is implied by the given ConstraintSystem. + /// + /// \return true or false if the constraint is proven to be respectively true, + /// or false. When the constraint cannot be proven to be either true or false, + /// std::nullopt is returned. + std::optional isImpliedBy(const ConstraintSystem &CS) const; + +private: + bool IsEq = false; }; /// Wrapper encapsulating separate constraint systems and corresponding value @@ -480,11 +492,10 @@ // subtracting all coefficients from B. ConstraintTy Res( SmallVector(Value2Index.size() + NewVariables.size() + 1, 0), - IsSigned); + IsSigned, IsEq); // Collect variables that are known to be positive in all uses in the // constraint. DenseMap KnownNonNegativeVariables; - Res.IsEq = IsEq; auto &R = Res.Coefficients; for (const auto &KV : VariablesA) { R[GetOrAddIndex(KV.Variable)] += KV.Coefficient; @@ -547,7 +558,7 @@ SmallVector NewVariables; ConstraintTy R = getConstraint(Pred, Op0, Op1, NewVariables); - if (R.IsEq || !NewVariables.empty()) + if (!NewVariables.empty()) return {}; return R; } @@ -559,6 +570,47 @@ }); } +std::optional +ConstraintTy::isImpliedBy(const ConstraintSystem &CS) const { + bool IsConditionImplied = CS.isConditionImplied(Coefficients); + + if (IsEq) { + auto NegatedOrEqual = ConstraintSystem::negateOrEqual(Coefficients); + bool IsNegatedOrEqualImplied = + !NegatedOrEqual.empty() && CS.isConditionImplied(NegatedOrEqual); + + // In order to check that `%a == %b` is true, we want to check that `%a >= + // %b` and `%a <= %b` must hold. + if (IsConditionImplied && IsNegatedOrEqualImplied) + return true; + + auto Negated = ConstraintSystem::negate(Coefficients); + bool IsNegatedImplied = !Negated.empty() && CS.isConditionImplied(Negated); + + auto StrictLessThan = ConstraintSystem::toStrictLessThan(Coefficients); + bool IsStrictLessThanImplied = + !StrictLessThan.empty() && CS.isConditionImplied(StrictLessThan); + + // In order to check that `%a == %b` is false, we want to check whether + // either `%a > %b` or `%a < %b` holds. + if (IsNegatedImplied || IsStrictLessThanImplied) + return false; + + return std::nullopt; + } + + if (IsConditionImplied) + return true; + + auto Negated = ConstraintSystem::negate(Coefficients); + auto IsNegatedImplied = !Negated.empty() && CS.isConditionImplied(Negated); + if (IsNegatedImplied) + return false; + + // Neither the condition nor its negated holds, did not prove anything. + return std::nullopt; +} + bool ConstraintInfo::doesHold(CmpInst::Predicate Pred, Value *A, Value *B) const { auto R = getConstraintForSolving(Pred, A, B); @@ -976,27 +1028,10 @@ return true; }; - bool Changed = false; - bool IsConditionImplied = CSToUse.isConditionImplied(R.Coefficients); + if (auto ImpliedCondition = R.isImpliedBy(CSToUse)) + return ReplaceCmpWithConstant(Cmp, *ImpliedCondition); - if (IsConditionImplied) { - Changed = ReplaceCmpWithConstant(Cmp, true); - if (!Changed) - return false; - } - - // Compute them separately. - auto Negated = ConstraintSystem::negate(R.Coefficients); - auto IsNegatedImplied = - !Negated.empty() && CSToUse.isConditionImplied(Negated); - - if (IsNegatedImplied) { - Changed = ReplaceCmpWithConstant(Cmp, false); - if (!Changed) - return false; - } - - return Changed; + return false; } static void @@ -1054,7 +1089,7 @@ DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned, std::move(ValuesToRelease)); - if (R.IsEq) { + if (R.isEq()) { // Also add the inverted constraint for equality constraints. for (auto &Coeff : R.Coefficients) Coeff *= -1; diff --git a/llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll b/llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll --- a/llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll +++ b/llvm/test/Transforms/ConstraintElimination/constants-unsigned-predicates.ll @@ -75,9 +75,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[F_0:%.*]] = icmp eq i8 10, 11 ; CHECK-NEXT: [[T_0:%.*]] = icmp eq i8 10, 10 -; CHECK-NEXT: [[RES_1:%.*]] = xor i1 [[T_0]], [[F_0]] +; CHECK-NEXT: [[RES_1:%.*]] = xor i1 true, false ; CHECK-NEXT: [[F_1:%.*]] = icmp eq i8 10, 9 -; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], [[F_1]] +; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], false ; CHECK-NEXT: ret i1 [[RES_2]] ; entry: diff --git a/llvm/test/Transforms/ConstraintElimination/eq.ll b/llvm/test/Transforms/ConstraintElimination/eq.ll --- a/llvm/test/Transforms/ConstraintElimination/eq.ll +++ b/llvm/test/Transforms/ConstraintElimination/eq.ll @@ -1,6 +1,8 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt -passes=constraint-elimination -S %s | FileCheck %s +declare void @llvm.assume(i1) + define i1 @test_eq_1(i8 %a, i8 %b) { ; CHECK-LABEL: @test_eq_1( ; CHECK-NEXT: entry: @@ -11,9 +13,9 @@ ; CHECK-NEXT: [[T_2:%.*]] = icmp ule i8 [[A]], [[B]] ; CHECK-NEXT: [[RES_1:%.*]] = xor i1 true, true ; CHECK-NEXT: [[T_3:%.*]] = icmp eq i8 [[A]], [[B]] -; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], [[T_3]] +; CHECK-NEXT: [[RES_2:%.*]] = xor i1 [[RES_1]], true ; CHECK-NEXT: [[T_4:%.*]] = icmp eq i8 [[B]], [[A]] -; CHECK-NEXT: [[RES_3:%.*]] = xor i1 [[RES_2]], [[T_4]] +; CHECK-NEXT: [[RES_3:%.*]] = xor i1 [[RES_2]], true ; CHECK-NEXT: [[F_1:%.*]] = icmp ugt i8 [[B]], [[A]] ; CHECK-NEXT: [[RES_4:%.*]] = xor i1 [[RES_3]], false ; CHECK-NEXT: [[F_2:%.*]] = icmp ult i8 [[B]], [[A]] @@ -376,3 +378,88 @@ ret i1 %xor.12 } + +define i1 @assume_b_plus_1_ult_a(i64 %a, i64 %b) { +; CHECK-LABEL: @assume_b_plus_1_ult_a( +; CHECK-NEXT: [[TMP1:%.*]] = add nuw i64 [[B:%.*]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i64 [[TMP1]], [[A:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[A]], [[B]] +; CHECK-NEXT: ret i1 false +; + %1 = add nuw i64 %b, 1 + %2 = icmp ult i64 %1, %a + tail call void @llvm.assume(i1 %2) + %3 = icmp eq i64 %a, %b + ret i1 %3 +} + +define i1 @assume_a_plus_1_eq_b(i64 %a, i64 %b) { +; CHECK-LABEL: @assume_a_plus_1_eq_b( +; CHECK-NEXT: [[TMP1:%.*]] = add nuw i64 [[A:%.*]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[TMP1]], [[B:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[A]], [[B]] +; CHECK-NEXT: ret i1 false +; + %1 = add nuw i64 %a, 1 + %2 = icmp eq i64 %1, %b + tail call void @llvm.assume(i1 %2) + %3 = icmp eq i64 %a, %b + ret i1 %3 +} + +define i1 @assume_a_ge_b_and_b_ge_c(i64 %a, i64 %b, i64 %c) { +; CHECK-LABEL: @assume_a_ge_b_and_b_ge_c( +; CHECK-NEXT: [[TMP1:%.*]] = icmp uge i64 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP1]]) +; CHECK-NEXT: [[TMP2:%.*]] = icmp uge i64 [[B]], [[C:%.*]] +; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[A]], [[C]] +; CHECK-NEXT: ret i1 [[TMP3]] +; + %1 = icmp uge i64 %a, %b + tail call void @llvm.assume(i1 %1) + %2 = icmp uge i64 %b, %c + tail call void @llvm.assume(i1 %2) + %3 = icmp eq i64 %a, %c + ret i1 %3 +} + +define i1 @test_transitivity_of_equality_and_plus_1(i64 %a, i64 %b, i64 %c) { +; CHECK-LABEL: @test_transitivity_of_equality_and_plus_1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[PRE_1:%.*]] = icmp eq i64 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: br i1 [[PRE_1]], label [[AB_EQUAL:%.*]], label [[NOT_EQ:%.*]] +; CHECK: ab_equal: +; CHECK-NEXT: [[BC_EQ:%.*]] = icmp eq i64 [[B]], [[C:%.*]] +; CHECK-NEXT: br i1 [[BC_EQ]], label [[BC_EQUAL:%.*]], label [[NOT_EQ]] +; CHECK: bc_equal: +; CHECK-NEXT: [[AC_EQ:%.*]] = icmp eq i64 [[A]], [[C]] +; CHECK-NEXT: [[A_PLUS_1:%.*]] = add nuw i64 [[A]], 1 +; CHECK-NEXT: [[C_PLUS_1:%.*]] = add nuw i64 [[C]], 1 +; CHECK-NEXT: [[AC_PLUS_1_EQ:%.*]] = icmp eq i64 [[A_PLUS_1]], [[C_PLUS_1]] +; CHECK-NEXT: [[RESULT:%.*]] = and i1 true, true +; CHECK-NEXT: ret i1 [[RESULT]] +; CHECK: not_eq: +; CHECK-NEXT: ret i1 false +; +entry: + %pre.1 = icmp eq i64 %a, %b + br i1 %pre.1, label %ab_equal, label %not_eq + +ab_equal: + %bc_eq = icmp eq i64 %b, %c + br i1 %bc_eq, label %bc_equal, label %not_eq + +bc_equal: + %ac_eq = icmp eq i64 %a, %c + %a_plus_1 = add nuw i64 %a, 1 + %c_plus_1 = add nuw i64 %c, 1 + %ac_plus_1_eq = icmp eq i64 %a_plus_1, %c_plus_1 + %result = and i1 %ac_eq, %ac_plus_1_eq + ret i1 %result + +not_eq: + ret i1 false +} diff --git a/llvm/test/Transforms/ConstraintElimination/reproducer-remarks-debug.ll b/llvm/test/Transforms/ConstraintElimination/reproducer-remarks-debug.ll --- a/llvm/test/Transforms/ConstraintElimination/reproducer-remarks-debug.ll +++ b/llvm/test/Transforms/ConstraintElimination/reproducer-remarks-debug.ll @@ -9,7 +9,6 @@ ; CHECK-NEXT: Creating reproducer for %c.2 = icmp eq ptr %a, null ; CHECK-NEXT: found external input ptr %a ; CHECK-NEXT: Materializing assumption %c.1 = icmp eq ptr %a, null -; CHECK-NEXT: --- define i1 @test_ptr_null_constant(ptr %a) { ; CHECK-LABEL: define i1 @"{{.+}}test_ptr_null_constantrepro"(ptr %a) {