Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -28748,9 +28748,146 @@ DAG.getConstant(0, DL, OtherVal.getValueType()), NewCmp); } +static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + + if (!VT.isVector() || !VT.isSimple() || + !(VT.getVectorElementType() == MVT::i32)) + return SDValue(); + + unsigned RegSize = 128; + if (Subtarget.hasAVX512()) + RegSize = 512; + else if (Subtarget.hasAVX2()) + RegSize = 256; + + // We only handle v16i32 for SSE2 / v32i32 for AVX2 / v64i32 for AVX512. + if (VT.getSizeInBits() / 4 > RegSize) + return SDValue(); + + // Detect the following pattern: + // + // 1: %2 = zext %0 to + // 2: %3 = zext %1 to + // 3: %4 = sub nsw %2, %3 + // 4: %5 = icmp sgt %4, [0 x N] or [-1 x N] + // 5: %6 = sub nsw zeroinitializer, %4 + // 6: %7 = select %5, %4, %6 + // 7: %8 = add nsw %7, %vec.phi + // + // The last instruction must be a reduction add. The instructions 3-6 forms an + // ABSDIFF pattern. + + // The two operands of reduction add are from PHI and a select-op as in line 7 + // above. + SDValue SelectOp, Phi; + if (Op0.getOpcode() == ISD::VSELECT) { + SelectOp = Op0; + Phi = Op1; + } else if (Op1.getOpcode() == ISD::VSELECT) { + SelectOp = Op1; + Phi = Op0; + } else + return SDValue(); + + // Check the condition of the select instruction is greater-than. + SDValue SetCC = SelectOp->getOperand(0); + if (SetCC.getOpcode() != ISD::SETCC) + return SDValue(); + ISD::CondCode CC = cast(SetCC.getOperand(2))->get(); + if (CC != ISD::SETGT) + return SDValue(); + + Op0 = SelectOp->getOperand(1); + Op1 = SelectOp->getOperand(2); + + // The second operand of SelectOp Op1 is the negation of the first operand + // Op0, which is implementes as 0 - Op0. + if (!(Op1.getOpcode() == ISD::SUB && + ISD::isBuildVectorAllZeros(Op1.getOperand(0).getNode()) && + Op1.getOperand(1) == Op0)) + return SDValue(); + + // The first operand of SetCC is the first operand of SelectOp, which is the + // difference between two input vectors. + if (SetCC.getOperand(0) != Op0) + return SDValue(); + + // The second operand of > comparison can be either -1 or 0. + if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) || + ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode()))) + return SDValue(); + + // The first operand of SelectOp is the difference between two input vectors. + if (Op0.getOpcode() != ISD::SUB) + return SDValue(); + + Op1 = Op0.getOperand(1); + Op0 = Op0.getOperand(0); + + // Check if the operands of the diff are zero-extended from vectors of i8. + if (Op0.getOpcode() != ISD::ZERO_EXTEND || + Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 || + Op1.getOpcode() != ISD::ZERO_EXTEND || + Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8) + return SDValue(); + + // SAD pattern detected. Now build a SAD instruction and an addition for + // reduction. Note that the number of elments of the result of SAD is less + // than the number of elements of its input. Therefore, we could only update + // part of elements in the reduction vector. + + // Legalize the type of the inputs of PSADBW. + EVT InVT = Op0.getOperand(0).getValueType(); + if (InVT.getSizeInBits() <= 128) + RegSize = 128; + else if (InVT.getSizeInBits() <= 256) + RegSize = 256; + + unsigned NumConcat = RegSize / InVT.getSizeInBits(); + SmallVector Ops( + NumConcat, DAG.getConstant(0, DL, Op0.getOperand(0).getValueType())); + Ops[0] = Op0.getOperand(0); + MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8); + Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); + Ops[0] = Op1.getOperand(0); + Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); + + // The output of PSADBW is a vector of i64. + MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64); + SDValue Sad = DAG.getNode(X86ISD::PSADBW, DL, SadVT, Op0, Op1); + + // We need to turn the vector of i64 into a vector of i32. + MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32); + Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad); + + NumConcat = VT.getSizeInBits() / ResVT.getSizeInBits(); + if (NumConcat > 1) { + // Update part of elements of the reduction vector. This is done by first + // extracting a sub-vector from it, updating this sub-vector, and inserting + // it back. + SDValue SubPhi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResVT, Phi, + DAG.getIntPtrConstant(0, DL)); + SDValue Res = DAG.getNode(ISD::ADD, DL, ResVT, Sad, SubPhi); + return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Phi, Res, + DAG.getIntPtrConstant(0, DL)); + } else + return DAG.getNode(ISD::ADD, DL, VT, Sad, Phi); +} + /// PerformADDCombine - Do target-specific dag combines on integer adds. static SDValue PerformAddCombine(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { + const SDNodeFlags *Flags = &cast(N)->Flags; + if (Flags->hasVectorReduction()) { + if (SDValue Sad = detectSADPattern(N, DAG, Subtarget)) + return Sad; + } + EVT VT = N->getValueType(0); SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); Index: test/CodeGen/X86/sad-avx2.ll =================================================================== --- /dev/null +++ test/CodeGen/X86/sad-avx2.ll @@ -0,0 +1,44 @@ +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx2 | FileCheck %s --check-prefix=AVX2 + +@a = global [1024 x i8] zeroinitializer, align 16 +@b = global [1024 x i8] zeroinitializer, align 16 + +define i32 @sad_avx2() { +; AVX2-LABEL: sad +; AVX2: vpsadbw %xmm1, %xmm2, %xmm1 +; AVX2: vpaddd %xmm0, %xmm1, %xmm1 +; +entry: + br label %vector.body + +vector.body: + %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] + %vec.phi = phi <8 x i32> [ zeroinitializer, %entry ], [ %10, %vector.body ] + %0 = getelementptr inbounds [1024 x i8], [1024 x i8]* @a, i64 0, i64 %index + %1 = bitcast i8* %0 to <8 x i8>* + %wide.load = load <8 x i8>, <8 x i8>* %1, align 8 + %2 = zext <8 x i8> %wide.load to <8 x i32> + %3 = getelementptr inbounds [1024 x i8], [1024 x i8]* @b, i64 0, i64 %index + %4 = bitcast i8* %3 to <8 x i8>* + %wide.load1 = load <8 x i8>, <8 x i8>* %4, align 8 + %5 = zext <8 x i8> %wide.load1 to <8 x i32> + %6 = sub nsw <8 x i32> %2, %5 + %7 = icmp sgt <8 x i32> %6, + %8 = sub nsw <8 x i32> zeroinitializer, %6 + %9 = select <8 x i1> %7, <8 x i32> %6, <8 x i32> %8 + %10 = add nsw <8 x i32> %9, %vec.phi + %index.next = add i64 %index, 8 + %11 = icmp eq i64 %index.next, 1024 + br i1 %11, label %middle.block, label %vector.body + +middle.block: + %.lcssa = phi <8 x i32> [ %10, %vector.body ] + %rdx.shuf = shufflevector <8 x i32> %.lcssa, <8 x i32> undef, <8 x i32> + %bin.rdx = add <8 x i32> %.lcssa, %rdx.shuf + %rdx.shuf2 = shufflevector <8 x i32> %bin.rdx, <8 x i32> undef, <8 x i32> + %bin.rdx3 = add <8 x i32> %bin.rdx, %rdx.shuf2 + %rdx.shuf4 = shufflevector <8 x i32> %bin.rdx3, <8 x i32> undef, <8 x i32> + %bin.rdx5 = add <8 x i32> %bin.rdx3, %rdx.shuf4 + %12 = extractelement <8 x i32> %bin.rdx5, i32 0 + ret i32 %12 +} Index: test/CodeGen/X86/sad-avx512.ll =================================================================== --- /dev/null +++ test/CodeGen/X86/sad-avx512.ll @@ -0,0 +1,46 @@ +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512bw | FileCheck %s --check-prefix=AVX512BW + +@a = global [1024 x i8] zeroinitializer, align 16 +@b = global [1024 x i8] zeroinitializer, align 16 + +define i32 @sad_avx512() { +; AVX512BW-LABEL: sad +; AVX512BW: vpsadbw {{.*}}, %xmm1, %xmm1 +; AVX512BW: vpaddd %xmm0, %xmm1, %xmm1 +; +entry: + br label %vector.body + +vector.body: + %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] + %vec.phi = phi <16 x i32> [ zeroinitializer, %entry ], [ %10, %vector.body ] + %0 = getelementptr inbounds [1024 x i8], [1024 x i8]* @a, i64 0, i64 %index + %1 = bitcast i8* %0 to <16 x i8>* + %wide.load = load <16 x i8>, <16 x i8>* %1, align 16 + %2 = zext <16 x i8> %wide.load to <16 x i32> + %3 = getelementptr inbounds [1024 x i8], [1024 x i8]* @b, i64 0, i64 %index + %4 = bitcast i8* %3 to <16 x i8>* + %wide.load1 = load <16 x i8>, <16 x i8>* %4, align 16 + %5 = zext <16 x i8> %wide.load1 to <16 x i32> + %6 = sub nsw <16 x i32> %2, %5 + %7 = icmp sgt <16 x i32> %6, + %8 = sub nsw <16 x i32> zeroinitializer, %6 + %9 = select <16 x i1> %7, <16 x i32> %6, <16 x i32> %8 + %10 = add nsw <16 x i32> %9, %vec.phi + %index.next = add i64 %index, 16 + %11 = icmp eq i64 %index.next, 1024 + br i1 %11, label %middle.block, label %vector.body + +middle.block: + %.lcssa = phi <16 x i32> [ %10, %vector.body ] + %rdx.shuf = shufflevector <16 x i32> %.lcssa, <16 x i32> undef, <16 x i32> + %bin.rdx = add <16 x i32> %.lcssa, %rdx.shuf + %rdx.shuf2 = shufflevector <16 x i32> %bin.rdx, <16 x i32> undef, <16 x i32> + %bin.rdx3 = add <16 x i32> %bin.rdx, %rdx.shuf2 + %rdx.shuf4 = shufflevector <16 x i32> %bin.rdx3, <16 x i32> undef, <16 x i32> + %bin.rdx5 = add <16 x i32> %bin.rdx3, %rdx.shuf4 + %rdx.shuf6 = shufflevector <16 x i32> %bin.rdx5, <16 x i32> undef, <16 x i32> + %bin.rdx7 = add <16 x i32> %bin.rdx5, %rdx.shuf6 + %12 = extractelement <16 x i32> %bin.rdx7, i32 0 + ret i32 %12 +} Index: test/CodeGen/X86/sad-sse2.ll =================================================================== --- /dev/null +++ test/CodeGen/X86/sad-sse2.ll @@ -0,0 +1,42 @@ +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+sse2 | FileCheck %s --check-prefix=SSE2 + +@a = global [1024 x i8] zeroinitializer, align 16 +@b = global [1024 x i8] zeroinitializer, align 16 + +define i32 @sad_sse2() { +; SSE2-LABEL: sad +; SSE2: psadbw %xmm1, %xmm2 +; SSE2: paddd %xmm2, %xmm0 +; +entry: + br label %vector.body + +vector.body: + %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] + %vec.phi = phi <4 x i32> [ zeroinitializer, %entry ], [ %10, %vector.body ] + %0 = getelementptr inbounds [1024 x i8], [1024 x i8]* @a, i64 0, i64 %index + %1 = bitcast i8* %0 to <4 x i8>* + %wide.load = load <4 x i8>, <4 x i8>* %1, align 4 + %2 = zext <4 x i8> %wide.load to <4 x i32> + %3 = getelementptr inbounds [1024 x i8], [1024 x i8]* @b, i64 0, i64 %index + %4 = bitcast i8* %3 to <4 x i8>* + %wide.load1 = load <4 x i8>, <4 x i8>* %4, align 4 + %5 = zext <4 x i8> %wide.load1 to <4 x i32> + %6 = sub nsw <4 x i32> %2, %5 + %7 = icmp sgt <4 x i32> %6, + %8 = sub nsw <4 x i32> zeroinitializer, %6 + %9 = select <4 x i1> %7, <4 x i32> %6, <4 x i32> %8 + %10 = add nsw <4 x i32> %9, %vec.phi + %index.next = add i64 %index, 4 + %11 = icmp eq i64 %index.next, 1024 + br i1 %11, label %middle.block, label %vector.body + +middle.block: + %.lcssa = phi <4 x i32> [ %10, %vector.body ] + %rdx.shuf = shufflevector <4 x i32> %.lcssa, <4 x i32> undef, <4 x i32> + %bin.rdx = add <4 x i32> %.lcssa, %rdx.shuf + %rdx.shuf2 = shufflevector <4 x i32> %bin.rdx, <4 x i32> undef, <4 x i32> + %bin.rdx3 = add <4 x i32> %bin.rdx, %rdx.shuf2 + %12 = extractelement <4 x i32> %bin.rdx3, i32 0 + ret i32 %12 +}