diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -167,6 +167,10 @@ const DominatorTree *DT = nullptr, bool UseInstrInfo = true); +/// Check if an assume intrinsic covers the non-equality information +bool getKnownNonEqualFromAssume(const Value *V1, const Value *V2, + const SimplifyQuery &Q); + /// Return true if the given values are known to be non-equal when defined. /// Supports scalar integer types only. bool isKnownNonEqual(const Value *V1, const Value *V2, const DataLayout &DL, diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -3107,6 +3107,50 @@ return true; } +bool llvm::getKnownNonEqualFromAssume(const Value *V1, const Value *V2, + const SimplifyQuery &Q) { + // Use of assumptions is context-sensitive. If we don't have a context, we + // cannot use them! + if (!Q.AC || !Q.CxtI) + return false; + + // Note that the patterns below need to be kept in sync with the code + // in AssumptionCache::updateAffectedValues. + + for (auto &AssumeVH : Q.AC->assumptionsFor(V1)) { + if (!AssumeVH) + continue; + + CallInst *I = cast(AssumeVH); + assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() && + "Got assumption for the wrong function!"); + + // Warning: This loop can end up being somewhat performance sensitive. + // We're running this loop for once for each value queried resulting in a + // runtime of ~O(#assumes * #values). + + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + + Value *Arg = I->getArgOperand(0); + ICmpInst *Cmp = dyn_cast(Arg); + if (!Cmp || !isValidAssumeForContext(I, Q.CxtI, Q.DT)) + continue; + if (!Cmp->isEquality()) + continue; + + // If we have a matching assumption: + // - EQ: return false + // - NE: return true + Value *Op0 = Cmp->getOperand(0); + Value *Op1 = Cmp->getOperand(1); + if ((Op0 == V1 && Op1 == V2) || (Op0 == V2 && Op1 == V1)) + return Cmp->getPredicate() == CmpInst::ICMP_NE; + } + + return false; +} + /// Return true if it is known that V1 != V2. static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth, const SimplifyQuery &Q) { @@ -3156,7 +3200,9 @@ Known2.Zero.intersects(Known1.One)) return true; } - return false; + + // Check whether a nearby assume intrinsic can determine some known bits. + return getKnownNonEqualFromAssume(V1, V2, Q); } /// Return true if 'V & Mask' is known to be zero. We use this predicate to diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5164,6 +5164,23 @@ return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType())); } + // ((A / -B) == (A / B)) -> ((A / B) == 0) + // if known A != INT_MIN or B != INT_MIN + { + const unsigned OpWidth = Op0->getType()->getScalarSizeInBits(); + APInt SignedMinValue = APInt::getSignedMinValue(OpWidth); + Constant *MinInt = ConstantInt::get(Op0->getType(), SignedMinValue); + if (match(Op0, m_OneUse(m_SDiv(m_Value(A), m_OneUse(m_Neg(m_Value(B)))))) && + match(Op1, m_SDiv(m_Specific(A), m_Specific(B)))) { + // Check if A is known to be != INT_MIN + if (isKnownNonEqual(A, MinInt, DL, &AC, &I, &DT, true)) + return new ICmpInst(Pred, Op1, Constant::getNullValue(A->getType())); + // Check if B is known to be != INT_MIN + if (isKnownNonEqual(B, MinInt, DL, &AC, &I, &DT, true)) + return new ICmpInst(Pred, Op1, Constant::getNullValue(A->getType())); + } + } + // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { diff --git a/llvm/test/Transforms/InstCombine/icmp-sdiv-sdiv.ll b/llvm/test/Transforms/InstCombine/icmp-sdiv-sdiv.ll --- a/llvm/test/Transforms/InstCombine/icmp-sdiv-sdiv.ll +++ b/llvm/test/Transforms/InstCombine/icmp-sdiv-sdiv.ll @@ -9,10 +9,8 @@ ; CHECK-LABEL: @icmp_sdiv_sdiv_normal_i8( ; CHECK-NEXT: [[PRECOND:%.*]] = icmp ne i8 [[C:%.*]], -128 ; CHECK-NEXT: call void @llvm.assume(i1 [[PRECOND]]) -; CHECK-NEXT: [[NEGC:%.*]] = sub i8 0, [[C]] -; CHECK-NEXT: [[D1:%.*]] = sdiv i8 [[X:%.*]], [[NEGC]] -; CHECK-NEXT: [[D2:%.*]] = sdiv i8 [[X]], [[C]] -; CHECK-NEXT: [[C:%.*]] = icmp eq i8 [[D1]], [[D2]] +; CHECK-NEXT: [[D2:%.*]] = sdiv i8 [[X:%.*]], [[C]] +; CHECK-NEXT: [[C:%.*]] = icmp eq i8 [[D2]], 0 ; CHECK-NEXT: ret i1 [[C]] ; %precond = icmp ne i8 %C, -128 @@ -28,10 +26,8 @@ ; CHECK-LABEL: @icmp_sdiv_sdiv_normal_i64( ; CHECK-NEXT: [[PRECOND:%.*]] = icmp ne i64 [[C:%.*]], -9223372036854775808 ; CHECK-NEXT: call void @llvm.assume(i1 [[PRECOND]]) -; CHECK-NEXT: [[NEGC:%.*]] = sub i64 0, [[C]] -; CHECK-NEXT: [[D1:%.*]] = sdiv i64 [[X:%.*]], [[NEGC]] -; CHECK-NEXT: [[D2:%.*]] = sdiv i64 [[X]], [[C]] -; CHECK-NEXT: [[C:%.*]] = icmp eq i64 [[D1]], [[D2]] +; CHECK-NEXT: [[D2:%.*]] = sdiv i64 [[X:%.*]], [[C]] +; CHECK-NEXT: [[C:%.*]] = icmp eq i64 [[D2]], 0 ; CHECK-NEXT: ret i1 [[C]] ; %precond = icmp ne i64 %C, -9223372036854775808