Index: llvm/lib/Support/KnownBits.cpp =================================================================== --- llvm/lib/Support/KnownBits.cpp +++ llvm/lib/Support/KnownBits.cpp @@ -182,24 +182,26 @@ // No matter the shift amount, the trailing zeros will stay zero. unsigned MinTrailingZeros = LHS.countMinTrailingZeros(); - // Minimum shift amount low bits are known zero. APInt MinShiftAmount = RHS.getMinValue(); - if (MinShiftAmount.ult(BitWidth)) { - MinTrailingZeros += MinShiftAmount.getZExtValue(); - MinTrailingZeros = std::min(MinTrailingZeros, BitWidth); - } + if (MinShiftAmount.uge(BitWidth)) + // Always poison. Return unknown because we don't like returning conflict. + return Known; + + // Minimum shift amount low bits are known zero. + MinTrailingZeros += MinShiftAmount.getZExtValue(); + MinTrailingZeros = std::min(MinTrailingZeros, BitWidth); // If the maximum shift is in range, then find the common bits from all // possible shifts. APInt MaxShiftAmount = RHS.getMaxValue(); - if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) { + if (!LHS.isUnknown()) { uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue(); uint64_t ShiftAmtOneMask = RHS.One.getZExtValue(); assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range"); Known.Zero.setAllBits(); Known.One.setAllBits(); for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(), - MaxShiftAmt = MaxShiftAmount.getZExtValue(); + MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1); ShiftAmt <= MaxShiftAmt; ++ShiftAmt) { // Skip if the shift amount is impossible. if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt || @@ -207,6 +209,7 @@ continue; KnownBits SpecificShift; SpecificShift.Zero = LHS.Zero << ShiftAmt; + SpecificShift.Zero.setLowBits(ShiftAmt); SpecificShift.One = LHS.One << ShiftAmt; Known = KnownBits::commonBits(Known, SpecificShift); if (Known.isUnknown()) @@ -237,22 +240,24 @@ // Minimum shift amount high bits are known zero. APInt MinShiftAmount = RHS.getMinValue(); - if (MinShiftAmount.ult(BitWidth)) { - MinLeadingZeros += MinShiftAmount.getZExtValue(); - MinLeadingZeros = std::min(MinLeadingZeros, BitWidth); - } + if (MinShiftAmount.uge(BitWidth)) + // Always poison. Return unknown because we don't like returning conflict. + return Known; + + MinLeadingZeros += MinShiftAmount.getZExtValue(); + MinLeadingZeros = std::min(MinLeadingZeros, BitWidth); // If the maximum shift is in range, then find the common bits from all // possible shifts. APInt MaxShiftAmount = RHS.getMaxValue(); - if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) { + if (!LHS.isUnknown()) { uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue(); uint64_t ShiftAmtOneMask = RHS.One.getZExtValue(); assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range"); Known.Zero.setAllBits(); Known.One.setAllBits(); for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(), - MaxShiftAmt = MaxShiftAmount.getZExtValue(); + MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1); ShiftAmt <= MaxShiftAmt; ++ShiftAmt) { // Skip if the shift amount is impossible. if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt || @@ -260,6 +265,7 @@ continue; KnownBits SpecificShift = LHS; SpecificShift.Zero.lshrInPlace(ShiftAmt); + SpecificShift.Zero.setHighBits(ShiftAmt); SpecificShift.One.lshrInPlace(ShiftAmt); Known = KnownBits::commonBits(Known, SpecificShift); if (Known.isUnknown()) @@ -289,28 +295,30 @@ // Minimum shift amount high bits are known sign bits. APInt MinShiftAmount = RHS.getMinValue(); - if (MinShiftAmount.ult(BitWidth)) { - if (MinLeadingZeros) { - MinLeadingZeros += MinShiftAmount.getZExtValue(); - MinLeadingZeros = std::min(MinLeadingZeros, BitWidth); - } - if (MinLeadingOnes) { - MinLeadingOnes += MinShiftAmount.getZExtValue(); - MinLeadingOnes = std::min(MinLeadingOnes, BitWidth); - } + if (MinShiftAmount.uge(BitWidth)) + // Always poison. Return unknown because we don't like returning conflict. + return Known; + + if (MinLeadingZeros) { + MinLeadingZeros += MinShiftAmount.getZExtValue(); + MinLeadingZeros = std::min(MinLeadingZeros, BitWidth); + } + if (MinLeadingOnes) { + MinLeadingOnes += MinShiftAmount.getZExtValue(); + MinLeadingOnes = std::min(MinLeadingOnes, BitWidth); } // If the maximum shift is in range, then find the common bits from all // possible shifts. APInt MaxShiftAmount = RHS.getMaxValue(); - if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) { + if (!LHS.isUnknown()) { uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue(); uint64_t ShiftAmtOneMask = RHS.One.getZExtValue(); assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range"); Known.Zero.setAllBits(); Known.One.setAllBits(); for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(), - MaxShiftAmt = MaxShiftAmount.getZExtValue(); + MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1); ShiftAmt <= MaxShiftAmt; ++ShiftAmt) { // Skip if the shift amount is impossible. if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt || Index: llvm/test/CodeGen/AMDGPU/amdgpu.private-memory.ll =================================================================== --- llvm/test/CodeGen/AMDGPU/amdgpu.private-memory.ll +++ llvm/test/CodeGen/AMDGPU/amdgpu.private-memory.ll @@ -221,7 +221,7 @@ ; SI-PROMOTE-VECT: s_load_dword [[IDX:s[0-9]+]] ; SI-PROMOTE-VECT: s_lshl_b32 [[SCALED_IDX:s[0-9]+]], [[IDX]], 4 ; SI-PROMOTE-VECT: s_lshr_b32 [[SREG:s[0-9]+]], 0x10000, [[SCALED_IDX]] -; SI-PROMOTE-VECT: s_and_b32 s{{[0-9]+}}, [[SREG]], 0xffff +; SI-PROMOTE-VECT: s_and_b32 s{{[0-9]+}}, [[SREG]], 1 define amdgpu_kernel void @short_array(ptr addrspace(1) %out, i32 %index) #0 { entry: %0 = alloca [2 x i16], addrspace(5) Index: llvm/test/Transforms/InstCombine/not-add.ll =================================================================== --- llvm/test/Transforms/InstCombine/not-add.ll +++ llvm/test/Transforms/InstCombine/not-add.ll @@ -172,7 +172,7 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], 1 ; CHECK-NEXT: [[B15:%.*]] = srem i32 ashr (i32 65536, i32 or (i32 zext (i1 icmp eq (ptr @g, ptr null) to i32), i32 65537)), [[XOR]] -; CHECK-NEXT: [[B12:%.*]] = add nuw nsw i32 [[B15]], ashr (i32 65536, i32 or (i32 zext (i1 icmp eq (ptr @g, ptr null) to i32), i32 65537)) +; CHECK-NEXT: [[B12:%.*]] = add nsw i32 [[B15]], ashr (i32 65536, i32 or (i32 zext (i1 icmp eq (ptr @g, ptr null) to i32), i32 65537)) ; CHECK-NEXT: [[B:%.*]] = xor i32 [[B12]], -1 ; CHECK-NEXT: store i32 [[B]], ptr undef, align 4 ; CHECK-NEXT: ret void Index: llvm/unittests/Support/KnownBitsTest.cpp =================================================================== --- llvm/unittests/Support/KnownBitsTest.cpp +++ llvm/unittests/Support/KnownBitsTest.cpp @@ -270,7 +270,6 @@ }, checkCorrectnessOnlyBinary); - // TODO: Make optimal for non-constant cases. testBinaryOpExhaustive( [](const KnownBits &Known1, const KnownBits &Known2) { return KnownBits::shl(Known1, Known2); @@ -279,9 +278,6 @@ if (N2.uge(N2.getBitWidth())) return std::nullopt; return N1.shl(N2); - }, - [](const KnownBits &, const KnownBits &Known) { - return Known.isConstant(); }); testBinaryOpExhaustive( [](const KnownBits &Known1, const KnownBits &Known2) { @@ -291,9 +287,6 @@ if (N2.uge(N2.getBitWidth())) return std::nullopt; return N1.lshr(N2); - }, - [](const KnownBits &, const KnownBits &Known) { - return Known.isConstant(); }); testBinaryOpExhaustive( [](const KnownBits &Known1, const KnownBits &Known2) { @@ -303,9 +296,6 @@ if (N2.uge(N2.getBitWidth())) return std::nullopt; return N1.ashr(N2); - }, - [](const KnownBits &, const KnownBits &Known) { - return Known.isConstant(); }); testBinaryOpExhaustive(