diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -173,6 +173,10 @@ One.extractBits(NumBits, BitPosition)); } + /// Return KnownBits based on this, but updated given that the underlying + /// value is known to be greater than or equal to Val. + KnownBits makeGE(const APInt &Val) const; + /// Returns the minimum number of trailing zero bits. unsigned countMinTrailingZeros() const { return Zero.countTrailingOnes(); @@ -241,6 +245,18 @@ static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS, KnownBits RHS); + /// Compute known bits for umax(LHS, RHS). + static KnownBits umax(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute known bits for umin(LHS, RHS). + static KnownBits umin(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute known bits for smax(LHS, RHS). + static KnownBits smax(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute known bits for smin(LHS, RHS). + static KnownBits smin(const KnownBits &LHS, const KnownBits &RHS); + /// Insert the bits from a smaller known bits starting at bitPosition. void insertBits(const KnownBits &SubBits, unsigned BitPosition) { Zero.insertBits(SubBits.Zero, BitPosition); diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -1212,59 +1212,41 @@ if (SelectPatternResult::isMinOrMax(SPF)) { computeKnownBits(RHS, Known, Depth + 1, Q); computeKnownBits(LHS, Known2, Depth + 1, Q); - } else { - computeKnownBits(I->getOperand(2), Known, Depth + 1, Q); - computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); + switch (SPF) { + default: + llvm_unreachable("Unhandled select pattern flavor!"); + case SPF_SMAX: + Known = KnownBits::smax(Known, Known2); + break; + case SPF_SMIN: + Known = KnownBits::smin(Known, Known2); + break; + case SPF_UMAX: + Known = KnownBits::umax(Known, Known2); + break; + case SPF_UMIN: + Known = KnownBits::umin(Known, Known2); + break; + } + break; } - unsigned MaxHighOnes = 0; - unsigned MaxHighZeros = 0; - if (SPF == SPF_SMAX) { - // If both sides are negative, the result is negative. - if (Known.isNegative() && Known2.isNegative()) - // We can derive a lower bound on the result by taking the max of the - // leading one bits. - MaxHighOnes = - std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes()); - // If either side is non-negative, the result is non-negative. - else if (Known.isNonNegative() || Known2.isNonNegative()) - MaxHighZeros = 1; - } else if (SPF == SPF_SMIN) { - // If both sides are non-negative, the result is non-negative. - if (Known.isNonNegative() && Known2.isNonNegative()) - // We can derive an upper bound on the result by taking the max of the - // leading zero bits. - MaxHighZeros = std::max(Known.countMinLeadingZeros(), - Known2.countMinLeadingZeros()); - // If either side is negative, the result is negative. - else if (Known.isNegative() || Known2.isNegative()) - MaxHighOnes = 1; - } else if (SPF == SPF_UMAX) { - // We can derive a lower bound on the result by taking the max of the - // leading one bits. - MaxHighOnes = - std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes()); - } else if (SPF == SPF_UMIN) { - // We can derive an upper bound on the result by taking the max of the - // leading zero bits. - MaxHighZeros = - std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros()); - } else if (SPF == SPF_ABS) { + computeKnownBits(I->getOperand(2), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); + + // Only known if known in both the LHS and RHS. + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + + if (SPF == SPF_ABS) { // RHS from matchSelectPattern returns the negation part of abs pattern. // If the negate has an NSW flag we can assume the sign bit of the result // will be 0 because that makes abs(INT_MIN) undefined. if (match(RHS, m_Neg(m_Specific(LHS))) && Q.IIQ.hasNoSignedWrap(cast(RHS))) - MaxHighZeros = 1; + Known.Zero.setSignBit(); } - // Only known if known in both the LHS and RHS. - Known.One &= Known2.One; - Known.Zero &= Known2.Zero; - if (MaxHighOnes > 0) - Known.One.setHighBits(MaxHighOnes); - if (MaxHighZeros > 0) - Known.Zero.setHighBits(MaxHighZeros); break; } case Instruction::FPTrunc: diff --git a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp --- a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp +++ b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp @@ -308,11 +308,24 @@ Known, DemandedElts, Depth + 1); break; } - case TargetOpcode::G_SMIN: + case TargetOpcode::G_SMIN: { + // TODO: Handle clamp pattern with number of sign bits + KnownBits KnownRHS; + computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts, + Depth + 1); + computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts, + Depth + 1); + Known = KnownBits::smin(Known, KnownRHS); + break; + } case TargetOpcode::G_SMAX: { // TODO: Handle clamp pattern with number of sign bits - computeKnownBitsMin(MI.getOperand(1).getReg(), MI.getOperand(2).getReg(), - Known, DemandedElts, Depth + 1); + KnownBits KnownRHS; + computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts, + Depth + 1); + computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts, + Depth + 1); + Known = KnownBits::smax(Known, KnownRHS); break; } case TargetOpcode::G_UMIN: { @@ -321,13 +334,7 @@ DemandedElts, Depth + 1); computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts, Depth + 1); - - // UMIN - we know that the result will have the maximum of the - // known zero leading bits of the inputs. - unsigned LeadZero = Known.countMinLeadingZeros(); - LeadZero = std::max(LeadZero, KnownRHS.countMinLeadingZeros()); - Known &= KnownRHS; - Known.Zero.setHighBits(LeadZero); + Known = KnownBits::umin(Known, KnownRHS); break; } case TargetOpcode::G_UMAX: { @@ -336,14 +343,7 @@ DemandedElts, Depth + 1); computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts, Depth + 1); - - // UMAX - we know that the result will have the maximum of the - // known one leading bits of the inputs. - unsigned LeadOne = Known.countMinLeadingOnes(); - LeadOne = std::max(LeadOne, KnownRHS.countMinLeadingOnes()); - Known.Zero &= KnownRHS.Zero; - Known.One &= KnownRHS.One; - Known.One.setHighBits(LeadOne); + Known = KnownBits::umax(Known, KnownRHS); break; } case TargetOpcode::G_FCMP: diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3387,29 +3387,13 @@ case ISD::UMIN: { Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); - - // UMIN - we know that the result will have the maximum of the - // known zero leading bits of the inputs. - unsigned LeadZero = Known.countMinLeadingZeros(); - LeadZero = std::max(LeadZero, Known2.countMinLeadingZeros()); - - Known.Zero &= Known2.Zero; - Known.One &= Known2.One; - Known.Zero.setHighBits(LeadZero); + Known = KnownBits::umin(Known, Known2); break; } case ISD::UMAX: { Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); - - // UMAX - we know that the result will have the maximum of the - // known one leading bits of the inputs. - unsigned LeadOne = Known.countMinLeadingOnes(); - LeadOne = std::max(LeadOne, Known2.countMinLeadingOnes()); - - Known.Zero &= Known2.Zero; - Known.One &= Known2.One; - Known.One.setHighBits(LeadOne); + Known = KnownBits::umax(Known, Known2); break; } case ISD::SMIN: @@ -3443,12 +3427,13 @@ } } - // Fallback - just get the shared known bits of the operands. Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); if (Known.isUnknown()) break; // Early-out Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); - Known.Zero &= Known2.Zero; - Known.One &= Known2.One; + if (IsMax) + Known = KnownBits::smax(Known, Known2); + else + Known = KnownBits::smin(Known, Known2); break; } case ISD::FrameIndex: diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -83,6 +83,68 @@ return KnownOut; } +KnownBits KnownBits::makeGE(const APInt &Val) const { + // Find the highest bit position in which our underlying value might be + // strictly greater than Val. + unsigned N = (~Zero & ~Val).countLeadingZeros(); + + // For every higher bit position, if Val has a 1 in that bit then our + // underlying value must also have a 1. + APInt MaskedVal(Val); + MaskedVal.clearLowBits(getBitWidth() - N); + return KnownBits(Zero, One | MaskedVal); +} + +KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) { + // If we can prove that LHS >= RHS then use LHS as the result. Likewise for + // RHS. Ideally our caller would already have spotted these cases and + // optimized away the umax operation, but we handle them here for + // completeness. + if (LHS.getMinValue().uge(RHS.getMaxValue())) + return LHS; + if (RHS.getMinValue().uge(LHS.getMaxValue())) + return RHS; + + // If the result of the umax is LHS then it must be greater than or equal to + // the minimum possible value of RHS. Likewise for RHS. Any known bits that + // are common to these two values are also known in the result. + KnownBits L = LHS.makeGE(RHS.getMinValue()); + KnownBits R = RHS.makeGE(LHS.getMinValue()); + return KnownBits(L.Zero & R.Zero, L.One & R.One); +} + +KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) { + // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0] + auto Flip = [](KnownBits Val) { return KnownBits(Val.One, Val.Zero); }; + return Flip(umax(Flip(LHS), Flip(RHS))); +} + +KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) { + // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF] + auto Flip = [](KnownBits Val) { + unsigned SignBitPosition = Val.getBitWidth() - 1; + APInt Zero = Val.Zero; + APInt One = Val.One; + Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]); + One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]); + return KnownBits(Zero, One); + }; + return Flip(umax(Flip(LHS), Flip(RHS))); +} + +KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) { + // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0] + auto Flip = [](KnownBits Val) { + unsigned SignBitPosition = Val.getBitWidth() - 1; + APInt Zero = Val.One; + APInt One = Val.Zero; + Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]); + One.setBitVal(SignBitPosition, Val.One[SignBitPosition]); + return KnownBits(Zero, One); + }; + return Flip(umax(Flip(LHS), Flip(RHS))); +} + KnownBits &KnownBits::operator&=(const KnownBits &RHS) { // Result bit is 0 if either operand bit is 0. Zero |= RHS.Zero; diff --git a/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp b/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp --- a/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp @@ -719,9 +719,9 @@ KnownBits KnownUmax = Info.getKnownBits(CopyUMax); EXPECT_EQ(64u, KnownUmax.getBitWidth()); - EXPECT_EQ(0u, KnownUmax.Zero.getZExtValue()); + EXPECT_EQ(0xffu, KnownUmax.Zero.getZExtValue()); EXPECT_EQ(0xffffffffffffff00, KnownUmax.One.getZExtValue()); - EXPECT_EQ(0u, KnownUmax.Zero.getZExtValue()); + EXPECT_EQ(0xffu, KnownUmax.Zero.getZExtValue()); EXPECT_EQ(0xffffffffffffff00, KnownUmax.One.getZExtValue()); } diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -103,13 +103,15 @@ unsigned Bits = 4; ForeachKnownBits(Bits, [&](const KnownBits &Known1) { ForeachKnownBits(Bits, [&](const KnownBits &Known2) { - KnownBits KnownAnd(Bits), KnownOr(Bits), KnownXor(Bits); + KnownBits KnownAnd(Bits); KnownAnd.Zero.setAllBits(); KnownAnd.One.setAllBits(); - KnownOr.Zero.setAllBits(); - KnownOr.One.setAllBits(); - KnownXor.Zero.setAllBits(); - KnownXor.One.setAllBits(); + KnownBits KnownOr(KnownAnd); + KnownBits KnownXor(KnownAnd); + KnownBits KnownUMax(KnownAnd); + KnownBits KnownUMin(KnownAnd); + KnownBits KnownSMax(KnownAnd); + KnownBits KnownSMin(KnownAnd); ForeachNumInKnownBits(Known1, [&](const APInt &N1) { ForeachNumInKnownBits(Known2, [&](const APInt &N2) { @@ -126,6 +128,22 @@ Res = N1 ^ N2; KnownXor.One &= Res; KnownXor.Zero &= ~Res; + + Res = APIntOps::umax(N1, N2); + KnownUMax.One &= Res; + KnownUMax.Zero &= ~Res; + + Res = APIntOps::umin(N1, N2); + KnownUMin.One &= Res; + KnownUMin.Zero &= ~Res; + + Res = APIntOps::smax(N1, N2); + KnownSMax.One &= Res; + KnownSMax.Zero &= ~Res; + + Res = APIntOps::smin(N1, N2); + KnownSMin.One &= Res; + KnownSMin.Zero &= ~Res; }); }); @@ -140,6 +158,22 @@ KnownBits ComputedXor = Known1 ^ Known2; EXPECT_EQ(KnownXor.Zero, ComputedXor.Zero); EXPECT_EQ(KnownXor.One, ComputedXor.One); + + KnownBits ComputedUMax = KnownBits::umax(Known1, Known2); + EXPECT_EQ(KnownUMax.Zero, ComputedUMax.Zero); + EXPECT_EQ(KnownUMax.One, ComputedUMax.One); + + KnownBits ComputedUMin = KnownBits::umin(Known1, Known2); + EXPECT_EQ(KnownUMin.Zero, ComputedUMin.Zero); + EXPECT_EQ(KnownUMin.One, ComputedUMin.One); + + KnownBits ComputedSMax = KnownBits::smax(Known1, Known2); + EXPECT_EQ(KnownSMax.Zero, ComputedSMax.Zero); + EXPECT_EQ(KnownSMax.One, ComputedSMax.One); + + KnownBits ComputedSMin = KnownBits::smin(Known1, Known2); + EXPECT_EQ(KnownSMin.Zero, ComputedSMin.Zero); + EXPECT_EQ(KnownSMin.One, ComputedSMin.One); }); }); }