Index: lib/IR/ConstantRange.cpp =================================================================== --- lib/IR/ConstantRange.cpp +++ lib/IR/ConstantRange.cpp @@ -308,83 +308,80 @@ return intersectWith(CR.inverse()); } + +/// Represents an unsigned "LowBits.getBitWidth() + 1" bit integer, with +/// `HighBit` as the MSB. +struct UnwrappedLimit { + const APInt &LowBits; + bool HighBit = false; + + explicit UnwrappedLimit(const APInt &LowBits, bool HighBit = false) : + LowBits(LowBits), HighBit(HighBit) {} + + bool ult(const UnwrappedLimit &Other) { + if (HighBit && !Other.HighBit) return false; + if (!HighBit && Other.HighBit) return true; + return LowBits.ult(Other.LowBits); + } + + UnwrappedLimit umin(const UnwrappedLimit &Other) { + if (ult(Other)) return *this; + return Other; + } + + UnwrappedLimit umax(const UnwrappedLimit &Other) { + if (ult(Other)) return Other; + return *this; + } +}; + +static std::pair +GetUnwrappedLimits(const ConstantRange &CR, bool HighBitPref) { + if (!CR.isWrappedSet()) + return std::make_pair(UnwrappedLimit(CR.getLower(), HighBitPref), + UnwrappedLimit(CR.getUpper(), HighBitPref)); + + return std::make_pair(UnwrappedLimit(CR.getLower()), + UnwrappedLimit(CR.getUpper(), true)); +} + /// intersectWith - Return the range that results from the intersection of this /// range with another range. The resultant range is guaranteed to include all /// elements contained in both input ranges, and to have the smallest possible /// set size that does so. Because there may be two intersections with the /// same set size, A.intersectWith(B) might not be equal to B.intersectWith(A). ConstantRange ConstantRange::intersectWith(const ConstantRange &CR) const { - assert(getBitWidth() == CR.getBitWidth() && + assert(getBitWidth() == CR.getBitWidth() && "ConstantRange types don't agree!"); // Handle common cases. if ( isEmptySet() || CR.isFullSet()) return *this; if (CR.isEmptySet() || isFullSet()) return CR; - if (!isWrappedSet() && CR.isWrappedSet()) - return CR.intersectWith(*this); - - if (!isWrappedSet() && !CR.isWrappedSet()) { - if (Lower.ult(CR.Lower)) { - if (Upper.ule(CR.Lower)) - return ConstantRange(getBitWidth(), false); - - if (Upper.ult(CR.Upper)) - return ConstantRange(CR.Lower, Upper); - - return CR; - } - if (Upper.ult(CR.Upper)) - return *this; - - if (Lower.ult(CR.Upper)) - return ConstantRange(Lower, CR.Upper); + auto IntersectViaUnwrapping = + [this](const ConstantRange &CR0, bool HighBitPref0, + const ConstantRange &CR1, bool HighBitPref1) { + auto SelfUnwrapped = GetUnwrappedLimits(CR0, HighBitPref0); + auto OtherUnwrapped = GetUnwrappedLimits(CR1, HighBitPref1); - return ConstantRange(getBitWidth(), false); - } - - if (isWrappedSet() && !CR.isWrappedSet()) { - if (CR.Lower.ult(Upper)) { - if (CR.Upper.ult(Upper)) - return CR; + auto NewLower = SelfUnwrapped.first.umax(OtherUnwrapped.first); + auto NewUpper = SelfUnwrapped.second.umin(OtherUnwrapped.second); - if (CR.Upper.ule(Lower)) - return ConstantRange(CR.Lower, Upper); + if (NewLower.ult(NewUpper)) + return ConstantRange(NewLower.LowBits, NewUpper.LowBits); + return ConstantRange(getBitWidth(), false /* empty */); + }; - if (getSetSize().ult(CR.getSetSize())) - return *this; - return CR; - } - if (CR.Lower.ult(Lower)) { - if (CR.Upper.ule(Lower)) - return ConstantRange(getBitWidth(), false); - - return ConstantRange(Lower, CR.Upper); - } - return CR; - } + if (isWrappedSet() == CR.isWrappedSet()) + return IntersectViaUnwrapping(*this, false, CR, false); - if (CR.Upper.ult(Upper)) { - if (CR.Lower.ult(Upper)) { - if (getSetSize().ult(CR.getSetSize())) - return *this; - return CR; - } - - if (CR.Lower.ult(Lower)) - return ConstantRange(Lower, CR.Upper); + if (!isWrappedSet() && CR.isWrappedSet()) + return CR.intersectWith(*this); - return CR; - } - if (CR.Upper.ule(Lower)) { - if (CR.Lower.ult(Lower)) - return *this; + ConstantRange I0 = IntersectViaUnwrapping(*this, false, CR, true); + ConstantRange I1 = IntersectViaUnwrapping(*this, false, CR, false); - return ConstantRange(CR.Lower, Upper); - } - if (getSetSize().ult(CR.getSetSize())) - return *this; - return CR; + return I0.getSetSize().ugt(I1.getSetSize()) ? I0 : I1; } Index: unittests/IR/ConstantRangeTest.cpp =================================================================== --- unittests/IR/ConstantRangeTest.cpp +++ unittests/IR/ConstantRangeTest.cpp @@ -242,17 +242,18 @@ // 01..4.6789ABCDEF where the dots represent values not in the intersection. ConstantRange LHS(APInt(16, 4), APInt(16, 2)); ConstantRange RHS(APInt(16, 6), APInt(16, 5)); - EXPECT_TRUE(LHS.intersectWith(RHS) == LHS); + EXPECT_TRUE(LHS.intersectWith(RHS) == + ConstantRange(APInt(16, 6), APInt(16, 2))); // previous bug: intersection of [min, 3) and [2, max) should be 2 LHS = ConstantRange(APInt(32, -2147483646), APInt(32, 3)); RHS = ConstantRange(APInt(32, 2), APInt(32, 2147483646)); EXPECT_EQ(LHS.intersectWith(RHS), ConstantRange(APInt(32, 2))); - // [2, 0) /\ [4, 3) = [2, 0) + // [2, 0) /\ [4, 3) = [4, 0) LHS = ConstantRange(APInt(32, 2), APInt(32, 0)); RHS = ConstantRange(APInt(32, 4), APInt(32, 3)); - EXPECT_EQ(LHS.intersectWith(RHS), ConstantRange(APInt(32, 2), APInt(32, 0))); + EXPECT_EQ(LHS.intersectWith(RHS), ConstantRange(APInt(32, 4), APInt(32, 0))); // [2, 0) /\ [4, 2) = [4, 0) LHS = ConstantRange(APInt(32, 2), APInt(32, 0)); @@ -264,20 +265,25 @@ RHS = ConstantRange(APInt(32, 5), APInt(32, 1)); EXPECT_EQ(LHS.intersectWith(RHS), ConstantRange(APInt(32, 5), APInt(32, 1))); - // [2, 0) /\ [7, 4) = [7, 4) + // [2, 0) /\ [7, 4) = [7, 0) LHS = ConstantRange(APInt(32, 2), APInt(32, 0)); RHS = ConstantRange(APInt(32, 7), APInt(32, 4)); - EXPECT_EQ(LHS.intersectWith(RHS), ConstantRange(APInt(32, 7), APInt(32, 4))); + EXPECT_EQ(LHS.intersectWith(RHS), ConstantRange(APInt(32, 7), APInt(32, 0))); // [4, 2) /\ [1, 0) = [1, 0) LHS = ConstantRange(APInt(32, 4), APInt(32, 2)); RHS = ConstantRange(APInt(32, 1), APInt(32, 0)); - EXPECT_EQ(LHS.intersectWith(RHS), ConstantRange(APInt(32, 4), APInt(32, 2))); - + EXPECT_EQ(LHS.intersectWith(RHS), ConstantRange(APInt(32, 4), APInt(32, 0))); + // [15, 0) /\ [7, 6) = [15, 0) LHS = ConstantRange(APInt(32, 15), APInt(32, 0)); RHS = ConstantRange(APInt(32, 7), APInt(32, 6)); EXPECT_EQ(LHS.intersectWith(RHS), ConstantRange(APInt(32, 15), APInt(32, 0))); + + LHS = ConstantRange(APInt(4, 0), APInt(4, -1)); + RHS = ConstantRange(APInt(4, -8), APInt(4, 7)); + + EXPECT_EQ(LHS.intersectWith(RHS), ConstantRange(APInt(4, -8), APInt(4, -1))); } TEST_F(ConstantRangeTest, UnionWith) {