diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -1228,7 +1228,8 @@ /// Given operands for an Shl, LShr or AShr, see if we can fold the result. /// If not, this returns null. static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, - Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { + Value *Op1, bool isNSW, const SimplifyQuery &Q, + unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) return C; @@ -1262,14 +1263,29 @@ // If any bits in the shift amount make that value greater than or equal to // the number of bits in the type, the shift is undefined. - KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (Known.getMinValue().uge(Known.getBitWidth())) + KnownBits KnownAmt = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (KnownAmt.getMinValue().uge(KnownAmt.getBitWidth())) return PoisonValue::get(Op0->getType()); + // Check for nsw shl leading to a poison value. + if (isNSW) { + assert(Opcode == Instruction::Shl && "Expected shl for nsw instruction"); + KnownBits KnownVal = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits KnownShl = KnownBits::shl(KnownVal, KnownAmt); + + if (KnownVal.Zero.isSignBitSet()) + KnownShl.Zero.setSignBit(); + if (KnownVal.One.isSignBitSet()) + KnownShl.One.setSignBit(); + + if (KnownShl.hasConflict()) + return PoisonValue::get(Op0->getType()); + } + // If all valid bits in the shift amount are known zero, the first operand is // unchanged. - unsigned NumValidShiftBits = Log2_32_Ceil(Known.getBitWidth()); - if (Known.countMinTrailingZeros() >= NumValidShiftBits) + unsigned NumValidShiftBits = Log2_32_Ceil(KnownAmt.getBitWidth()); + if (KnownAmt.countMinTrailingZeros() >= NumValidShiftBits) return Op0; return nullptr; @@ -1280,7 +1296,7 @@ static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, bool isExact, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse)) + if (Value *V = SimplifyShift(Opcode, Op0, Op1, false, Q, MaxRecurse)) return V; // X >> X -> 0 @@ -1306,7 +1322,7 @@ /// If not, this returns null. static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Value *V = SimplifyShift(Instruction::Shl, Op0, Op1, Q, MaxRecurse)) + if (Value *V = SimplifyShift(Instruction::Shl, Op0, Op1, isNSW, Q, MaxRecurse)) return V; // undef << X -> 0 diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -994,30 +994,26 @@ bool ShiftAmtIsConstant = Known.isConstant(); bool MaxShiftAmtIsOutOfRange = Known.getMaxValue().uge(BitWidth); - if (ShiftAmtIsConstant) { - Known = KF(Known2, Known); - - // If the known bits conflict, this must be an overflowing left shift, so - // the shift result is poison. We can return anything we want. Choose 0 for - // the best folding opportunity. - if (Known.hasConflict()) - Known.setAllZero(); + // Use the KF callback to get an initial knownbits approximation. + Known = KF(Known2, Known); + // If the known bits conflict, this must be an overflowing left shift, so + // the shift result is poison. + if (Known.hasConflict()) { + Known.resetAll(); return; } + // Constant shift amount - we're not going to improve on this. + if (ShiftAmtIsConstant) + return; + // If the shift amount could be greater than or equal to the bit-width of the // LHS, the value could be poison, but bail out because the check below is // expensive. // TODO: Should we just carry on? - if (MaxShiftAmtIsOutOfRange) { - Known.resetAll(); + if (MaxShiftAmtIsOutOfRange) return; - } - - // It would be more-clearly correct to use the two temporaries for this - // calculation. Reusing the APInts here to prevent unnecessary allocations. - Known.resetAll(); // If we know the shifter operand is nonzero, we can sometimes infer more // known bits. However this is expensive to compute, so be lazy about it and @@ -1057,10 +1053,9 @@ Known, KF(Known2, KnownBits::makeConstant(APInt(32, ShiftAmt)))); } - // If the known bits conflict, the result is poison. Return a 0 and hope the - // caller can further optimize that. + // If the known bits conflict, the result is poison. if (Known.hasConflict()) - Known.setAllZero(); + Known.resetAll(); } static void computeKnownBitsFromOperator(const Operator *I, @@ -1232,10 +1227,6 @@ }; computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q, KF); - // Trailing zeros of a right-shifted constant never decrease. - const APInt *C; - if (match(I->getOperand(0), m_APInt(C))) - Known.Zero.setLowBits(C->countTrailingZeros()); break; } case Instruction::LShr: { @@ -1244,10 +1235,6 @@ }; computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q, KF); - // Leading zeros of a left-shifted constant never decrease. - const APInt *C; - if (match(I->getOperand(0), m_APInt(C))) - Known.Zero.setHighBits(C->countLeadingZeros()); break; } case Instruction::AShr: { diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -187,6 +187,23 @@ MinTrailingZeros = std::min(MinTrailingZeros, BitWidth); } + // If the maximum shift is in range, then find the common bits from all + // possible shifts. + APInt MaxShiftAmount = RHS.getMaxValue(); + if (MaxShiftAmount.ult(BitWidth)) { + assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range"); + Known.Zero.setAllBits(); + Known.One.setAllBits(); + for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(), + MaxShiftAmt = MaxShiftAmount.getZExtValue(); + ShiftAmt <= MaxShiftAmt; ++ShiftAmt) { + KnownBits SpecificShift = LHS; + SpecificShift.Zero <<= ShiftAmt; + SpecificShift.One <<= ShiftAmt; + Known = KnownBits::commonBits(Known, SpecificShift); + } + } + Known.Zero.setLowBits(MinTrailingZeros); return Known; } diff --git a/llvm/test/Transforms/InstCombine/known-signbit-shift.ll b/llvm/test/Transforms/InstCombine/known-signbit-shift.ll --- a/llvm/test/Transforms/InstCombine/known-signbit-shift.ll +++ b/llvm/test/Transforms/InstCombine/known-signbit-shift.ll @@ -31,7 +31,7 @@ ; This test should not crash opt. The shift produces poison. define i32 @test_no_sign_bit_conflict1(i1 %b) { ; CHECK-LABEL: @test_no_sign_bit_conflict1( -; CHECK-NEXT: ret i32 undef +; CHECK-NEXT: ret i32 poison ; %sel = select i1 %b, i32 8193, i32 8192 %mul = shl nsw i32 %sel, 18 @@ -42,7 +42,7 @@ ; This test should not crash opt. The shift produces poison. define i32 @test_no_sign_bit_conflict2(i1 %b) { ; CHECK-LABEL: @test_no_sign_bit_conflict2( -; CHECK-NEXT: ret i32 undef +; CHECK-NEXT: ret i32 poison ; %sel = select i1 %b, i32 -8193, i32 -8194 %mul = shl nsw i32 %sel, 18 diff --git a/llvm/test/Transforms/InstSimplify/icmp-constant.ll b/llvm/test/Transforms/InstSimplify/icmp-constant.ll --- a/llvm/test/Transforms/InstSimplify/icmp-constant.ll +++ b/llvm/test/Transforms/InstSimplify/icmp-constant.ll @@ -772,7 +772,7 @@ define i1 @ne_shl_by_constant_produces_poison(i8 %x) { ; CHECK-LABEL: @ne_shl_by_constant_produces_poison( -; CHECK-NEXT: ret i1 true +; CHECK-NEXT: ret i1 poison ; %zx = zext i8 %x to i16 ; zx = 0x00xx %xor = xor i16 %zx, 32767 ; xor = 0x7fyy @@ -784,7 +784,7 @@ define i1 @eq_shl_by_constant_produces_poison(i8 %x) { ; CHECK-LABEL: @eq_shl_by_constant_produces_poison( -; CHECK-NEXT: ret i1 false +; CHECK-NEXT: ret i1 poison ; %clear_high_bit = and i8 %x, 127 ; 0x7f %set_next_high_bits = or i8 %clear_high_bit, 112 ; 0x70 @@ -799,7 +799,7 @@ define i1 @eq_shl_by_variable_produces_poison(i8 %x) { ; CHECK-LABEL: @eq_shl_by_variable_produces_poison( -; CHECK-NEXT: ret i1 false +; CHECK-NEXT: ret i1 poison ; %clear_high_bit = and i8 %x, 127 ; 0x7f %set_next_high_bits = or i8 %clear_high_bit, 112 ; 0x70