diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h --- a/llvm/include/llvm/IR/ConstantRange.h +++ b/llvm/include/llvm/IR/ConstantRange.h @@ -245,6 +245,10 @@ /// Return the smallest signed value contained in the ConstantRange. APInt getSignedMin() const; + /// Return the value contained in the ConstantRange with smallest absolute + /// value. + SmallVector getClosestToZero() const; + /// Return true if this range is equal to another range. bool operator==(const ConstantRange &CR) const { return Lower == CR.Lower && Upper == CR.Upper; @@ -363,6 +367,14 @@ /// treating both this and \p Other as unsigned ranges. ConstantRange multiply(const ConstantRange &Other) const; + /// Return a new range representing the possible values resulting + /// from an multiplication with wrap type \p NoWrapKind of a value in this + /// range and a value in \p Other. + /// If the result range is disjoint, the preferred range is determined by the + /// \p PreferredRangeType. + ConstantRange mulWithNoWrap(const ConstantRange &Other, unsigned NoWrapKind, + PreferredRangeType RangeType = Smallest) const; + /// Return a new range representing the possible values resulting /// from a signed maximum of a value in this range and a value in \p Other. ConstantRange smax(const ConstantRange &Other) const; diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -385,6 +385,30 @@ return getLower(); } +SmallVector ConstantRange::getClosestToZero() const { + if (isEmptySet()) + return {}; + + // If the range contains zero, said zero itself is the answer. + APInt Zero = APInt::getNullValue(getBitWidth()); + if (contains(Zero)) + return {Zero}; + + APInt Lower = getLower(); + APInt Upper = getUpper() - 1; + + // If both limits are equidistant from zero, then we *MUST* return both! + // We can don't bother checking for zero, we know we don't have it by now. + if (Lower.isNonNegative() && Upper.isNegative() && Lower == -Upper) + return {Lower, Upper}; + + // Return the limit that has smallest absolute value. This is INT_MIN-safe. + auto CompareAbs = [](const APInt &A, const APInt &B) { + return A.abs().ult(B.abs()); + }; + return {std::min({Lower, Upper}, CompareAbs)}; +} + bool ConstantRange::contains(const APInt &V) const { if (Lower == Upper) return isFullSet(); @@ -984,6 +1008,59 @@ return UR.isSizeStrictlySmallerThan(SR) ? UR : SR; } +ConstantRange ConstantRange::mulWithNoWrap(const ConstantRange &Other, + unsigned NoWrapKind, + PreferredRangeType RangeType) const { + // Calculate the range for "X * Y" which is guaranteed not to wrap(overflow). + // (X is from this, and Y is from Other) + if (isEmptySet() || Other.isEmptySet()) + return getEmpty(); + if (isFullSet() && Other.isFullSet()) + return getFull(); + + using OBO = OverflowingBinaryOperator; + ConstantRange Result = multiply(Other); + + if (NoWrapKind & OBO::NoUnsignedWrap) { + bool Overflow; + + (void)getUnsignedMin().umul_ov(Other.getUnsignedMin(), Overflow); + if (Overflow) + return getEmpty(); // Always overflows. + + if (NoWrapKind & OBO::NoSignedWrap) { + (void)getUnsignedMin().smul_ov(Other.getUnsignedMin(), Overflow); + if (Overflow) + return getEmpty(); // Always overflows. + } + + Result = Result.intersectWith(umul_sat(Other), RangeType); + } + + if (NoWrapKind & OBO::NoSignedWrap) { + bool AllOverflow = true; + SmallVector ClosestToZero = getClosestToZero(); + SmallVector OtherClosestToZero = Other.getClosestToZero(); + for (const APInt &L : ClosestToZero) { + for (const APInt &R : OtherClosestToZero) { + bool Overflow; + (void)L.smul_ov(R, Overflow); + AllOverflow &= Overflow; + if (!AllOverflow) + break; + } + if (!AllOverflow) + break; + } + if (AllOverflow) + return getEmpty(); // Always overflows. + + Result = Result.intersectWith(smul_sat(Other), RangeType); + } + + return Result; +} + ConstantRange ConstantRange::smax(const ConstantRange &Other) const { // X smax Y is: range(smax(X_smin, Y_smin), diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -644,7 +644,8 @@ } template -static void TestAddWithNoSignedWrapExhaustive(Fn1 RangeFn, Fn2 IntFn) { +static void TestAddWithNoSignedWrapExhaustive(Fn1 RangeFn, Fn2 IntFn, + bool CorrectnessOnly = false) { unsigned Bits = 4; EnumerateTwoConstantRanges(Bits, [&](const ConstantRange &CR1, const ConstantRange &CR2) { @@ -669,7 +670,8 @@ EXPECT_EQ(CR.isEmptySet(), AllOverflow); - if (!CR1.isSignWrappedSet() && !CR2.isSignWrappedSet()) { + if (!CorrectnessOnly && !CR1.isSignWrappedSet() && + !CR2.isSignWrappedSet()) { if (Min.sgt(Max)) { EXPECT_TRUE(CR.isEmptySet()); return; @@ -682,7 +684,8 @@ } template -static void TestAddWithNoUnsignedWrapExhaustive(Fn1 RangeFn, Fn2 IntFn) { +static void TestAddWithNoUnsignedWrapExhaustive(Fn1 RangeFn, Fn2 IntFn, + bool CorrectnessOnly = false) { unsigned Bits = 4; EnumerateTwoConstantRanges(Bits, [&](const ConstantRange &CR1, const ConstantRange &CR2) { @@ -707,7 +710,7 @@ EXPECT_EQ(CR.isEmptySet(), AllOverflow); - if (!CR1.isWrappedSet() && !CR2.isWrappedSet()) { + if (!CorrectnessOnly && !CR1.isWrappedSet() && !CR2.isWrappedSet()) { if (Min.ugt(Max)) { EXPECT_TRUE(CR.isEmptySet()); return; @@ -720,9 +723,10 @@ } template -static void TestAddWithNoSignedUnsignedWrapExhaustive(Fn1 RangeFn, - Fn2 IntFnSigned, - Fn3 IntFnUnsigned) { +static void +TestAddWithNoSignedUnsignedWrapExhaustive(Fn1 RangeFn, Fn2 IntFnSigned, + Fn3 IntFnUnsigned, + bool CorrectnessOnly = false) { unsigned Bits = 4; EnumerateTwoConstantRanges( Bits, [&](const ConstantRange &CR1, const ConstantRange &CR2) { @@ -754,7 +758,7 @@ EXPECT_EQ(CR.isEmptySet(), AllOverflow); - if (!CR1.isWrappedSet() && !CR2.isWrappedSet() && + if (!CorrectnessOnly && !CR1.isWrappedSet() && !CR2.isWrappedSet() && !CR1.isSignWrappedSet() && !CR2.isSignWrappedSet()) { if (UMin.ugt(UMax) || SMin.sgt(SMax)) { EXPECT_TRUE(CR.isEmptySet()); @@ -1007,6 +1011,40 @@ ConstantRange(APInt(8, -2), APInt(8, 1))); } +TEST_F(ConstantRangeTest, MulWithNoWrap) { + typedef OverflowingBinaryOperator OBO; + + TestAddWithNoSignedWrapExhaustive( + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.mulWithNoWrap(CR2, OBO::NoSignedWrap); + }, + [](bool &IsOverflow, const APInt &N1, const APInt &N2) { + return N1.smul_ov(N2, IsOverflow); + }, + /*CorrectnessOnly=*/true); + + TestAddWithNoUnsignedWrapExhaustive( + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.mulWithNoWrap(CR2, OBO::NoUnsignedWrap); + }, + [](bool &IsOverflow, const APInt &N1, const APInt &N2) { + return N1.umul_ov(N2, IsOverflow); + }, + /*CorrectnessOnly=*/true); + + TestAddWithNoSignedUnsignedWrapExhaustive( + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.mulWithNoWrap(CR2, OBO::NoUnsignedWrap | OBO::NoSignedWrap); + }, + [](bool &IsOverflow, const APInt &N1, const APInt &N2) { + return N1.smul_ov(N2, IsOverflow); + }, + [](bool &IsOverflow, const APInt &N1, const APInt &N2) { + return N1.umul_ov(N2, IsOverflow); + }, + /*CorrectnessOnly=*/true); +} + TEST_F(ConstantRangeTest, UMax) { EXPECT_EQ(Full.umax(Full), Full); EXPECT_EQ(Full.umax(Empty), Empty); @@ -2295,4 +2333,37 @@ }); } +TEST_F(ConstantRangeTest, getClosestToZero) { + unsigned Bits = 4; + EnumerateConstantRanges(Bits, [&](const ConstantRange &CR) { + const SmallVector ClosestToZero = CR.getClosestToZero(); + + EXPECT_EQ(ClosestToZero.empty(), CR.isEmptySet()); + if (ClosestToZero.empty()) + return; + + EXPECT_LE(ClosestToZero.size(), 2U); + if (ClosestToZero.size() == 2U) + EXPECT_EQ(ClosestToZero.front(), -(ClosestToZero.back())); + + for (const APInt &V : ClosestToZero) + EXPECT_TRUE(CR.contains(V)); + + if (llvm::any_of(ClosestToZero, [Bits](const APInt &V) { + return V == APInt::getSignedMinValue(Bits); + })) { + EXPECT_EQ(ClosestToZero.size(), 1U); + const APInt *S = CR.getSingleElement(); + EXPECT_NE(S, nullptr); + EXPECT_EQ(*S, ClosestToZero.front()); + + return; + } + + APInt ClosestToZeroAbs = ClosestToZero.front().abs(); + for (APInt C = APInt(Bits, 0); C.ult(ClosestToZeroAbs); ++C) + EXPECT_FALSE(CR.contains(C) || CR.contains(-C)); + }); +} + } // anonymous namespace