diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -412,13 +412,12 @@ KnownOut = KnownBits::computeForAddSub(Add, NSW, Known2, KnownOut); } -static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, - const APInt &DemandedElts, KnownBits &Known, - KnownBits &Known2, unsigned Depth, - const Query &Q) { - computeKnownBits(Op1, DemandedElts, Known, Depth + 1, Q); - computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q); - +static KnownBits computeKnownBitsMulDoMul(const Value *Op0, const Value *Op1, + bool NSW, const APInt &DemandedElts, + const KnownBits &Known0, + const KnownBits &Known1, + unsigned Depth, const Query &Q) { + KnownBits KnownOut = Known0; bool isKnownNegative = false; bool isKnownNonNegative = false; // If the multiplication is known not to overflow, compute the sign bit. @@ -427,10 +426,10 @@ // The product of a number with itself is non-negative. isKnownNonNegative = true; } else { - bool isKnownNonNegativeOp1 = Known.isNonNegative(); - bool isKnownNonNegativeOp0 = Known2.isNonNegative(); - bool isKnownNegativeOp1 = Known.isNegative(); - bool isKnownNegativeOp0 = Known2.isNegative(); + bool isKnownNonNegativeOp1 = Known0.isNonNegative(); + bool isKnownNonNegativeOp0 = Known1.isNonNegative(); + bool isKnownNegativeOp1 = Known0.isNegative(); + bool isKnownNegativeOp0 = Known1.isNegative(); // The product of two numbers with the same sign is non-negative. isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) || (isKnownNonNegativeOp1 && isKnownNonNegativeOp0); @@ -439,8 +438,8 @@ if (!isKnownNonNegative) isKnownNegative = (isKnownNegativeOp1 && isKnownNonNegativeOp0 && - Known2.isNonZero()) || - (isKnownNegativeOp0 && isKnownNonNegativeOp1 && Known.isNonZero()); + Known1.isNonZero()) || + (isKnownNegativeOp0 && isKnownNonNegativeOp1 && Known0.isNonZero()); } } @@ -449,17 +448,30 @@ if (SelfMultiply) SelfMultiply &= isGuaranteedNotToBeUndefOrPoison(Op0, Q.AC, Q.CxtI, Q.DT, Depth + 1); - Known = KnownBits::mul(Known, Known2, SelfMultiply); + KnownOut = KnownBits::mul(Known0, Known1, SelfMultiply); // Only make use of no-wrap flags if we failed to compute the sign bit // directly. This matters if the multiplication always overflows, in // which case we prefer to follow the result of the direct computation, // though as the program is invoking undefined behaviour we can choose // whatever we like here. - if (isKnownNonNegative && !Known.isNegative()) - Known.makeNonNegative(); - else if (isKnownNegative && !Known.isNonNegative()) - Known.makeNegative(); + if (isKnownNonNegative && !KnownOut.isNegative()) + KnownOut.makeNonNegative(); + else if (isKnownNegative && !KnownOut.isNonNegative()) + KnownOut.makeNegative(); + + return KnownOut; +} + +static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, + const APInt &DemandedElts, KnownBits &Known, + KnownBits &Known2, unsigned Depth, + const Query &Q) { + computeKnownBits(Op1, DemandedElts, Known, Depth + 1, Q); + computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q); + + Known = computeKnownBitsMulDoMul(Op0, Op1, NSW, DemandedElts, Known, Known2, + Depth, Q); } void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, @@ -2872,11 +2884,27 @@ // If X and Y are non-zero then so is X * Y as long as the multiplication // does not overflow. const OverflowingBinaryOperator *BO = cast(V); - if ((Q.IIQ.hasNoSignedWrap(BO) || Q.IIQ.hasNoUnsignedWrap(BO)) && - isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q) && - isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q)) - return true; - break; + if (Q.IIQ.hasNoSignedWrap(BO) || Q.IIQ.hasNoUnsignedWrap(BO)) + return isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q) && + isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q); + + // If either X or Y is odd, then if the other is non-zero the result can't + // be zero. + KnownBits XKnown = + computeKnownBits(I->getOperand(0), DemandedElts, Depth, Q); + if (XKnown.One[0]) + return isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q); + + KnownBits YKnown = + computeKnownBits(I->getOperand(1), DemandedElts, Depth, Q); + if (YKnown.One[0]) + return XKnown.isNonZero() || + isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q); + + return computeKnownBitsMulDoMul(I->getOperand(0), I->getOperand(1), + /*NSW*/ false, DemandedElts, XKnown, YKnown, + Depth, Q) + .isNonZero(); } case Instruction::Select: // (C ? X : Y) != 0 if X != 0 and Y != 0. diff --git a/llvm/test/Analysis/ValueTracking/known-non-zero.ll b/llvm/test/Analysis/ValueTracking/known-non-zero.ll --- a/llvm/test/Analysis/ValueTracking/known-non-zero.ll +++ b/llvm/test/Analysis/ValueTracking/known-non-zero.ll @@ -767,13 +767,9 @@ define i1 @mul_nonzero_odd(i8 %xx, i8 %y, i8 %ind) { ; CHECK-LABEL: @mul_nonzero_odd( -; CHECK-NEXT: [[XO:%.*]] = or i8 [[XX:%.*]], 1 ; CHECK-NEXT: [[Y_NZ:%.*]] = icmp ne i8 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[Y_NZ]]) -; CHECK-NEXT: [[X:%.*]] = mul i8 [[XO]], [[Y]] -; CHECK-NEXT: [[Z:%.*]] = or i8 [[X]], [[IND:%.*]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[Z]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %xo = or i8 %xx, 1 %y_nz = icmp ne i8 %y, 0