diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -577,6 +577,73 @@ return !isEphemeralValueOf(Inv, CxtI); } +static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) { + // Use of assumptions is context-sensitive. If we don't have a context, we + // cannot use them! + if (!Q.AC || !Q.CxtI) + return false; + + // Note that the patterns below need to be kept in sync with the code + // in AssumptionCache::updateAffectedValues. + + for (auto &AssumeVH : Q.AC->assumptionsFor(V)) { + if (!AssumeVH) + continue; + CallInst *I = cast(AssumeVH); + assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() && + "Got assumption for the wrong function!"); + if (Q.isExcluded(I)) + continue; + + // Warning: This loop can end up being somewhat performance sensitive. + // We're running this loop for once for each value queried resulting in a + // runtime of ~O(#assumes * #values). + + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + + Value *Arg = I->getArgOperand(0); + + ICmpInst *Cmp = dyn_cast(Arg); + if (!Cmp) + continue; + + auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V))); + + Value *RHS; + CmpInst::Predicate Pred; + if (!match(Cmp, m_ICmp(Pred, m_V, m_Value(RHS)))) + continue; + + // assume(v u> y) -> assume(v != 0) + if (Pred == ICmpInst::ICMP_UGT) + return isValidAssumeForContext(I, Q.CxtI, Q.DT); + + switch (Cmp->getPredicate()) { + case ICmpInst::ICMP_NE: + // assume(v != 0) + // We special-case this one to ensure that we handle `assume(v != null)`. + if (match(RHS, m_Zero()) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) + return true; + break; + default: { + ConstantInt *CI; + if (!match(RHS, m_ConstantInt(CI))) + break; + ConstantRange RHSRange(CI->getValue()); + ConstantRange TrueValues = + ConstantRange::makeAllowedICmpRegion(Pred, RHSRange); + if (!TrueValues.contains(APInt::getNullValue(CI->getBitWidth())) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) + return true; + break; + } + } + } + + return false; +} + static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, unsigned Depth, const Query &Q) { // Use of assumptions is context-sensitive. If we don't have a context, we @@ -2080,6 +2147,9 @@ } } + if (isKnownNonZeroFromAssume(V, Q)) + return true; + // Some of the tests below are recursive, so bail out if we hit the limit. if (Depth++ >= MaxDepth) return false; diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll --- a/llvm/test/Transforms/InstCombine/assume.ll +++ b/llvm/test/Transforms/InstCombine/assume.ll @@ -252,8 +252,7 @@ ; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32* [[A:%.*]], align 4 ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[LOAD]], 0 ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[RVAL:%.*]] = icmp eq i32 [[LOAD]], 0 -; CHECK-NEXT: ret i1 [[RVAL]] +; CHECK-NEXT: ret i1 false ; %load = load i32, i32* %a %cmp = icmp ne i32 %load, 0 @@ -268,13 +267,12 @@ define i1 @nonnull3(i32** %a, i1 %control) { ; CHECK-LABEL: @nonnull3( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[LOAD:%.*]] = load i32*, i32** [[A:%.*]], align 8 ; CHECK-NEXT: br i1 [[CONTROL:%.*]], label [[TAKEN:%.*]], label [[NOT_TAKEN:%.*]] ; CHECK: taken: +; CHECK-NEXT: [[LOAD:%.*]] = load i32*, i32** [[A:%.*]], align 8 ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32* [[LOAD]], null ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[RVAL:%.*]] = icmp eq i32* [[LOAD]], null -; CHECK-NEXT: ret i1 [[RVAL]] +; CHECK-NEXT: ret i1 false ; CHECK: not_taken: ; CHECK-NEXT: ret i1 true ; @@ -297,11 +295,10 @@ define i1 @nonnull4(i32** %a) { ; CHECK-LABEL: @nonnull4( ; CHECK-NEXT: [[LOAD:%.*]] = load i32*, i32** [[A:%.*]], align 8 -; CHECK-NEXT: tail call void @escape(i32* [[LOAD]]) +; CHECK-NEXT: tail call void @escape(i32* nonnull [[LOAD]]) ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32* [[LOAD]], null ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[RVAL:%.*]] = icmp eq i32* [[LOAD]], null -; CHECK-NEXT: ret i1 [[RVAL]] +; CHECK-NEXT: ret i1 false ; %load = load i32*, i32** %a ;; This call may throw! @@ -336,12 +333,12 @@ define void @debug_interference(i8 %x) { ; CHECK-LABEL: @debug_interference( -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i8 [[X:%.*]], 0 -; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 -; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP1]]) -; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 +; CHECK-NEXT: [[CMP2:%.*]] = icmp ne i8 [[X:%.*]], 0 ; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 ; CHECK-NEXT: tail call void @llvm.assume(i1 false) +; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 +; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP2]]) ; CHECK-NEXT: ret void ; %cmp1 = icmp eq i8 %x, 0 diff --git a/llvm/test/Transforms/InstSimplify/assume-non-zero.ll b/llvm/test/Transforms/InstSimplify/assume-non-zero.ll --- a/llvm/test/Transforms/InstSimplify/assume-non-zero.ll +++ b/llvm/test/Transforms/InstSimplify/assume-non-zero.ll @@ -10,8 +10,7 @@ ; CHECK-LABEL: @nonnull0_true( ; CHECK-NEXT: [[A:%.*]] = icmp ne i8* [[X:%.*]], null ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8* [[X]], null -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp ne i8* %x, null call void @llvm.assume(i1 %a) @@ -24,8 +23,7 @@ ; CHECK-NEXT: [[INTPTR:%.*]] = ptrtoint i8* [[X:%.*]] to i64 ; CHECK-NEXT: [[A:%.*]] = icmp ne i64 [[INTPTR]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8* [[X]], null -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %intptr = ptrtoint i8* %x to i64 %a = icmp ne i64 %intptr, 0 @@ -38,8 +36,7 @@ ; CHECK-LABEL: @nonnull2_true( ; CHECK-NEXT: [[A:%.*]] = icmp ugt i8 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp ugt i8 %x, %y call void @llvm.assume(i1 %a) @@ -61,8 +58,7 @@ ; CHECK-LABEL: @nonnull4_true( ; CHECK-NEXT: [[A:%.*]] = icmp uge i8 [[X:%.*]], 1 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp uge i8 %x, 1 call void @llvm.assume(i1 %a) @@ -86,8 +82,7 @@ ; CHECK-LABEL: @nonnull6_true( ; CHECK-NEXT: [[A:%.*]] = icmp sgt i8 [[X:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp sgt i8 %x, 0 call void @llvm.assume(i1 %a) @@ -98,8 +93,7 @@ ; CHECK-LABEL: @nonnull7_true( ; CHECK-NEXT: [[A:%.*]] = icmp sgt i8 [[X:%.*]], 1 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp sgt i8 %x, 1 call void @llvm.assume(i1 %a) @@ -135,8 +129,7 @@ ; CHECK-LABEL: @nonnull10_true( ; CHECK-NEXT: [[A:%.*]] = icmp sge i8 [[X:%.*]], 1 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp sge i8 %x, 1 call void @llvm.assume(i1 %a)