diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -214,6 +214,7 @@ SQSHLU_I, SRSHR_I, URSHR_I, + RSHRNB_I, // Vector shift by constant and insert VSLI, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2572,6 +2572,7 @@ MAKE_CASE(AArch64ISD::CALL_BTI) MAKE_CASE(AArch64ISD::MRRS) MAKE_CASE(AArch64ISD::MSRR) + MAKE_CASE(AArch64ISD::RSHRNB_I) } #undef MAKE_CASE return nullptr; @@ -20070,7 +20071,49 @@ return SDValue(); } -static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG) { +// Try to simplify: +// t1 = nxv8i16 add(X, 1 << (ShiftValue - 1)) +// t2 = nxv8i16 srl(t1, ShiftValue) +// to +// t1 = nxv8i16 rshrnb(X, shiftvalue). +// rshrnb will zero the top half bits of each element. Therefore, this combine +// should only be performed when a following instruction with the rshrnb +// as an operand does not care about the top half of each element. For example, +// a uzp1 or a truncating store. +static SDValue trySimplifySrlAddToRshrnb(SDNode *Srl, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + EVT VT = Srl->getValueType(0); + if (!Subtarget->hasSVE2() || Srl->getOpcode() != ISD::SRL || + !VT.isScalableVector()) + return SDValue(); + + auto SrlOp1 = + dyn_cast_or_null(DAG.getSplatValue(Srl->getOperand(1))); + if (!SrlOp1) + return SDValue(); + unsigned ShiftValue = SrlOp1->getZExtValue(); + + SDValue Add = Srl->getOperand(0); + if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse()) + return SDValue(); + auto AddOp1 = + dyn_cast_or_null(DAG.getSplatValue(Add->getOperand(1))); + if (!AddOp1) + return SDValue(); + int64_t AddValue = AddOp1->getZExtValue(); + + if (AddValue != 1 << (ShiftValue - 1)) + return SDValue(); + + SDLoc DL(Srl); + auto Rshrnb = DAG.getNode( + AArch64ISD::RSHRNB_I, DL, VT, + {Add->getOperand(0), DAG.getTargetConstant(ShiftValue, DL, MVT::i32)}); + return Rshrnb; +} + +static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { SDLoc DL(N); SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); @@ -20103,6 +20146,14 @@ } } + if (SDValue Rshrnb = + trySimplifySrlAddToRshrnb(cast(Op0), DAG, Subtarget)) + return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1); + + if (SDValue Rshrnb = + trySimplifySrlAddToRshrnb(cast(Op1), DAG, Subtarget)) + return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb); + // uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z) if (Op0.getOpcode() == AArch64ISD::UUNPKLO) { if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) { @@ -20719,6 +20770,18 @@ if (SDValue Store = combineBoolVectorAndTruncateStore(DAG, ST)) return Store; + if (ST->isTruncatingStore()) + if (SDValue Rshrnb = trySimplifySrlAddToRshrnb( + cast(ST->getOperand(1)), DAG, Subtarget)) { + EVT RshrnbVT = Rshrnb.getValueType(); + EVT StoreVT = ST->getMemoryVT(); + if ((RshrnbVT == MVT::nxv8i16 && StoreVT == MVT::nxv8i8) || + (RshrnbVT == MVT::nxv4i32 && StoreVT == MVT::nxv4i16) || + (RshrnbVT == MVT::nxv2i64 && StoreVT == MVT::nxv2i32)) + return DAG.getTruncStore(ST->getChain(), ST, Rshrnb, ST->getBasePtr(), + StoreVT, ST->getMemOperand()); + } + return SDValue(); } @@ -23036,7 +23099,7 @@ case AArch64ISD::UUNPKHI: return performUnpackCombine(N, DAG, Subtarget); case AArch64ISD::UZP1: - return performUzpCombine(N, DAG); + return performUzpCombine(N, DAG, Subtarget); case AArch64ISD::SETCC_MERGE_ZERO: return performSetccMergeZeroCombine(N, DCI); case AArch64ISD::REINTERPRET_CAST: diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -820,6 +820,12 @@ SDTypeProfile<1, 1, [SDTCisVT<0, i64>, SDTCisVT<1, i32>]>, [SDNPHasChain, SDNPOutGlue]>; +def SD_AArch64rshrnb : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisVec<1>, SDTCisInt<2>]>; +def AArch64rshrnb : SDNode<"AArch64ISD::RSHRNB_I", SD_AArch64rshrnb>; +def AArch64rshrnb_pf : PatFrags<(ops node:$rs, node:$i), + [(AArch64rshrnb node:$rs, node:$i), + (int_aarch64_sve_rshrnb node:$rs, node:$i)]>; + // Match add node and also treat an 'or' node is as an 'add' if the or'ed operands // have no common bits. def add_and_or_is_add : PatFrags<(ops node:$lhs, node:$rhs), diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -3519,7 +3519,7 @@ defm SQSHRUNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b000, "sqshrunb", int_aarch64_sve_sqshrunb>; defm SQRSHRUNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b001, "sqrshrunb", int_aarch64_sve_sqrshrunb>; defm SHRNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b010, "shrnb", int_aarch64_sve_shrnb>; - defm RSHRNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b011, "rshrnb", int_aarch64_sve_rshrnb>; + defm RSHRNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b011, "rshrnb", AArch64rshrnb_pf>; defm SQSHRNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b100, "sqshrnb", int_aarch64_sve_sqshrnb>; defm SQRSHRNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b101, "sqrshrnb", int_aarch64_sve_sqrshrnb>; defm UQSHRNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b110, "uqshrnb", int_aarch64_sve_uqshrnb>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -4305,6 +4305,10 @@ def : SVE_2_Op_Imm_Pat(NAME # _B)>; def : SVE_2_Op_Imm_Pat(NAME # _H)>; def : SVE_2_Op_Imm_Pat(NAME # _S)>; + + def : SVE_2_Op_Imm_Pat(NAME # _B)>; + def : SVE_2_Op_Imm_Pat(NAME # _H)>; + def : SVE_2_Op_Imm_Pat(NAME # _S)>; } class sve2_int_bin_shift_imm_narrow_top tsz8_64, bits<3> opc, diff --git a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll @@ -0,0 +1,237 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2 < %s | FileCheck %s + +define void @add_lshr_rshrnb_b_6(ptr %ptr, ptr %dst, i64 %index){ +; CHECK-LABEL: add_lshr_rshrnb_b_6: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0] +; CHECK-NEXT: rshrnb z0.b, z0.h, #6 +; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2] +; CHECK-NEXT: ret + %load = load , ptr %ptr, align 2 + %1 = add %load, trunc ( shufflevector ( insertelement ( poison, i32 32, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i32 6, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i8, ptr %dst, i64 %index + store %3, ptr %4, align 1 + ret void +} + +define void @neg_add_lshr_rshrnb_b_6(ptr %ptr, ptr %dst, i64 %index){ +; CHECK-LABEL: neg_add_lshr_rshrnb_b_6: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0] +; CHECK-NEXT: add z0.h, z0.h, #1 // =0x1 +; CHECK-NEXT: lsr z0.h, z0.h, #6 +; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2] +; CHECK-NEXT: ret + %load = load , ptr %ptr, align 2 + %1 = add %load, trunc ( shufflevector ( insertelement ( poison, i32 1, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i32 6, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i8, ptr %dst, i64 %index + store %3, ptr %4, align 1 + ret void +} + +define void @add_lshr_rshrnb_h_7(ptr %ptr, ptr %dst, i64 %index){ +; CHECK-LABEL: add_lshr_rshrnb_h_7: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0] +; CHECK-NEXT: rshrnb z0.b, z0.h, #7 +; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2] +; CHECK-NEXT: ret + %load = load , ptr %ptr, align 2 + %1 = add %load, trunc ( shufflevector ( insertelement ( poison, i32 64, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i32 7, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i8, ptr %dst, i64 %index + store %3, ptr %4, align 1 + ret void +} + +define void @add_lshr_rshrn_h_6(ptr %ptr, ptr %dst, i64 %index){ +; CHECK-LABEL: add_lshr_rshrn_h_6: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0] +; CHECK-NEXT: rshrnb z0.h, z0.s, #6 +; CHECK-NEXT: st1h { z0.s }, p0, [x1, x2, lsl #1] +; CHECK-NEXT: ret + %load = load , ptr %ptr, align 2 + %1 = add %load, trunc ( shufflevector ( insertelement ( poison, i64 32, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i64 6, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i16, ptr %dst, i64 %index + store %3, ptr %4, align 1 + ret void +} + +define void @add_lshr_rshrnb_h_2(ptr %ptr, ptr %dst, i64 %index){ +; CHECK-LABEL: add_lshr_rshrnb_h_2: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0] +; CHECK-NEXT: rshrnb z0.h, z0.s, #2 +; CHECK-NEXT: st1h { z0.s }, p0, [x1, x2, lsl #1] +; CHECK-NEXT: ret + %load = load , ptr %ptr, align 2 + %1 = add %load, trunc ( shufflevector ( insertelement ( poison, i64 2, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i64 2, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i16, ptr %dst, i64 %index + store %3, ptr %4, align 1 + ret void +} + +define void @neg_add_lshr_rshrnb_h_0(ptr %ptr, ptr %dst, i64 %index){ +; CHECK-LABEL: neg_add_lshr_rshrnb_h_0: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %load = load , ptr %ptr, align 2 + %1 = add %load, trunc ( shufflevector ( insertelement ( poison, i64 1, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i64 -1, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i16, ptr %dst, i64 %index + store %3, ptr %4, align 1 + ret void +} + +define void @wide_add_shift_add_rshrnb_b(ptr %dest, i64 %index, %arg1){ +; CHECK-LABEL: wide_add_shift_add_rshrnb_b: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.b +; CHECK-NEXT: rshrnb z1.b, z1.h, #6 +; CHECK-NEXT: ld1b { z2.b }, p0/z, [x0, x1] +; CHECK-NEXT: rshrnb z0.b, z0.h, #6 +; CHECK-NEXT: uzp1 z0.b, z0.b, z1.b +; CHECK-NEXT: add z0.b, z2.b, z0.b +; CHECK-NEXT: st1b { z0.b }, p0, [x0, x1] +; CHECK-NEXT: ret + %1 = add %arg1, shufflevector ( insertelement ( poison, i16 32, i64 0), poison, zeroinitializer) + %2 = lshr %1, shufflevector ( insertelement ( poison, i16 6, i64 0), poison, zeroinitializer) + %3 = getelementptr inbounds i8, ptr %dest, i64 %index + %load = load , ptr %3, align 2 + %4 = trunc %2 to + %5 = add %load, %4 + store %5, ptr %3, align 2 + ret void +} + +define void @wide_add_shift_add_rshrnb_h(ptr %dest, i64 %index, %arg1){ +; CHECK-LABEL: wide_add_shift_add_rshrnb_h: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: rshrnb z1.h, z1.s, #6 +; CHECK-NEXT: ld1h { z2.h }, p0/z, [x0, x1, lsl #1] +; CHECK-NEXT: rshrnb z0.h, z0.s, #6 +; CHECK-NEXT: uzp1 z0.h, z0.h, z1.h +; CHECK-NEXT: add z0.h, z2.h, z0.h +; CHECK-NEXT: st1h { z0.h }, p0, [x0, x1, lsl #1] +; CHECK-NEXT: ret + %1 = add %arg1, shufflevector ( insertelement ( poison, i32 32, i64 0), poison, zeroinitializer) + %2 = lshr %1, shufflevector ( insertelement ( poison, i32 6, i64 0), poison, zeroinitializer) + %3 = getelementptr inbounds i16, ptr %dest, i64 %index + %load = load , ptr %3, align 2 + %4 = trunc %2 to + %5 = add %load, %4 + store %5, ptr %3, align 2 + ret void +} + +define void @neg_trunc_lsr_add_op1_not_splat(ptr %ptr, ptr %dst, i64 %index, %add_op1){ +; CHECK-LABEL: neg_trunc_lsr_add_op1_not_splat: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: ld1h { z1.h }, p0/z, [x0] +; CHECK-NEXT: add z0.h, z1.h, z0.h +; CHECK-NEXT: lsr z0.h, z0.h, #6 +; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2] +; CHECK-NEXT: ret + %load = load , ptr %ptr, align 2 + %1 = add %load, %add_op1 + %2 = lshr %1, shufflevector ( insertelement ( poison, i16 6, i64 0), poison, zeroinitializer) + %3 = trunc %2 to + %4 = getelementptr inbounds i8, ptr %dst, i64 %index + store %3, ptr %4, align 1 + ret void +} + +define void @neg_trunc_lsr_op1_not_splat(ptr %ptr, ptr %dst, i64 %index, %lshr_op1){ +; CHECK-LABEL: neg_trunc_lsr_op1_not_splat: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: ld1h { z1.h }, p0/z, [x0] +; CHECK-NEXT: add z1.h, z1.h, #32 // =0x20 +; CHECK-NEXT: lsrr z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2] +; CHECK-NEXT: ret + %load = load , ptr %ptr, align 2 + %1 = add %load, shufflevector ( insertelement ( poison, i16 32, i64 0), poison, zeroinitializer) + %2 = lshr %1, %lshr_op1 + %3 = trunc %2 to + %4 = getelementptr inbounds i8, ptr %dst, i64 %index + store %3, ptr %4, align 1 + ret void +} + +define void @neg_add_has_two_uses(ptr %ptr, ptr %dst, ptr %dst2, i64 %index){ +; CHECK-LABEL: neg_add_has_two_uses: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0] +; CHECK-NEXT: add z0.h, z0.h, #32 // =0x20 +; CHECK-NEXT: lsr z1.h, z0.h, #6 +; CHECK-NEXT: add z0.h, z0.h, z0.h +; CHECK-NEXT: st1h { z0.h }, p0, [x2, x3, lsl #1] +; CHECK-NEXT: st1b { z1.h }, p0, [x1, x3] +; CHECK-NEXT: ret + %load = load , ptr %ptr, align 2 + %1 = add %load, trunc ( shufflevector ( insertelement ( poison, i32 32, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i32 6, i64 0), poison, zeroinitializer) to ) + %3 = add %1, %1 + %4 = getelementptr inbounds i16, ptr %dst2, i64 %index + %5 = trunc %2 to + %6 = getelementptr inbounds i8, ptr %dst, i64 %index + store %3, ptr %4, align 1 + store %5, ptr %6, align 1 + ret void +} + +define void @add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){ +; CHECK-LABEL: add_lshr_rshrnb_s: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0] +; CHECK-NEXT: rshrnb z0.s, z0.d, #6 +; CHECK-NEXT: st1w { z0.d }, p0, [x1, x2, lsl #2] +; CHECK-NEXT: ret + %load = load , ptr %ptr, align 2 + %1 = add %load, shufflevector ( insertelement ( poison, i64 32, i64 0), poison, zeroinitializer) + %2 = lshr %1, shufflevector ( insertelement ( poison, i64 6, i64 0), poison, zeroinitializer) + %3 = trunc %2 to + %4 = getelementptr inbounds i32, ptr %dst, i64 %index + store %3, ptr %4, align 1 + ret void +} + +define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){ +; CHECK-LABEL: neg_add_lshr_rshrnb_s: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0] +; CHECK-NEXT: add z0.d, z0.d, #32 // =0x20 +; CHECK-NEXT: lsr z0.d, z0.d, #6 +; CHECK-NEXT: st1h { z0.d }, p0, [x1, x2, lsl #1] +; CHECK-NEXT: ret + %load = load , ptr %ptr, align 2 + %1 = add %load, shufflevector ( insertelement ( poison, i64 32, i64 0), poison, zeroinitializer) + %2 = lshr %1, shufflevector ( insertelement ( poison, i64 6, i64 0), poison, zeroinitializer) + %3 = trunc %2 to + %4 = getelementptr inbounds i16, ptr %dst, i64 %index + store %3, ptr %4, align 1 + ret void +}