diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -168,30 +168,34 @@ bool NSW) { unsigned BitWidth = LHS.getBitWidth(); auto ShiftByConst = [&](const KnownBits &LHS, - uint64_t ShiftAmt) -> std::optional { + unsigned ShiftAmt) -> std::optional { KnownBits Known; - Known.Zero = LHS.Zero << ShiftAmt; + bool ShiftedOutZero, ShiftedOutOne; + Known.Zero = LHS.Zero.ushl_ov(ShiftAmt, ShiftedOutZero); Known.Zero.setLowBits(ShiftAmt); - Known.One = LHS.One << ShiftAmt; - if ((!NUW && !NSW) || ShiftAmt == 0) - return Known; + Known.One = LHS.One.ushl_ov(ShiftAmt, ShiftedOutOne); + + if (NUW) { + if (ShiftedOutOne) + // One bit has been shifted out. + return std::nullopt; + if (ShiftAmt != 0) + // NUW means we can assume anything shifted out was a zero. + ShiftedOutZero = true; + } - KnownBits ShiftedOutBits = LHS.extractBits(ShiftAmt, BitWidth - ShiftAmt); - if (NUW && !ShiftedOutBits.One.isZero()) - // One bit has been shifted out. - return std::nullopt; if (NSW) { - if (!ShiftedOutBits.Zero.isZero() && !ShiftedOutBits.One.isZero()) + if (ShiftedOutZero && ShiftedOutOne) // Both zeros and ones have been shifted out. return std::nullopt; - if (NUW || !ShiftedOutBits.Zero.isZero()) { + if (ShiftedOutZero) { if (Known.isNegative()) // Zero bit has been shifted out, but result sign is negative. return std::nullopt; Known.makeNonNegative(); - } else if (!ShiftedOutBits.One.isZero()) { + } else if (ShiftedOutOne) { if (Known.isNonNegative()) - // One bit has been shifted out, but result sign is negative. + // One bit has been shifted out, but result sign is non-negative. return std::nullopt; Known.makeNegative(); } @@ -199,47 +203,31 @@ return Known; }; - // If the shift amount is a valid constant then transform LHS directly. - if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) { - if (auto Res = ShiftByConst(LHS, RHS.getConstant().getZExtValue())) - return *Res; - KnownBits Known(BitWidth); - Known.setAllZero(); - return Known; - } - + // Fast path for a common case when LHS is completely unknown. KnownBits Known(BitWidth); - APInt MinShiftAmount = RHS.getMinValue(); - if (MinShiftAmount.uge(BitWidth)) { - // Always poison. Return zero because we don't like returning conflict. - Known.setAllZero(); - return Known; - } - + unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth); if (LHS.isUnknown()) { - // No matter the shift amount, the trailing zeros will stay zero. - unsigned MinTrailingZeros = LHS.countMinTrailingZeros(); - // Minimum shift amount low bits are known zero. - MinTrailingZeros += MinShiftAmount.getZExtValue(); - MinTrailingZeros = std::min(MinTrailingZeros, BitWidth); - Known.Zero.setLowBits(MinTrailingZeros); - if (NUW && NSW && !MinShiftAmount.isZero()) + if (MinShiftAmount == BitWidth) { + // Always poison. Return zero because we don't like returning conflict. + Known.setAllZero(); + return Known; + } + Known.Zero.setLowBits(MinShiftAmount); + if (NUW && NSW && MinShiftAmount != 0) Known.makeNonNegative(); return Known; } // Find the common bits from all possible shifts. - APInt MaxShiftAmount = RHS.getMaxValue(); - uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue(); - uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue(); - assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range"); + unsigned MaxShiftAmount = RHS.getMaxValue().getLimitedValue(BitWidth - 1); + unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue(); + unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue(); Known.Zero.setAllBits(); Known.One.setAllBits(); - for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(), - MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1); - ShiftAmt <= MaxShiftAmt; ++ShiftAmt) { + for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; + ++ShiftAmt) { // Skip if the shift amount is impossible. - if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt || + if ((ShiftAmtZeroMask & ShiftAmt) != 0 || (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) continue; auto Res = ShiftByConst(LHS, ShiftAmt); @@ -252,10 +240,8 @@ } // All shift amounts may result in poison. - if (Known.hasConflict()) { - assert((NUW || NSW) && "Can only happen with nowrap flags"); + if (Known.hasConflict()) Known.setAllZero(); - } return Known; }