diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h --- a/llvm/include/llvm/IR/ConstantRange.h +++ b/llvm/include/llvm/IR/ConstantRange.h @@ -426,6 +426,14 @@ /// TODO: This isn't fully implemented yet. ConstantRange shl(const ConstantRange &Other) const; + /// Return a new range representing the possible values resulting + /// from a left shift with wrap type \p NoWrapKind of a value in this + /// range and a value in \p Other. + /// If the result range is disjoint, the preferred range is determined by the + /// \p PreferredRangeType. + ConstantRange shlWithNoWrap(const ConstantRange &Other, unsigned NoWrapKind, + PreferredRangeType RangeType = Smallest) const; + /// Return a new range representing the possible values resulting from a /// logical right shift of a value in this range and a value in \p Other. ConstantRange lshr(const ConstantRange &Other) const; diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -1313,6 +1313,61 @@ return ConstantRange(std::move(min), std::move(max) + 1); } +ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other, + unsigned NoWrapKind, + PreferredRangeType RangeType) const { + unsigned BitWidth = Other.getBitWidth(); + + // For given range of shift amounts, if we ignore all illegal shift amounts + // (that always produce poison), what shift amount range is left? + ConstantRange ShAmt = Other.intersectWith( + ConstantRange(APInt(BitWidth, 0), APInt(BitWidth, (BitWidth - 1) + 1))); + + // Calculate the range for "X << Y" which is guaranteed not to wrap(overflow). + // (X is from this, and Y is from ShAmt) + if (isEmptySet() || ShAmt.isEmptySet()) + return getEmpty(); + if (isFullSet() && ShAmt.isFullSet()) + return getFull(); + + using OBO = OverflowingBinaryOperator; + ConstantRange Result = shl(ShAmt); + + if (NoWrapKind & OBO::NoUnsignedWrap) { + bool Overflow; + + (void)getUnsignedMin().ushl_ov(ShAmt.getUnsignedMin(), Overflow); + if (Overflow) + return getEmpty(); // Always overflows. + + if (NoWrapKind & OBO::NoSignedWrap) { + (void)getUnsignedMin().sshl_ov(ShAmt.getUnsignedMin(), Overflow); + if (Overflow) + return getEmpty(); // Always overflows. + } + + Result = Result.intersectWith(ushl_sat(ShAmt), RangeType); + } + + if (NoWrapKind & OBO::NoSignedWrap) { + bool AllOverflow = true; + SmallVector ClosestToZero = getClosestToZero(); + for (const APInt &L : ClosestToZero) { + bool Overflow; + (void)L.sshl_ov(ShAmt.getUnsignedMin(), Overflow); + AllOverflow &= Overflow; + if (!AllOverflow) + break; + } + if (AllOverflow) + return getEmpty(); // Always overflows. + + Result = Result.intersectWith(sshl_sat(ShAmt), RangeType); + } + + return Result; +} + ConstantRange ConstantRange::lshr(const ConstantRange &Other) const { if (isEmptySet() || Other.isEmptySet()) diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -1345,6 +1345,40 @@ EXPECT_EQ(One.shl(WrapNullMax), Full); } +TEST_F(ConstantRangeTest, ShlWithNoWrap) { + typedef OverflowingBinaryOperator OBO; + + TestAddWithNoSignedWrapExhaustive( + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap); + }, + [](bool &IsOverflow, const APInt &N1, const APInt &N2) { + return N1.sshl_ov(N2, IsOverflow); + }, + /*CorrectnessOnly=*/true); + + TestAddWithNoUnsignedWrapExhaustive( + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap); + }, + [](bool &IsOverflow, const APInt &N1, const APInt &N2) { + return N1.ushl_ov(N2, IsOverflow); + }, + /*CorrectnessOnly=*/true); + + TestAddWithNoSignedUnsignedWrapExhaustive( + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap | OBO::NoSignedWrap); + }, + [](bool &IsOverflow, const APInt &N1, const APInt &N2) { + return N1.sshl_ov(N2, IsOverflow); + }, + [](bool &IsOverflow, const APInt &N1, const APInt &N2) { + return N1.ushl_ov(N2, IsOverflow); + }, + /*CorrectnessOnly=*/true); +} + TEST_F(ConstantRangeTest, Lshr) { EXPECT_EQ(Full.lshr(Full), Full); EXPECT_EQ(Full.lshr(Empty), Empty);