Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -39181,45 +39181,53 @@ // We know N is a reduction add, which means one of its operands is a phi. // To match SAD, we need the other operand to be a vector select. - SDValue SelectOp, Phi; - if (Op0.getOpcode() == ISD::VSELECT) { - SelectOp = Op0; - Phi = Op1; - } else if (Op1.getOpcode() == ISD::VSELECT) { - SelectOp = Op1; - Phi = Op0; - } else - return SDValue(); + if (Op0.getOpcode() != ISD::VSELECT) + std::swap(Op0, Op1); + if (Op0.getOpcode() != ISD::VSELECT) + return SDValue(); + + auto BuildPSADBW = [&](SDValue Op0, SDValue Op1) { + // SAD pattern detected. Now build a SAD instruction and an addition for + // reduction. Note that the number of elements of the result of SAD is less + // than the number of elements of its input. Therefore, we could only update + // part of elements in the reduction vector. + SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget); + + // The output of PSADBW is a vector of i64. + // We need to turn the vector of i64 into a vector of i32. + // If the reduction vector is at least as wide as the psadbw result, just + // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero + // anyway. + MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32); + if (VT.getSizeInBits() >= ResVT.getSizeInBits()) + Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad); + else + Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad); + + if (VT.getSizeInBits() > ResVT.getSizeInBits()) { + // Fill the upper elements with zero to match the add width. + SDValue Zero = DAG.getConstant(0, DL, VT); + Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad, + DAG.getIntPtrConstant(0, DL)); + } + + return Sad; + }; // Check whether we have an abs-diff pattern feeding into the select. - if(!detectZextAbsDiff(SelectOp, Op0, Op1)) - return SDValue(); - - // SAD pattern detected. Now build a SAD instruction and an addition for - // reduction. Note that the number of elements of the result of SAD is less - // than the number of elements of its input. Therefore, we could only update - // part of elements in the reduction vector. - SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget); - - // The output of PSADBW is a vector of i64. - // We need to turn the vector of i64 into a vector of i32. - // If the reduction vector is at least as wide as the psadbw result, just - // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero - // anyway. - MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32); - if (VT.getSizeInBits() >= ResVT.getSizeInBits()) - Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad); - else - Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad); + SDValue SadOp0, SadOp1; + if (!detectZextAbsDiff(Op0, SadOp0, SadOp1)) + return SDValue(); + + Op0 = BuildPSADBW(SadOp0, SadOp1); - if (VT.getSizeInBits() > ResVT.getSizeInBits()) { - // Fill the upper elements with zero to match the add width. - SDValue Zero = DAG.getConstant(0, DL, VT); - Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad, - DAG.getIntPtrConstant(0, DL)); + // It's possible we have a sad on the other side too. + if (Op1.getOpcode() == ISD::VSELECT && + detectZextAbsDiff(Op1, SadOp0, SadOp1)) { + Op1 = BuildPSADBW(SadOp0, SadOp1); } - return DAG.getNode(ISD::ADD, DL, VT, Sad, Phi); + return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1); } /// Convert vector increment or decrement to sub/add with an all-ones constant: Index: test/CodeGen/X86/sad.ll =================================================================== --- test/CodeGen/X86/sad.ll +++ test/CodeGen/X86/sad.ll @@ -1395,3 +1395,99 @@ %sum = extractelement <32 x i32> %sum3, i32 0 ret i32 %sum } + +; This test contains two absolute difference patterns joined by an add. The result of that add is then reduced to a single element. +; SelectionDAGBuilder should tag the joining add as a vector reduction. We neeed to recognize that both sides can use psadbw. +define i32 @sad_double_reduction(<16 x i8>* %arg, <16 x i8>* %arg1, <16 x i8>* %arg2, <16 x i8>* %arg3) { +; SSE2-LABEL: sad_double_reduction: +; SSE2: # %bb.0: # %bb +; SSE2-NEXT: movdqu (%rdi), %xmm0 +; SSE2-NEXT: movdqu (%rsi), %xmm1 +; SSE2-NEXT: psadbw %xmm0, %xmm1 +; SSE2-NEXT: movdqu (%rdx), %xmm0 +; SSE2-NEXT: movdqu (%rcx), %xmm2 +; SSE2-NEXT: psadbw %xmm0, %xmm2 +; SSE2-NEXT: paddd %xmm1, %xmm2 +; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm2[2,3,0,1] +; SSE2-NEXT: paddd %xmm2, %xmm0 +; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] +; SSE2-NEXT: paddd %xmm0, %xmm1 +; SSE2-NEXT: movd %xmm1, %eax +; SSE2-NEXT: retq +; +; AVX1-LABEL: sad_double_reduction: +; AVX1: # %bb.0: # %bb +; AVX1-NEXT: vmovdqu (%rdi), %xmm0 +; AVX1-NEXT: vmovdqu (%rdx), %xmm1 +; AVX1-NEXT: vpsadbw (%rsi), %xmm0, %xmm0 +; AVX1-NEXT: vpsadbw (%rcx), %xmm1, %xmm1 +; AVX1-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] +; AVX1-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX1-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX1-NEXT: vmovd %xmm0, %eax +; AVX1-NEXT: retq +; +; AVX2-LABEL: sad_double_reduction: +; AVX2: # %bb.0: # %bb +; AVX2-NEXT: vmovdqu (%rdi), %xmm0 +; AVX2-NEXT: vmovdqu (%rdx), %xmm1 +; AVX2-NEXT: vpsadbw (%rsi), %xmm0, %xmm0 +; AVX2-NEXT: vpsadbw (%rcx), %xmm1, %xmm1 +; AVX2-NEXT: vpaddd %ymm0, %ymm1, %ymm0 +; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVX2-NEXT: vpaddd %ymm1, %ymm0, %ymm0 +; AVX2-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] +; AVX2-NEXT: vpaddd %ymm1, %ymm0, %ymm0 +; AVX2-NEXT: vphaddd %ymm0, %ymm0, %ymm0 +; AVX2-NEXT: vmovd %xmm0, %eax +; AVX2-NEXT: vzeroupper +; AVX2-NEXT: retq +; +; AVX512-LABEL: sad_double_reduction: +; AVX512: # %bb.0: # %bb +; AVX512-NEXT: vmovdqu (%rdi), %xmm0 +; AVX512-NEXT: vmovdqu (%rdx), %xmm1 +; AVX512-NEXT: vpsadbw (%rsi), %xmm0, %xmm0 +; AVX512-NEXT: vpsadbw (%rcx), %xmm1, %xmm1 +; AVX512-NEXT: vpaddd %zmm0, %zmm1, %zmm0 +; AVX512-NEXT: vextracti64x4 $1, %zmm0, %ymm1 +; AVX512-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVX512-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] +; AVX512-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] +; AVX512-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512-NEXT: vmovd %xmm0, %eax +; AVX512-NEXT: vzeroupper +; AVX512-NEXT: retq +bb: + %tmp = load <16 x i8>, <16 x i8>* %arg, align 1 + %tmp4 = load <16 x i8>, <16 x i8>* %arg1, align 1 + %tmp5 = zext <16 x i8> %tmp to <16 x i32> + %tmp6 = zext <16 x i8> %tmp4 to <16 x i32> + %tmp7 = sub nsw <16 x i32> %tmp5, %tmp6 + %tmp8 = icmp slt <16 x i32> %tmp7, zeroinitializer + %tmp9 = sub nsw <16 x i32> zeroinitializer, %tmp7 + %tmp10 = select <16 x i1> %tmp8, <16 x i32> %tmp9, <16 x i32> %tmp7 + %tmp11 = load <16 x i8>, <16 x i8>* %arg2, align 1 + %tmp12 = load <16 x i8>, <16 x i8>* %arg3, align 1 + %tmp13 = zext <16 x i8> %tmp11 to <16 x i32> + %tmp14 = zext <16 x i8> %tmp12 to <16 x i32> + %tmp15 = sub nsw <16 x i32> %tmp13, %tmp14 + %tmp16 = icmp slt <16 x i32> %tmp15, zeroinitializer + %tmp17 = sub nsw <16 x i32> zeroinitializer, %tmp15 + %tmp18 = select <16 x i1> %tmp16, <16 x i32> %tmp17, <16 x i32> %tmp15 + %tmp19 = add nuw nsw <16 x i32> %tmp18, %tmp10 + %tmp20 = shufflevector <16 x i32> %tmp19, <16 x i32> undef, <16 x i32> + %tmp21 = add <16 x i32> %tmp19, %tmp20 + %tmp22 = shufflevector <16 x i32> %tmp21, <16 x i32> undef, <16 x i32> + %tmp23 = add <16 x i32> %tmp21, %tmp22 + %tmp24 = shufflevector <16 x i32> %tmp23, <16 x i32> undef, <16 x i32> + %tmp25 = add <16 x i32> %tmp23, %tmp24 + %tmp26 = shufflevector <16 x i32> %tmp25, <16 x i32> undef, <16 x i32> + %tmp27 = add <16 x i32> %tmp25, %tmp26 + %tmp28 = extractelement <16 x i32> %tmp27, i64 0 + ret i32 %tmp28 +}