Index: llvm/lib/Analysis/LazyValueInfo.cpp =================================================================== --- llvm/lib/Analysis/LazyValueInfo.cpp +++ llvm/lib/Analysis/LazyValueInfo.cpp @@ -424,6 +424,8 @@ BasicBlock *BB); std::optional solveBlockValueOverflowIntrinsic(WithOverflowInst *WO, BasicBlock *BB); + std::optional + solveBlockValueCtlzIntrinsic(IntrinsicInst *II, BasicBlock *BB); std::optional solveBlockValueIntrinsic(IntrinsicInst *II, BasicBlock *BB); std::optional @@ -988,8 +990,108 @@ }); } +std::optional +LazyValueInfoImpl::solveBlockValueCtlzIntrinsic(IntrinsicInst *II, + BasicBlock *BB) { + Value *Argument = II->getArgOperand(0); + + using VLE = ValueLatticeElement; + std::optional BlockVal = getBlockValue(Argument, BB); + if (!BlockVal) + return std::nullopt; + + VLE V = BlockVal.getValue(); + + const unsigned OperandBitWidth = DL.getTypeSizeInBits(II->getType()); + auto GetAPInt = [OperandBitWidth](uint64_t V) { + return APInt(OperandBitWidth, V); + }; + auto GetRange = [&GetAPInt](uint64_t Lower, uint64_t Upper) { + return VLE::getRange(ConstantRange(GetAPInt(Lower), GetAPInt(Upper))); + }; + + Constant *C = nullptr; + if (V.isConstant()) + C = V.getConstant(); + else if (V.isNotConstant()) + C = V.getNotConstant(); + ConstantInt *CI = dyn_cast_or_null(C); + const APInt *NV = CI != nullptr ? &CI->getValue() : nullptr; + + bool ZeroIsPoison = cast(II->getArgOperand(1))->isOne(); + std::optional Res = std::nullopt; + + if (V.isUnknownOrUndef()) { + // No valid values + Res = V; + } else if (V.isOverdefined()) { + if (ZeroIsPoison) { + // It might be zero, the result is overdefined + Res = VLE::getOverdefined(); + } else { + // From 0 to the bit width + Res = GetRange(0, OperandBitWidth + 1); + } + } else if (V.isConstant()) { + if (ZeroIsPoison && (CI == nullptr || CI->isZero())) { + // If we have an explicit zero (or we cannot tell), the result is + // undefined + Res = VLE::getOverdefined(); + } else if (NV != nullptr) { + // Zero is safe, and we have the constant, return the exact result + Res = VLE::get(ConstantInt::get(II->getType(), NV->countLeadingZeros())); + } else { + // Zero is safe but the constant is not known, get the range of bits + Res = GetRange(0, OperandBitWidth + 1); + } + } else if (V.isNotConstant()) { + if (CI != nullptr && CI->isZero()) { + // We can explicitly exclude zero, valid results are from 0 to bit width + // minus one + Res = GetRange(0, OperandBitWidth); + } else if (ZeroIsPoison) { + // Zero is not safe, and we can't explicitly exclude it + Res = VLE::getOverdefined(); + } else { + // Zero is safe, but we can't say much. We could say "not one" but we + // cannot express the disjoint range + Res = GetRange(0, OperandBitWidth + 1); + } + } else if (V.isConstantRange()) { + const ConstantRange &Range = V.getConstantRange(); + if (ZeroIsPoison && Range.contains(GetAPInt(0))) { + // Zero is not safe and it's not excluded by the range + Res = VLE::getOverdefined(); + } else if (Range.isWrappedSet() || Range.isFullSet()) { + // The range wraps, therefore it includes the two extreme encodings, all + // zeros and all ones. The only way we can express this [0, BitWidth + 1) + Res = GetRange(0, OperandBitWidth + 1); + } else { + // Zero is either safe or not in the range. The output range is composed + // by the result of countLeadingZero of the two extremes, sorted + APInt Lower = GetAPInt(Range.getLower().countLeadingZeros()); + APInt Last = Range.getUpper() - 1; + APInt Upper = GetAPInt(Last.countLeadingZeros()); + + if (Lower.eq(Upper)) { + Res = VLE::get(ConstantInt::get(II->getType(), Lower)); + } else { + if (Lower.ugt(Upper)) + std::swap(Lower, Upper); + ++Upper; + Res = VLE::getRange(ConstantRange(Lower, Upper)); + } + } + } + + return Res; +} + std::optional LazyValueInfoImpl::solveBlockValueIntrinsic(IntrinsicInst *II, BasicBlock *BB) { + if (II->getIntrinsicID() == Intrinsic::ctlz) + return solveBlockValueCtlzIntrinsic(II, BB); + if (!ConstantRange::isIntrinsicSupported(II->getIntrinsicID())) { LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() << "' - unknown intrinsic.\n"); Index: llvm/test/Analysis/LazyValueAnalysis/lvi-for-ctlz.ll =================================================================== --- /dev/null +++ llvm/test/Analysis/LazyValueAnalysis/lvi-for-ctlz.ll @@ -0,0 +1,35 @@ +; RUN: opt < %s -passes=jump-threading -print-lvi-after-jump-threading -disable-output 2>&1 | FileCheck %s + +; Test LazyValueInfo to correctly propagate the constant range lattice values to ctlz intrinsic. +; In the following example, forward.0 basic block, which counts the leading zeroes of %val, is +; taken only if %val is within [1,9]. This allows LVI to infer that ctlz.i32 may return a number +; of zeroes that is bounded on [28, 31] interval. +define i32 @test_ctlz(i32 %val) { +; CHECK-LABEL: LVI for function 'test_ctlz': +entry: + %add = add i32 %val, -1 + %cond.0 = icmp ult i32 %add, 9 + br i1 %cond.0, label %forward.0, label %forward.1 + +forward.0: ; preds = %entry +; CHECK-LABEL: forward.0: +; CHECK: ; LatticeVal for: 'i32 %val' is: constantrange<1, 10> +; CHECK: ; LatticeVal for: ' %ret_val.0 = tail call i32 @llvm.ctlz.i32(i32 %val, i1 true), !range !0' in BB: '%forward.0' is: constantrange<28, 32> +; CHECK-NOT: ; LatticeVal for: ' %ret_val.0 = tail call i32 @llvm.ctlz.i32(i32 %val, i1 true), !range !0' in BB: '%forward.0' is: constantrange<0, 33> +; CHECK: %ret_val.0 = tail call i32 @llvm.ctlz.i32(i32 %val, i1 true), !range !0 + %ret_val.0 = tail call i32 @llvm.ctlz.i32(i32 %val, i1 true), !range !0 + ret i32 %ret_val.0 + +forward.1: ; preds = %entry +; CHECK-LABEL: forward.1 +; CHECK: ; LatticeVal for: 'i32 %val' is: constantrange<10, 1> +; CHECK: ; LatticeVal for: ' %cond.1 = icmp ugt i32 %val, 20' in BB: '%forward.1' is: overdefined +; CHECK: %cond.1 = icmp ugt i32 %val, 20 + %cond.1 = icmp ugt i32 %val, 20 + %ret_val.1 = select i1 %cond.1, i32 10, i32 0 + ret i32 %ret_val.1 +} + +declare i32 @llvm.ctlz.i32(i32, i1 immarg) nounwind willreturn + +!0 = !{i32 0, i32 33}