diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h --- a/llvm/include/llvm/IR/ConstantRange.h +++ b/llvm/include/llvm/IR/ConstantRange.h @@ -530,6 +530,13 @@ /// ignoring a possible zero value contained in the input range. ConstantRange ctlz(bool ZeroIsPoison = false) const; + /// Calculate cttz range. If \p ZeroIsPoison is set, the range is computed + /// ignoring a possible zero value contained in the input range. + ConstantRange cttz(bool ZeroIsPoison = false) const; + + /// Calculate ctpop range. + ConstantRange ctpop() const; + /// Represents whether an operation on the given constant range is known to /// always or never overflow. enum class OverflowResult { diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -945,6 +945,8 @@ case Intrinsic::smax: case Intrinsic::abs: case Intrinsic::ctlz: + case Intrinsic::cttz: + case Intrinsic::ctpop: return true; default: return false; @@ -982,6 +984,15 @@ assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean"); return Ops[0].ctlz(ZeroIsPoison->getBoolValue()); } + case Intrinsic::cttz: { + const APInt *ZeroIsPoison = Ops[1].getSingleElement(); + assert(ZeroIsPoison && "Must be known (immarg)"); + assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean"); + return Ops[0].cttz(ZeroIsPoison->getBoolValue()); + } + case Intrinsic::ctpop: { + return Ops[0].ctpop(); + } default: assert(!isIntrinsicSupported(IntrinsicID) && "Shouldn't be supported"); llvm_unreachable("Unsupported intrinsic"); @@ -1724,6 +1735,122 @@ return getNonEmpty(APInt(getBitWidth(), getUnsignedMax().countl_zero()), APInt(getBitWidth(), getUnsignedMin().countl_zero() + 1)); } +static ConstantRange getUnsignedCountTrailingZerosRange(const APInt &Lower, + const APInt &Upper) { + assert(Lower.ule(Upper)); + unsigned BitWidth = Lower.getBitWidth(); + if (Lower == Upper) + return ConstantRange::getEmpty(BitWidth); + if (Lower + 1 == Upper) + return ConstantRange(APInt(BitWidth, Lower.countr_zero())); + if (Lower.isZero()) + return ConstantRange(APInt::getZero(BitWidth), + APInt(BitWidth, BitWidth + 1)); + + // Calculate longest common prefix. + unsigned LCPLength = (Lower ^ (Upper - 1)).countl_zero(); + // If Lower is {LCP, 000...}, the maximum is Lower.countr_zero(). + // Otherwise, the maximum is BitWidth - LCPLength - 1 ({LCP, 100...}). + return ConstantRange( + APInt::getZero(BitWidth), + APInt(BitWidth, std::max(BitWidth - LCPLength, Lower.countr_zero() + 1))); +} + +ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const { + if (isEmptySet()) + return getEmpty(); + + APInt Zero = APInt::getZero(getBitWidth()); + + if (ZeroIsPoison && contains(Zero)) { + // ZeroIsPoison is set, and zero is contained. We discern three cases, in + // which a zero can appear: + // 1) Lower is zero, handling cases of kind [0, 1), [0, 2), etc. + // 2) Upper is zero, wrapped set, handling cases of kind [3, 0], etc. + // 3) Zero contained in a wrapped set, e.g., [3, 2), [3, 1), etc. + + if (getLower().isZero()) { + if ((getUpper() - 1).isZero()) { + // We have in input interval of kind [0, 1). In this case we cannot + // really help but return empty-set. + return getEmpty(); + } + + // Compute the resulting range by excluding zero from Lower. + return getUnsignedCountTrailingZerosRange(getLower() + 1, getUpper()); + } else if ((getUpper() - 1).isZero()) { + // Compute the resulting range by excluding zero from Upper. + return ConstantRange( + Zero, APInt(getBitWidth(), + (getUnsignedMax() - getLower() + 1).logBase2() + 1)); + } else { + ConstantRange CR1( + Zero, APInt(getBitWidth(), + (getUnsignedMax() - getLower() + 1).logBase2() + 1)); + ConstantRange CR2 = getUnsignedCountTrailingZerosRange( + APInt(getBitWidth(), 1), getUpper()); + return CR1.unionWith(CR2); + } + } + + if (isFullSet()) { + return getNonEmpty(Zero, APInt(getBitWidth(), getBitWidth() + 1)); + } + if (!isUpperWrapped()) { + return getUnsignedCountTrailingZerosRange(getLower(), getUpper()); + } + ConstantRange CR1( + Zero, + APInt(getBitWidth(), (getUnsignedMax() - getLower() + 1).logBase2() + 1)); + ConstantRange CR2 = getUnsignedCountTrailingZerosRange(Zero, getUpper()); + return CR1.unionWith(CR2); +} + +static ConstantRange getUnsignedPopCountRange(const APInt &Lower, + const APInt &Upper) { + assert(Lower.ule(Upper)); + unsigned BitWidth = Lower.getBitWidth(); + if (Lower == Upper) + return ConstantRange::getEmpty(BitWidth); + if (Lower + 1 == Upper) + return ConstantRange(APInt(BitWidth, Lower.popcount())); + + APInt Max = Upper - 1; + // Calculate longest common prefix. + unsigned LCPLength = (Lower ^ Max).countl_zero(); + unsigned LCPPopCount = Lower.getHiBits(LCPLength).popcount(); + // If Lower is {LCP, 000...}, the minimum is the popcount of LCP. + // Otherwise, the minimum is the popcount of LCP + 1. + unsigned MinBits = + LCPPopCount + (Lower.countr_zero() < BitWidth - LCPLength ? 1 : 0); + // If Max is {LCP, 111...}, the maximum is the popcount of LCP + (BitWidth - + // length of LCP). + // Otherwise, the minimum is the popcount of LCP + (BitWidth - + // length of LCP - 1). + unsigned MaxBits = LCPPopCount + (BitWidth - LCPLength) + + (Max.countr_one() >= BitWidth - LCPLength ? 1 : 0); + return ConstantRange(APInt(BitWidth, MinBits), APInt(BitWidth, MaxBits)); +} + +ConstantRange ConstantRange::ctpop() const { + if (isEmptySet()) + return getEmpty(); + + unsigned BitWidth = getBitWidth(); + APInt Zero = APInt::getZero(BitWidth); + if (isFullSet()) { + return getNonEmpty(Zero, APInt(BitWidth, BitWidth + 1)); + } + if (!isUpperWrapped()) { + return getUnsignedPopCountRange(getLower(), getUpper()); + } + ConstantRange CR1 = ConstantRange( + APInt(BitWidth, + BitWidth - (getUnsignedMax() - getLower() + 1).logBase2()), + APInt(BitWidth, BitWidth + 1)); // [lower, intmax] + ConstantRange CR2 = getUnsignedPopCountRange(Zero, getUpper()); // [0, upper) + return CR1.unionWith(CR2); +} ConstantRange::OverflowResult ConstantRange::unsignedAddMayOverflow( const ConstantRange &Other) const { diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll --- a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll +++ b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll @@ -1010,6 +1010,60 @@ ret i1 %res2 } +define i1 @cttz_fold(i16 %x) { +; CHECK-LABEL: @cttz_fold( +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256 +; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]] +; CHECK: if: +; CHECK-NEXT: [[CTTZ:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true) +; CHECK-NEXT: ret i1 false +; CHECK: else: +; CHECK-NEXT: [[CTTZ2:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true) +; CHECK-NEXT: [[RES2:%.*]] = icmp ult i16 [[CTTZ2]], 8 +; CHECK-NEXT: ret i1 [[RES2]] +; + %cmp = icmp ult i16 %x, 256 + br i1 %cmp, label %if, label %else + +if: + %cttz = call i16 @llvm.cttz.i16(i16 %x, i1 true) + %res = icmp uge i16 %cttz, 8 + ret i1 %res + +else: + %cttz2 = call i16 @llvm.cttz.i16(i16 %x, i1 true) + %res2 = icmp ult i16 %cttz2, 8 + ret i1 %res2 +} + +define i1 @ctpop_fold(i16 %x) { +; CHECK-LABEL: @ctpop_fold( +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256 +; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]] +; CHECK: if: +; CHECK-NEXT: [[CTPOP:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]]) +; CHECK-NEXT: ret i1 true +; CHECK: else: +; CHECK-NEXT: [[CTPOP2:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]]) +; CHECK-NEXT: [[RES2:%.*]] = icmp ugt i16 [[CTPOP2]], 8 +; CHECK-NEXT: ret i1 [[RES2]] +; + %cmp = icmp ult i16 %x, 256 + br i1 %cmp, label %if, label %else + +if: + %ctpop = call i16 @llvm.ctpop.i16(i16 %x) + %res = icmp ule i16 %ctpop, 8 + ret i1 %res + +else: + %ctpop2 = call i16 @llvm.ctpop.i16(i16 %x) + %res2 = icmp ugt i16 %ctpop2, 8 + ret i1 %res2 +} + declare i16 @llvm.ctlz.i16(i16, i1) +declare i16 @llvm.cttz.i16(i16, i1) +declare i16 @llvm.ctpop.i16(i16) declare i16 @llvm.abs.i16(i16, i1) declare void @llvm.assume(i1) diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -2431,6 +2431,26 @@ }); } +TEST_F(ConstantRangeTest, Cttz) { + TestUnaryOpExhaustive( + [](const ConstantRange &CR) { return CR.cttz(); }, + [](const APInt &N) { return APInt(N.getBitWidth(), N.countr_zero()); }); + + TestUnaryOpExhaustive( + [](const ConstantRange &CR) { return CR.cttz(/*ZeroIsPoison=*/true); }, + [](const APInt &N) -> std::optional { + if (N.isZero()) + return std::nullopt; + return APInt(N.getBitWidth(), N.countr_zero()); + }); +} + +TEST_F(ConstantRangeTest, Ctpop) { + TestUnaryOpExhaustive( + [](const ConstantRange &CR) { return CR.ctpop(); }, + [](const APInt &N) { return APInt(N.getBitWidth(), N.popcount()); }); +} + TEST_F(ConstantRangeTest, castOps) { ConstantRange A(APInt(16, 66), APInt(16, 128)); ConstantRange FpToI8 = A.castOp(Instruction::FPToSI, 8);