diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -382,7 +382,8 @@ /// Compute known bits for shl(LHS, RHS). /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS. - static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS); + static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, + bool NUW = false, bool NSW = false); /// Compute known bits for lshr(LHS, RHS). /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS. diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -1353,20 +1353,10 @@ break; } case Instruction::Shl: { + bool NUW = Q.IIQ.hasNoUnsignedWrap(cast(I)); bool NSW = Q.IIQ.hasNoSignedWrap(cast(I)); - auto KF = [NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt) { - KnownBits Result = KnownBits::shl(KnownVal, KnownAmt); - // If this shift has "nsw" keyword, then the result is either a poison - // value or has the same sign bit as the first operand. - if (NSW) { - if (KnownVal.Zero.isSignBitSet()) - Result.Zero.setSignBit(); - if (KnownVal.One.isSignBitSet()) - Result.One.setSignBit(); - if (Result.hasConflict()) - Result.setAllZero(); - } - return Result; + auto KF = [NUW, NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt) { + return KnownBits::shl(KnownVal, KnownAmt, NUW, NSW); }; computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q, KF); diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -164,21 +164,51 @@ return Flip(umax(Flip(LHS), Flip(RHS))); } -KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) { +KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW, + bool NSW) { unsigned BitWidth = LHS.getBitWidth(); - KnownBits Known(BitWidth); + auto ShiftByConst = [&](const KnownBits &LHS, + uint64_t ShiftAmt) -> std::optional { + KnownBits Known; + Known.Zero = LHS.Zero << ShiftAmt; + Known.Zero.setLowBits(ShiftAmt); + Known.One = LHS.One << ShiftAmt; + if ((!NUW && !NSW) || ShiftAmt == 0) + return Known; + + KnownBits ShiftedOutBits = LHS.extractBits(ShiftAmt, BitWidth - ShiftAmt); + if (NUW && !ShiftedOutBits.One.isZero()) + // One bit has been shifted out. + return std::nullopt; + if (NSW) { + if (!ShiftedOutBits.Zero.isZero() && !ShiftedOutBits.One.isZero()) + // Both zeros and ones have been shifted out. + return std::nullopt; + if (NUW || !ShiftedOutBits.Zero.isZero()) { + if (Known.isNegative()) + // Zero bit has been shifted out, but result sign is negative. + return std::nullopt; + Known.makeNonNegative(); + } else if (!ShiftedOutBits.One.isZero()) { + if (Known.isNonNegative()) + // One bit has been shifted out, but result sign is negative. + return std::nullopt; + Known.makeNegative(); + } + } + return Known; + }; // If the shift amount is a valid constant then transform LHS directly. if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) { - unsigned Shift = RHS.getConstant().getZExtValue(); - Known = LHS; - Known.Zero <<= Shift; - Known.One <<= Shift; - // Low bits are known zero. - Known.Zero.setLowBits(Shift); + if (auto Res = ShiftByConst(LHS, RHS.getConstant().getZExtValue())) + return *Res; + KnownBits Known(BitWidth); + Known.setAllZero(); return Known; } + KnownBits Known(BitWidth); APInt MinShiftAmount = RHS.getMinValue(); if (MinShiftAmount.uge(BitWidth)) { // Always poison. Return zero because we don't like returning conflict. @@ -193,6 +223,8 @@ MinTrailingZeros += MinShiftAmount.getZExtValue(); MinTrailingZeros = std::min(MinTrailingZeros, BitWidth); Known.Zero.setLowBits(MinTrailingZeros); + if (NUW && NSW && !MinShiftAmount.isZero()) + Known.makeNonNegative(); return Known; } @@ -210,15 +242,20 @@ if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt || (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) continue; - KnownBits SpecificShift; - SpecificShift.Zero = LHS.Zero << ShiftAmt; - SpecificShift.Zero.setLowBits(ShiftAmt); - SpecificShift.One = LHS.One << ShiftAmt; - Known = Known.intersectWith(SpecificShift); + auto Res = ShiftByConst(LHS, ShiftAmt); + if (!Res) + // All larger shift amounts will overflow as well. + break; + Known = Known.intersectWith(*Res); if (Known.isUnknown()) break; } + // All shift amounts may result in poison. + if (Known.hasConflict()) { + assert((NUW || NSW) && "Can only happen with nowrap flags"); + Known.setAllZero(); + } return Known; } diff --git a/llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll b/llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll --- a/llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll +++ b/llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll @@ -513,7 +513,7 @@ ; CHECK-NEXT: [[IDX_SHL_1:%.*]] = shl nuw nsw i8 [[IDX]], 1 ; CHECK-NEXT: [[ADD_PTR_SHL_1:%.*]] = getelementptr inbounds i32, ptr [[SRC]], i8 [[IDX_SHL_1]] ; CHECK-NEXT: [[C_MAX_0:%.*]] = icmp ult ptr [[ADD_PTR_SHL_1]], [[MAX]] -; CHECK-NEXT: call void @use(i1 [[C_MAX_0]]) +; CHECK-NEXT: call void @use(i1 true) ; CHECK-NEXT: [[IDX_SHL_2:%.*]] = shl nuw i8 [[IDX]], 2 ; CHECK-NEXT: [[ADD_PTR_SHL_2:%.*]] = getelementptr inbounds i32, ptr [[SRC]], i8 [[IDX_SHL_2]] ; CHECK-NEXT: [[C_MAX_1:%.*]] = icmp ult ptr [[ADD_PTR_SHL_2]], [[MAX]] diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll --- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll @@ -82,7 +82,7 @@ ; CHECK-NEXT: [[ARRAYIDX3:%.*]] = getelementptr inbounds [1024 x i32], ptr @CD, i64 0, i64 [[OR]] ; CHECK-NEXT: store i32 [[MUL]], ptr [[ARRAYIDX3]], align 4 ; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 2 -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[INDVARS_IV]], 1022 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[INDVARS_IV]], 1022 ; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END]], !llvm.loop [[LOOP3:![0-9]+]] ; CHECK: for.end: ; CHECK-NEXT: ret void diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -343,6 +343,41 @@ return std::nullopt; return N1.shl(N2); }); + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::shl(Known1, Known2, /* NUW */ true); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + bool Overflow; + APInt Res = N1.ushl_ov(N2, Overflow); + if (Overflow) + return std::nullopt; + return Res; + }); + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::shl(Known1, Known2, /* NUW */ false, /* NSW */ true); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + bool Overflow; + APInt Res = N1.sshl_ov(N2, Overflow); + if (Overflow) + return std::nullopt; + return Res; + }); + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::shl(Known1, Known2, /* NUW */ true, /* NSW */ true); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + bool OverflowUnsigned, OverflowSigned; + APInt Res = N1.ushl_ov(N2, OverflowUnsigned); + (void)N1.sshl_ov(N2, OverflowSigned); + if (OverflowUnsigned || OverflowSigned) + return std::nullopt; + return Res; + }); + testBinaryOpExhaustive( [](const KnownBits &Known1, const KnownBits &Known2) { return KnownBits::lshr(Known1, Known2);