diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -188,6 +188,10 @@ SADDV, UADDV, + // Vector halving addition + SHADD, + UHADD, + // Vector rounding halving addition SRHADD, URHADD, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1462,6 +1462,8 @@ MAKE_CASE(AArch64ISD::UADDV) MAKE_CASE(AArch64ISD::SRHADD) MAKE_CASE(AArch64ISD::URHADD) + MAKE_CASE(AArch64ISD::SHADD) + MAKE_CASE(AArch64ISD::UHADD) MAKE_CASE(AArch64ISD::SMINV) MAKE_CASE(AArch64ISD::UMINV) MAKE_CASE(AArch64ISD::SMAXV) @@ -3299,9 +3301,16 @@ } case Intrinsic::aarch64_neon_srhadd: - case Intrinsic::aarch64_neon_urhadd: { - bool IsSignedAdd = IntNo == Intrinsic::aarch64_neon_srhadd; - unsigned Opcode = IsSignedAdd ? AArch64ISD::SRHADD : AArch64ISD::URHADD; + case Intrinsic::aarch64_neon_urhadd: + case Intrinsic::aarch64_neon_shadd: + case Intrinsic::aarch64_neon_uhadd: { + bool IsSignedAdd = (IntNo == Intrinsic::aarch64_neon_srhadd || + IntNo == Intrinsic::aarch64_neon_shadd); + bool IsRoundingAdd = (IntNo == Intrinsic::aarch64_neon_srhadd || + IntNo == Intrinsic::aarch64_neon_urhadd); + unsigned Opcode = + IsSignedAdd ? (IsRoundingAdd ? AArch64ISD::SRHADD : AArch64ISD::SHADD) + : (IsRoundingAdd ? AArch64ISD::URHADD : AArch64ISD::UHADD); return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); } @@ -8847,52 +8856,64 @@ } // Attempt to form urhadd(OpA, OpB) from -// truncate(vlshr(sub(zext(OpB), xor(zext(OpA), Ones(ElemSizeInBits))), 1)). -// The original form of this expression is -// truncate(srl(add(zext(OpB), add(zext(OpA), 1)), 1)) and before this function -// is called the srl will have been lowered to AArch64ISD::VLSHR and the -// ((OpA + OpB + 1) >> 1) expression will have been changed to (OpB - (~OpA)). -// This pass can also recognize a variant of this pattern that uses sign -// extension instead of zero extension and form a srhadd(OpA, OpB) from it. -SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op, - SelectionDAG &DAG) const { - EVT VT = Op.getValueType(); - - if (!VT.isVector() || VT.isScalableVector()) - return Op; - - if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType())) - return LowerFixedLengthVectorTruncateToSVE(Op, DAG); +// truncate(vlshr(sub(zext(OpB), xor(zext(OpA), Ones(ElemSizeInBits))), 1)) +// or uhadd(OpA, OpB) from truncate(vlshr(add(zext(OpA), zext(OpB)), 1)). +// The original form of the first expression is +// truncate(srl(add(zext(OpB), add(zext(OpA), 1)), 1)) and the +// (OpA + OpB + 1) subexpression will have been changed to (OpB - (~OpA)). +// Before this function is called the srl will have been lowered to +// AArch64ISD::VLSHR. +// This pass can also recognize signed variants of the patterns that use sign +// extension instead of zero extension and form a srhadd(OpA, OpB) or a +// shadd(OpA, OpB) from them. +static SDValue tryLowerToHalvingAdd(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); // Since we are looking for a right shift by a constant value of 1 and we are // operating on types at least 16 bits in length (sign/zero extended OpA and // OpB, which are at least 8 bits), it follows that the truncate will always // discard the shifted-in bit and therefore the right shift will be logical // regardless of the signedness of OpA and OpB. - SDValue Shift = Op.getOperand(0); + SDValue Shift = N->getOperand(0); if (Shift.getOpcode() != AArch64ISD::VLSHR) - return Op; + return SDValue(); // Is the right shift using an immediate value of 1? uint64_t ShiftAmount = Shift.getConstantOperandVal(1); if (ShiftAmount != 1) - return Op; + return SDValue(); - SDValue Sub = Shift->getOperand(0); - if (Sub.getOpcode() != ISD::SUB) - return Op; + SDValue ExtendOpA, ExtendOpB; + SDValue Sub = Shift.getOperand(0); + if (Sub.getOpcode() == ISD::SUB) { - SDValue Xor = Sub.getOperand(1); - if (Xor.getOpcode() != ISD::XOR) - return Op; + SDValue Xor = Sub.getOperand(1); + if (Xor.getOpcode() != ISD::XOR) + return SDValue(); + + // Is the XOR using a constant amount of all ones in the right hand side? + uint64_t C; + if (!isAllConstantBuildVector(Xor.getOperand(1), C)) + return SDValue(); + + unsigned ElemSizeInBits = VT.getScalarSizeInBits(); + APInt CAsAPInt(ElemSizeInBits, C); + if (CAsAPInt != APInt::getAllOnesValue(ElemSizeInBits)) + return SDValue(); + + ExtendOpA = Xor.getOperand(0); + ExtendOpB = Sub.getOperand(0); + } else if (Sub.getOpcode() == ISD::ADD) { + ExtendOpA = Sub.getOperand(0); + ExtendOpB = Sub.getOperand(1); + } else + return SDValue(); - SDValue ExtendOpA = Xor.getOperand(0); - SDValue ExtendOpB = Sub.getOperand(0); unsigned ExtendOpAOpc = ExtendOpA.getOpcode(); unsigned ExtendOpBOpc = ExtendOpB.getOpcode(); if (!(ExtendOpAOpc == ExtendOpBOpc && (ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND))) - return Op; + return SDValue(); // Is the result of the right shift being truncated to the same value type as // the original operands, OpA and OpB? @@ -8901,24 +8922,33 @@ EVT OpAVT = OpA.getValueType(); assert(ExtendOpA.getValueType() == ExtendOpB.getValueType()); if (!(VT == OpAVT && OpAVT == OpB.getValueType())) - return Op; + return SDValue(); - // Is the XOR using a constant amount of all ones in the right hand side? - uint64_t C; - if (!isAllConstantBuildVector(Xor.getOperand(1), C)) - return Op; + SDLoc DL(N); + bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND; + bool IsRHADD = Sub.getOpcode() == ISD::SUB; + unsigned HADDOpc = IsSignExtend + ? (IsRHADD ? AArch64ISD::SRHADD : AArch64ISD::SHADD) + : (IsRHADD ? AArch64ISD::URHADD : AArch64ISD::UHADD); + SDValue ResultHADD = DAG.getNode(HADDOpc, DL, VT, OpA, OpB); - unsigned ElemSizeInBits = VT.getScalarSizeInBits(); - APInt CAsAPInt(ElemSizeInBits, C); - if (CAsAPInt != APInt::getAllOnesValue(ElemSizeInBits)) + return ResultHADD; +} + +SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op, + SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + + if (!VT.isVector() || VT.isScalableVector()) return Op; - SDLoc DL(Op); - bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND; - unsigned RHADDOpc = IsSignExtend ? AArch64ISD::SRHADD : AArch64ISD::URHADD; - SDValue ResultURHADD = DAG.getNode(RHADDOpc, DL, VT, OpA, OpB); + if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType())) + return LowerFixedLengthVectorTruncateToSVE(Op, DAG); - return ResultURHADD; + if (SDValue Res = tryLowerToHalvingAdd(Op.getNode(), DAG)) + return Res; + + return Op; } SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op, @@ -11169,9 +11199,9 @@ if (DCI.isBeforeLegalizeOps()) return SDValue(); - // Optimise concat_vectors of two [us]rhadds that use extracted subvectors - // from the same original vectors. Combine these into a single [us]rhadd that - // operates on the two original vectors. Example: + // Optimise concat_vectors of two [us]rhadds or [us]hadds that use extracted + // subvectors from the same original vectors. Combine these into a single + // [us]rhadd or [us]hadd that operates on the two original vectors. Example: // (v16i8 (concat_vectors (v8i8 (urhadd (extract_subvector (v16i8 OpA, <0>), // extract_subvector (v16i8 OpB, // <0>))), @@ -11181,7 +11211,8 @@ // -> // (v16i8(urhadd(v16i8 OpA, v16i8 OpB))) if (N->getNumOperands() == 2 && N0Opc == N1Opc && - (N0Opc == AArch64ISD::URHADD || N0Opc == AArch64ISD::SRHADD)) { + (N0Opc == AArch64ISD::URHADD || N0Opc == AArch64ISD::SRHADD || + N0Opc == AArch64ISD::UHADD || N0Opc == AArch64ISD::SHADD)) { SDValue N00 = N0->getOperand(0); SDValue N01 = N0->getOperand(1); SDValue N10 = N1->getOperand(0); diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -556,6 +556,8 @@ def AArch64srhadd : SDNode<"AArch64ISD::SRHADD", SDT_AArch64binvec>; def AArch64urhadd : SDNode<"AArch64ISD::URHADD", SDT_AArch64binvec>; +def AArch64shadd : SDNode<"AArch64ISD::SHADD", SDT_AArch64binvec>; +def AArch64uhadd : SDNode<"AArch64ISD::UHADD", SDT_AArch64binvec>; def SDT_AArch64SETTAG : SDTypeProfile<0, 2, [SDTCisPtrTy<0>, SDTCisPtrTy<1>]>; def AArch64stg : SDNode<"AArch64ISD::STG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; @@ -4064,7 +4066,7 @@ defm SABA : SIMDThreeSameVectorBHSTied<0, 0b01111, "saba", TriOpFrag<(add node:$LHS, (int_aarch64_neon_sabd node:$MHS, node:$RHS))> >; defm SABD : SIMDThreeSameVectorBHS<0,0b01110,"sabd", int_aarch64_neon_sabd>; -defm SHADD : SIMDThreeSameVectorBHS<0,0b00000,"shadd", int_aarch64_neon_shadd>; +defm SHADD : SIMDThreeSameVectorBHS<0,0b00000,"shadd", AArch64shadd>; defm SHSUB : SIMDThreeSameVectorBHS<0,0b00100,"shsub", int_aarch64_neon_shsub>; defm SMAXP : SIMDThreeSameVectorBHS<0,0b10100,"smaxp", int_aarch64_neon_smaxp>; defm SMAX : SIMDThreeSameVectorBHS<0,0b01100,"smax", smax>; @@ -4083,7 +4085,7 @@ defm UABA : SIMDThreeSameVectorBHSTied<1, 0b01111, "uaba", TriOpFrag<(add node:$LHS, (int_aarch64_neon_uabd node:$MHS, node:$RHS))> >; defm UABD : SIMDThreeSameVectorBHS<1,0b01110,"uabd", int_aarch64_neon_uabd>; -defm UHADD : SIMDThreeSameVectorBHS<1,0b00000,"uhadd", int_aarch64_neon_uhadd>; +defm UHADD : SIMDThreeSameVectorBHS<1,0b00000,"uhadd", AArch64uhadd>; defm UHSUB : SIMDThreeSameVectorBHS<1,0b00100,"uhsub", int_aarch64_neon_uhsub>; defm UMAXP : SIMDThreeSameVectorBHS<1,0b10100,"umaxp", int_aarch64_neon_umaxp>; defm UMAX : SIMDThreeSameVectorBHS<1,0b01100,"umax", umax>; diff --git a/llvm/test/CodeGen/AArch64/arm64-vhadd.ll b/llvm/test/CodeGen/AArch64/arm64-vhadd.ll --- a/llvm/test/CodeGen/AArch64/arm64-vhadd.ll +++ b/llvm/test/CodeGen/AArch64/arm64-vhadd.ll @@ -425,6 +425,96 @@ ret void } +define void @testLowerToSHADD8b(<8 x i8> %src1, <8 x i8> %src2, <8 x i8>* %dest) nounwind { +; CHECK-LABEL: testLowerToSHADD8b: +; CHECK: // %bb.0: +; CHECK-NEXT: shadd.8b v0, v0, v1 +; CHECK-NEXT: str d0, [x0] +; CHECK-NEXT: ret + %sextsrc1 = sext <8 x i8> %src1 to <8 x i16> + %sextsrc2 = sext <8 x i8> %src2 to <8 x i16> + %add = add <8 x i16> %sextsrc1, %sextsrc2 + %resulti16 = lshr <8 x i16> %add, + %result = trunc <8 x i16> %resulti16 to <8 x i8> + store <8 x i8> %result, <8 x i8>* %dest, align 8 + ret void +} + +define void @testLowerToSHADD4h(<4 x i16> %src1, <4 x i16> %src2, <4 x i16>* %dest) nounwind { +; CHECK-LABEL: testLowerToSHADD4h: +; CHECK: // %bb.0: +; CHECK-NEXT: shadd.4h v0, v0, v1 +; CHECK-NEXT: str d0, [x0] +; CHECK-NEXT: ret + %sextsrc1 = sext <4 x i16> %src1 to <4 x i32> + %sextsrc2 = sext <4 x i16> %src2 to <4 x i32> + %add = add <4 x i32> %sextsrc1, %sextsrc2 + %resulti16 = lshr <4 x i32> %add, + %result = trunc <4 x i32> %resulti16 to <4 x i16> + store <4 x i16> %result, <4 x i16>* %dest, align 8 + ret void +} + +define void @testLowerToSHADD2s(<2 x i32> %src1, <2 x i32> %src2, <2 x i32>* %dest) nounwind { +; CHECK-LABEL: testLowerToSHADD2s: +; CHECK: // %bb.0: +; CHECK-NEXT: shadd.2s v0, v0, v1 +; CHECK-NEXT: str d0, [x0] +; CHECK-NEXT: ret + %sextsrc1 = sext <2 x i32> %src1 to <2 x i64> + %sextsrc2 = sext <2 x i32> %src2 to <2 x i64> + %add = add <2 x i64> %sextsrc1, %sextsrc2 + %resulti16 = lshr <2 x i64> %add, + %result = trunc <2 x i64> %resulti16 to <2 x i32> + store <2 x i32> %result, <2 x i32>* %dest, align 8 + ret void +} + +define void @testLowerToSHADD16b(<16 x i8> %src1, <16 x i8> %src2, <16 x i8>* %dest) nounwind { +; CHECK-LABEL: testLowerToSHADD16b: +; CHECK: // %bb.0: +; CHECK-NEXT: shadd.16b v0, v0, v1 +; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: ret + %sextsrc1 = sext <16 x i8> %src1 to <16 x i16> + %sextsrc2 = sext <16 x i8> %src2 to <16 x i16> + %add = add <16 x i16> %sextsrc1, %sextsrc2 + %resulti16 = lshr <16 x i16> %add, + %result = trunc <16 x i16> %resulti16 to <16 x i8> + store <16 x i8> %result, <16 x i8>* %dest, align 16 + ret void +} + +define void @testLowerToSHADD8h(<8 x i16> %src1, <8 x i16> %src2, <8 x i16>* %dest) nounwind { +; CHECK-LABEL: testLowerToSHADD8h: +; CHECK: // %bb.0: +; CHECK-NEXT: shadd.8h v0, v0, v1 +; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: ret + %sextsrc1 = sext <8 x i16> %src1 to <8 x i32> + %sextsrc2 = sext <8 x i16> %src2 to <8 x i32> + %add = add <8 x i32> %sextsrc1, %sextsrc2 + %resulti16 = lshr <8 x i32> %add, + %result = trunc <8 x i32> %resulti16 to <8 x i16> + store <8 x i16> %result, <8 x i16>* %dest, align 16 + ret void +} + +define void @testLowerToSHADD4s(<4 x i32> %src1, <4 x i32> %src2, <4 x i32>* %dest) nounwind { +; CHECK-LABEL: testLowerToSHADD4s: +; CHECK: // %bb.0: +; CHECK-NEXT: shadd.4s v0, v0, v1 +; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: ret + %sextsrc1 = sext <4 x i32> %src1 to <4 x i64> + %sextsrc2 = sext <4 x i32> %src2 to <4 x i64> + %add = add <4 x i64> %sextsrc1, %sextsrc2 + %resulti16 = lshr <4 x i64> %add, + %result = trunc <4 x i64> %resulti16 to <4 x i32> + store <4 x i32> %result, <4 x i32>* %dest, align 16 + ret void +} + define void @testLowerToURHADD8b(<8 x i8> %src1, <8 x i8> %src2, <8 x i8>* %dest) nounwind { ; CHECK-LABEL: testLowerToURHADD8b: ; CHECK: // %bb.0: @@ -521,6 +611,96 @@ ret void } +define void @testLowerToUHADD8b(<8 x i8> %src1, <8 x i8> %src2, <8 x i8>* %dest) nounwind { +; CHECK-LABEL: testLowerToUHADD8b: +; CHECK: // %bb.0: +; CHECK-NEXT: uhadd.8b v0, v0, v1 +; CHECK-NEXT: str d0, [x0] +; CHECK-NEXT: ret + %zextsrc1 = zext <8 x i8> %src1 to <8 x i16> + %zextsrc2 = zext <8 x i8> %src2 to <8 x i16> + %add = add <8 x i16> %zextsrc1, %zextsrc2 + %resulti16 = lshr <8 x i16> %add, + %result = trunc <8 x i16> %resulti16 to <8 x i8> + store <8 x i8> %result, <8 x i8>* %dest, align 8 + ret void +} + +define void @testLowerToUHADD4h(<4 x i16> %src1, <4 x i16> %src2, <4 x i16>* %dest) nounwind { +; CHECK-LABEL: testLowerToUHADD4h: +; CHECK: // %bb.0: +; CHECK-NEXT: uhadd.4h v0, v0, v1 +; CHECK-NEXT: str d0, [x0] +; CHECK-NEXT: ret + %zextsrc1 = zext <4 x i16> %src1 to <4 x i32> + %zextsrc2 = zext <4 x i16> %src2 to <4 x i32> + %add = add <4 x i32> %zextsrc1, %zextsrc2 + %resulti16 = lshr <4 x i32> %add, + %result = trunc <4 x i32> %resulti16 to <4 x i16> + store <4 x i16> %result, <4 x i16>* %dest, align 8 + ret void +} + +define void @testLowerToUHADD2s(<2 x i32> %src1, <2 x i32> %src2, <2 x i32>* %dest) nounwind { +; CHECK-LABEL: testLowerToUHADD2s: +; CHECK: // %bb.0: +; CHECK-NEXT: uhadd.2s v0, v0, v1 +; CHECK-NEXT: str d0, [x0] +; CHECK-NEXT: ret + %zextsrc1 = zext <2 x i32> %src1 to <2 x i64> + %zextsrc2 = zext <2 x i32> %src2 to <2 x i64> + %add = add <2 x i64> %zextsrc1, %zextsrc2 + %resulti16 = lshr <2 x i64> %add, + %result = trunc <2 x i64> %resulti16 to <2 x i32> + store <2 x i32> %result, <2 x i32>* %dest, align 8 + ret void +} + +define void @testLowerToUHADD16b(<16 x i8> %src1, <16 x i8> %src2, <16 x i8>* %dest) nounwind { +; CHECK-LABEL: testLowerToUHADD16b: +; CHECK: // %bb.0: +; CHECK-NEXT: uhadd.16b v0, v0, v1 +; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: ret + %zextsrc1 = zext <16 x i8> %src1 to <16 x i16> + %zextsrc2 = zext <16 x i8> %src2 to <16 x i16> + %add = add <16 x i16> %zextsrc1, %zextsrc2 + %resulti16 = lshr <16 x i16> %add, + %result = trunc <16 x i16> %resulti16 to <16 x i8> + store <16 x i8> %result, <16 x i8>* %dest, align 16 + ret void +} + +define void @testLowerToUHADD8h(<8 x i16> %src1, <8 x i16> %src2, <8 x i16>* %dest) nounwind { +; CHECK-LABEL: testLowerToUHADD8h: +; CHECK: // %bb.0: +; CHECK-NEXT: uhadd.8h v0, v0, v1 +; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: ret + %zextsrc1 = zext <8 x i16> %src1 to <8 x i32> + %zextsrc2 = zext <8 x i16> %src2 to <8 x i32> + %add = add <8 x i32> %zextsrc1, %zextsrc2 + %resulti16 = lshr <8 x i32> %add, + %result = trunc <8 x i32> %resulti16 to <8 x i16> + store <8 x i16> %result, <8 x i16>* %dest, align 16 + ret void +} + +define void @testLowerToUHADD4s(<4 x i32> %src1, <4 x i32> %src2, <4 x i32>* %dest) nounwind { +; CHECK-LABEL: testLowerToUHADD4s: +; CHECK: // %bb.0: +; CHECK-NEXT: uhadd.4s v0, v0, v1 +; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: ret + %zextsrc1 = zext <4 x i32> %src1 to <4 x i64> + %zextsrc2 = zext <4 x i32> %src2 to <4 x i64> + %add = add <4 x i64> %zextsrc1, %zextsrc2 + %resulti16 = lshr <4 x i64> %add, + %result = trunc <4 x i64> %resulti16 to <4 x i32> + store <4 x i32> %result, <4 x i32>* %dest, align 16 + ret void +} + declare <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8>, <8 x i8>) nounwind readnone declare <4 x i16> @llvm.aarch64.neon.srhadd.v4i16(<4 x i16>, <4 x i16>) nounwind readnone declare <2 x i32> @llvm.aarch64.neon.srhadd.v2i32(<2 x i32>, <2 x i32>) nounwind readnone