diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -225,6 +225,10 @@ SRHADD, URHADD, + // Absolute difference + UABD, + SABD, + // Vector across-lanes min/max // Only the lower result lane is defined. SMINV, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -759,6 +759,7 @@ // Vector add and sub nodes may conceal a high-half opportunity. // Also, try to fold ADD into CSINC/CSINV.. setTargetDAGCombine(ISD::ADD); + setTargetDAGCombine(ISD::ABS); setTargetDAGCombine(ISD::SUB); setTargetDAGCombine(ISD::SRL); setTargetDAGCombine(ISD::XOR); @@ -1830,6 +1831,8 @@ MAKE_CASE(AArch64ISD::STNP) MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::INDEX_VECTOR) + MAKE_CASE(AArch64ISD::UABD) + MAKE_CASE(AArch64ISD::SABD) } #undef MAKE_CASE return nullptr; @@ -3659,6 +3662,15 @@ return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); } + + case Intrinsic::aarch64_neon_uabd: { + return DAG.getNode(AArch64ISD::UABD, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + } + case Intrinsic::aarch64_neon_sabd: { + return DAG.getNode(AArch64ISD::SABD, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + } } } @@ -11050,6 +11062,48 @@ return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot); } +// Given a ABS node, detect the following pattern: +// (ABS (SUB (EXTEND a), (EXTEND b))). +// Generates UABD/SABD instruction. +static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const AArch64Subtarget *Subtarget) { + SDValue AbsOp1 = N->getOperand(0); + SDValue Op0, Op1; + + if (AbsOp1.getOpcode() != ISD::SUB) + return SDValue(); + + Op0 = AbsOp1.getOperand(0); + Op1 = AbsOp1.getOperand(1); + + unsigned Opc0 = Op0.getOpcode(); + // Check if the operands of the sub are (zero|sign)-extended. + if (Opc0 != Op1.getOpcode() || + (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND)) + return SDValue(); + + EVT VectorT1 = Op0.getOperand(0).getValueType(); + EVT VectorT2 = Op1.getOperand(0).getValueType(); + // Check if vectors are of same type and valid size. + uint64_t Size = VectorT1.getFixedSizeInBits(); + if (VectorT1 != VectorT2 || (Size != 64 && Size != 128)) + return SDValue(); + + // Check if vector element types are valid. + EVT VT1 = VectorT1.getVectorElementType(); + if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32) + return SDValue(); + + Op0 = Op0.getOperand(0); + Op1 = Op1.getOperand(0); + unsigned ABDOpcode = + (Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD; + SDValue ABD = + DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1); + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD); +} + static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -12339,8 +12393,8 @@ if (DCI.isBeforeLegalizeOps()) return SDValue(); - SDValue LHS = N->getOperand(1); - SDValue RHS = N->getOperand(2); + SDValue LHS = N->getOperand((IID == Intrinsic::not_intrinsic) ? 0 : 1); + SDValue RHS = N->getOperand((IID == Intrinsic::not_intrinsic) ? 1 : 2); assert(LHS.getValueType().is64BitVector() && RHS.getValueType().is64BitVector() && "unexpected shape for long operation"); @@ -12358,6 +12412,9 @@ return SDValue(); } + if (IID == Intrinsic::not_intrinsic) + return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), LHS, RHS); + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), N->getValueType(0), N->getOperand(0), LHS, RHS); } @@ -12868,18 +12925,15 @@ // helps the backend to decide that an sabdl2 would be useful, saving a real // extract_high operation. if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND && - N->getOperand(0).getOpcode() == ISD::INTRINSIC_WO_CHAIN) { + (N->getOperand(0).getOpcode() == AArch64ISD::UABD || + N->getOperand(0).getOpcode() == AArch64ISD::SABD)) { SDNode *ABDNode = N->getOperand(0).getNode(); - unsigned IID = getIntrinsicID(ABDNode); - if (IID == Intrinsic::aarch64_neon_sabd || - IID == Intrinsic::aarch64_neon_uabd) { - SDValue NewABD = tryCombineLongOpWithDup(IID, ABDNode, DCI, DAG); - if (!NewABD.getNode()) - return SDValue(); + SDValue NewABD = + tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG); + if (!NewABD.getNode()) + return SDValue(); - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), - NewABD); - } + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), NewABD); } // This is effectively a custom type legalization for AArch64. @@ -14672,6 +14726,8 @@ default: LLVM_DEBUG(dbgs() << "Custom combining: skipping\n"); break; + case ISD::ABS: + return performABSCombine(N, DAG, DCI, Subtarget); case ISD::ADD: case ISD::SUB: return performAddSubLongCombine(N, DCI, DAG); diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -559,6 +559,16 @@ def AArch64shadd : SDNode<"AArch64ISD::SHADD", SDT_AArch64binvec>; def AArch64uhadd : SDNode<"AArch64ISD::UHADD", SDT_AArch64binvec>; +def AArch64uabd_n : SDNode<"AArch64ISD::UABD", SDT_AArch64binvec>; +def AArch64sabd_n : SDNode<"AArch64ISD::SABD", SDT_AArch64binvec>; + +def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs), + [(AArch64uabd_n node:$lhs, node:$rhs), + (int_aarch64_neon_uabd node:$lhs, node:$rhs)]>; +def AArch64sabd : PatFrags<(ops node:$lhs, node:$rhs), + [(AArch64sabd_n node:$lhs, node:$rhs), + (int_aarch64_neon_sabd node:$lhs, node:$rhs)]>; + def SDT_AArch64SETTAG : SDTypeProfile<0, 2, [SDTCisPtrTy<0>, SDTCisPtrTy<1>]>; def AArch64stg : SDNode<"AArch64ISD::STG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; def AArch64stzg : SDNode<"AArch64ISD::STZG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; @@ -3812,7 +3822,7 @@ //===----------------------------------------------------------------------===// defm UABDL : SIMDLongThreeVectorBHSabdl<1, 0b0111, "uabdl", - int_aarch64_neon_uabd>; + AArch64uabd>; // Match UABDL in log2-shuffle patterns. def : Pat<(abs (v8i16 (sub (zext (v8i8 V64:$opA)), (zext (v8i8 V64:$opB))))), @@ -4082,8 +4092,8 @@ defm MUL : SIMDThreeSameVectorBHS<0, 0b10011, "mul", mul>; defm PMUL : SIMDThreeSameVectorB<1, 0b10011, "pmul", int_aarch64_neon_pmul>; defm SABA : SIMDThreeSameVectorBHSTied<0, 0b01111, "saba", - TriOpFrag<(add node:$LHS, (int_aarch64_neon_sabd node:$MHS, node:$RHS))> >; -defm SABD : SIMDThreeSameVectorBHS<0,0b01110,"sabd", int_aarch64_neon_sabd>; + TriOpFrag<(add node:$LHS, (AArch64sabd node:$MHS, node:$RHS))> >; +defm SABD : SIMDThreeSameVectorBHS<0,0b01110,"sabd", AArch64sabd>; defm SHADD : SIMDThreeSameVectorBHS<0,0b00000,"shadd", AArch64shadd>; defm SHSUB : SIMDThreeSameVectorBHS<0,0b00100,"shsub", int_aarch64_neon_shsub>; defm SMAXP : SIMDThreeSameVectorBHS<0,0b10100,"smaxp", int_aarch64_neon_smaxp>; @@ -4101,8 +4111,8 @@ defm SSHL : SIMDThreeSameVector<0,0b01000,"sshl", int_aarch64_neon_sshl>; defm SUB : SIMDThreeSameVector<1,0b10000,"sub", sub>; defm UABA : SIMDThreeSameVectorBHSTied<1, 0b01111, "uaba", - TriOpFrag<(add node:$LHS, (int_aarch64_neon_uabd node:$MHS, node:$RHS))> >; -defm UABD : SIMDThreeSameVectorBHS<1,0b01110,"uabd", int_aarch64_neon_uabd>; + TriOpFrag<(add node:$LHS, (AArch64uabd node:$MHS, node:$RHS))> >; +defm UABD : SIMDThreeSameVectorBHS<1,0b01110,"uabd", AArch64uabd>; defm UHADD : SIMDThreeSameVectorBHS<1,0b00000,"uhadd", AArch64uhadd>; defm UHSUB : SIMDThreeSameVectorBHS<1,0b00100,"uhsub", int_aarch64_neon_uhsub>; defm UMAXP : SIMDThreeSameVectorBHS<1,0b10100,"umaxp", int_aarch64_neon_umaxp>; @@ -4676,9 +4686,9 @@ defm RSUBHN : SIMDNarrowThreeVectorBHS<1,0b0110,"rsubhn",int_aarch64_neon_rsubhn>; defm PMULL : SIMDDifferentThreeVectorBD<0,0b1110,"pmull",int_aarch64_neon_pmull>; defm SABAL : SIMDLongThreeVectorTiedBHSabal<0,0b0101,"sabal", - int_aarch64_neon_sabd>; + AArch64sabd>; defm SABDL : SIMDLongThreeVectorBHSabdl<0, 0b0111, "sabdl", - int_aarch64_neon_sabd>; + AArch64sabd>; defm SADDL : SIMDLongThreeVectorBHS< 0, 0b0000, "saddl", BinOpFrag<(add (sext node:$LHS), (sext node:$RHS))>>; defm SADDW : SIMDWideThreeVectorBHS< 0, 0b0001, "saddw", @@ -4699,7 +4709,7 @@ defm SSUBW : SIMDWideThreeVectorBHS<0, 0b0011, "ssubw", BinOpFrag<(sub node:$LHS, (sext node:$RHS))>>; defm UABAL : SIMDLongThreeVectorTiedBHSabal<1, 0b0101, "uabal", - int_aarch64_neon_uabd>; + AArch64uabd>; defm UADDL : SIMDLongThreeVectorBHS<1, 0b0000, "uaddl", BinOpFrag<(add (zext node:$LHS), (zext node:$RHS))>>; defm UADDW : SIMDWideThreeVectorBHS<1, 0b0001, "uaddw", diff --git a/llvm/test/CodeGen/AArch64/arm64-vabs.ll b/llvm/test/CodeGen/AArch64/arm64-vabs.ll --- a/llvm/test/CodeGen/AArch64/arm64-vabs.ll +++ b/llvm/test/CodeGen/AArch64/arm64-vabs.ll @@ -142,11 +142,11 @@ } declare i16 @llvm.vector.reduce.add.v16i16(<16 x i16>) +declare i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32>) -define i16 @uabdl8h_rdx(<16 x i8>* %a, <16 x i8>* %b) { -; CHECK-LABEL: uabdl8h_rdx -; CHECK: uabdl2.8h -; CHECK: uabdl.8h +define i16 @uabd16b_rdx(<16 x i8>* %a, <16 x i8>* %b) { +; CHECK-LABEL: uabd16b_rdx +; CHECK: uabd.16b %aload = load <16 x i8>, <16 x i8>* %a, align 1 %bload = load <16 x i8>, <16 x i8>* %b, align 1 %aext = zext <16 x i8> %aload to <16 x i16> @@ -159,12 +159,39 @@ ret i16 %reduced_v } +define i32 @uabd16b_rdx_i32(<16 x i8> %a, <16 x i8> %b) { +; CHECK-LABEL: uabd16b_rdx_i32 +; CHECK: uabd.16b + %aext = zext <16 x i8> %a to <16 x i32> + %bext = zext <16 x i8> %b to <16 x i32> + %abdiff = sub nsw <16 x i32> %aext, %bext + %abcmp = icmp slt <16 x i32> %abdiff, zeroinitializer + %ababs = sub nsw <16 x i32> zeroinitializer, %abdiff + %absel = select <16 x i1> %abcmp, <16 x i32> %ababs, <16 x i32> %abdiff + %reduced_v = call i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32> %absel) + ret i32 %reduced_v +} + +define i32 @sabd16b_rdx_i32(<16 x i8> %a, <16 x i8> %b) { +; CHECK-LABEL: sabd16b_rdx_i32 +; CHECK: sabd.16b + %aext = sext <16 x i8> %a to <16 x i32> + %bext = sext <16 x i8> %b to <16 x i32> + %abdiff = sub nsw <16 x i32> %aext, %bext + %abcmp = icmp slt <16 x i32> %abdiff, zeroinitializer + %ababs = sub nsw <16 x i32> zeroinitializer, %abdiff + %absel = select <16 x i1> %abcmp, <16 x i32> %ababs, <16 x i32> %abdiff + %reduced_v = call i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32> %absel) + ret i32 %reduced_v +} + + declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>) +declare i32 @llvm.experimental.vector.reduce.add.v4i32(<4 x i32>) -define i32 @uabdl4s_rdx(<8 x i16>* %a, <8 x i16>* %b) { -; CHECK-LABEL: uabdl4s_rdx -; CHECK: uabdl2.4s -; CHECK: uabdl.4s +define i32 @uabd8h_rdx(<8 x i16>* %a, <8 x i16>* %b) { +; CHECK-LABEL: uabd8h_rdx +; CHECK: uabd.8h %aload = load <8 x i16>, <8 x i16>* %a, align 1 %bload = load <8 x i16>, <8 x i16>* %b, align 1 %aext = zext <8 x i16> %aload to <8 x i32> @@ -177,12 +204,38 @@ ret i32 %reduced_v } +define i32 @sabd8h_rdx(<8 x i16> %a, <8 x i16> %b) { +; CHECK-LABEL: sabd8h_rdx +; CHECK: sabd.8h + %aext = sext <8 x i16> %a to <8 x i32> + %bext = sext <8 x i16> %b to <8 x i32> + %abdiff = sub nsw <8 x i32> %aext, %bext + %abcmp = icmp slt <8 x i32> %abdiff, zeroinitializer + %ababs = sub nsw <8 x i32> zeroinitializer, %abdiff + %absel = select <8 x i1> %abcmp, <8 x i32> %ababs, <8 x i32> %abdiff + %reduced_v = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %absel) + ret i32 %reduced_v +} + +define i32 @uabdl4s_rdx_i32(<4 x i16> %a, <4 x i16> %b) { +; CHECK-LABEL: uabdl4s_rdx_i32 +; CHECK: uabdl.4s + %aext = zext <4 x i16> %a to <4 x i32> + %bext = zext <4 x i16> %b to <4 x i32> + %abdiff = sub nsw <4 x i32> %aext, %bext + %abcmp = icmp slt <4 x i32> %abdiff, zeroinitializer + %ababs = sub nsw <4 x i32> zeroinitializer, %abdiff + %absel = select <4 x i1> %abcmp, <4 x i32> %ababs, <4 x i32> %abdiff + %reduced_v = call i32 @llvm.experimental.vector.reduce.add.v4i32(<4 x i32> %absel) + ret i32 %reduced_v +} + declare i64 @llvm.vector.reduce.add.v4i64(<4 x i64>) +declare i64 @llvm.experimental.vector.reduce.add.v2i64(<2 x i64>) -define i64 @uabdl2d_rdx(<4 x i32>* %a, <4 x i32>* %b, i32 %h) { -; CHECK: uabdl2d_rdx -; CHECK: uabdl2.2d -; CHECK: uabdl.2d +define i64 @uabd4s_rdx(<4 x i32>* %a, <4 x i32>* %b, i32 %h) { +; CHECK: uabd4s_rdx +; CHECK: uabd.4s %aload = load <4 x i32>, <4 x i32>* %a, align 1 %bload = load <4 x i32>, <4 x i32>* %b, align 1 %aext = zext <4 x i32> %aload to <4 x i64> @@ -195,6 +248,32 @@ ret i64 %reduced_v } +define i64 @sabd4s_rdx(<4 x i32> %a, <4 x i32> %b) { +; CHECK: sabd4s_rdx +; CHECK: sabd.4s + %aext = sext <4 x i32> %a to <4 x i64> + %bext = sext <4 x i32> %b to <4 x i64> + %abdiff = sub nsw <4 x i64> %aext, %bext + %abcmp = icmp slt <4 x i64> %abdiff, zeroinitializer + %ababs = sub nsw <4 x i64> zeroinitializer, %abdiff + %absel = select <4 x i1> %abcmp, <4 x i64> %ababs, <4 x i64> %abdiff + %reduced_v = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> %absel) + ret i64 %reduced_v +} + +define i64 @uabdl2d_rdx_i64(<2 x i32> %a, <2 x i32> %b) { +; CHECK-LABEL: uabdl2d_rdx_i64 +; CHECK: uabdl.2d + %aext = zext <2 x i32> %a to <2 x i64> + %bext = zext <2 x i32> %b to <2 x i64> + %abdiff = sub nsw <2 x i64> %aext, %bext + %abcmp = icmp slt <2 x i64> %abdiff, zeroinitializer + %ababs = sub nsw <2 x i64> zeroinitializer, %abdiff + %absel = select <2 x i1> %abcmp, <2 x i64> %ababs, <2 x i64> %abdiff + %reduced_v = call i64 @llvm.experimental.vector.reduce.add.v2i64(<2 x i64> %absel) + ret i64 %reduced_v +} + define <2 x float> @fabd_2s(<2 x float>* %A, <2 x float>* %B) nounwind { ;CHECK-LABEL: fabd_2s: ;CHECK: fabd.2s