diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -759,6 +759,7 @@ // Vector add and sub nodes may conceal a high-half opportunity. // Also, try to fold ADD into CSINC/CSINV.. setTargetDAGCombine(ISD::ADD); + setTargetDAGCombine(ISD::ABS); setTargetDAGCombine(ISD::SUB); setTargetDAGCombine(ISD::SRL); setTargetDAGCombine(ISD::XOR); @@ -10989,6 +10990,49 @@ return SDValue(); } +// Given a ABS node, detect the following pattern: +// (ABS (SUB (ZERO_EXTEND a), (ZERO_EXTEND b))). +// This is useful as it is the input into a SAD pattern. +static bool detectZextAbsDiff(const SDValue &Abs, SDValue &Op0, SDValue &Op1) { + SDValue AbsOp1 = Abs->getOperand(0); + if (AbsOp1.getOpcode() != ISD::SUB) + return false; + + Op0 = AbsOp1.getOperand(0); + Op1 = AbsOp1.getOperand(1); + + // Check if the operands of the sub are zero-extended from vectors of i8. + if (Op0.getOpcode() != ISD::ZERO_EXTEND || + Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 || + Op1.getOpcode() != ISD::ZERO_EXTEND || + Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8) + return false; + + return true; +} + +// Detect and combine pattern for unsigned absolute difference of i8 types +// Generates UABD instruction. +static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const AArch64Subtarget *Subtarget) { + EVT VT = N->getValueType(0); + if (VT != MVT::v16i32 && VT != MVT::v16i16) + return SDValue(); + + SDValue Op0, Op1; + if (!detectZextAbsDiff(SDValue(N, 0), Op0, Op1)) + return SDValue(); + + Op0 = Op0.getOperand(0); + Op1 = Op1.getOperand(0); + SDValue ABD = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, SDLoc(N), Op0->getValueType(0), + DAG.getConstant(Intrinsic::aarch64_neon_uabd, SDLoc(N), MVT::i32), Op0, + Op1); + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, ABD); +} + static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -14611,6 +14655,8 @@ default: LLVM_DEBUG(dbgs() << "Custom combining: skipping\n"); break; + case ISD::ABS: + return performABSCombine(N, DAG, DCI, Subtarget); case ISD::ADD: case ISD::SUB: return performAddSubLongCombine(N, DCI, DAG); diff --git a/llvm/test/CodeGen/AArch64/arm64-vabs.ll b/llvm/test/CodeGen/AArch64/arm64-vabs.ll --- a/llvm/test/CodeGen/AArch64/arm64-vabs.ll +++ b/llvm/test/CodeGen/AArch64/arm64-vabs.ll @@ -142,11 +142,11 @@ } declare i16 @llvm.experimental.vector.reduce.add.v16i16(<16 x i16>) +declare i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32>) -define i16 @uabdl8h_rdx(<16 x i8>* %a, <16 x i8>* %b) { -; CHECK-LABEL: uabdl8h_rdx -; CHECK: uabdl2.8h -; CHECK: uabdl.8h +define i16 @uabd16b_rdx(<16 x i8>* %a, <16 x i8>* %b) { +; CHECK-LABEL: uabd16b_rdx +; CHECK: uabd.16b %aload = load <16 x i8>, <16 x i8>* %a, align 1 %bload = load <16 x i8>, <16 x i8>* %b, align 1 %aext = zext <16 x i8> %aload to <16 x i16> @@ -159,6 +159,21 @@ ret i16 %reduced_v } +define i32 @uabd16b_rdx_i32(<16 x i8>* %a, <16 x i8>* %b) { +; CHECK-LABEL: uabd16b_rdx_i32 +; CHECK: uabd.16b + %aload = load <16 x i8>, <16 x i8>* %a, align 1 + %bload = load <16 x i8>, <16 x i8>* %b, align 1 + %aext = zext <16 x i8> %aload to <16 x i32> + %bext = zext <16 x i8> %bload to <16 x i32> + %abdiff = sub nsw <16 x i32> %aext, %bext + %abcmp = icmp slt <16 x i32> %abdiff, zeroinitializer + %ababs = sub nsw <16 x i32> zeroinitializer, %abdiff + %absel = select <16 x i1> %abcmp, <16 x i32> %ababs, <16 x i32> %abdiff + %reduced_v = call i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32> %absel) + ret i32 %reduced_v +} + declare i32 @llvm.experimental.vector.reduce.add.v8i32(<8 x i32>) define i32 @uabdl4s_rdx(<8 x i16>* %a, <8 x i16>* %b) {