diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -20075,7 +20075,49 @@ return SDValue(); } -static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG) { +static SDValue trySimplifySrlAddToRshrnb(SDNode *Srl, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget, + EVT VT, EVT TruncatedVT) { + if (!Subtarget->hasSVE2() || Srl->getOpcode() != ISD::SRL || + !VT.isScalableVector() || !TruncatedVT.isScalableVector()) + return SDValue(); + + unsigned Opc; + EVT TruncatedElTy = TruncatedVT.getVectorElementType(); + EVT EltTy = VT.getVectorElementType(); + if (TruncatedElTy == MVT::i8 && EltTy == MVT::i16) + Opc = AArch64::RSHRNB_ZZI_B; + else if (TruncatedElTy == MVT::i16 && EltTy == MVT::i32) + Opc = AArch64::RSHRNB_ZZI_H; + else + return SDValue(); + + auto SrlOp1 = dyn_cast(DAG.getSplatValue(Srl->getOperand(1))); + if (!SrlOp1) + return SDValue(); + unsigned ShiftValue = SrlOp1->getZExtValue(); + + SDValue Add = Srl->getOperand(0); + if (Add->getOpcode() != ISD::ADD) + return SDValue(); + auto AddOp1 = dyn_cast(DAG.getSplatValue(Add->getOperand(1))); + if (!AddOp1) + return SDValue(); + int64_t AddValue = AddOp1->getZExtValue(); + + if (AddValue != 1 << (ShiftValue - 1)) + return SDValue(); + + SDLoc DL(Srl); + return SDValue( + DAG.getMachineNode(Opc, DL, VT, + {Add->getOperand(0), + DAG.getTargetConstant(ShiftValue, DL, MVT::i32)}), + 0); +} + +static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { SDLoc DL(N); SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); @@ -20108,6 +20150,33 @@ } } + if (Op0.getOpcode() == ISD::SRL) { + EVT DoubleVT = MVT::Other; + if (ResVT == MVT::nxv8i16) { + DoubleVT = MVT::nxv4i32; + } else if (ResVT == MVT::nxv16i8) { + DoubleVT = MVT::nxv8i16; + } else { + return SDValue(); + } + if (SDValue res = trySimplifySrlAddToRshrnb(cast(Op0), DAG, + Subtarget, DoubleVT, ResVT)) + return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, res, Op1); + } + if (Op1.getOpcode() == ISD::SRL) { + EVT DoubleVT = MVT::Other; + if (ResVT == MVT::nxv8i16) { + DoubleVT = MVT::nxv4i32; + } else if (ResVT == MVT::nxv16i8) { + DoubleVT = MVT::nxv8i16; + } else { + return SDValue(); + } + if (SDValue res = trySimplifySrlAddToRshrnb(cast(Op1), DAG, + Subtarget, DoubleVT, ResVT)) + return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, res); + } + // uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z) if (Op0.getOpcode() == AArch64ISD::UUNPKLO) { if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) { @@ -20724,6 +20793,13 @@ if (SDValue Store = combineBoolVectorAndTruncateStore(DAG, ST)) return Store; + if (ST->isTruncatingStore()) + if (SDValue RSHRNB = trySimplifySrlAddToRshrnb( + cast(ST->getOperand(1)), DAG, Subtarget, + ST->getValue().getValueType(), ST->getMemoryVT())) + return DAG.getTruncStore(ST->getChain(), ST, RSHRNB, ST->getBasePtr(), + ST->getMemoryVT(), ST->getMemOperand()); + return SDValue(); } @@ -23033,7 +23109,7 @@ case AArch64ISD::UUNPKHI: return performUnpackCombine(N, DAG, Subtarget); case AArch64ISD::UZP1: - return performUzpCombine(N, DAG); + return performUzpCombine(N, DAG, Subtarget); case AArch64ISD::SETCC_MERGE_ZERO: return performSetccMergeZeroCombine(N, DCI); case AArch64ISD::REINTERPRET_CAST: diff --git a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll @@ -0,0 +1,144 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2 < %s | FileCheck %s + +define void @add_lshr_rshrnb_b_6(ptr %ptr, ptr %dst, i64 %index12){ +; CHECK-LABEL: add_lshr_rshrnb_b_6: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0] +; CHECK-NEXT: rshrnb z0.b, z0.h, #6 +; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2] +; CHECK-NEXT: ret + %wide.load13 = load , ptr %ptr, align 2 + %1 = add %wide.load13, trunc ( shufflevector ( insertelement ( poison, i32 32, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i32 6, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i8, ptr %dst, i64 %index12 + store %3, ptr %4, align 1 + ret void +} + +define void @neg_add_lshr_rshrnb_b_6(ptr %ptr, ptr %dst, i64 %index12){ +; CHECK-LABEL: neg_add_lshr_rshrnb_b_6: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0] +; CHECK-NEXT: add z0.h, z0.h, #1 // =0x1 +; CHECK-NEXT: lsr z0.h, z0.h, #6 +; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2] +; CHECK-NEXT: ret + %wide.load13 = load , ptr %ptr, align 2 + %1 = add %wide.load13, trunc ( shufflevector ( insertelement ( poison, i32 1, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i32 6, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i8, ptr %dst, i64 %index12 + store %3, ptr %4, align 1 + ret void +} + +define void @add_lshr_rshrnb_h_7(ptr %ptr, ptr %dst, i64 %index12){ +; CHECK-LABEL: add_lshr_rshrnb_h_7: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0] +; CHECK-NEXT: rshrnb z0.b, z0.h, #7 +; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2] +; CHECK-NEXT: ret + %wide.load13 = load , ptr %ptr, align 2 + %1 = add %wide.load13, trunc ( shufflevector ( insertelement ( poison, i32 64, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i32 7, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i8, ptr %dst, i64 %index12 + store %3, ptr %4, align 1 + ret void +} + +define void @add_lshr_rshrn_h_6(ptr %ptr, ptr %dst, i64 %index12){ +; CHECK-LABEL: add_lshr_rshrn_h_6: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0] +; CHECK-NEXT: rshrnb z0.h, z0.s, #6 +; CHECK-NEXT: st1h { z0.s }, p0, [x1, x2, lsl #1] +; CHECK-NEXT: ret + %wide.load13 = load , ptr %ptr, align 2 + %1 = add %wide.load13, trunc ( shufflevector ( insertelement ( poison, i64 32, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i64 6, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i16, ptr %dst, i64 %index12 + store %3, ptr %4, align 1 + ret void +} + +define void @add_lshr_rshrnb_h_2(ptr %ptr, ptr %dst, i64 %index12){ +; CHECK-LABEL: add_lshr_rshrnb_h_2: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0] +; CHECK-NEXT: rshrnb z0.h, z0.s, #2 +; CHECK-NEXT: st1h { z0.s }, p0, [x1, x2, lsl #1] +; CHECK-NEXT: ret + %wide.load13 = load , ptr %ptr, align 2 + %1 = add %wide.load13, trunc ( shufflevector ( insertelement ( poison, i64 2, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i64 2, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i16, ptr %dst, i64 %index12 + store %3, ptr %4, align 1 + ret void +} + +define void @neg_add_lshr_rshrnb_h_0(ptr %ptr, ptr %dst, i64 %index12){ +; CHECK-LABEL: neg_add_lshr_rshrnb_h_0: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %wide.load13 = load , ptr %ptr, align 2 + %1 = add %wide.load13, trunc ( shufflevector ( insertelement ( poison, i64 1, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i64 -1, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i16, ptr %dst, i64 %index12 + store %3, ptr %4, align 1 + ret void +} + +define void @wide_add_shift_add_rshrnb_b(ptr %dest, i64 %index, %arg1){ +; CHECK-LABEL: wide_add_shift_add_rshrnb_b: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.b +; CHECK-NEXT: rshrnb z1.b, z1.h, #6 +; CHECK-NEXT: ld1b { z2.b }, p0/z, [x0, x1] +; CHECK-NEXT: rshrnb z0.b, z0.h, #6 +; CHECK-NEXT: uzp1 z0.b, z0.b, z1.b +; CHECK-NEXT: add z0.b, z2.b, z0.b +; CHECK-NEXT: st1b { z0.b }, p0, [x0, x1] +; CHECK-NEXT: ret + %1 = add %arg1, shufflevector ( insertelement ( poison, i16 32, i64 0), poison, zeroinitializer) + %2 = lshr %1, shufflevector ( insertelement ( poison, i16 6, i64 0), poison, zeroinitializer) + %3 = getelementptr inbounds i8, ptr %dest, i64 %index + %wide.load12 = load , ptr %3, align 2 + %4 = trunc %2 to + %5 = add %wide.load12, %4 + store %5, ptr %3, align 2 + ret void +} + +define void @wide_add_shift_add_rshrnb_h(ptr %dest, i64 %index, %arg1){ +; CHECK-LABEL: wide_add_shift_add_rshrnb_h: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: rshrnb z1.h, z1.s, #6 +; CHECK-NEXT: ld1h { z2.h }, p0/z, [x0, x1, lsl #1] +; CHECK-NEXT: rshrnb z0.h, z0.s, #6 +; CHECK-NEXT: uzp1 z0.h, z0.h, z1.h +; CHECK-NEXT: add z0.h, z2.h, z0.h +; CHECK-NEXT: st1h { z0.h }, p0, [x0, x1, lsl #1] +; CHECK-NEXT: ret + %1 = add %arg1, shufflevector ( insertelement ( poison, i32 32, i64 0), poison, zeroinitializer) + %2 = lshr %1, shufflevector ( insertelement ( poison, i32 6, i64 0), poison, zeroinitializer) + %3 = getelementptr inbounds i16, ptr %dest, i64 %index + %wide.load12 = load , ptr %3, align 2 + %4 = trunc %2 to + %5 = add %wide.load12, %4 + store %5, ptr %3, align 2 + ret void +} +