diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -576,6 +576,73 @@ return !isEphemeralValueOf(Inv, CxtI); } +static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) { + // Use of assumptions is context-sensitive. If we don't have a context, we + // cannot use them! + if (!Q.AC || !Q.CxtI) + return false; + + // Note that the patterns below need to be kept in sync with the code + // in AssumptionCache::updateAffectedValues. + + auto CmpExcludesZero = [V](ICmpInst *Cmp) { + auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V))); + + Value *RHS; + CmpInst::Predicate Pred; + if (!match(Cmp, m_c_ICmp(Pred, m_V, m_Value(RHS)))) + return false; + // Canonicalize 'v' to be on the LHS of the comparison. + if (Cmp->getOperand(1) != RHS) + Pred = CmpInst::getSwappedPredicate(Pred); + + // assume(v u> y) -> assume(v != 0) + if (Pred == ICmpInst::ICMP_UGT) + return true; + + // assume(v != 0) + // We special-case this one to ensure that we handle `assume(v != null)`. + if (Pred == ICmpInst::ICMP_NE) + return match(RHS, m_Zero()); + + // All other predicates - rely on generic ConstantRange handling. + ConstantInt *CI; + if (!match(RHS, m_ConstantInt(CI))) + return false; + ConstantRange RHSRange(CI->getValue()); + ConstantRange TrueValues = + ConstantRange::makeAllowedICmpRegion(Pred, RHSRange); + return !TrueValues.contains(APInt::getNullValue(CI->getBitWidth())); + }; + + for (auto &AssumeVH : Q.AC->assumptionsFor(V)) { + if (!AssumeVH) + continue; + CallInst *I = cast(AssumeVH); + assert(I->getFunction() == Q.CxtI->getFunction() && + "Got assumption for the wrong function!"); + if (Q.isExcluded(I)) + continue; + + // Warning: This loop can end up being somewhat performance sensitive. + // We're running this loop for once for each value queried resulting in a + // runtime of ~O(#assumes * #values). + + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + + Value *Arg = I->getArgOperand(0); + ICmpInst *Cmp = dyn_cast(Arg); + if (!Cmp) + continue; + + if (CmpExcludesZero(Cmp) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) + return true; + } + + return false; +} + static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, unsigned Depth, const Query &Q) { // Use of assumptions is context-sensitive. If we don't have a context, we @@ -2079,6 +2146,9 @@ } } + if (isKnownNonZeroFromAssume(V, Q)) + return true; + // Some of the tests below are recursive, so bail out if we hit the limit. if (Depth++ >= MaxDepth) return false; diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll --- a/llvm/test/Transforms/InstCombine/assume.ll +++ b/llvm/test/Transforms/InstCombine/assume.ll @@ -252,8 +252,7 @@ ; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32* [[A:%.*]], align 4 ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[LOAD]], 0 ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[RVAL:%.*]] = icmp eq i32 [[LOAD]], 0 -; CHECK-NEXT: ret i1 [[RVAL]] +; CHECK-NEXT: ret i1 false ; %load = load i32, i32* %a %cmp = icmp ne i32 %load, 0 @@ -273,10 +272,10 @@ ; CHECK: taken: ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32* [[LOAD]], null ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[RVAL:%.*]] = icmp eq i32* [[LOAD]], null -; CHECK-NEXT: ret i1 [[RVAL]] +; CHECK-NEXT: ret i1 false ; CHECK: not_taken: -; CHECK-NEXT: ret i1 true +; CHECK-NEXT: [[RVAL_2:%.*]] = icmp sgt i32* [[LOAD]], null +; CHECK-NEXT: ret i1 [[RVAL_2]] ; entry: %load = load i32*, i32** %a @@ -287,7 +286,8 @@ %rval = icmp eq i32* %load, null ret i1 %rval not_taken: - ret i1 true + %rval.2 = icmp sgt i32* %load, null + ret i1 %rval.2 } ; Make sure the above canonicalization does not trigger @@ -300,8 +300,7 @@ ; CHECK-NEXT: tail call void @escape(i32* [[LOAD]]) ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32* [[LOAD]], null ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[RVAL:%.*]] = icmp eq i32* [[LOAD]], null -; CHECK-NEXT: ret i1 [[RVAL]] +; CHECK-NEXT: ret i1 false ; %load = load i32*, i32** %a ;; This call may throw! @@ -353,12 +352,12 @@ define void @debug_interference(i8 %x) { ; CHECK-LABEL: @debug_interference( -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i8 [[X:%.*]], 0 +; CHECK-NEXT: [[CMP2:%.*]] = icmp ne i8 [[X:%.*]], 0 ; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 -; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP1]]) +; CHECK-NEXT: tail call void @llvm.assume(i1 false) ; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 ; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 -; CHECK-NEXT: tail call void @llvm.assume(i1 false) +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP2]]) ; CHECK-NEXT: ret void ; %cmp1 = icmp eq i8 %x, 0 diff --git a/llvm/test/Transforms/InstSimplify/assume-non-zero.ll b/llvm/test/Transforms/InstSimplify/assume-non-zero.ll --- a/llvm/test/Transforms/InstSimplify/assume-non-zero.ll +++ b/llvm/test/Transforms/InstSimplify/assume-non-zero.ll @@ -10,8 +10,7 @@ ; CHECK-LABEL: @nonnull0_true( ; CHECK-NEXT: [[A:%.*]] = icmp ne i8* [[X:%.*]], null ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8* [[X]], null -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp ne i8* %x, null call void @llvm.assume(i1 %a) @@ -24,8 +23,7 @@ ; CHECK-NEXT: [[INTPTR:%.*]] = ptrtoint i8* [[X:%.*]] to i64 ; CHECK-NEXT: [[A:%.*]] = icmp ne i64 [[INTPTR]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8* [[X]], null -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %intptr = ptrtoint i8* %x to i64 %a = icmp ne i64 %intptr, 0 @@ -38,14 +36,24 @@ ; CHECK-LABEL: @nonnull2_true( ; CHECK-NEXT: [[A:%.*]] = icmp ugt i8 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp ugt i8 %x, %y call void @llvm.assume(i1 %a) %q = icmp ne i8 %x, 0 ret i1 %q } +define i1 @nonnull2_true_swapped(i8 %x, i8 %y) { +; CHECK-LABEL: @nonnull2_true_swapped( +; CHECK-NEXT: [[A:%.*]] = icmp ult i8 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) +; CHECK-NEXT: ret i1 true +; + %a = icmp ult i8 %y, %x + call void @llvm.assume(i1 %a) + %q = icmp ne i8 %x, 0 + ret i1 %q +} define i1 @nonnull3_unknown(i8 %x) { ; CHECK-LABEL: @nonnull3_unknown( @@ -61,8 +69,7 @@ ; CHECK-LABEL: @nonnull4_true( ; CHECK-NEXT: [[A:%.*]] = icmp uge i8 [[X:%.*]], 1 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp uge i8 %x, 1 call void @llvm.assume(i1 %a) @@ -86,8 +93,7 @@ ; CHECK-LABEL: @nonnull6_true( ; CHECK-NEXT: [[A:%.*]] = icmp sgt i8 [[X:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp sgt i8 %x, 0 call void @llvm.assume(i1 %a) @@ -98,8 +104,7 @@ ; CHECK-LABEL: @nonnull7_true( ; CHECK-NEXT: [[A:%.*]] = icmp sgt i8 [[X:%.*]], 1 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp sgt i8 %x, 1 call void @llvm.assume(i1 %a) @@ -135,8 +140,7 @@ ; CHECK-LABEL: @nonnull10_true( ; CHECK-NEXT: [[A:%.*]] = icmp sge i8 [[X:%.*]], 1 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp sge i8 %x, 1 call void @llvm.assume(i1 %a)