diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -413,16 +413,55 @@ // Absolute value preserves trailing zero count. KnownBits KnownAbs(getBitWidth()); - KnownAbs.Zero.setLowBits(countMinTrailingZeros()); - // We only know that the absolute values's MSB will be zero if INT_MIN is - // poison, or there is a set bit that isn't the sign bit (otherwise it could - // be INT_MIN). - if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) - KnownAbs.Zero.setSignBit(); + // If the input is negative, then abs(x) == -x. + if (isNegative()) { + KnownBits Tmp = *this; + // Special case for IntMinIsPoison. We know the sign bit is set and we know + // all the rest of the bits except one to be zero. Since we have + // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is + // INT_MIN. + if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth()) + Tmp.One.setBit(countMinTrailingZeros()); + + KnownAbs = computeForAddSub( + /*Add*/ false, IntMinIsPoison, + KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp); + + // One more special case for IntMinIsPoison. If we don't know any ones other + // than the signbit, we know for certain that all the unknowns can't be + // zero. So if we know high zero bits, but have unknown low bits, we know + // for certain those high-zero bits will end up as one. This is because, + // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up + // to the high bits. If we know a known INT_MIN input skip this. The result + // is poison anyways. + if (IntMinIsPoison && Tmp.countMinPopulation() == 1 && + Tmp.countMaxPopulation() != 1) { + Tmp.One.clearSignBit(); + Tmp.Zero.setSignBit(); + KnownAbs.One.setBits(getBitWidth() - Tmp.countMinLeadingZeros(), + getBitWidth() - 1); + } + + } else { + unsigned MaxTZ = countMaxTrailingZeros(); + unsigned MinTZ = countMinTrailingZeros(); + + KnownAbs.Zero.setLowBits(MinTZ); + // If we know the lowest set 1, then preserve it. + if (MaxTZ == MinTZ && MaxTZ < getBitWidth()) + KnownAbs.One.setBit(MaxTZ); + + // We only know that the absolute values's MSB will be zero if INT_MIN is + // poison, or there is a set bit that isn't the sign bit (otherwise it could + // be INT_MIN). + if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) { + KnownAbs.One.clearSignBit(); + KnownAbs.Zero.setSignBit(); + } + } - // FIXME: Handle known negative input? - // FIXME: Calculate the negated Known bits and combine them? + assert(!KnownAbs.hasConflict() && "Bad Output"); return KnownAbs; } diff --git a/llvm/test/Analysis/ValueTracking/knownbits-abs.ll b/llvm/test/Analysis/ValueTracking/knownbits-abs.ll --- a/llvm/test/Analysis/ValueTracking/knownbits-abs.ll +++ b/llvm/test/Analysis/ValueTracking/knownbits-abs.ll @@ -4,12 +4,7 @@ define i1 @abs_low_bit_set(i8 %x) { ; CHECK-LABEL: @abs_low_bit_set( -; CHECK-NEXT: [[XX:%.*]] = and i8 [[X:%.*]], -16 -; CHECK-NEXT: [[V:%.*]] = or i8 [[XX]], 4 -; CHECK-NEXT: [[ABS:%.*]] = call i8 @llvm.abs.i8(i8 [[V]], i1 true) -; CHECK-NEXT: [[AND:%.*]] = and i8 [[ABS]], 4 -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %xx = and i8 %x, 240 %v = or i8 %xx, 4 @@ -36,12 +31,7 @@ define i1 @abs_negative(i8 %x) { ; CHECK-LABEL: @abs_negative( -; CHECK-NEXT: [[XX:%.*]] = and i8 [[X:%.*]], -16 -; CHECK-NEXT: [[V:%.*]] = or i8 [[XX]], -124 -; CHECK-NEXT: [[ABS:%.*]] = call i8 @llvm.abs.i8(i8 [[V]], i1 true) -; CHECK-NEXT: [[AND:%.*]] = and i8 [[ABS]], 8 -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %xx = and i8 %x, 240 %v = or i8 %xx, 132 @@ -53,11 +43,7 @@ define i1 @abs_negative2(i8 %x) { ; CHECK-LABEL: @abs_negative2( -; CHECK-NEXT: [[V:%.*]] = or i8 [[X:%.*]], -125 -; CHECK-NEXT: [[ABS:%.*]] = call i8 @llvm.abs.i8(i8 [[V]], i1 true) -; CHECK-NEXT: [[AND:%.*]] = and i8 [[ABS]], 2 -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], 2 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %v = or i8 %x, 131 %abs = call i8 @llvm.abs.i8(i8 %v, i1 true) diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -356,20 +356,15 @@ } TEST(KnownBitsTest, UnaryExhaustive) { - // TODO: Make optimal for cases that are not known non-negative. - testUnaryOpExhaustive( - [](const KnownBits &Known) { return Known.abs(); }, - [](const APInt &N) { return N.abs(); }, - [](const KnownBits &Known) { return Known.isNonNegative(); }); - - testUnaryOpExhaustive( - [](const KnownBits &Known) { return Known.abs(true); }, - [](const APInt &N) -> std::optional { - if (N.isMinSignedValue()) - return std::nullopt; - return N.abs(); - }, - [](const KnownBits &Known) { return Known.isNonNegative(); }); + testUnaryOpExhaustive([](const KnownBits &Known) { return Known.abs(); }, + [](const APInt &N) { return N.abs(); }); + + testUnaryOpExhaustive([](const KnownBits &Known) { return Known.abs(true); }, + [](const APInt &N) -> std::optional { + if (N.isMinSignedValue()) + return std::nullopt; + return N.abs(); + }); testUnaryOpExhaustive([](const KnownBits &Known) { return Known.blsi(); }, [](const APInt &N) { return N & -N; });