diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -566,17 +566,83 @@ if (Inv == CxtI) return false; - // The context comes first, but they're both in the same block. Make sure - // there is nothing in between that might interrupt the control flow. - for (BasicBlock::const_iterator I = - std::next(BasicBlock::const_iterator(CxtI)), IE(Inv); - I != IE; ++I) + // The context comes first, but they're both in the same block. + // Make sure there is nothing in between that might interrupt + // the control flow, not even CxtI itself. + for (BasicBlock::const_iterator I(CxtI), IE(Inv); I != IE; ++I) if (!isGuaranteedToTransferExecutionToSuccessor(&*I)) return false; return !isEphemeralValueOf(Inv, CxtI); } +static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) { + // Use of assumptions is context-sensitive. If we don't have a context, we + // cannot use them! + if (!Q.AC || !Q.CxtI) + return false; + + // Note that the patterns below need to be kept in sync with the code + // in AssumptionCache::updateAffectedValues. + + auto CmpExcludesZero = [V](ICmpInst *Cmp) { + auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V))); + + Value *RHS; + CmpInst::Predicate Pred; + if (!match(Cmp, m_c_ICmp(Pred, m_V, m_Value(RHS)))) + return false; + // Canonicalize 'v' to be on the LHS of the comparison. + if (Cmp->getOperand(1) != RHS) + Pred = CmpInst::getSwappedPredicate(Pred); + + // assume(v u> y) -> assume(v != 0) + if (Pred == ICmpInst::ICMP_UGT) + return true; + + // assume(v != 0) + // We special-case this one to ensure that we handle `assume(v != null)`. + if (Pred == ICmpInst::ICMP_NE) + return match(RHS, m_Zero()); + + // All other predicates - rely on generic ConstantRange handling. + ConstantInt *CI; + if (!match(RHS, m_ConstantInt(CI))) + return false; + ConstantRange RHSRange(CI->getValue()); + ConstantRange TrueValues = + ConstantRange::makeAllowedICmpRegion(Pred, RHSRange); + return !TrueValues.contains(APInt::getNullValue(CI->getBitWidth())); + }; + + for (auto &AssumeVH : Q.AC->assumptionsFor(V)) { + if (!AssumeVH) + continue; + CallInst *I = cast(AssumeVH); + assert(I->getFunction() == Q.CxtI->getFunction() && + "Got assumption for the wrong function!"); + if (Q.isExcluded(I)) + continue; + + // Warning: This loop can end up being somewhat performance sensitive. + // We're running this loop for once for each value queried resulting in a + // runtime of ~O(#assumes * #values). + + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + + Value *Arg = I->getArgOperand(0); + ICmpInst *Cmp = dyn_cast(Arg); + if (!Cmp) + continue; + + if (CmpExcludesZero(Cmp) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) + return true; + } + + return false; +} + static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, unsigned Depth, const Query &Q) { // Use of assumptions is context-sensitive. If we don't have a context, we @@ -2080,6 +2146,9 @@ } } + if (isKnownNonZeroFromAssume(V, Q)) + return true; + // Some of the tests below are recursive, so bail out if we hit the limit. if (Depth++ >= MaxDepth) return false; diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll --- a/llvm/test/Transforms/InstCombine/assume.ll +++ b/llvm/test/Transforms/InstCombine/assume.ll @@ -252,8 +252,7 @@ ; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32* [[A:%.*]], align 4 ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[LOAD]], 0 ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[RVAL:%.*]] = icmp eq i32 [[LOAD]], 0 -; CHECK-NEXT: ret i1 [[RVAL]] +; CHECK-NEXT: ret i1 false ; %load = load i32, i32* %a %cmp = icmp ne i32 %load, 0 @@ -268,13 +267,12 @@ define i1 @nonnull3(i32** %a, i1 %control) { ; CHECK-LABEL: @nonnull3( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[LOAD:%.*]] = load i32*, i32** [[A:%.*]], align 8 ; CHECK-NEXT: br i1 [[CONTROL:%.*]], label [[TAKEN:%.*]], label [[NOT_TAKEN:%.*]] ; CHECK: taken: +; CHECK-NEXT: [[LOAD:%.*]] = load i32*, i32** [[A:%.*]], align 8 ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32* [[LOAD]], null ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[RVAL:%.*]] = icmp eq i32* [[LOAD]], null -; CHECK-NEXT: ret i1 [[RVAL]] +; CHECK-NEXT: ret i1 false ; CHECK: not_taken: ; CHECK-NEXT: ret i1 true ; @@ -300,8 +298,7 @@ ; CHECK-NEXT: tail call void @escape(i32* [[LOAD]]) ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32* [[LOAD]], null ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[RVAL:%.*]] = icmp eq i32* [[LOAD]], null -; CHECK-NEXT: ret i1 [[RVAL]] +; CHECK-NEXT: ret i1 false ; %load = load i32*, i32** %a ;; This call may throw! @@ -311,6 +308,23 @@ %rval = icmp eq i32* %load, null ret i1 %rval } +define i1 @nonnull5(i32** %a) { +; CHECK-LABEL: @nonnull5( +; CHECK-NEXT: [[LOAD:%.*]] = load i32*, i32** [[A:%.*]], align 8 +; CHECK-NEXT: tail call void @escape(i32* [[LOAD]]) +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32* [[LOAD]], null +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) +; CHECK-NEXT: ret i1 false +; + %load = load i32*, i32** %a + ;; This call may throw! + tail call void @escape(i32* %load) + %integral = ptrtoint i32* %load to i64 + %cmp = icmp slt i64 %integral, 0 + tail call void @llvm.assume(i1 %cmp) ; %load has at least highest bit set + %rval = icmp eq i32* %load, null + ret i1 %rval +} ; PR35846 - https://bugs.llvm.org/show_bug.cgi?id=35846 @@ -336,12 +350,12 @@ define void @debug_interference(i8 %x) { ; CHECK-LABEL: @debug_interference( -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i8 [[X:%.*]], 0 -; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 -; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP1]]) -; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 +; CHECK-NEXT: [[CMP2:%.*]] = icmp ne i8 [[X:%.*]], 0 ; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 ; CHECK-NEXT: tail call void @llvm.assume(i1 false) +; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 +; CHECK-NEXT: tail call void @llvm.dbg.value(metadata i32 5, metadata !7, metadata !DIExpression()), !dbg !9 +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP2]]) ; CHECK-NEXT: ret void ; %cmp1 = icmp eq i8 %x, 0 diff --git a/llvm/test/Transforms/InstSimplify/assume-non-zero.ll b/llvm/test/Transforms/InstSimplify/assume-non-zero.ll --- a/llvm/test/Transforms/InstSimplify/assume-non-zero.ll +++ b/llvm/test/Transforms/InstSimplify/assume-non-zero.ll @@ -10,8 +10,7 @@ ; CHECK-LABEL: @nonnull0_true( ; CHECK-NEXT: [[A:%.*]] = icmp ne i8* [[X:%.*]], null ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8* [[X]], null -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp ne i8* %x, null call void @llvm.assume(i1 %a) @@ -24,8 +23,7 @@ ; CHECK-NEXT: [[INTPTR:%.*]] = ptrtoint i8* [[X:%.*]] to i64 ; CHECK-NEXT: [[A:%.*]] = icmp ne i64 [[INTPTR]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8* [[X]], null -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %intptr = ptrtoint i8* %x to i64 %a = icmp ne i64 %intptr, 0 @@ -38,14 +36,24 @@ ; CHECK-LABEL: @nonnull2_true( ; CHECK-NEXT: [[A:%.*]] = icmp ugt i8 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp ugt i8 %x, %y call void @llvm.assume(i1 %a) %q = icmp ne i8 %x, 0 ret i1 %q } +define i1 @nonnull2_true_swapped(i8 %x, i8 %y) { +; CHECK-LABEL: @nonnull2_true_swapped( +; CHECK-NEXT: [[A:%.*]] = icmp ult i8 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) +; CHECK-NEXT: ret i1 true +; + %a = icmp ult i8 %y, %x + call void @llvm.assume(i1 %a) + %q = icmp ne i8 %x, 0 + ret i1 %q +} define i1 @nonnull3_unknown(i8 %x) { ; CHECK-LABEL: @nonnull3_unknown( @@ -61,8 +69,7 @@ ; CHECK-LABEL: @nonnull4_true( ; CHECK-NEXT: [[A:%.*]] = icmp uge i8 [[X:%.*]], 1 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp uge i8 %x, 1 call void @llvm.assume(i1 %a) @@ -86,8 +93,7 @@ ; CHECK-LABEL: @nonnull6_true( ; CHECK-NEXT: [[A:%.*]] = icmp sgt i8 [[X:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp sgt i8 %x, 0 call void @llvm.assume(i1 %a) @@ -98,8 +104,7 @@ ; CHECK-LABEL: @nonnull7_true( ; CHECK-NEXT: [[A:%.*]] = icmp sgt i8 [[X:%.*]], 1 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp sgt i8 %x, 1 call void @llvm.assume(i1 %a) @@ -135,8 +140,7 @@ ; CHECK-LABEL: @nonnull10_true( ; CHECK-NEXT: [[A:%.*]] = icmp sge i8 [[X:%.*]], 1 ; CHECK-NEXT: call void @llvm.assume(i1 [[A]]) -; CHECK-NEXT: [[Q:%.*]] = icmp ne i8 [[X]], 0 -; CHECK-NEXT: ret i1 [[Q]] +; CHECK-NEXT: ret i1 true ; %a = icmp sge i8 %x, 1 call void @llvm.assume(i1 %a)