Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -28985,6 +28985,7 @@ return DAG.getVectorShuffle(VT, DL, Concat, DAG.getUNDEF(VT), Mask); } + static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -33380,7 +33381,8 @@ /// set to A, RHS to B, and the routine returns 'true'. /// Note that the binary operation should have the property that if one of the /// operands is UNDEF then the result is UNDEF. -static bool isHorizontalBinOp(SDValue &LHS, SDValue &RHS, bool IsCommutative) { +static bool isHorizontalBinOp(SDValue &LHS, SDValue &RHS, bool IsCommutative, + bool AllowAVX512VT = false) { // Look for the following pattern: if // A = < float a0, float a1, float a2, float a3 > // B = < float b0, float b1, float b2, float b3 > @@ -33397,8 +33399,8 @@ MVT VT = LHS.getSimpleValueType(); - assert((VT.is128BitVector() || VT.is256BitVector()) && - "Unsupported vector type for horizontal add/sub"); + assert((AllowAVX512VT || VT.is128BitVector() || VT.is256BitVector()) && + "Unsupported vector type for horizontal add/sub"); // Handle 128 and 256-bit vector lengths. AVX defines horizontal add/sub to // operate independently on 128-bit lanes. @@ -35275,8 +35277,64 @@ return DAG.getNode(NewOpcode, SDLoc(N), VT, N->getOperand(0), AllOnesVec); } +static bool hasUnusedLanesAVX512(SDNode *N, SDValue &Op0, SDValue &Op1, + bool checkUpperHalfF) { + bool Op0SV = Op0.getOpcode() == ISD::VECTOR_SHUFFLE; + bool Op1SV = Op1.getOpcode() == ISD::VECTOR_SHUFFLE; + // At least one of the operands should be a vector shuffle. + if (!Op0SV && !Op1SV) + return false; + + EVT VT = N->getValueType(0); + int NumElems = VT.getVectorNumElements(); + + auto checkIsUnusedLane = [&](SDValue &Node, bool checkUpperHalfF) -> bool { + int Start = checkUpperHalfF ? NumElems / 2 : 0; + int End = checkUpperHalfF ? NumElems : NumElems / 2; + ShuffleVectorSDNode *SV = dyn_cast(Node.getNode()); + if (!SV) + return false; + ArrayRef Mask = SV->getMask(); + for (int i = Start; i < End; i++) { + if (Mask[i] >= 0) + return false; + } + return true; + }; + + if (VT.getSizeInBits() == 512 && N->hasOneUse()) { + SDNode *User = *(N->use_begin()); + EVT UserVT = User->getValueType(0); + int64_t SubVecSize = 0; + switch (User->getOpcode()) { + default: + return false; + case ISD::EXTRACT_SUBVECTOR: { + SubVecSize = User->getValueType(0).getVectorNumElements(); + } + case ISD::EXTRACT_VECTOR_ELT: { + SDValue Idx = User->getOperand(1); + if (!isa(Idx)) + return false; + int64_t Index = (dyn_cast(Idx.getNode()))->getSExtValue(); + int64_t HalfNumElems = NumElems / 2; + if ((checkUpperHalfF && Index > HalfNumElems) || + (!checkUpperHalfF && ((Index + SubVecSize) < HalfNumElems))) + return false; + } break; + } + if ((UserVT.bitsGE(VT)) || + (Op0SV && !checkIsUnusedLane(Op0, checkUpperHalfF)) || + (Op1SV && !checkIsUnusedLane(Op1, checkUpperHalfF))) + return false; + return true; + } + return false; +} + static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { + SDLoc DL(N); const SDNodeFlags Flags = N->getFlags(); if (Flags.hasVectorReduction()) { if (SDValue Sad = combineLoopSADPattern(N, DAG, Subtarget)) @@ -35294,6 +35352,24 @@ isHorizontalBinOp(Op0, Op1, true)) return DAG.getNode(X86ISD::HADD, SDLoc(N), VT, Op0, Op1); + if (Subtarget.hasAVX512() && hasUnusedLanesAVX512(N, Op0, Op1, true) && + isHorizontalBinOp(Op0, Op1, true, true)) { + EVT ElemType = VT.getVectorElementType(); + unsigned HalfLength = VT.getVectorNumElements() / 2; + EVT NewVT = EVT::getVectorVT(*DAG.getContext(), ElemType, HalfLength); + + SDValue SubVecOp0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewVT, Op0, + DAG.getIntPtrConstant(0, DL)); + SDValue SubVecOp1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewVT, Op1, + DAG.getIntPtrConstant(0, DL)); + SDValue HADDNode = + DAG.getNode(X86ISD::HADD, SDLoc(N), NewVT, SubVecOp0, SubVecOp1); + + SmallVector ConcatOps(2, DAG.getUNDEF(NewVT)); + ConcatOps[0] = HADDNode; + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps); + } + if (SDValue V = combineIncDecVector(N, DAG)) return V; Index: test/CodeGen/X86/madd.ll =================================================================== --- test/CodeGen/X86/madd.ll +++ test/CodeGen/X86/madd.ll @@ -329,8 +329,7 @@ ; AVX512-NEXT: vpaddd %zmm1, %zmm0, %zmm0 ; AVX512-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[2,3,2,3,6,7,6,7,10,11,10,11,14,15,14,15] ; AVX512-NEXT: vpaddd %zmm1, %zmm0, %zmm0 -; AVX512-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[1,1,2,3,5,5,6,7,9,9,10,11,13,13,14,15] -; AVX512-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512-NEXT: vphaddd %ymm0, %ymm0, %ymm0 ; AVX512-NEXT: vmovd %xmm0, %eax ; AVX512-NEXT: vzeroupper ; AVX512-NEXT: retq Index: test/CodeGen/X86/sad.ll =================================================================== --- test/CodeGen/X86/sad.ll +++ test/CodeGen/X86/sad.ll @@ -78,8 +78,7 @@ ; AVX512F-NEXT: vpaddd %zmm1, %zmm0, %zmm0 ; AVX512F-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[2,3,2,3,6,7,6,7,10,11,10,11,14,15,14,15] ; AVX512F-NEXT: vpaddd %zmm1, %zmm0, %zmm0 -; AVX512F-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[1,1,2,3,5,5,6,7,9,9,10,11,13,13,14,15] -; AVX512F-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512F-NEXT: vphaddd %ymm0, %ymm0, %ymm0 ; AVX512F-NEXT: vmovd %xmm0, %eax ; AVX512F-NEXT: vzeroupper ; AVX512F-NEXT: retq @@ -104,8 +103,7 @@ ; AVX512BW-NEXT: vpaddd %zmm1, %zmm0, %zmm0 ; AVX512BW-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[2,3,2,3,6,7,6,7,10,11,10,11,14,15,14,15] ; AVX512BW-NEXT: vpaddd %zmm1, %zmm0, %zmm0 -; AVX512BW-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[1,1,2,3,5,5,6,7,9,9,10,11,13,13,14,15] -; AVX512BW-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512BW-NEXT: vphaddd %ymm0, %ymm0, %ymm0 ; AVX512BW-NEXT: vmovd %xmm0, %eax ; AVX512BW-NEXT: vzeroupper ; AVX512BW-NEXT: retq @@ -327,8 +325,7 @@ ; AVX512F-NEXT: vpaddd %zmm1, %zmm0, %zmm0 ; AVX512F-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[2,3,2,3,6,7,6,7,10,11,10,11,14,15,14,15] ; AVX512F-NEXT: vpaddd %zmm1, %zmm0, %zmm0 -; AVX512F-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[1,1,2,3,5,5,6,7,9,9,10,11,13,13,14,15] -; AVX512F-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512F-NEXT: vphaddd %ymm0, %ymm0, %ymm0 ; AVX512F-NEXT: vmovd %xmm0, %eax ; AVX512F-NEXT: vzeroupper ; AVX512F-NEXT: retq @@ -355,8 +352,7 @@ ; AVX512BW-NEXT: vpaddd %zmm1, %zmm0, %zmm0 ; AVX512BW-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[2,3,2,3,6,7,6,7,10,11,10,11,14,15,14,15] ; AVX512BW-NEXT: vpaddd %zmm1, %zmm0, %zmm0 -; AVX512BW-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[1,1,2,3,5,5,6,7,9,9,10,11,13,13,14,15] -; AVX512BW-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512BW-NEXT: vphaddd %ymm0, %ymm0, %ymm0 ; AVX512BW-NEXT: vmovd %xmm0, %eax ; AVX512BW-NEXT: vzeroupper ; AVX512BW-NEXT: retq @@ -800,8 +796,7 @@ ; AVX512F-NEXT: vpaddd %zmm1, %zmm0, %zmm0 ; AVX512F-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[2,3,2,3,6,7,6,7,10,11,10,11,14,15,14,15] ; AVX512F-NEXT: vpaddd %zmm1, %zmm0, %zmm0 -; AVX512F-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[1,1,2,3,5,5,6,7,9,9,10,11,13,13,14,15] -; AVX512F-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512F-NEXT: vphaddd %ymm0, %ymm0, %ymm0 ; AVX512F-NEXT: vmovd %xmm0, %eax ; AVX512F-NEXT: vzeroupper ; AVX512F-NEXT: retq @@ -829,8 +824,7 @@ ; AVX512BW-NEXT: vpaddd %zmm1, %zmm0, %zmm0 ; AVX512BW-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[2,3,2,3,6,7,6,7,10,11,10,11,14,15,14,15] ; AVX512BW-NEXT: vpaddd %zmm1, %zmm0, %zmm0 -; AVX512BW-NEXT: vpshufd {{.*#+}} zmm1 = zmm0[1,1,2,3,5,5,6,7,9,9,10,11,13,13,14,15] -; AVX512BW-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512BW-NEXT: vphaddd %ymm0, %ymm0, %ymm0 ; AVX512BW-NEXT: vmovd %xmm0, %eax ; AVX512BW-NEXT: vzeroupper ; AVX512BW-NEXT: retq