Index: llvm/lib/Support/KnownBits.cpp =================================================================== --- llvm/lib/Support/KnownBits.cpp +++ llvm/lib/Support/KnownBits.cpp @@ -164,41 +164,33 @@ return Flip(umax(Flip(LHS), Flip(RHS))); } +static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) { + if (isPowerOf2_32(BitWidth)) + return MaxValue.extractBitsAsZExtValue(Log2_32(BitWidth), 0); + // This is only an approximate upper bound. + return MaxValue.getLimitedValue(BitWidth - 1); +} + KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW, bool NSW) { unsigned BitWidth = LHS.getBitWidth(); - auto ShiftByConst = [&](const KnownBits &LHS, - unsigned ShiftAmt) -> std::optional { + auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { KnownBits Known; bool ShiftedOutZero, ShiftedOutOne; Known.Zero = LHS.Zero.ushl_ov(ShiftAmt, ShiftedOutZero); Known.Zero.setLowBits(ShiftAmt); Known.One = LHS.One.ushl_ov(ShiftAmt, ShiftedOutOne); - if (NUW) { - if (ShiftedOutOne) - // One bit has been shifted out. - return std::nullopt; - if (ShiftAmt != 0) + // All cases returning poison have been handled by MaxShiftAmount already. + if (NSW) { + if (NUW && ShiftAmt != 0) // NUW means we can assume anything shifted out was a zero. ShiftedOutZero = true; - } - if (NSW) { - if (ShiftedOutZero && ShiftedOutOne) - // Both zeros and ones have been shifted out. - return std::nullopt; - if (ShiftedOutZero) { - if (Known.isNegative()) - // Zero bit has been shifted out, but result sign is negative. - return std::nullopt; + if (ShiftedOutZero) Known.makeNonNegative(); - } else if (ShiftedOutOne) { - if (Known.isNonNegative()) - // One bit has been shifted out, but result sign is non-negative. - return std::nullopt; + else if (ShiftedOutOne) Known.makeNegative(); - } } return Known; }; @@ -218,8 +210,34 @@ return Known; } + // Determine maximum shift amount, taking NUW/NSW flags into account. + APInt MaxValue = RHS.getMaxValue(); + unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); + if (NUW && NSW) + MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros() - 1); + if (NUW) + MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros()); + if (NSW) + MaxShiftAmount = std::min( + MaxShiftAmount, + std::max(LHS.countMaxLeadingZeros(), LHS.countMaxLeadingOnes()) - 1); + + // Fast path for common case where the shift amount is unknown. + if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 && + isPowerOf2_32(BitWidth)) { + Known.Zero.setLowBits(LHS.countMinTrailingZeros()); + if (LHS.isAllOnes()) + Known.One.setSignBit(); + if (NSW) { + if (LHS.isNonNegative()) + Known.makeNonNegative(); + if (LHS.isNegative()) + Known.makeNegative(); + } + return Known; + } + // Find the common bits from all possible shifts. - unsigned MaxShiftAmount = RHS.getMaxValue().getLimitedValue(BitWidth - 1); unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue(); unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue(); Known.Zero.setAllBits(); @@ -230,11 +248,7 @@ if ((ShiftAmtZeroMask & ShiftAmt) != 0 || (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) continue; - auto Res = ShiftByConst(LHS, ShiftAmt); - if (!Res) - // All larger shift amounts will overflow as well. - break; - Known = Known.intersectWith(*Res); + Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt)); if (Known.isUnknown()) break; }