Index: llvm/include/llvm/Support/KnownBits.h =================================================================== --- llvm/include/llvm/Support/KnownBits.h +++ llvm/include/llvm/Support/KnownBits.h @@ -370,7 +370,8 @@ /// Compute known bits for shl(LHS, RHS). /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS. - static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS); + static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, + bool NUW = false, bool NSW = false); /// Compute known bits for lshr(LHS, RHS). /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS. Index: llvm/lib/Analysis/ValueTracking.cpp =================================================================== --- llvm/lib/Analysis/ValueTracking.cpp +++ llvm/lib/Analysis/ValueTracking.cpp @@ -1348,18 +1348,10 @@ break; } case Instruction::Shl: { + bool NUW = Q.IIQ.hasNoUnsignedWrap(cast(I)); bool NSW = Q.IIQ.hasNoSignedWrap(cast(I)); - auto KF = [NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt) { - KnownBits Result = KnownBits::shl(KnownVal, KnownAmt); - // If this shift has "nsw" keyword, then the result is either a poison - // value or has the same sign bit as the first operand. - if (NSW) { - if (KnownVal.Zero.isSignBitSet()) - Result.Zero.setSignBit(); - if (KnownVal.One.isSignBitSet()) - Result.One.setSignBit(); - } - return Result; + auto KF = [NUW, NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt) { + return KnownBits::shl(KnownVal, KnownAmt, NUW, NSW); }; computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q, KF); Index: llvm/lib/Support/KnownBits.cpp =================================================================== --- llvm/lib/Support/KnownBits.cpp +++ llvm/lib/Support/KnownBits.cpp @@ -164,22 +164,51 @@ return Flip(umax(Flip(LHS), Flip(RHS))); } -KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) { +KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW, + bool NSW) { unsigned BitWidth = LHS.getBitWidth(); - KnownBits Known(BitWidth); + auto ShiftByConst = [&](const KnownBits &LHS, + uint64_t ShiftAmt) -> std::optional { + KnownBits Known; + Known.Zero = LHS.Zero << ShiftAmt; + Known.Zero.setLowBits(ShiftAmt); + Known.One = LHS.One << ShiftAmt; + if ((NUW || NSW) && ShiftAmt != 0) { + KnownBits ShiftedOutBits = LHS.extractBits(ShiftAmt, BitWidth - ShiftAmt); + if (NUW && !ShiftedOutBits.One.isZero()) + // One bit has been shifted out. + return std::nullopt; + if (NSW) { + if (!ShiftedOutBits.Zero.isZero() && !ShiftedOutBits.One.isZero()) + // Both zeros and ones have been shifted out. + return std::nullopt; + if (NUW || !ShiftedOutBits.Zero.isZero()) { + if (Known.isNegative()) + // Zero bit has been shifted out, but result sign is negative. + return std::nullopt; + Known.makeNonNegative(); + } else if (!ShiftedOutBits.One.isZero()) { + if (Known.isNonNegative()) + // One bit has been shifted out, but result sign is negative. + return std::nullopt; + Known.makeNegative(); + } + } + } + return Known; + }; // If the shift amount is a valid constant then transform LHS directly. if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) { - unsigned Shift = RHS.getConstant().getZExtValue(); - Known = LHS; - Known.Zero <<= Shift; - Known.One <<= Shift; - // Low bits are known zero. - Known.Zero.setLowBits(Shift); + if (auto Res = ShiftByConst(LHS, RHS.getConstant().getZExtValue())) + return *Res; + KnownBits Known(BitWidth); + Known.setAllZero(); return Known; } // No matter the shift amount, the trailing zeros will stay zero. + KnownBits Known(BitWidth); unsigned MinTrailingZeros = LHS.countMinTrailingZeros(); APInt MinShiftAmount = RHS.getMinValue(); @@ -196,7 +225,7 @@ // If the maximum shift is in range, then find the common bits from all // possible shifts. APInt MaxShiftAmount = RHS.getMaxValue(); - if (!LHS.isUnknown()) { + if (!LHS.isUnknown() || (NUW && NSW)) { uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue(); uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue(); assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range"); @@ -209,11 +238,11 @@ if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt || (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) continue; - KnownBits SpecificShift; - SpecificShift.Zero = LHS.Zero << ShiftAmt; - SpecificShift.Zero.setLowBits(ShiftAmt); - SpecificShift.One = LHS.One << ShiftAmt; - Known = Known.intersectWith(SpecificShift); + if (auto Res = ShiftByConst(LHS, ShiftAmt)) + Known = Known.intersectWith(*Res); + else + // All larger shift amounts will overflow as well. + break; if (Known.isUnknown()) break; } Index: llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll =================================================================== --- llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll +++ llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll @@ -513,7 +513,7 @@ ; CHECK-NEXT: [[IDX_SHL_1:%.*]] = shl nuw nsw i8 [[IDX]], 1 ; CHECK-NEXT: [[ADD_PTR_SHL_1:%.*]] = getelementptr inbounds i32, ptr [[SRC]], i8 [[IDX_SHL_1]] ; CHECK-NEXT: [[C_MAX_0:%.*]] = icmp ult ptr [[ADD_PTR_SHL_1]], [[MAX]] -; CHECK-NEXT: call void @use(i1 [[C_MAX_0]]) +; CHECK-NEXT: call void @use(i1 true) ; CHECK-NEXT: [[IDX_SHL_2:%.*]] = shl nuw i8 [[IDX]], 2 ; CHECK-NEXT: [[ADD_PTR_SHL_2:%.*]] = getelementptr inbounds i32, ptr [[SRC]], i8 [[IDX_SHL_2]] ; CHECK-NEXT: [[C_MAX_1:%.*]] = icmp ult ptr [[ADD_PTR_SHL_2]], [[MAX]] Index: llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll =================================================================== --- llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll +++ llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll @@ -82,7 +82,7 @@ ; CHECK-NEXT: [[ARRAYIDX3:%.*]] = getelementptr inbounds [1024 x i32], ptr @CD, i64 0, i64 [[OR]] ; CHECK-NEXT: store i32 [[MUL]], ptr [[ARRAYIDX3]], align 4 ; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 2 -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[INDVARS_IV]], 1022 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[INDVARS_IV]], 1022 ; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END]], !llvm.loop [[LOOP3:![0-9]+]] ; CHECK: for.end: ; CHECK-NEXT: ret void Index: llvm/unittests/Support/KnownBitsTest.cpp =================================================================== --- llvm/unittests/Support/KnownBitsTest.cpp +++ llvm/unittests/Support/KnownBitsTest.cpp @@ -310,6 +310,41 @@ return std::nullopt; return N1.shl(N2); }); + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::shl(Known1, Known2, /* NUW */ true); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + bool Overflow; + APInt Res = N1.ushl_ov(N2, Overflow); + if (Overflow) + return std::nullopt; + return Res; + }); + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::shl(Known1, Known2, /* NUW */ false, /* NSW */ true); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + bool Overflow; + APInt Res = N1.sshl_ov(N2, Overflow); + if (Overflow) + return std::nullopt; + return Res; + }); + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::shl(Known1, Known2, /* NUW */ true, /* NSW */ true); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + bool OverflowUnsigned, OverflowSigned; + APInt Res = N1.ushl_ov(N2, OverflowUnsigned); + (void)N1.sshl_ov(N2, OverflowSigned); + if (OverflowUnsigned || OverflowSigned) + return std::nullopt; + return Res; + }); + testBinaryOpExhaustive( [](const KnownBits &Known1, const KnownBits &Known2) { return KnownBits::lshr(Known1, Known2);