Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -37010,6 +37010,95 @@ return DAG.getNode(NewOpcode, SDLoc(N), VT, N->getOperand(0), AllOnesVec); } +static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1, EVT VT, + const X86Subtarget &Subtarget) { + // Example of pattern we try to detect: + // t := (v8i32 mul (sext (v8i16 x0), (sext (v8i16 x1)))) + //(add (build_vector (extract_elt t, 0), + // (extract_elt t, 2), + // (extract_elt t, 4), + // (extract_elt t, 6)), + // (build_vector (extract_elt t, 1), + // (extract_elt t, 3), + // (extract_elt t, 5), + // (extract_elt t, 7))) + + if (!Subtarget.hasSSE2()) + return SDValue(); + if (VT != MVT::v4i32 && (VT != MVT::v8i32 || !Subtarget.hasAVX2()) && + (VT != MVT::v16i32 || !Subtarget.hasBWI())) + return SDValue(); + unsigned ValNumElts = VT.getVectorNumElements(); + + // Helper for examining one ADD operand. + auto IsBuildVectorOfExtractsFromMul = []( + SDValue Op, ArrayRef ExpectedIndices, SDValue &RetMul) { + if (Op->getOpcode() != ISD::BUILD_VECTOR) + return false; + SDValue Mul; + for (unsigned i = 0, e = ExpectedIndices.size(); i != e; ++i) { + //TODO: Be more tolerant to undefs. + if (Op->getOperand(i)->getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return false; + auto *Idx = dyn_cast(Op->getOperand(i)->getOperand(1)); + if (!Idx || Idx->getZExtValue() != ExpectedIndices[i]) + return false; + if (Mul) { + // Check that the extract is from the same MUL previously seen. + if (Mul != Op->getOperand(i)->getOperand(0)) + return false; + } else { + // First time an extract_elt's source vector is visited. Must be a MUL + // with 2X number of vector elements than the BUILD_VECTOR. + Mul = Op->getOperand(i)->getOperand(0); + if (Mul->getOpcode() != ISD::MUL || + Mul.getValueType().getVectorNumElements() != + 2 * ExpectedIndices.size()) + return false; + } + } + RetMul = Mul; + return true; + }; + SDValue L, R; + const unsigned ExpectedEvenIndices[] = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + const unsigned ExpectedOddIndices[] = {1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31}; + // Try the two possible orderings: (add even, odd) , (add odd, even) + if (!(IsBuildVectorOfExtractsFromMul( + Op0, makeArrayRef(ExpectedEvenIndices, ValNumElts), L) && + IsBuildVectorOfExtractsFromMul( + Op1, makeArrayRef(ExpectedOddIndices, ValNumElts), R) && + L == R) && + !(IsBuildVectorOfExtractsFromMul( + Op0, makeArrayRef(ExpectedOddIndices, ValNumElts), L) && + IsBuildVectorOfExtractsFromMul( + Op1, makeArrayRef(ExpectedEvenIndices, ValNumElts), R) && + L == R)) + return SDValue(); + + // Check if the Mul source can be safely shrunk. + ShrinkMode Mode; + if (!canReduceVMulWidth(L.getNode(), DAG, Mode) || Mode == MULU16) + return SDValue(); + + // Shrink by adding truncate nodes and let DAGCombine fold with the + // sources. + MVT TruncVT; + switch (ValNumElts) { + default: llvm_unreachable("Unexpected number of elements"); + case 4: TruncVT = MVT::v8i16; break; + case 8: TruncVT = MVT::v16i16; break; + case 16: TruncVT = MVT::v32i16; break; + } + return DAG.getNode(X86ISD::VPMADDWD, SDLoc(L), VT, + DAG.getNode(ISD::TRUNCATE, SDLoc(L.getOperand(0)), + TruncVT, L.getOperand(0)), + DAG.getNode(ISD::TRUNCATE, SDLoc(L.getOperand(1)), + TruncVT, L.getOperand(1))); +} + static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { const SDNodeFlags Flags = N->getFlags(); @@ -37023,6 +37112,9 @@ SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); + if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1, VT, Subtarget)) + return MAdd; + // Try to synthesize horizontal adds from adds of shuffles. if (((Subtarget.hasSSSE3() && (VT == MVT::v8i16 || VT == MVT::v4i32)) || (Subtarget.hasInt256() && (VT == MVT::v16i16 || VT == MVT::v8i32))) && Index: test/CodeGen/X86/madd.ll =================================================================== --- test/CodeGen/X86/madd.ll +++ test/CodeGen/X86/madd.ll @@ -316,26 +316,12 @@ define <4 x i32> @pmaddwd_8(<8 x i16> %A, <8 x i16> %B) { ; SSE2-LABEL: pmaddwd_8: ; SSE2: # %bb.0: -; SSE2-NEXT: movdqa %xmm0, %xmm2 -; SSE2-NEXT: pmulhw %xmm1, %xmm2 -; SSE2-NEXT: pmullw %xmm1, %xmm0 -; SSE2-NEXT: movdqa %xmm0, %xmm1 -; SSE2-NEXT: punpckhwd {{.*#+}} xmm1 = xmm1[4],xmm2[4],xmm1[5],xmm2[5],xmm1[6],xmm2[6],xmm1[7],xmm2[7] -; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3] -; SSE2-NEXT: movdqa %xmm0, %xmm2 -; SSE2-NEXT: shufps {{.*#+}} xmm2 = xmm2[0,2],xmm1[0,2] -; SSE2-NEXT: shufps {{.*#+}} xmm0 = xmm0[1,3],xmm1[1,3] -; SSE2-NEXT: paddd %xmm2, %xmm0 +; SSE2-NEXT: pmaddwd %xmm1, %xmm0 ; SSE2-NEXT: retq ; ; AVX-LABEL: pmaddwd_8: ; AVX: # %bb.0: -; AVX-NEXT: vpmovsxwd %xmm0, %ymm0 -; AVX-NEXT: vpmovsxwd %xmm1, %ymm1 -; AVX-NEXT: vpmulld %ymm1, %ymm0, %ymm0 -; AVX-NEXT: vextracti128 $1, %ymm0, %xmm1 -; AVX-NEXT: vphaddd %xmm1, %xmm0, %xmm0 -; AVX-NEXT: vzeroupper +; AVX-NEXT: vpmaddwd %xmm1, %xmm0, %xmm0 ; AVX-NEXT: retq %a = sext <8 x i16> %A to <8 x i32> %b = sext <8 x i16> %B to <8 x i32> @@ -346,6 +332,83 @@ ret <4 x i32> %ret } +define <4 x i32> @pmaddwd_8_swapped(<8 x i16> %A, <8 x i16> %B) { +; SSE2-LABEL: pmaddwd_8_swapped: +; SSE2: # %bb.0: +; SSE2-NEXT: pmaddwd %xmm1, %xmm0 +; SSE2-NEXT: retq +; +; AVX-LABEL: pmaddwd_8_swapped: +; AVX: # %bb.0: +; AVX-NEXT: vpmaddwd %xmm1, %xmm0, %xmm0 +; AVX-NEXT: retq + %a = sext <8 x i16> %A to <8 x i32> + %b = sext <8 x i16> %B to <8 x i32> + %m = mul nsw <8 x i32> %a, %b + %odd = shufflevector <8 x i32> %m, <8 x i32> undef, <4 x i32> + %even = shufflevector <8 x i32> %m, <8 x i32> undef, <4 x i32> + %ret = add <4 x i32> %even, %odd + ret <4 x i32> %ret +} + +define <4 x i32> @larger_mul(<16 x i16> %A, <16 x i16> %B) { +; SSE2-LABEL: larger_mul: +; SSE2: # %bb.0: +; SSE2-NEXT: movdqa %xmm0, %xmm1 +; SSE2-NEXT: pmulhw %xmm2, %xmm1 +; SSE2-NEXT: pmullw %xmm2, %xmm0 +; SSE2-NEXT: movdqa %xmm0, %xmm2 +; SSE2-NEXT: punpckhwd {{.*#+}} xmm2 = xmm2[4],xmm1[4],xmm2[5],xmm1[5],xmm2[6],xmm1[6],xmm2[7],xmm1[7] +; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] +; SSE2-NEXT: movdqa %xmm0, %xmm1 +; SSE2-NEXT: shufps {{.*#+}} xmm1 = xmm1[0,2],xmm2[0,2] +; SSE2-NEXT: shufps {{.*#+}} xmm0 = xmm0[1,3],xmm2[1,3] +; SSE2-NEXT: paddd %xmm1, %xmm0 +; SSE2-NEXT: retq +; +; AVX2-LABEL: larger_mul: +; AVX2: # %bb.0: +; AVX2-NEXT: vpmovsxwd %xmm0, %ymm0 +; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm2 +; AVX2-NEXT: vpackssdw %xmm2, %xmm0, %xmm0 +; AVX2-NEXT: vpmovsxwd %xmm1, %ymm1 +; AVX2-NEXT: vextracti128 $1, %ymm1, %xmm2 +; AVX2-NEXT: vpackssdw %xmm2, %xmm1, %xmm1 +; AVX2-NEXT: vpmaddwd %xmm1, %xmm0, %xmm0 +; AVX2-NEXT: vzeroupper +; AVX2-NEXT: retq +; +; AVX512-LABEL: larger_mul: +; AVX512: # %bb.0: +; AVX512-NEXT: vpmovsxwd %ymm0, %zmm0 +; AVX512-NEXT: vpmovsxwd %ymm1, %zmm1 +; AVX512-NEXT: vpmulld %zmm1, %zmm0, %zmm0 +; AVX512-NEXT: vpextrd $2, %xmm0, %eax +; AVX512-NEXT: vpinsrd $1, %eax, %xmm0, %xmm1 +; AVX512-NEXT: vextracti128 $1, %ymm0, %xmm2 +; AVX512-NEXT: vmovd %xmm2, %eax +; AVX512-NEXT: vpinsrd $2, %eax, %xmm1, %xmm1 +; AVX512-NEXT: vpextrd $2, %xmm2, %eax +; AVX512-NEXT: vpinsrd $3, %eax, %xmm1, %xmm1 +; AVX512-NEXT: vpextrd $3, %xmm0, %eax +; AVX512-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[1,1,2,3] +; AVX512-NEXT: vpinsrd $1, %eax, %xmm0, %xmm0 +; AVX512-NEXT: vpextrd $1, %xmm2, %eax +; AVX512-NEXT: vpinsrd $2, %eax, %xmm0, %xmm0 +; AVX512-NEXT: vpextrd $3, %xmm2, %eax +; AVX512-NEXT: vpinsrd $3, %eax, %xmm0, %xmm0 +; AVX512-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; AVX512-NEXT: vzeroupper +; AVX512-NEXT: retq + %a = sext <16 x i16> %A to <16 x i32> + %b = sext <16 x i16> %B to <16 x i32> + %m = mul nsw <16 x i32> %a, %b + %odd = shufflevector <16 x i32> %m, <16 x i32> undef, <4 x i32> + %even = shufflevector <16 x i32> %m, <16 x i32> undef, <4 x i32> + %ret = add <4 x i32> %odd, %even + ret <4 x i32> %ret +} + define <8 x i32> @pmaddwd_16(<16 x i16> %A, <16 x i16> %B) { ; SSE2-LABEL: pmaddwd_16: ; SSE2: # %bb.0: @@ -371,35 +434,10 @@ ; SSE2-NEXT: paddd %xmm5, %xmm0 ; SSE2-NEXT: retq ; -; AVX2-LABEL: pmaddwd_16: -; AVX2: # %bb.0: -; AVX2-NEXT: vpmovsxwd %xmm0, %ymm2 -; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm0 -; AVX2-NEXT: vpmovsxwd %xmm0, %ymm0 -; AVX2-NEXT: vpmovsxwd %xmm1, %ymm3 -; AVX2-NEXT: vpmulld %ymm3, %ymm2, %ymm2 -; AVX2-NEXT: vextracti128 $1, %ymm1, %xmm1 -; AVX2-NEXT: vpmovsxwd %xmm1, %ymm1 -; AVX2-NEXT: vpmulld %ymm1, %ymm0, %ymm0 -; AVX2-NEXT: vshufps {{.*#+}} ymm1 = ymm2[0,2],ymm0[0,2],ymm2[4,6],ymm0[4,6] -; AVX2-NEXT: vpermpd {{.*#+}} ymm1 = ymm1[0,2,1,3] -; AVX2-NEXT: vshufps {{.*#+}} ymm0 = ymm2[1,3],ymm0[1,3],ymm2[5,7],ymm0[5,7] -; AVX2-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,2,1,3] -; AVX2-NEXT: vpaddd %ymm0, %ymm1, %ymm0 -; AVX2-NEXT: retq -; -; AVX512-LABEL: pmaddwd_16: -; AVX512: # %bb.0: -; AVX512-NEXT: vpmovsxwd %ymm0, %zmm0 -; AVX512-NEXT: vpmovsxwd %ymm1, %zmm1 -; AVX512-NEXT: vpmulld %zmm1, %zmm0, %zmm0 -; AVX512-NEXT: vextracti64x4 $1, %zmm0, %ymm1 -; AVX512-NEXT: vshufps {{.*#+}} ymm2 = ymm0[0,2],ymm1[0,2],ymm0[4,6],ymm1[4,6] -; AVX512-NEXT: vpermpd {{.*#+}} ymm2 = ymm2[0,2,1,3] -; AVX512-NEXT: vshufps {{.*#+}} ymm0 = ymm0[1,3],ymm1[1,3],ymm0[5,7],ymm1[5,7] -; AVX512-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,2,1,3] -; AVX512-NEXT: vpaddd %ymm0, %ymm2, %ymm0 -; AVX512-NEXT: retq +; AVX-LABEL: pmaddwd_16: +; AVX: # %bb.0: +; AVX-NEXT: vpmaddwd %ymm1, %ymm0, %ymm0 +; AVX-NEXT: retq %a = sext <16 x i16> %A to <16 x i32> %b = sext <16 x i16> %B to <16 x i32> %m = mul nsw <16 x i32> %a, %b @@ -501,19 +539,7 @@ ; ; AVX512BW-LABEL: pmaddwd_32: ; AVX512BW: # %bb.0: -; AVX512BW-NEXT: vpmovsxwd %ymm0, %zmm2 -; AVX512BW-NEXT: vextracti64x4 $1, %zmm0, %ymm0 -; AVX512BW-NEXT: vpmovsxwd %ymm0, %zmm0 -; AVX512BW-NEXT: vpmovsxwd %ymm1, %zmm3 -; AVX512BW-NEXT: vpmulld %zmm3, %zmm2, %zmm2 -; AVX512BW-NEXT: vextracti64x4 $1, %zmm1, %ymm1 -; AVX512BW-NEXT: vpmovsxwd %ymm1, %zmm1 -; AVX512BW-NEXT: vpmulld %zmm1, %zmm0, %zmm0 -; AVX512BW-NEXT: vmovdqa32 {{.*#+}} zmm1 = [0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30] -; AVX512BW-NEXT: vpermi2d %zmm0, %zmm2, %zmm1 -; AVX512BW-NEXT: vmovdqa32 {{.*#+}} zmm3 = [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31] -; AVX512BW-NEXT: vpermi2d %zmm0, %zmm2, %zmm3 -; AVX512BW-NEXT: vpaddd %zmm3, %zmm1, %zmm0 +; AVX512BW-NEXT: vpmaddwd %zmm1, %zmm0, %zmm0 ; AVX512BW-NEXT: retq %a = sext <32 x i16> %A to <32 x i32> %b = sext <32 x i16> %B to <32 x i32> @@ -527,26 +553,12 @@ define <4 x i32> @pmaddwd_const(<8 x i16> %A) { ; SSE2-LABEL: pmaddwd_const: ; SSE2: # %bb.0: -; SSE2-NEXT: movdqa {{.*#+}} xmm1 = [32767,32768,0,0,1,7,42,32] -; SSE2-NEXT: movdqa %xmm0, %xmm2 -; SSE2-NEXT: pmulhw %xmm1, %xmm2 -; SSE2-NEXT: pmullw %xmm1, %xmm0 -; SSE2-NEXT: movdqa %xmm0, %xmm1 -; SSE2-NEXT: punpckhwd {{.*#+}} xmm1 = xmm1[4],xmm2[4],xmm1[5],xmm2[5],xmm1[6],xmm2[6],xmm1[7],xmm2[7] -; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3] -; SSE2-NEXT: movdqa %xmm0, %xmm2 -; SSE2-NEXT: shufps {{.*#+}} xmm2 = xmm2[0,2],xmm1[0,2] -; SSE2-NEXT: shufps {{.*#+}} xmm0 = xmm0[1,3],xmm1[1,3] -; SSE2-NEXT: paddd %xmm2, %xmm0 +; SSE2-NEXT: pmaddwd {{.*}}(%rip), %xmm0 ; SSE2-NEXT: retq ; ; AVX-LABEL: pmaddwd_const: ; AVX: # %bb.0: -; AVX-NEXT: vpmovsxwd %xmm0, %ymm0 -; AVX-NEXT: vpmulld {{.*}}(%rip), %ymm0, %ymm0 -; AVX-NEXT: vextracti128 $1, %ymm0, %xmm1 -; AVX-NEXT: vphaddd %xmm1, %xmm0, %xmm0 -; AVX-NEXT: vzeroupper +; AVX-NEXT: vpmaddwd {{.*}}(%rip), %xmm0, %xmm0 ; AVX-NEXT: retq %a = sext <8 x i16> %A to <8 x i32> %m = mul nsw <8 x i32> %a,