diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h --- a/llvm/include/llvm/IR/ConstantRange.h +++ b/llvm/include/llvm/IR/ConstantRange.h @@ -530,6 +530,13 @@ /// ignoring a possible zero value contained in the input range. ConstantRange ctlz(bool ZeroIsPoison = false) const; + /// Calculate cttz range. If \p ZeroIsPoison is set, the range is computed + /// ignoring a possible zero value contained in the input range. + ConstantRange cttz(bool ZeroIsPoison = false) const; + + /// Calculate ctpop range. + ConstantRange ctpop() const; + /// Represents whether an operation on the given constant range is known to /// always or never overflow. enum class OverflowResult { diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -945,6 +945,8 @@ case Intrinsic::smax: case Intrinsic::abs: case Intrinsic::ctlz: + case Intrinsic::cttz: + case Intrinsic::ctpop: return true; default: return false; @@ -982,6 +984,15 @@ assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean"); return Ops[0].ctlz(ZeroIsPoison->getBoolValue()); } + case Intrinsic::cttz: { + const APInt *ZeroIsPoison = Ops[1].getSingleElement(); + assert(ZeroIsPoison && "Must be known (immarg)"); + assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean"); + return Ops[0].cttz(ZeroIsPoison->getBoolValue()); + } + case Intrinsic::ctpop: { + return Ops[0].ctpop(); + } default: assert(!isIntrinsicSupported(IntrinsicID) && "Shouldn't be supported"); llvm_unreachable("Unsupported intrinsic"); @@ -1711,6 +1722,135 @@ APInt(getBitWidth(), getUnsignedMin().countl_zero() + 1)); } +static ConstantRange getUnsignedCountTrailingZerosRange(const APInt &Lower, + const APInt &Upper) { + assert(Lower.ule(Upper)); + unsigned BitWidth = Lower.getBitWidth(); + if (Lower == Upper) + return ConstantRange::getEmpty(BitWidth); + if (Lower + 1 == Upper) + return ConstantRange(APInt(BitWidth, Lower.countr_zero())); + if (Lower.isZero()) + return ConstantRange(APInt::getZero(BitWidth), + APInt(BitWidth, BitWidth + 1)); + unsigned MaxValue = 0; + APInt Cur = Lower; + while (Cur.ult(Upper)) { + MaxValue = Cur.countr_zero(); + APInt Next = Cur + APInt::getOneBitSet(BitWidth, MaxValue); + if (!Cur.ult(Next)) + break; + Cur = Next; + } + return ConstantRange(APInt::getZero(BitWidth), APInt(BitWidth, MaxValue + 1)); +} + +ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const { + if (isEmptySet()) + return getEmpty(); + + APInt Zero = APInt::getZero(getBitWidth()); + + if (ZeroIsPoison && contains(Zero)) { + // ZeroIsPoison is set, and zero is contained. We discern three cases, in + // which a zero can appear: + // 1) Lower is zero, handling cases of kind [0, 1), [0, 2), etc. + // 2) Upper is zero, wrapped set, handling cases of kind [3, 0], etc. + // 3) Zero contained in a wrapped set, e.g., [3, 2), [3, 1), etc. + + if (getLower().isZero()) { + if ((getUpper() - 1).isZero()) { + // We have in input interval of kind [0, 1). In this case we cannot + // really help but return empty-set. + return getEmpty(); + } + + // Compute the resulting range by excluding zero from Lower. + return getUnsignedCountTrailingZerosRange(getLower() + 1, getUpper()); + } else if ((getUpper() - 1).isZero()) { + // Compute the resulting range by excluding zero from Upper. + return ConstantRange( + Zero, APInt(getBitWidth(), + (getUnsignedMax() - getLower() + 1).logBase2() + 1)); + } else { + ConstantRange CR1( + Zero, APInt(getBitWidth(), + (getUnsignedMax() - getLower() + 1).logBase2() + 1)); + ConstantRange CR2 = getUnsignedCountTrailingZerosRange( + APInt(getBitWidth(), 1), getUpper()); + return CR1.unionWith(CR2); + } + } + + if (isFullSet()) { + return getNonEmpty(Zero, APInt(getBitWidth(), getBitWidth() + 1)); + } + if (!isUpperWrapped()) { + return getUnsignedCountTrailingZerosRange(getLower(), getUpper()); + } + ConstantRange CR1( + Zero, + APInt(getBitWidth(), (getUnsignedMax() - getLower() + 1).logBase2() + 1)); + ConstantRange CR2 = getUnsignedCountTrailingZerosRange(Zero, getUpper()); + return CR1.unionWith(CR2); +} + +static ConstantRange getUnsignedPopCountRange(const APInt &Lower, + const APInt &Upper) { + assert(Lower.ule(Upper)); + unsigned BitWidth = Lower.getBitWidth(); + if (Lower == Upper) + return ConstantRange::getEmpty(BitWidth); + if (Lower + 1 == Upper) + return ConstantRange(APInt(BitWidth, Lower.popcount())); + unsigned MinBits = BitWidth; + unsigned MaxBits = 0; + APInt LowerBound = Lower; + while (!LowerBound.isZero()) { + // [LowerBound, LowerBound + lowbit(LowerBound)) + unsigned Ctz = LowerBound.countr_zero(); + APInt Next = LowerBound + APInt::getOneBitSet(BitWidth, Ctz); + if (!(LowerBound.ult(Next) && Next.ule(Upper))) + break; + unsigned PrefixPopCount = LowerBound.popcount(); + MinBits = std::min(MinBits, PrefixPopCount); + MaxBits = std::max(MaxBits, PrefixPopCount + Ctz); + LowerBound = Next; + } + APInt UpperBound = Upper; + while (!UpperBound.isZero()) { + // [UpperBound - lowbit(UpperBound), UpperBound) + unsigned Ctz = UpperBound.countr_zero(); + APInt Next = UpperBound - APInt::getOneBitSet(BitWidth, Ctz); + if (!(LowerBound.ule(Next) && Next.ult(UpperBound))) + break; + unsigned PrefixPopCount = Next.popcount(); + MinBits = std::min(MinBits, PrefixPopCount); + MaxBits = std::max(MaxBits, PrefixPopCount + Ctz); + UpperBound = Next; + } + assert(LowerBound == UpperBound); + return ConstantRange(APInt(BitWidth, MinBits), APInt(BitWidth, MaxBits + 1)); +} + +ConstantRange ConstantRange::ctpop() const { + if (isEmptySet()) + return getEmpty(); + + APInt Zero = APInt::getZero(getBitWidth()); + if (isFullSet()) { + return getNonEmpty(Zero, APInt(getBitWidth(), getBitWidth() + 1)); + } + if (!isUpperWrapped()) { + return getUnsignedPopCountRange(getLower(), getUpper()); + } + ConstantRange CR1 = + getUnsignedPopCountRange(getLower(), getUnsignedMax()); // [lower, intmax) + ConstantRange CR2 = getUnsignedPopCountRange(Zero, getUpper()); // [0, upper) + ConstantRange CR3(APInt(getBitWidth(), getBitWidth())); // [intmax, intmax] + return CR1.unionWith(CR2).unionWith(CR3); +} + ConstantRange::OverflowResult ConstantRange::unsignedAddMayOverflow( const ConstantRange &Other) const { if (isEmptySet() || Other.isEmptySet()) diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll --- a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll +++ b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll @@ -1010,6 +1010,60 @@ ret i1 %res2 } +define i1 @cttz_fold(i16 %x) { +; CHECK-LABEL: @cttz_fold( +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256 +; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]] +; CHECK: if: +; CHECK-NEXT: [[CTTZ:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true) +; CHECK-NEXT: ret i1 false +; CHECK: else: +; CHECK-NEXT: [[CTTZ2:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true) +; CHECK-NEXT: [[RES2:%.*]] = icmp ult i16 [[CTTZ2]], 8 +; CHECK-NEXT: ret i1 [[RES2]] +; + %cmp = icmp ult i16 %x, 256 + br i1 %cmp, label %if, label %else + +if: + %cttz = call i16 @llvm.cttz.i16(i16 %x, i1 true) + %res = icmp uge i16 %cttz, 8 + ret i1 %res + +else: + %cttz2 = call i16 @llvm.cttz.i16(i16 %x, i1 true) + %res2 = icmp ult i16 %cttz2, 8 + ret i1 %res2 +} + +define i1 @ctpop_fold(i16 %x) { +; CHECK-LABEL: @ctpop_fold( +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256 +; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]] +; CHECK: if: +; CHECK-NEXT: [[CTPOP:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]]) +; CHECK-NEXT: ret i1 true +; CHECK: else: +; CHECK-NEXT: [[CTPOP2:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]]) +; CHECK-NEXT: [[RES2:%.*]] = icmp ugt i16 [[CTPOP2]], 8 +; CHECK-NEXT: ret i1 [[RES2]] +; + %cmp = icmp ult i16 %x, 256 + br i1 %cmp, label %if, label %else + +if: + %ctpop = call i16 @llvm.ctpop.i16(i16 %x) + %res = icmp ule i16 %ctpop, 8 + ret i1 %res + +else: + %ctpop2 = call i16 @llvm.ctpop.i16(i16 %x) + %res2 = icmp ugt i16 %ctpop2, 8 + ret i1 %res2 +} + declare i16 @llvm.ctlz.i16(i16, i1) +declare i16 @llvm.cttz.i16(i16, i1) +declare i16 @llvm.ctpop.i16(i16) declare i16 @llvm.abs.i16(i16, i1) declare void @llvm.assume(i1) diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -2411,6 +2411,26 @@ }); } +TEST_F(ConstantRangeTest, Cttz) { + TestUnaryOpExhaustive( + [](const ConstantRange &CR) { return CR.cttz(); }, + [](const APInt &N) { return APInt(N.getBitWidth(), N.countr_zero()); }); + + TestUnaryOpExhaustive( + [](const ConstantRange &CR) { return CR.cttz(/*ZeroIsPoison=*/true); }, + [](const APInt &N) -> std::optional { + if (N.isZero()) + return std::nullopt; + return APInt(N.getBitWidth(), N.countr_zero()); + }); +} + +TEST_F(ConstantRangeTest, Ctpop) { + TestUnaryOpExhaustive( + [](const ConstantRange &CR) { return CR.ctpop(); }, + [](const APInt &N) { return APInt(N.getBitWidth(), N.popcount()); }); +} + TEST_F(ConstantRangeTest, castOps) { ConstantRange A(APInt(16, 66), APInt(16, 128)); ConstantRange FpToI8 = A.castOp(Instruction::FPToSI, 8);