diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -169,6 +169,28 @@ return SelectSVELogicalImm(N, VT, Imm); } + // Returns a suitable CNT/INC/DEC/RDVL multiplier to calculate VSCALE*N. + template + bool SelectCntImm(SDValue N, SDValue &Imm) { + if (!isa(N)) + return false; + + int64_t MulImm = cast(N)->getSExtValue(); + if (Shift) + MulImm = 1 << MulImm; + + if ((MulImm % std::abs(Scale)) != 0) + return false; + + MulImm /= Scale; + if ((MulImm >= Min) && (MulImm <= Max)) { + Imm = CurDAG->getTargetConstant(MulImm, SDLoc(N), MVT::i32); + return true; + } + + return false; + } + /// Form sequences of consecutive 64/128-bit registers for use in NEON /// instructions making use of a vector-list (e.g. ldN, tbl). Vecs must have /// between 1 and 4 elements. If it contains a single element that is returned diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -9541,6 +9541,19 @@ return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), SRA); } +static bool IsSVECntIntrinsic(SDValue S) { + switch(getIntrinsicID(S.getNode())) { + default: + break; + case Intrinsic::aarch64_sve_cntb: + case Intrinsic::aarch64_sve_cnth: + case Intrinsic::aarch64_sve_cntw: + case Intrinsic::aarch64_sve_cntd: + return true; + } + return false; +} + static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -9551,9 +9564,18 @@ if (!isa(N->getOperand(1))) return SDValue(); + SDValue N0 = N->getOperand(0); ConstantSDNode *C = cast(N->getOperand(1)); const APInt &ConstValue = C->getAPIntValue(); + // Allow the scaling to be folded into the `cnt` instruction by preventing + // the scaling to be obscured here. This makes it easier to pattern match. + if (IsSVECntIntrinsic(N0) || + (N0->getOpcode() == ISD::TRUNCATE && + (IsSVECntIntrinsic(N0->getOperand(0))))) + if (ConstValue.sge(1) && ConstValue.sle(16)) + return SDValue(); + // Multiplication of a power of two plus/minus one can be done more // cheaply as as shift+add/sub. For now, this is true unilaterally. If // future CPUs have a cheaper MADD instruction, this may need to be @@ -9564,7 +9586,6 @@ // e.g. 6=3*2=(2+1)*2. // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 // which equals to (1+2)*16-(1+2). - SDValue N0 = N->getOperand(0); // TrailingZeroes is used to test if the mul can be lowered to // shift+add+shift. unsigned TrailingZeroes = ConstValue.countTrailingZeros(); diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -244,6 +244,10 @@ let DecoderMethod = "DecodeSVEIncDecImm"; } +// This allows i32 immediate extraction from i64 based arithmetic. +def sve_cnt_mul_imm : ComplexPattern">; +def sve_cnt_shl_imm : ComplexPattern">; + //===----------------------------------------------------------------------===// // SVE PTrue - These are used extensively throughout the pattern matching so // it's important we define them first. @@ -635,6 +639,12 @@ def : InstAlias(NAME) GPR64:$Rd, 0b11111, 1), 2>; + def : Pat<(i64 (mul (op sve_pred_enum:$pattern), (sve_cnt_mul_imm i32:$imm))), + (!cast(NAME) sve_pred_enum:$pattern, sve_incdec_imm:$imm)>; + + def : Pat<(i64 (shl (op sve_pred_enum:$pattern), (i64 (sve_cnt_shl_imm i32:$imm)))), + (!cast(NAME) sve_pred_enum:$pattern, sve_incdec_imm:$imm)>; + def : Pat<(i64 (op sve_pred_enum:$pattern)), (!cast(NAME) sve_pred_enum:$pattern, 1)>; } diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-counting-elems.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-counting-elems.ll --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-counting-elems.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-counting-elems.ll @@ -12,6 +12,24 @@ ret i64 %out } +define i64 @cntb_mul3() { +; CHECK-LABEL: cntb_mul3: +; CHECK: cntb x0, vl6, mul #3 +; CHECK-NEXT: ret + %cnt = call i64 @llvm.aarch64.sve.cntb(i32 6) + %out = mul i64 %cnt, 3 + ret i64 %out +} + +define i64 @cntb_mul4() { +; CHECK-LABEL: cntb_mul4: +; CHECK: cntb x0, vl8, mul #4 +; CHECK-NEXT: ret + %cnt = call i64 @llvm.aarch64.sve.cntb(i32 8) + %out = mul i64 %cnt, 4 + ret i64 %out +} + ; ; CNTH ; @@ -24,6 +42,24 @@ ret i64 %out } +define i64 @cnth_mul5() { +; CHECK-LABEL: cnth_mul5: +; CHECK: cnth x0, vl7, mul #5 +; CHECK-NEXT: ret + %cnt = call i64 @llvm.aarch64.sve.cnth(i32 7) + %out = mul i64 %cnt, 5 + ret i64 %out +} + +define i64 @cnth_mul8() { +; CHECK-LABEL: cnth_mul8: +; CHECK: cnth x0, vl5, mul #8 +; CHECK-NEXT: ret + %cnt = call i64 @llvm.aarch64.sve.cnth(i32 5) + %out = mul i64 %cnt, 8 + ret i64 %out +} + ; ; CNTW ; @@ -36,6 +72,24 @@ ret i64 %out } +define i64 @cntw_mul11() { +; CHECK-LABEL: cntw_mul11: +; CHECK: cntw x0, vl8, mul #11 +; CHECK-NEXT: ret + %cnt = call i64 @llvm.aarch64.sve.cntw(i32 8) + %out = mul i64 %cnt, 11 + ret i64 %out +} + +define i64 @cntw_mul2() { +; CHECK-LABEL: cntw_mul2: +; CHECK: cntw x0, vl6, mul #2 +; CHECK-NEXT: ret + %cnt = call i64 @llvm.aarch64.sve.cntw(i32 6) + %out = mul i64 %cnt, 2 + ret i64 %out +} + ; ; CNTD ; @@ -48,6 +102,24 @@ ret i64 %out } +define i64 @cntd_mul15() { +; CHECK-LABEL: cntd_mul15: +; CHECK: cntd x0, vl16, mul #15 +; CHECK-NEXT: ret + %cnt = call i64 @llvm.aarch64.sve.cntd(i32 9) + %out = mul i64 %cnt, 15 + ret i64 %out +} + +define i64 @cntd_mul16() { +; CHECK-LABEL: cntd_mul16: +; CHECK: cntd x0, vl32, mul #16 +; CHECK-NEXT: ret + %cnt = call i64 @llvm.aarch64.sve.cntd(i32 10) + %out = mul i64 %cnt, 16 + ret i64 %out +} + ; ; CNTP ;