diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -577,6 +577,55 @@ return !isEphemeralValueOf(Inv, CxtI); } +static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) { + // Use of assumptions is context-sensitive. If we don't have a context, we + // cannot use them! + if (!Q.AC || !Q.CxtI) + return false; + + // Note that the patterns below need to be kept in sync with the code + // in AssumptionCache::updateAffectedValues. + + for (auto &AssumeVH : Q.AC->assumptionsFor(V)) { + if (!AssumeVH) + continue; + CallInst *I = cast(AssumeVH); + assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() && + "Got assumption for the wrong function!"); + if (Q.isExcluded(I)) + continue; + + // Warning: This loop can end up being somewhat performance sensitive. + // We're running this loop for once for each value queried resulting in a + // runtime of ~O(#assumes * #values). + + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + + Value *Arg = I->getArgOperand(0); + + ICmpInst *Cmp = dyn_cast(Arg); + if (!Cmp) + continue; + + auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V))); + + CmpInst::Predicate Pred; + switch (Cmp->getPredicate()) { + default: + break; + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_UGT: + // assume(v !=/u> 0) + if (match(Cmp, m_ICmp(Pred, m_V, m_Zero())) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) + return true; + } + } + + return false; +} + static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, unsigned Depth, const Query &Q) { // Use of assumptions is context-sensitive. If we don't have a context, we @@ -2080,6 +2129,9 @@ } } + if (isKnownNonZeroFromAssume(V, Q)) + return true; + // Some of the tests below are recursive, so bail out if we hit the limit. if (Depth++ >= MaxDepth) return false; diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll --- a/llvm/test/Transforms/InstCombine/assume.ll +++ b/llvm/test/Transforms/InstCombine/assume.ll @@ -252,8 +252,7 @@ ; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32* [[A:%.*]], align 4 ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[LOAD]], 0 ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[RVAL:%.*]] = icmp eq i32 [[LOAD]], 0 -; CHECK-NEXT: ret i1 [[RVAL]] +; CHECK-NEXT: ret i1 false ; %load = load i32, i32* %a %cmp = icmp ne i32 %load, 0 @@ -268,13 +267,12 @@ define i1 @nonnull3(i32** %a, i1 %control) { ; CHECK-LABEL: @nonnull3( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[LOAD:%.*]] = load i32*, i32** [[A:%.*]], align 8 ; CHECK-NEXT: br i1 [[CONTROL:%.*]], label [[TAKEN:%.*]], label [[NOT_TAKEN:%.*]] ; CHECK: taken: +; CHECK-NEXT: [[LOAD:%.*]] = load i32*, i32** [[A:%.*]], align 8 ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32* [[LOAD]], null ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[RVAL:%.*]] = icmp eq i32* [[LOAD]], null -; CHECK-NEXT: ret i1 [[RVAL]] +; CHECK-NEXT: ret i1 false ; CHECK: not_taken: ; CHECK-NEXT: ret i1 true ; @@ -297,11 +295,10 @@ define i1 @nonnull4(i32** %a) { ; CHECK-LABEL: @nonnull4( ; CHECK-NEXT: [[LOAD:%.*]] = load i32*, i32** [[A:%.*]], align 8 -; CHECK-NEXT: tail call void @escape(i32* [[LOAD]]) +; CHECK-NEXT: tail call void @escape(i32* nonnull [[LOAD]]) ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32* [[LOAD]], null ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[RVAL:%.*]] = icmp eq i32* [[LOAD]], null -; CHECK-NEXT: ret i1 [[RVAL]] +; CHECK-NEXT: ret i1 false ; %load = load i32*, i32** %a ;; This call may throw! @@ -336,12 +333,12 @@ define void @debug_interference(i8 %x) { ; CHECK-LABEL: @debug_interference( -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i8 [[X:%.*]], 0 -; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 -; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP1]]) -; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 +; CHECK-NEXT: [[CMP2:%.*]] = icmp ne i8 [[X:%.*]], 0 ; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 ; CHECK-NEXT: tail call void @llvm.assume(i1 false) +; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 +; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP2]]) ; CHECK-NEXT: ret void ; %cmp1 = icmp eq i8 %x, 0