diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1129,6 +1129,28 @@ KnownBits Known2; switch (Op.getOpcode()) { + case ISD::VSCALE: { + Function const &F = TLO.DAG.getMachineFunction().getFunction(); + Attribute const &Attr = F.getFnAttribute(Attribute::VScaleRange); + if (!Attr.isValid()) + return false; + Optional MaxVScale = Attr.getVScaleRangeMax(); + if (!MaxVScale.has_value()) + return false; + int64_t VScaleResultUpperbound = *MaxVScale; + if (auto *MulImm = dyn_cast(Op.getOperand(0))) { + VScaleResultUpperbound *= MulImm->getSExtValue(); + } else { + return false; + } + bool Negative = false; + if ((Negative = VScaleResultUpperbound < 0)) + VScaleResultUpperbound = -VScaleResultUpperbound; + unsigned RequiredBits = Log2_64(VScaleResultUpperbound) + 1; + if (RequiredBits < BitWidth) + (Negative ? Known.One : Known.Zero).setHighBits(BitWidth - RequiredBits); + return false; + } case ISD::SCALAR_TO_VECTOR: { if (!DemandedElts[0]) return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT)); @@ -2677,7 +2699,9 @@ [[fallthrough]]; } default: - if (Op.getOpcode() >= ISD::BUILTIN_OP_END) { + // We also ask the target about intrinsics (which could be specific to it). + if (Op.getOpcode() >= ISD::BUILTIN_OP_END || + Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN) { if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, DemandedElts, Known, TLO, Depth)) return true; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -92,6 +92,7 @@ #include #include #include +#include #include #include #include @@ -14995,17 +14996,20 @@ return CSNeg; } -static bool IsSVECntIntrinsic(SDValue S) { +static std::optional IsSVECntIntrinsic(SDValue S) { switch(getIntrinsicID(S.getNode())) { default: break; case Intrinsic::aarch64_sve_cntb: + return 8; case Intrinsic::aarch64_sve_cnth: + return 16; case Intrinsic::aarch64_sve_cntw: + return 32; case Intrinsic::aarch64_sve_cntd: - return true; + return 64; } - return false; + return {}; } /// Calculates what the pre-extend type is, based on the extension @@ -23290,6 +23294,20 @@ } } + if (auto ElementSize = IsSVECntIntrinsic(Op)) { + // The SVE count intrinsics don't support the multiplier immediate so we + // don't have to account for that here. The value returned may be slightly + // over the true required bits, as this is based on the "ALL" pattern. The + // other patterns are also exposed by these intrinsics, but they all return + // a value that's strictly less than "ALL". + unsigned MaxElements = AArch64::SVEMaxBitsPerVector / *ElementSize; + unsigned RequiredBits = Log2_64(MaxElements) + 1; + unsigned BitWidth = Known.Zero.getBitWidth(); + if (RequiredBits < BitWidth) + Known.Zero.setHighBits(BitWidth - RequiredBits); + return false; + } + return TargetLowering::SimplifyDemandedBitsForTargetNode( Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth); } diff --git a/llvm/test/CodeGen/AArch64/vscale-and-sve-cnt-demandedbits.ll b/llvm/test/CodeGen/AArch64/vscale-and-sve-cnt-demandedbits.ll --- a/llvm/test/CodeGen/AArch64/vscale-and-sve-cnt-demandedbits.ll +++ b/llvm/test/CodeGen/AArch64/vscale-and-sve-cnt-demandedbits.ll @@ -14,9 +14,8 @@ ; CHECK: // %bb.0: ; CHECK-NEXT: rdvl x8, #1 ; CHECK-NEXT: lsr x8, x8, #4 -; CHECK-NEXT: and w9, w8, #0x1f -; CHECK-NEXT: and w8, w8, #0xfffffffc -; CHECK-NEXT: add w0, w9, w8 +; CHECK-NEXT: and w9, w8, #0x1c +; CHECK-NEXT: add w0, w8, w9 ; CHECK-NEXT: ret %vscale = call i32 @llvm.vscale.i32() %and_redundant = and i32 %vscale, 31 @@ -29,9 +28,8 @@ ; CHECK-LABEL: cntb_and_elimination: ; CHECK: // %bb.0: ; CHECK-NEXT: cntb x8 -; CHECK-NEXT: and x9, x8, #0x1ff -; CHECK-NEXT: and x8, x8, #0x3fffffffc -; CHECK-NEXT: add x0, x9, x8 +; CHECK-NEXT: and x9, x8, #0x1fc +; CHECK-NEXT: add x0, x8, x9 ; CHECK-NEXT: ret %cntb = call i64 @llvm.aarch64.sve.cntb(i32 31) %and_redundant = and i64 %cntb, 511 @@ -44,9 +42,8 @@ ; CHECK-LABEL: cnth_and_elimination: ; CHECK: // %bb.0: ; CHECK-NEXT: cnth x8 -; CHECK-NEXT: and x9, x8, #0x3ff -; CHECK-NEXT: and x8, x8, #0x3fffffffc -; CHECK-NEXT: add x0, x9, x8 +; CHECK-NEXT: and x9, x8, #0xfc +; CHECK-NEXT: add x0, x8, x9 ; CHECK-NEXT: ret %cnth = call i64 @llvm.aarch64.sve.cnth(i32 31) %and_redundant = and i64 %cnth, 1023 @@ -59,9 +56,8 @@ ; CHECK-LABEL: cntw_and_elimination: ; CHECK: // %bb.0: ; CHECK-NEXT: cntw x8 -; CHECK-NEXT: and x9, x8, #0x7f -; CHECK-NEXT: and x8, x8, #0x3fffffffc -; CHECK-NEXT: add x0, x9, x8 +; CHECK-NEXT: and x9, x8, #0x7c +; CHECK-NEXT: add x0, x8, x9 ; CHECK-NEXT: ret %cntw = call i64 @llvm.aarch64.sve.cntw(i32 31) %and_redundant = and i64 %cntw, 127 @@ -74,9 +70,8 @@ ; CHECK-LABEL: cntd_and_elimination: ; CHECK: // %bb.0: ; CHECK-NEXT: cntd x8 -; CHECK-NEXT: and x9, x8, #0x3f -; CHECK-NEXT: and x8, x8, #0x3fffffffc -; CHECK-NEXT: add x0, x9, x8 +; CHECK-NEXT: and x9, x8, #0x3c +; CHECK-NEXT: add x0, x8, x9 ; CHECK-NEXT: ret %cntd = call i64 @llvm.aarch64.sve.cntd(i32 31) %and_redundant = and i64 %cntd, 63 @@ -89,8 +84,7 @@ ; CHECK-LABEL: vscale_trunc_zext: ; CHECK: // %bb.0: ; CHECK-NEXT: rdvl x8, #1 -; CHECK-NEXT: lsr x8, x8, #4 -; CHECK-NEXT: and x0, x8, #0xffffffff +; CHECK-NEXT: lsr x0, x8, #4 ; CHECK-NEXT: ret %vscale = call i32 @llvm.vscale.i32() %zext = zext i32 %vscale to i64 @@ -101,8 +95,7 @@ ; CHECK-LABEL: vscale_trunc_sext: ; CHECK: // %bb.0: ; CHECK-NEXT: rdvl x8, #1 -; CHECK-NEXT: lsr x8, x8, #4 -; CHECK-NEXT: sxtw x0, w8 +; CHECK-NEXT: lsr x0, x8, #4 ; CHECK-NEXT: ret %vscale = call i32 @llvm.vscale.i32() %sext = sext i32 %vscale to i64 @@ -112,8 +105,7 @@ define i64 @count_bytes_trunc_zext() { ; CHECK-LABEL: count_bytes_trunc_zext: ; CHECK: // %bb.0: -; CHECK-NEXT: cntb x8 -; CHECK-NEXT: and x0, x8, #0xffffffff +; CHECK-NEXT: cntb x0 ; CHECK-NEXT: ret %cnt = call i64 @llvm.aarch64.sve.cntb(i32 31) %trunc = trunc i64 %cnt to i32 @@ -124,8 +116,7 @@ define i64 @count_halfs_trunc_zext() { ; CHECK-LABEL: count_halfs_trunc_zext: ; CHECK: // %bb.0: -; CHECK-NEXT: cnth x8 -; CHECK-NEXT: and x0, x8, #0xffffffff +; CHECK-NEXT: cnth x0 ; CHECK-NEXT: ret %cnt = call i64 @llvm.aarch64.sve.cnth(i32 31) %trunc = trunc i64 %cnt to i32 @@ -136,8 +127,7 @@ define i64 @count_words_trunc_zext() { ; CHECK-LABEL: count_words_trunc_zext: ; CHECK: // %bb.0: -; CHECK-NEXT: cntw x8 -; CHECK-NEXT: and x0, x8, #0xffffffff +; CHECK-NEXT: cntw x0 ; CHECK-NEXT: ret %cnt = call i64 @llvm.aarch64.sve.cntw(i32 31) %trunc = trunc i64 %cnt to i32 @@ -148,8 +138,7 @@ define i64 @count_doubles_trunc_zext() { ; CHECK-LABEL: count_doubles_trunc_zext: ; CHECK: // %bb.0: -; CHECK-NEXT: cntd x8 -; CHECK-NEXT: and x0, x8, #0xffffffff +; CHECK-NEXT: cntd x0 ; CHECK-NEXT: ret %cnt = call i64 @llvm.aarch64.sve.cntd(i32 31) %trunc = trunc i64 %cnt to i32 @@ -160,8 +149,7 @@ define i64 @count_bytes_trunc_sext() { ; CHECK-LABEL: count_bytes_trunc_sext: ; CHECK: // %bb.0: -; CHECK-NEXT: cntb x8 -; CHECK-NEXT: sxtw x0, w8 +; CHECK-NEXT: cntb x0 ; CHECK-NEXT: ret %cnt = call i64 @llvm.aarch64.sve.cntb(i32 31) %trunc = trunc i64 %cnt to i32 @@ -172,8 +160,7 @@ define i64 @count_halfs_trunc_sext() { ; CHECK-LABEL: count_halfs_trunc_sext: ; CHECK: // %bb.0: -; CHECK-NEXT: cnth x8 -; CHECK-NEXT: sxtw x0, w8 +; CHECK-NEXT: cnth x0 ; CHECK-NEXT: ret %cnt = call i64 @llvm.aarch64.sve.cnth(i32 31) %trunc = trunc i64 %cnt to i32 @@ -184,8 +171,7 @@ define i64 @count_words_trunc_sext() { ; CHECK-LABEL: count_words_trunc_sext: ; CHECK: // %bb.0: -; CHECK-NEXT: cntw x8 -; CHECK-NEXT: sxtw x0, w8 +; CHECK-NEXT: cntw x0 ; CHECK-NEXT: ret %cnt = call i64 @llvm.aarch64.sve.cntw(i32 31) %trunc = trunc i64 %cnt to i32 @@ -196,8 +182,7 @@ define i64 @count_doubles_trunc_sext() { ; CHECK-LABEL: count_doubles_trunc_sext: ; CHECK: // %bb.0: -; CHECK-NEXT: cntd x8 -; CHECK-NEXT: sxtw x0, w8 +; CHECK-NEXT: cntd x0 ; CHECK-NEXT: ret %cnt = call i64 @llvm.aarch64.sve.cntd(i32 31) %trunc = trunc i64 %cnt to i32 @@ -212,9 +197,8 @@ ; CHECK-NEXT: mov w9, #5 ; CHECK-NEXT: lsr x8, x8, #4 ; CHECK-NEXT: mul x8, x8, x9 -; CHECK-NEXT: and w9, w8, #0x7f -; CHECK-NEXT: and w8, w8, #0x3f -; CHECK-NEXT: add w0, w9, w8 +; CHECK-NEXT: and w9, w8, #0x3f +; CHECK-NEXT: add w0, w8, w9 ; CHECK-NEXT: ret %vscale = call i32 @llvm.vscale.i32() %mul = mul i32 %vscale, 5 @@ -231,9 +215,8 @@ ; CHECK-NEXT: mov x9, #-5 ; CHECK-NEXT: lsr x8, x8, #4 ; CHECK-NEXT: mul x8, x8, x9 -; CHECK-NEXT: orr w9, w8, #0xffffff80 -; CHECK-NEXT: and w8, w8, #0xffffffc0 -; CHECK-NEXT: add w0, w9, w8 +; CHECK-NEXT: and w9, w8, #0xffffffc0 +; CHECK-NEXT: add w0, w8, w9 ; CHECK-NEXT: ret %vscale = call i32 @llvm.vscale.i32() %mul = mul i32 %vscale, -5