diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -820,6 +820,9 @@ return P >= FIRST_ICMP_PREDICATE && P <= LAST_ICMP_PREDICATE; } + /// Returns if this cmp is only used by a singular assume + bool isAssumedTrue(bool OneUse) const; + static StringRef getPredicateName(Predicate P); bool isFPPredicate() const { return isFPPredicate(getPredicate()); } diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -2313,7 +2313,36 @@ CmpInst::Predicate Pred; bool NonNullIfTrue; const User *CmpUse = nullptr; - if (match(U, m_c_ICmp(Pred, m_Specific(V), m_Value(RHS)))) { + const APInt *AddC; + if (match(U, m_Add(m_Specific(V), m_APInt(AddC))) && !AddC->isZero()) { + unsigned NumUsesExplored = 0; + for (const auto *U2 : U->users()) { + + // This is inner loop of a loop through V->Users() so be more + // conservative here. + if (NumUsesExplored++ >= (DomConditionsMaxUses + 3) / 4) + break; + + // Only handles (A + C1) u< C2, which is the canonical form of A > C3 + // && A < C4. + if (match(U2, m_ICmp(Pred, m_Specific(U), m_Value(RHS)))) { + auto *RHSC = dyn_cast(RHS); + // TODO: Since we already here, we could check all (or at least more) + // conditions, at the moment we only handle `ICMP_ULT` to get the `A u< + // C1 && A u> C2` canonicalization case. + if (RHSC == nullptr || Pred != ICmpInst::ICMP_ULT) + continue; + + NonNullIfTrue = RHSC->getValue().ult(*AddC); + if (NonNullIfTrue && + dyn_cast(U2)->isAssumedTrue(/*OneUse*/ false)) + return true; + + CmpUse = U2; + break; + } + } + } else if (match(U, m_c_ICmp(Pred, m_Specific(V), m_Value(RHS)))) { CmpUse = U; if (cmpExcludesZero(Pred, RHS)) NonNullIfTrue = true; @@ -2321,7 +2350,9 @@ NonNullIfTrue = false; else return false; - } else + } + + if (CmpUse == nullptr) return false; SmallVector WorkList; @@ -2410,6 +2441,7 @@ bool ImpliesNonZero = false; if (auto *OpU = dyn_cast(U)) { switch (OpU->getOpcode()) { + case Instruction::Add: case Instruction::ICmp: if (checkDominatingConditionForNonNull(V, U, Q)) return true; diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -26,6 +26,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" @@ -4073,6 +4074,15 @@ llvm_unreachable("Unsupported predicate kind"); } +bool CmpInst::isAssumedTrue(bool OneUse) const { + if (OneUse && !hasOneUse()) + return false; + for (const auto *U : users()) + if (dyn_cast(U)) + return true; + return false; +} + CmpInst::Predicate CmpInst::getInversePredicate(Predicate pred) { switch (pred) { default: llvm_unreachable("Unknown cmp predicate!"); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2821,9 +2821,13 @@ // (X + -1) X <=u C (if X is never null) if (Pred == CmpInst::ICMP_ULT && C2->isAllOnes()) { - const SimplifyQuery Q = SQ.getWithInstruction(&Cmp); - if (llvm::isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) - return new ICmpInst(ICmpInst::ICMP_ULE, X, ConstantInt::get(Ty, C)); + // Don't make this transform if the only use of this Cmp is assume as we + // both lose information and de-canonicalize the (A < C1 && A > C2) case + if (!Cmp.isAssumedTrue(/*OneUse*/ true)) { + const SimplifyQuery Q = SQ.getWithInstruction(&Cmp); + if (llvm::isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(ICmpInst::ICMP_ULE, X, ConstantInt::get(Ty, C)); + } } if (!Add->hasOneUse()) diff --git a/llvm/test/Analysis/ValueTracking/known-non-zero-range.ll b/llvm/test/Analysis/ValueTracking/known-non-zero-range.ll --- a/llvm/test/Analysis/ValueTracking/known-non-zero-range.ll +++ b/llvm/test/Analysis/ValueTracking/known-non-zero-range.ll @@ -12,9 +12,7 @@ ; CHECK-NEXT: br i1 [[NE]], label [[TRUE:%.*]], label [[FALSE:%.*]] ; CHECK: true: ; CHECK-NEXT: [[CMP0:%.*]] = icmp ugt i8 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i8 [[Y]], 0 -; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP0]], [[CMP1]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 [[CMP0]] ; CHECK: false: ; CHECK-NEXT: [[RT:%.*]] = icmp eq i8 [[X]], 0 ; CHECK-NEXT: call void @use1(i1 [[RT]]) @@ -47,9 +45,7 @@ ; CHECK-NEXT: br i1 [[NE]], label [[TRUE:%.*]], label [[FALSE:%.*]] ; CHECK: true: ; CHECK-NEXT: [[CMP0:%.*]] = icmp ugt i8 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i8 [[Y]], 0 -; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP0]], [[CMP1]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 [[CMP0]] ; CHECK: false: ; CHECK-NEXT: [[RT:%.*]] = icmp eq i8 [[X]], 0 ; CHECK-NEXT: call void @use1(i1 [[RT]]) @@ -81,9 +77,7 @@ ; CHECK-NEXT: [[NE:%.*]] = icmp ult i8 [[UB]], 14 ; CHECK-NEXT: call void @llvm.assume(i1 [[NE]]) ; CHECK-NEXT: [[CMP0:%.*]] = icmp ugt i8 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i8 [[Y]], 0 -; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP0]], [[CMP1]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 [[CMP0]] ; %ub = add i8 %x, -1 %ne = icmp ult i8 %ub, 14 @@ -103,9 +97,7 @@ ; CHECK-NEXT: br i1 [[NE]], label [[TRUE:%.*]], label [[FALSE:%.*]] ; CHECK: true: ; CHECK-NEXT: [[CMP0:%.*]] = icmp ugt i8 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i8 [[Y]], 0 -; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP0]], [[CMP1]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 [[CMP0]] ; CHECK: false: ; CHECK-NEXT: [[RT:%.*]] = icmp eq i8 [[X]], 0 ; CHECK-NEXT: call void @use1(i1 [[RT]]) @@ -139,9 +131,7 @@ ; CHECK-NEXT: [[NE:%.*]] = icmp ult i8 [[TMP1]], 14 ; CHECK-NEXT: call void @llvm.assume(i1 [[NE]]) ; CHECK-NEXT: [[CMP0:%.*]] = icmp ugt i8 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i8 [[Y]], 0 -; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP0]], [[CMP1]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 [[CMP0]] ; %x1 = and i8 %x, 123 %x2 = call i8 @llvm.abs.i8(i8 %x1, i1 true) @@ -182,9 +172,7 @@ ; CHECK-NEXT: br i1 [[NE]], label [[TRUE:%.*]], label [[FALSE:%.*]] ; CHECK: false: ; CHECK-NEXT: [[CMP0:%.*]] = icmp ugt i8 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i8 [[Y]], 0 -; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP0]], [[CMP1]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 [[CMP0]] ; CHECK: true: ; CHECK-NEXT: [[RT:%.*]] = icmp eq i8 [[X]], 0 ; CHECK-NEXT: call void @use1(i1 [[RT]]) diff --git a/llvm/test/Analysis/ValueTracking/known-non-zero-through-dom-use.ll b/llvm/test/Analysis/ValueTracking/known-non-zero-through-dom-use.ll --- a/llvm/test/Analysis/ValueTracking/known-non-zero-through-dom-use.ll +++ b/llvm/test/Analysis/ValueTracking/known-non-zero-through-dom-use.ll @@ -19,9 +19,7 @@ ; CHECK-NEXT: [[NE:%.*]] = icmp ult i8 [[NOTSUB]], -5 ; CHECK-NEXT: call void @llvm.assume(i1 [[NE]]) ; CHECK-NEXT: [[CMP0:%.*]] = icmp ugt i8 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i8 [[Y]], 0 -; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP0]], [[CMP1]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 [[CMP0]] ; %z = sub i8 0, %x %ne = icmp ugt i8 %z, 4