Index: include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- include/llvm/CodeGen/SelectionDAGNodes.h +++ include/llvm/CodeGen/SelectionDAGNodes.h @@ -328,6 +328,7 @@ bool NoInfs : 1; bool NoSignedZeros : 1; bool AllowReciprocal : 1; + bool Reduction : 1; public: /// Default constructor turns off all optimization flags. @@ -340,6 +341,7 @@ NoInfs = false; NoSignedZeros = false; AllowReciprocal = false; + Reduction = false; } // These are mutators for each flag. @@ -351,6 +353,7 @@ void setNoInfs(bool b) { NoInfs = b; } void setNoSignedZeros(bool b) { NoSignedZeros = b; } void setAllowReciprocal(bool b) { AllowReciprocal = b; } + void setReduction(bool b) { Reduction = b; } // These are accessors for each flag. bool hasNoUnsignedWrap() const { return NoUnsignedWrap; } @@ -361,6 +364,7 @@ bool hasNoInfs() const { return NoInfs; } bool hasNoSignedZeros() const { return NoSignedZeros; } bool hasAllowReciprocal() const { return AllowReciprocal; } + bool hasReduction() const { return Reduction; } /// Return a raw encoding of the flags. /// This function should only be used to add data to the NodeID value. Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -2330,10 +2330,21 @@ if (const FPMathOperator *FPOp = dyn_cast(&I)) FMF = FPOp->getFastMathFlags(); + // Check if this binary op is a reduction. + auto IsReductionPHI = [](const Value *V) { + const PHINode *PN = dyn_cast(V); + if (PN && PN->getMetadata("llvm.loop.vectorize.reduction")) + return true; + return false; + }; + bool reduction = + IsReductionPHI(I.getOperand(0)) || IsReductionPHI(I.getOperand(1)); + SDNodeFlags Flags; Flags.setExact(exact); Flags.setNoSignedWrap(nsw); Flags.setNoUnsignedWrap(nuw); + Flags.setReduction(reduction); if (EnableFMFInDAG) { Flags.setAllowReciprocal(FMF.allowReciprocal()); Flags.setNoInfs(FMF.noInfs()); Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -26621,9 +26621,153 @@ DAG.getConstant(0, DL, OtherVal.getValueType()), NewCmp); } +// Check if the given SDValue is a constant vector with all Val. +static bool isConstVectorOf(SDValue V, int Val) { + BuildVectorSDNode *BV = dyn_cast(V); + if (!BV || !BV->isConstant()) + return false; + auto NumOperands = V.getNumOperands(); + for (unsigned i = 0; i < NumOperands; i++) { + ConstantSDNode *C = dyn_cast(V.getOperand(i)); + if (!C || C->getSExtValue() != Val) + return false; + } + return true; +} + +static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG, + const X86Subtarget *Subtarget) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + + if (!VT.isVector() || !VT.isSimple()) + return SDValue(); + unsigned NumElems = VT.getVectorNumElements(); + + if (!(VT.getVectorElementType() == MVT::i32 && isPowerOf2_32(NumElems))) + return SDValue(); + + unsigned RegSize = 128; + if (Subtarget->hasAVX512()) + RegSize = 512; + else if (Subtarget->hasAVX2()) + RegSize = 256; + + if (VT.getSizeInBits() > RegSize) + return SDValue(); + + // Detect the following pattern: + // + // 1: %2 = zext %0 to + // 2: %3 = zext %1 to + // 3: %4 = sub nsw %2, %3 + // 4: %5 = icmp sgt %4, [0 x N] or [-1 x N] + // 5: %6 = sub nsw zeroinitializer, %4 + // 6: %7 = select %5, %4, %6 + // 7: %8 = add nsw %7, %vec.phi + // + // The last instruction must be a reduction add. The instructions 3-6 forms an + // ABSDIFF pattern. + + // The two operands of reduction add are from PHI and a select-op as in line 7 + // above. + SDValue SelectOp, Phi; + if (Op0.getOpcode() == ISD::CopyFromReg && Op1.getOpcode() == ISD::VSELECT) { + SelectOp = Op1; + Phi = Op0; + } else if (Op1.getOpcode() == ISD::CopyFromReg && + Op0.getOpcode() == ISD::VSELECT) { + SelectOp = Op0; + Phi = Op1; + } else + return SDValue(); + + // Check the condition of the select instruction is greater-than. + SDValue SetCC = SelectOp->getOperand(0); + if (SetCC.getOpcode() != ISD::SETCC) + return SDValue(); + ISD::CondCode CC = cast(SetCC.getOperand(2))->get(); + if (CC != ISD::SETGT) + return SDValue(); + + Op0 = SelectOp->getOperand(1); + Op1 = SelectOp->getOperand(2); + + // The second operand of SelectOp Op1 is the negation of the first operand + // Op0, which is implementes as 0 - Op0. + if (!(Op1.getOpcode() == ISD::SUB && isConstVectorOf(Op1.getOperand(0), 0) && + Op1.getOperand(1) == Op0)) + return SDValue(); + + // The first operand of SetCC is the first operand of SelectOp, which is the + // difference between two input vectors. + if (SetCC.getOperand(0) != Op0) + return SDValue(); + + // The second operand of > comparison can be either -1 or 0. + if (!(isConstVectorOf(SetCC.getOperand(1), 0) || + isConstVectorOf(SetCC.getOperand(1), -1))) + return SDValue(); + + // The first operand of SelectOp is the difference between two input vectors. + if (Op0.getOpcode() != ISD::SUB) + return SDValue(); + + Op1 = Op0.getOperand(1); + Op0 = Op0.getOperand(0); + + // Check if the operands of the diff are zero-extended from vectors of i8. + if (Op0.getOpcode() != ISD::ZERO_EXTEND || + Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 || + Op1.getOpcode() != ISD::ZERO_EXTEND || + Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8) + return SDValue(); + + EVT InVT = Op0.getOperand(0).getValueType(); + if (InVT.getSizeInBits() <= 128) + RegSize = 128; + else if (InVT.getSizeInBits() <= 256) + RegSize = 256; + + // SAD pattern detected. Now build a SAD instruction and an addition for + // reduction. + unsigned NumConcat = + RegSize / Op0.getOperand(0).getValueType().getSizeInBits(); + SmallVector Ops( + NumConcat, DAG.getConstant(0, DL, Op0.getOperand(0).getValueType())); + Ops[0] = Op0.getOperand(0); + MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8); + Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); + Ops[0] = Op1.getOperand(0); + Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); + + SDValue Sad = DAG.getNode(X86ISD::PSADBW, DL, ExtendedVT, Op0, Op1); + MVT SadVT = MVT::getVectorVT(MVT::i32, RegSize / 32); + Sad = DAG.getNode(ISD::BITCAST, DL, SadVT, Sad); + + NumConcat = VT.getSizeInBits() / SadVT.getSizeInBits(); + if (NumConcat > 1) { + SDValue SubPhi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SadVT, Phi, + DAG.getIntPtrConstant(0, DL)); + SDValue Res = DAG.getNode(ISD::ADD, DL, SadVT, Sad, SubPhi); + return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Phi, Res, + DAG.getIntPtrConstant(0, DL)); + } else + return DAG.getNode(ISD::ADD, DL, VT, Sad, Phi); +} + /// PerformADDCombine - Do target-specific dag combines on integer adds. static SDValue PerformAddCombine(SDNode *N, SelectionDAG &DAG, const X86Subtarget *Subtarget) { + const SDNodeFlags *Flags = &cast(N)->Flags; + if (Flags->hasReduction()) { + SDValue SAD = detectSADPattern(N, DAG, Subtarget); + if (SAD.getNode()) + return SAD; + } + EVT VT = N->getValueType(0); SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); Index: lib/Target/X86/X86InstrSSE.td =================================================================== --- lib/Target/X86/X86InstrSSE.td +++ lib/Target/X86/X86InstrSSE.td @@ -4062,6 +4062,8 @@ SSE_INTALU_ITINS_P, 1, NoVLX_Or_NoBWI>; defm PMAXSW : PDI_binop_all<0xEE, "pmaxsw", smax, v8i16, v16i16, SSE_INTALU_ITINS_P, 1, NoVLX_Or_NoBWI>; +defm PSADBW : PDI_binop_all<0xF6, "psadbw", X86psadbw, v16i8, v32i8, + SSE_INTALU_ITINS_P, 1, NoVLX_Or_NoBWI>; // Intrinsic forms defm PSUBSB : PDI_binop_all_int<0xE8, "psubsb", int_x86_sse2_psubs_b, @@ -4082,8 +4084,6 @@ int_x86_avx2_pavg_b, SSE_INTALU_ITINS_P, 1>; defm PAVGW : PDI_binop_all_int<0xE3, "pavgw", int_x86_sse2_pavg_w, int_x86_avx2_pavg_w, SSE_INTALU_ITINS_P, 1>; -defm PSADBW : PDI_binop_all_int<0xF6, "psadbw", int_x86_sse2_psad_bw, - int_x86_avx2_psad_bw, SSE_PMADD, 1>; let Predicates = [HasAVX2] in def : Pat<(v32i8 (X86psadbw (v32i8 VR256:$src1), Index: lib/Target/X86/X86IntrinsicsInfo.h =================================================================== --- lib/Target/X86/X86IntrinsicsInfo.h +++ lib/Target/X86/X86IntrinsicsInfo.h @@ -282,6 +282,7 @@ X86_INTRINSIC_DATA(avx2_pmulh_w, INTR_TYPE_2OP, ISD::MULHS, 0), X86_INTRINSIC_DATA(avx2_pmulhu_w, INTR_TYPE_2OP, ISD::MULHU, 0), X86_INTRINSIC_DATA(avx2_pmulu_dq, INTR_TYPE_2OP, X86ISD::PMULUDQ, 0), + X86_INTRINSIC_DATA(avx2_psad_bw, INTR_TYPE_2OP, X86ISD::PSADBW, 0), X86_INTRINSIC_DATA(avx2_pshuf_b, INTR_TYPE_2OP, X86ISD::PSHUFB, 0), X86_INTRINSIC_DATA(avx2_psign_b, INTR_TYPE_2OP, X86ISD::PSIGN, 0), X86_INTRINSIC_DATA(avx2_psign_d, INTR_TYPE_2OP, X86ISD::PSIGN, 0), @@ -1694,6 +1695,7 @@ X86_INTRINSIC_DATA(sse2_pmulh_w, INTR_TYPE_2OP, ISD::MULHS, 0), X86_INTRINSIC_DATA(sse2_pmulhu_w, INTR_TYPE_2OP, ISD::MULHU, 0), X86_INTRINSIC_DATA(sse2_pmulu_dq, INTR_TYPE_2OP, X86ISD::PMULUDQ, 0), + X86_INTRINSIC_DATA(sse2_psad_bw, INTR_TYPE_2OP, X86ISD::PSADBW, 0), X86_INTRINSIC_DATA(sse2_pshuf_d, INTR_TYPE_2OP, X86ISD::PSHUFD, 0), X86_INTRINSIC_DATA(sse2_pshufh_w, INTR_TYPE_2OP, X86ISD::PSHUFHW, 0), X86_INTRINSIC_DATA(sse2_pshufl_w, INTR_TYPE_2OP, X86ISD::PSHUFLW, 0), Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -3611,8 +3611,11 @@ // This is phase one of vectorizing PHIs. Type *VecTy = (VF == 1) ? PN->getType() : VectorType::get(PN->getType(), VF); - Entry[part] = PHINode::Create( + PHINode *PHI = PHINode::Create( VecTy, 2, "vec.phi", &*LoopVectorBody.back()->getFirstInsertionPt()); + MDNode *MD = MDNode::get(PHI->getContext(), None); + PHI->setMetadata("llvm.loop.vectorize.reduction", MD); + Entry[part] = PHI; } PV->push_back(P); return; Index: test/CodeGen/X86/sad.ll =================================================================== --- /dev/null +++ test/CodeGen/X86/sad.ll @@ -0,0 +1,47 @@ +; RUN: opt < %s -O2 -mtriple=x86_64-unknown-unknown -mcpu=x86-64 -mattr=+sse2 -force-target-max-vector-interleave=1 -unroll-count=1 | llc | FileCheck %s --check-prefix=SSE2 +; RUN: opt < %s -O2 -mtriple=x86_64-unknown-unknown -mcpu=x86-64 -mattr=+avx2 -force-target-max-vector-interleave=1 -unroll-count=1 | llc | FileCheck %s --check-prefix=AVX2 +; RUN: opt < %s -O2 -mtriple=x86_64-unknown-unknown -mcpu=x86-64 -mattr=+avx512bw -force-target-max-vector-interleave=1 -unroll-count=1 | llc | FileCheck %s --check-prefix=AVX512BW + +@a = global [1024 x i8] zeroinitializer, align 16 +@b = global [1024 x i8] zeroinitializer, align 16 + +define i32 @sad() { +; SSE2-LABEL: sad +; SSE2: # BB#0: +; SSE2: psadbw %xmm1, %xmm2 +; SSE2: paddd %xmm2, %xmm0 +; +; AVX2-LABEL: sad +; AVX2: # BB#0: +; AVX2: vpsadbw %xmm1, %xmm2, %xmm1 +; AVX2: vpaddd %xmm0, %xmm1, %xmm1 +; +; AVX512BW-LABEL: sad +; AVX512BW: # BB#0: +; AVX512BW: vpsadbw {{.*}}, %xmm1, %xmm1 +; AVX512BW: vpaddd %xmm0, %xmm1, %xmm1 +; +entry: + br label %for.body + +for.cond.cleanup: + ret i32 %add + +for.body: + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + %s.010 = phi i32 [ 0, %entry ], [ %add, %for.body ] + %arrayidx = getelementptr inbounds [1024 x i8], [1024 x i8]* @a, i64 0, i64 %indvars.iv + %0 = load i8, i8* %arrayidx, align 1 + %conv = zext i8 %0 to i32 + %arrayidx2 = getelementptr inbounds [1024 x i8], [1024 x i8]* @b, i64 0, i64 %indvars.iv + %1 = load i8, i8* %arrayidx2, align 1 + %conv3 = zext i8 %1 to i32 + %sub = sub nsw i32 %conv, %conv3 + %ispos = icmp sgt i32 %sub, -1 + %neg = sub nsw i32 0, %sub + %2 = select i1 %ispos, i32 %sub, i32 %neg + %add = add nsw i32 %2, %s.010 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, 1024 + br i1 %exitcond, label %for.cond.cleanup, label %for.body +}