Index: llvm/include/llvm/CodeGen/ISDOpcodes.h =================================================================== --- llvm/include/llvm/CodeGen/ISDOpcodes.h +++ llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -579,6 +579,12 @@ RHADDS, RHADDU, + // ABDS/ABDU - Absolute difference - Return the absolute difference between + // two numbers interpreted as signed/unsigned. + // i.e trunc(abs(sext(Op0) - sext(Op1))). + ABDS, + ABDU, + /// [US]{MIN/MAX} - Binary minimum or maximum or signed or unsigned /// integers. SMIN, Index: llvm/include/llvm/Target/TargetSelectionDAG.td =================================================================== --- llvm/include/llvm/Target/TargetSelectionDAG.td +++ llvm/include/llvm/Target/TargetSelectionDAG.td @@ -358,6 +358,8 @@ def haddu : SDNode<"ISD::HADDU" , SDTIntBinOp, [SDNPCommutative]>; def rhadds : SDNode<"ISD::RHADDS" , SDTIntBinOp, [SDNPCommutative]>; def rhaddu : SDNode<"ISD::RHADDU" , SDTIntBinOp, [SDNPCommutative]>; +def abds : SDNode<"ISD::ABDS" , SDTIntBinOp, []>; +def abdu : SDNode<"ISD::ABDU" , SDTIntBinOp, []>; def smullohi : SDNode<"ISD::SMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>; def umullohi : SDNode<"ISD::UMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>; def sdiv : SDNode<"ISD::SDIV" , SDTIntBinOp>; Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -8845,6 +8845,40 @@ return SDValue(); } +// Given a ABS node, detect the following pattern: +// (ABS (SUB (EXTEND a), (EXTEND b))). +// Generates UABD/SABD instruction. +static SDValue combineAbsToABD(SDNode *N, SelectionDAG &DAG, + const TargetLowering &TLI) { + SDValue AbsOp1 = N->getOperand(0); + SDValue Op0, Op1; + + if (AbsOp1.getOpcode() != ISD::SUB) + return SDValue(); + + Op0 = AbsOp1.getOperand(0); + Op1 = AbsOp1.getOperand(1); + + unsigned Opc0 = Op0.getOpcode(); + // Check if the operands of the sub are (zero|sign)-extended. + if (Opc0 != Op1.getOpcode() || + (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND)) + return SDValue(); + + EVT VT1 = Op0.getOperand(0).getValueType(); + EVT VT2 = Op1.getOperand(0).getValueType(); + // Check if the operands are of same type and valid size. + unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU; + if (VT1 != VT2 || !TLI.isOperationLegalOrCustom(ABDOpcode, VT1)) + return SDValue(); + + Op0 = Op0.getOperand(0); + Op1 = Op1.getOperand(0); + SDValue ABD = + DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1); + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD); +} + SDValue DAGCombiner::visitABS(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); @@ -8858,6 +8892,10 @@ // fold (abs x) -> x iff not-negative if (DAG.SignBitIsZero(N0)) return N0; + + if (SDValue ABD = combineAbsToABD(N, DAG, TLI)) + return ABD; + return SDValue(); } Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -235,6 +235,8 @@ case ISD::HADDS: return "hadds"; case ISD::RHADDU: return "rhaddu"; case ISD::RHADDS: return "rhadds"; + case ISD::ABDS: return "abds"; + case ISD::ABDU: return "abdu"; case ISD::SDIV: return "sdiv"; case ISD::UDIV: return "udiv"; case ISD::SREM: return "srem"; Index: llvm/lib/CodeGen/TargetLoweringBase.cpp =================================================================== --- llvm/lib/CodeGen/TargetLoweringBase.cpp +++ llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -773,6 +773,9 @@ setOperationAction(ISD::HADDU, VT, Expand); setOperationAction(ISD::RHADDS, VT, Expand); setOperationAction(ISD::RHADDU, VT, Expand); + // Absolute difference + setOperationAction(ISD::ABDS, VT, Expand); + setOperationAction(ISD::ABDU, VT, Expand); // These default to Expand so they will be expanded to CTLZ/CTTZ by default. setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Expand); Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -217,10 +217,6 @@ SADDV, UADDV, - // Absolute difference - UABD, - SABD, - // Vector across-lanes min/max // Only the lower result lane is defined. SMINV, Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -978,6 +978,8 @@ setOperationAction(ISD::HADDU, VT, Legal); setOperationAction(ISD::RHADDS, VT, Legal); setOperationAction(ISD::RHADDU, VT, Legal); + setOperationAction(ISD::ABDS, VT, Legal); + setOperationAction(ISD::ABDU, VT, Legal); } // Vector reductions @@ -1914,8 +1916,6 @@ MAKE_CASE(AArch64ISD::STNP) MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::INDEX_VECTOR) - MAKE_CASE(AArch64ISD::UABD) - MAKE_CASE(AArch64ISD::SABD) } #undef MAKE_CASE return nullptr; @@ -3750,11 +3750,11 @@ } case Intrinsic::aarch64_neon_uabd: { - return DAG.getNode(AArch64ISD::UABD, dl, Op.getValueType(), + return DAG.getNode(ISD::ABDU, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); } case Intrinsic::aarch64_neon_sabd: { - return DAG.getNode(AArch64ISD::SABD, dl, Op.getValueType(), + return DAG.getNode(ISD::ABDS, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); } } @@ -11382,48 +11382,6 @@ return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot); } -// Given a ABS node, detect the following pattern: -// (ABS (SUB (EXTEND a), (EXTEND b))). -// Generates UABD/SABD instruction. -static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const AArch64Subtarget *Subtarget) { - SDValue AbsOp1 = N->getOperand(0); - SDValue Op0, Op1; - - if (AbsOp1.getOpcode() != ISD::SUB) - return SDValue(); - - Op0 = AbsOp1.getOperand(0); - Op1 = AbsOp1.getOperand(1); - - unsigned Opc0 = Op0.getOpcode(); - // Check if the operands of the sub are (zero|sign)-extended. - if (Opc0 != Op1.getOpcode() || - (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND)) - return SDValue(); - - EVT VectorT1 = Op0.getOperand(0).getValueType(); - EVT VectorT2 = Op1.getOperand(0).getValueType(); - // Check if vectors are of same type and valid size. - uint64_t Size = VectorT1.getFixedSizeInBits(); - if (VectorT1 != VectorT2 || (Size != 64 && Size != 128)) - return SDValue(); - - // Check if vector element types are valid. - EVT VT1 = VectorT1.getVectorElementType(); - if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32) - return SDValue(); - - Op0 = Op0.getOperand(0); - Op1 = Op1.getOperand(0); - unsigned ABDOpcode = - (Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD; - SDValue ABD = - DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1); - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD); -} - static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -13209,8 +13167,8 @@ // helps the backend to decide that an sabdl2 would be useful, saving a real // extract_high operation. if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND && - (N->getOperand(0).getOpcode() == AArch64ISD::UABD || - N->getOperand(0).getOpcode() == AArch64ISD::SABD)) { + (N->getOperand(0).getOpcode() == ISD::ABDU || + N->getOperand(0).getOpcode() == ISD::ABDS)) { SDNode *ABDNode = N->getOperand(0).getNode(); SDValue NewABD = tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG); @@ -15010,8 +14968,6 @@ default: LLVM_DEBUG(dbgs() << "Custom combining: skipping\n"); break; - case ISD::ABS: - return performABSCombine(N, DAG, DCI, Subtarget); case ISD::ADD: case ISD::SUB: return performAddSubCombine(N, DCI, DAG); Index: llvm/lib/Target/AArch64/AArch64InstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -550,14 +550,11 @@ def AArch64smaxv : SDNode<"AArch64ISD::SMAXV", SDT_AArch64UnaryVec>; def AArch64umaxv : SDNode<"AArch64ISD::UMAXV", SDT_AArch64UnaryVec>; -def AArch64uabd_n : SDNode<"AArch64ISD::UABD", SDT_AArch64binvec>; -def AArch64sabd_n : SDNode<"AArch64ISD::SABD", SDT_AArch64binvec>; - def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs), - [(AArch64uabd_n node:$lhs, node:$rhs), + [(abdu node:$lhs, node:$rhs), (int_aarch64_neon_uabd node:$lhs, node:$rhs)]>; def AArch64sabd : PatFrags<(ops node:$lhs, node:$rhs), - [(AArch64sabd_n node:$lhs, node:$rhs), + [(abds node:$lhs, node:$rhs), (int_aarch64_neon_sabd node:$lhs, node:$rhs)]>; def SDT_AArch64SETTAG : SDTypeProfile<0, 2, [SDTCisPtrTy<0>, SDTCisPtrTy<1>]>;