Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -38816,6 +38816,127 @@ PMADDBuilder); } +// Attempt to turn this pattern into PMADDWD. +// (mul (add (zext (build_vector)), (zext (build_vector))), +// (add (zext (build_vector)), (zext (build_vector))) +static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + const SDLoc &DL, EVT VT, + const X86Subtarget &Subtarget) { + if (!Subtarget.hasSSE2()) + return SDValue(); + + if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL) + return SDValue(); + + if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 || + VT.getVectorNumElements() < 4 || + !isPowerOf2_32(VT.getVectorNumElements())) + return SDValue(); + + SDValue N00 = N0.getOperand(0); + SDValue N01 = N0.getOperand(1); + SDValue N10 = N1.getOperand(0); + SDValue N11 = N1.getOperand(1); + + // All inputs need to be sign extends. + // TODO: Support ZERO_EXTEND from known positive? + if (N00.getOpcode() != ISD::SIGN_EXTEND || + N01.getOpcode() != ISD::SIGN_EXTEND || + N10.getOpcode() != ISD::SIGN_EXTEND || + N11.getOpcode() != ISD::SIGN_EXTEND) + return SDValue(); + + // Peek through the extends. + N00 = N00.getOperand(0); + N01 = N01.getOperand(0); + N10 = N10.getOperand(0); + N11 = N11.getOperand(0); + + // Must be extending from vXi16. + EVT InVT = N00.getValueType(); + if (InVT.getVectorElementType() != MVT::i16 || N01.getValueType() != InVT || + N10.getValueType() != InVT || N11.getValueType() != InVT) + return SDValue(); + + // All inputs should be build_vectors. + if (N00.getOpcode() != ISD::BUILD_VECTOR || + N01.getOpcode() != ISD::BUILD_VECTOR || + N10.getOpcode() != ISD::BUILD_VECTOR || + N11.getOpcode() != ISD::BUILD_VECTOR) + return SDValue(); + + // For each element, we need to ensure we have an odd element from one vector + // multiplied by the odd element of another vector and the even element from + // one of the same vectors being multiplied by the even element from the + // other vector. So we need to make sure for each element i, this operator + // is being performed: + // A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1] + SDValue In0, In1; + for (unsigned i = 0; i != N00.getNumOperands(); ++i) { + SDValue N00Elt = N00.getOperand(i); + SDValue N01Elt = N01.getOperand(i); + SDValue N10Elt = N10.getOperand(i); + SDValue N11Elt = N11.getOperand(i); + // TODO: Be more tolerant to undefs. + if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + auto *ConstN00Elt = dyn_cast(N00Elt.getOperand(1)); + auto *ConstN01Elt = dyn_cast(N01Elt.getOperand(1)); + auto *ConstN10Elt = dyn_cast(N10Elt.getOperand(1)); + auto *ConstN11Elt = dyn_cast(N11Elt.getOperand(1)); + if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt) + return SDValue(); + unsigned IdxN00 = ConstN00Elt->getZExtValue(); + unsigned IdxN01 = ConstN01Elt->getZExtValue(); + unsigned IdxN10 = ConstN10Elt->getZExtValue(); + unsigned IdxN11 = ConstN11Elt->getZExtValue(); + // Add is commutative so indices can be reordered. + if (IdxN00 > IdxN10) { + std::swap(IdxN00, IdxN10); + std::swap(IdxN01, IdxN11); + } + // N0 indices be the even elemtn. N1 indices must be the next odd element. + if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || + IdxN01 != 2 * i || IdxN11 != 2 * i + 1) + return SDValue(); + SDValue N00In = N00Elt.getOperand(0); + SDValue N01In = N01Elt.getOperand(0); + SDValue N10In = N10Elt.getOperand(0); + SDValue N11In = N11Elt.getOperand(0); + // First time we find an input capture it. + if (!In0) { + In0 = N00In; + In1 = N01In; + } + // Mul is commutative so the input vectors can be in any order. + // Canonicalize to make the compares easier. + if (In0 != N00In) + std::swap(N00In, N01In); + if (In0 != N10In) + std::swap(N10In, N11In); + if (In0 != N00In || In1 != N01In || In0 != N10In || In1 != N11In) + return SDValue(); + } + + auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef Ops) { + // Shrink by adding truncate nodes and let DAGCombine fold with the + // sources. + EVT InVT = Ops[0].getValueType(); + assert(InVT.getScalarType() == MVT::i16 && + "Unexpected scalar element type"); + assert(InVT == Ops[1].getValueType() && "Operands' types mismatch"); + EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, + InVT.getVectorNumElements() / 2); + return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]); + }; + return SplitOpsAndApply(DAG, Subtarget, DL, VT, { In0, In1 }, + PMADDBuilder); +} + static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { const SDNodeFlags Flags = N->getFlags(); @@ -38831,6 +38952,8 @@ if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1, SDLoc(N), VT, Subtarget)) return MAdd; + if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, SDLoc(N), VT, Subtarget)) + return MAdd; // Try to synthesize horizontal adds from adds of shuffles. if ((VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v16i16 || Index: llvm/trunk/test/CodeGen/X86/madd.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/madd.ll +++ llvm/trunk/test/CodeGen/X86/madd.ll @@ -2299,3 +2299,373 @@ %a = add <32 x i32> %sa, %sb ret <32 x i32> %a } + +; NOTE: We're testing with loads because ABI lowering creates a concat_vectors that extract_vector_elt creation can see through. +; This would require the combine to recreate the concat_vectors. +define <4 x i32> @pmaddwd_128(<8 x i16>* %Aptr, <8 x i16>* %Bptr) { +; SSE2-LABEL: pmaddwd_128: +; SSE2: # %bb.0: +; SSE2-NEXT: movdqa (%rdi), %xmm0 +; SSE2-NEXT: pmaddwd (%rsi), %xmm0 +; SSE2-NEXT: retq +; +; AVX-LABEL: pmaddwd_128: +; AVX: # %bb.0: +; AVX-NEXT: vmovdqa (%rdi), %xmm0 +; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 +; AVX-NEXT: retq + %A = load <8 x i16>, <8 x i16>* %Aptr + %B = load <8 x i16>, <8 x i16>* %Bptr + %A_even = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> + %A_odd = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> + %B_even = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> + %B_odd = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> + %A_even_ext = sext <4 x i16> %A_even to <4 x i32> + %B_even_ext = sext <4 x i16> %B_even to <4 x i32> + %A_odd_ext = sext <4 x i16> %A_odd to <4 x i32> + %B_odd_ext = sext <4 x i16> %B_odd to <4 x i32> + %even_mul = mul <4 x i32> %A_even_ext, %B_even_ext + %odd_mul = mul <4 x i32> %A_odd_ext, %B_odd_ext + %add = add <4 x i32> %even_mul, %odd_mul + ret <4 x i32> %add +} + +define <8 x i32> @pmaddwd_256(<16 x i16>* %Aptr, <16 x i16>* %Bptr) { +; SSE2-LABEL: pmaddwd_256: +; SSE2: # %bb.0: +; SSE2-NEXT: movdqa (%rdi), %xmm0 +; SSE2-NEXT: movdqa 16(%rdi), %xmm1 +; SSE2-NEXT: pmaddwd (%rsi), %xmm0 +; SSE2-NEXT: pmaddwd 16(%rsi), %xmm1 +; SSE2-NEXT: retq +; +; AVX1-LABEL: pmaddwd_256: +; AVX1: # %bb.0: +; AVX1-NEXT: vmovdqa (%rdi), %ymm0 +; AVX1-NEXT: vmovdqa (%rsi), %ymm1 +; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm2 +; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm3 +; AVX1-NEXT: vpmaddwd %xmm2, %xmm3, %xmm2 +; AVX1-NEXT: vpmaddwd %xmm1, %xmm0, %xmm0 +; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm0, %ymm0 +; AVX1-NEXT: retq +; +; AVX256-LABEL: pmaddwd_256: +; AVX256: # %bb.0: +; AVX256-NEXT: vmovdqa (%rdi), %ymm0 +; AVX256-NEXT: vpmaddwd (%rsi), %ymm0, %ymm0 +; AVX256-NEXT: retq + %A = load <16 x i16>, <16 x i16>* %Aptr + %B = load <16 x i16>, <16 x i16>* %Bptr + %A_even = shufflevector <16 x i16> %A, <16 x i16> undef, <8 x i32> + %A_odd = shufflevector <16 x i16> %A, <16 x i16> undef, <8 x i32> + %B_even = shufflevector <16 x i16> %B, <16 x i16> undef, <8 x i32> + %B_odd = shufflevector <16 x i16> %B, <16 x i16> undef, <8 x i32> + %A_even_ext = sext <8 x i16> %A_even to <8 x i32> + %B_even_ext = sext <8 x i16> %B_even to <8 x i32> + %A_odd_ext = sext <8 x i16> %A_odd to <8 x i32> + %B_odd_ext = sext <8 x i16> %B_odd to <8 x i32> + %even_mul = mul <8 x i32> %A_even_ext, %B_even_ext + %odd_mul = mul <8 x i32> %A_odd_ext, %B_odd_ext + %add = add <8 x i32> %even_mul, %odd_mul + ret <8 x i32> %add +} + +define <16 x i32> @pmaddwd_512(<32 x i16>* %Aptr, <32 x i16>* %Bptr) { +; SSE2-LABEL: pmaddwd_512: +; SSE2: # %bb.0: +; SSE2-NEXT: movdqa (%rdi), %xmm0 +; SSE2-NEXT: movdqa 16(%rdi), %xmm1 +; SSE2-NEXT: movdqa 32(%rdi), %xmm2 +; SSE2-NEXT: movdqa 48(%rdi), %xmm3 +; SSE2-NEXT: pmaddwd (%rsi), %xmm0 +; SSE2-NEXT: pmaddwd 16(%rsi), %xmm1 +; SSE2-NEXT: pmaddwd 32(%rsi), %xmm2 +; SSE2-NEXT: pmaddwd 48(%rsi), %xmm3 +; SSE2-NEXT: retq +; +; AVX1-LABEL: pmaddwd_512: +; AVX1: # %bb.0: +; AVX1-NEXT: vmovdqa (%rdi), %ymm0 +; AVX1-NEXT: vmovdqa 32(%rdi), %ymm1 +; AVX1-NEXT: vmovdqa (%rsi), %ymm2 +; AVX1-NEXT: vmovdqa 32(%rsi), %ymm3 +; AVX1-NEXT: vextractf128 $1, %ymm2, %xmm4 +; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm5 +; AVX1-NEXT: vpmaddwd %xmm4, %xmm5, %xmm4 +; AVX1-NEXT: vpmaddwd %xmm2, %xmm0, %xmm0 +; AVX1-NEXT: vinsertf128 $1, %xmm4, %ymm0, %ymm0 +; AVX1-NEXT: vextractf128 $1, %ymm3, %xmm2 +; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm4 +; AVX1-NEXT: vpmaddwd %xmm2, %xmm4, %xmm2 +; AVX1-NEXT: vpmaddwd %xmm3, %xmm1, %xmm1 +; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm1, %ymm1 +; AVX1-NEXT: retq +; +; AVX2-LABEL: pmaddwd_512: +; AVX2: # %bb.0: +; AVX2-NEXT: vmovdqa (%rdi), %ymm0 +; AVX2-NEXT: vmovdqa 32(%rdi), %ymm1 +; AVX2-NEXT: vpmaddwd (%rsi), %ymm0, %ymm0 +; AVX2-NEXT: vpmaddwd 32(%rsi), %ymm1, %ymm1 +; AVX2-NEXT: retq +; +; AVX512F-LABEL: pmaddwd_512: +; AVX512F: # %bb.0: +; AVX512F-NEXT: vmovdqa (%rdi), %ymm0 +; AVX512F-NEXT: vmovdqa 32(%rdi), %ymm1 +; AVX512F-NEXT: vpmaddwd 32(%rsi), %ymm1, %ymm1 +; AVX512F-NEXT: vpmaddwd (%rsi), %ymm0, %ymm0 +; AVX512F-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm0 +; AVX512F-NEXT: retq +; +; AVX512BW-LABEL: pmaddwd_512: +; AVX512BW: # %bb.0: +; AVX512BW-NEXT: vmovdqa64 (%rdi), %zmm0 +; AVX512BW-NEXT: vpmaddwd (%rsi), %zmm0, %zmm0 +; AVX512BW-NEXT: retq + %A = load <32 x i16>, <32 x i16>* %Aptr + %B = load <32 x i16>, <32 x i16>* %Bptr + %A_even = shufflevector <32 x i16> %A, <32 x i16> undef, <16 x i32> + %A_odd = shufflevector <32 x i16> %A, <32 x i16> undef, <16 x i32> + %B_even = shufflevector <32 x i16> %B, <32 x i16> undef, <16 x i32> + %B_odd = shufflevector <32 x i16> %B, <32 x i16> undef, <16 x i32> + %A_even_ext = sext <16 x i16> %A_even to <16 x i32> + %B_even_ext = sext <16 x i16> %B_even to <16 x i32> + %A_odd_ext = sext <16 x i16> %A_odd to <16 x i32> + %B_odd_ext = sext <16 x i16> %B_odd to <16 x i32> + %even_mul = mul <16 x i32> %A_even_ext, %B_even_ext + %odd_mul = mul <16 x i32> %A_odd_ext, %B_odd_ext + %add = add <16 x i32> %even_mul, %odd_mul + ret <16 x i32> %add +} + +define <32 x i32> @pmaddwd_1024(<64 x i16>* %Aptr, <64 x i16>* %Bptr) { +; SSE2-LABEL: pmaddwd_1024: +; SSE2: # %bb.0: +; SSE2-NEXT: movdqa 112(%rsi), %xmm0 +; SSE2-NEXT: movdqa 96(%rsi), %xmm1 +; SSE2-NEXT: movdqa 80(%rsi), %xmm2 +; SSE2-NEXT: movdqa 64(%rsi), %xmm3 +; SSE2-NEXT: movdqa (%rsi), %xmm4 +; SSE2-NEXT: movdqa 16(%rsi), %xmm5 +; SSE2-NEXT: movdqa 32(%rsi), %xmm6 +; SSE2-NEXT: movdqa 48(%rsi), %xmm7 +; SSE2-NEXT: pmaddwd (%rdx), %xmm4 +; SSE2-NEXT: pmaddwd 16(%rdx), %xmm5 +; SSE2-NEXT: pmaddwd 32(%rdx), %xmm6 +; SSE2-NEXT: pmaddwd 48(%rdx), %xmm7 +; SSE2-NEXT: pmaddwd 64(%rdx), %xmm3 +; SSE2-NEXT: pmaddwd 80(%rdx), %xmm2 +; SSE2-NEXT: pmaddwd 96(%rdx), %xmm1 +; SSE2-NEXT: pmaddwd 112(%rdx), %xmm0 +; SSE2-NEXT: movdqa %xmm0, 112(%rdi) +; SSE2-NEXT: movdqa %xmm1, 96(%rdi) +; SSE2-NEXT: movdqa %xmm2, 80(%rdi) +; SSE2-NEXT: movdqa %xmm3, 64(%rdi) +; SSE2-NEXT: movdqa %xmm7, 48(%rdi) +; SSE2-NEXT: movdqa %xmm6, 32(%rdi) +; SSE2-NEXT: movdqa %xmm5, 16(%rdi) +; SSE2-NEXT: movdqa %xmm4, (%rdi) +; SSE2-NEXT: movq %rdi, %rax +; SSE2-NEXT: retq +; +; AVX1-LABEL: pmaddwd_1024: +; AVX1: # %bb.0: +; AVX1-NEXT: vmovdqa (%rdi), %ymm0 +; AVX1-NEXT: vmovdqa 32(%rdi), %ymm1 +; AVX1-NEXT: vmovdqa 64(%rdi), %ymm2 +; AVX1-NEXT: vmovdqa 96(%rdi), %ymm8 +; AVX1-NEXT: vmovdqa (%rsi), %ymm4 +; AVX1-NEXT: vmovdqa 32(%rsi), %ymm5 +; AVX1-NEXT: vmovdqa 64(%rsi), %ymm6 +; AVX1-NEXT: vmovdqa 96(%rsi), %ymm9 +; AVX1-NEXT: vextractf128 $1, %ymm4, %xmm3 +; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm7 +; AVX1-NEXT: vpmaddwd %xmm3, %xmm7, %xmm3 +; AVX1-NEXT: vpmaddwd %xmm4, %xmm0, %xmm0 +; AVX1-NEXT: vinsertf128 $1, %xmm3, %ymm0, %ymm0 +; AVX1-NEXT: vextractf128 $1, %ymm5, %xmm3 +; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm4 +; AVX1-NEXT: vpmaddwd %xmm3, %xmm4, %xmm3 +; AVX1-NEXT: vpmaddwd %xmm5, %xmm1, %xmm1 +; AVX1-NEXT: vinsertf128 $1, %xmm3, %ymm1, %ymm1 +; AVX1-NEXT: vextractf128 $1, %ymm6, %xmm3 +; AVX1-NEXT: vextractf128 $1, %ymm2, %xmm4 +; AVX1-NEXT: vpmaddwd %xmm3, %xmm4, %xmm3 +; AVX1-NEXT: vpmaddwd %xmm6, %xmm2, %xmm2 +; AVX1-NEXT: vinsertf128 $1, %xmm3, %ymm2, %ymm2 +; AVX1-NEXT: vextractf128 $1, %ymm9, %xmm3 +; AVX1-NEXT: vextractf128 $1, %ymm8, %xmm4 +; AVX1-NEXT: vpmaddwd %xmm3, %xmm4, %xmm3 +; AVX1-NEXT: vpmaddwd %xmm9, %xmm8, %xmm4 +; AVX1-NEXT: vinsertf128 $1, %xmm3, %ymm4, %ymm3 +; AVX1-NEXT: retq +; +; AVX2-LABEL: pmaddwd_1024: +; AVX2: # %bb.0: +; AVX2-NEXT: vmovdqa (%rdi), %ymm0 +; AVX2-NEXT: vmovdqa 32(%rdi), %ymm1 +; AVX2-NEXT: vmovdqa 64(%rdi), %ymm2 +; AVX2-NEXT: vmovdqa 96(%rdi), %ymm3 +; AVX2-NEXT: vpmaddwd (%rsi), %ymm0, %ymm0 +; AVX2-NEXT: vpmaddwd 32(%rsi), %ymm1, %ymm1 +; AVX2-NEXT: vpmaddwd 64(%rsi), %ymm2, %ymm2 +; AVX2-NEXT: vpmaddwd 96(%rsi), %ymm3, %ymm3 +; AVX2-NEXT: retq +; +; AVX512F-LABEL: pmaddwd_1024: +; AVX512F: # %bb.0: +; AVX512F-NEXT: vmovdqa (%rdi), %ymm0 +; AVX512F-NEXT: vmovdqa 32(%rdi), %ymm1 +; AVX512F-NEXT: vmovdqa 64(%rdi), %ymm2 +; AVX512F-NEXT: vmovdqa 96(%rdi), %ymm3 +; AVX512F-NEXT: vpmaddwd 32(%rsi), %ymm1, %ymm1 +; AVX512F-NEXT: vpmaddwd (%rsi), %ymm0, %ymm0 +; AVX512F-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm0 +; AVX512F-NEXT: vpmaddwd 96(%rsi), %ymm3, %ymm1 +; AVX512F-NEXT: vpmaddwd 64(%rsi), %ymm2, %ymm2 +; AVX512F-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm1 +; AVX512F-NEXT: retq +; +; AVX512BW-LABEL: pmaddwd_1024: +; AVX512BW: # %bb.0: +; AVX512BW-NEXT: vmovdqa64 (%rdi), %zmm0 +; AVX512BW-NEXT: vmovdqa64 64(%rdi), %zmm1 +; AVX512BW-NEXT: vpmaddwd (%rsi), %zmm0, %zmm0 +; AVX512BW-NEXT: vpmaddwd 64(%rsi), %zmm1, %zmm1 +; AVX512BW-NEXT: retq + %A = load <64 x i16>, <64 x i16>* %Aptr + %B = load <64 x i16>, <64 x i16>* %Bptr + %A_even = shufflevector <64 x i16> %A, <64 x i16> undef, <32 x i32> + %A_odd = shufflevector <64 x i16> %A, <64 x i16> undef, <32 x i32> + %B_even = shufflevector <64 x i16> %B, <64 x i16> undef, <32 x i32> + %B_odd = shufflevector <64 x i16> %B, <64 x i16> undef, <32 x i32> + %A_even_ext = sext <32 x i16> %A_even to <32 x i32> + %B_even_ext = sext <32 x i16> %B_even to <32 x i32> + %A_odd_ext = sext <32 x i16> %A_odd to <32 x i32> + %B_odd_ext = sext <32 x i16> %B_odd to <32 x i32> + %even_mul = mul <32 x i32> %A_even_ext, %B_even_ext + %odd_mul = mul <32 x i32> %A_odd_ext, %B_odd_ext + %add = add <32 x i32> %even_mul, %odd_mul + ret <32 x i32> %add +} + +define <4 x i32> @pmaddwd_commuted_mul(<8 x i16>* %Aptr, <8 x i16>* %Bptr) { +; SSE2-LABEL: pmaddwd_commuted_mul: +; SSE2: # %bb.0: +; SSE2-NEXT: movdqa (%rdi), %xmm0 +; SSE2-NEXT: pmaddwd (%rsi), %xmm0 +; SSE2-NEXT: retq +; +; AVX-LABEL: pmaddwd_commuted_mul: +; AVX: # %bb.0: +; AVX-NEXT: vmovdqa (%rdi), %xmm0 +; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 +; AVX-NEXT: retq + %A = load <8 x i16>, <8 x i16>* %Aptr + %B = load <8 x i16>, <8 x i16>* %Bptr + %A_even = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> + %A_odd = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> + %B_even = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> + %B_odd = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> + %A_even_ext = sext <4 x i16> %A_even to <4 x i32> + %B_even_ext = sext <4 x i16> %B_even to <4 x i32> + %A_odd_ext = sext <4 x i16> %A_odd to <4 x i32> + %B_odd_ext = sext <4 x i16> %B_odd to <4 x i32> + %even_mul = mul <4 x i32> %A_even_ext, %B_even_ext + %odd_mul = mul <4 x i32> %B_odd_ext, %A_odd_ext ; Different order than previous mul + %add = add <4 x i32> %even_mul, %odd_mul + ret <4 x i32> %add +} + +define <4 x i32> @pmaddwd_swapped_indices(<8 x i16>* %Aptr, <8 x i16>* %Bptr) { +; SSE2-LABEL: pmaddwd_swapped_indices: +; SSE2: # %bb.0: +; SSE2-NEXT: movdqa (%rdi), %xmm0 +; SSE2-NEXT: pmaddwd (%rsi), %xmm0 +; SSE2-NEXT: retq +; +; AVX-LABEL: pmaddwd_swapped_indices: +; AVX: # %bb.0: +; AVX-NEXT: vmovdqa (%rdi), %xmm0 +; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 +; AVX-NEXT: retq + %A = load <8 x i16>, <8 x i16>* %Aptr + %B = load <8 x i16>, <8 x i16>* %Bptr + %A_even = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> ; indices aren't all even + %A_odd = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> ; indices aren't all odd + %B_even = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> ; same indices as A + %B_odd = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> ; same indices as A + %A_even_ext = sext <4 x i16> %A_even to <4 x i32> + %B_even_ext = sext <4 x i16> %B_even to <4 x i32> + %A_odd_ext = sext <4 x i16> %A_odd to <4 x i32> + %B_odd_ext = sext <4 x i16> %B_odd to <4 x i32> + %even_mul = mul <4 x i32> %A_even_ext, %B_even_ext + %odd_mul = mul <4 x i32> %A_odd_ext, %B_odd_ext + %add = add <4 x i32> %even_mul, %odd_mul + ret <4 x i32> %add +} + +; Negative test were indices aren't paired properly +define <4 x i32> @pmaddwd_bad_indices(<8 x i16>* %Aptr, <8 x i16>* %Bptr) { +; SSE2-LABEL: pmaddwd_bad_indices: +; SSE2: # %bb.0: +; SSE2-NEXT: movdqa (%rdi), %xmm0 +; SSE2-NEXT: movdqa (%rsi), %xmm1 +; SSE2-NEXT: pshuflw {{.*#+}} xmm2 = xmm1[0,2,2,3,4,5,6,7] +; SSE2-NEXT: pshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,4,6,6,7] +; SSE2-NEXT: pshufd {{.*#+}} xmm2 = xmm2[0,2,2,3] +; SSE2-NEXT: pshuflw {{.*#+}} xmm3 = xmm0[2,1,2,3,4,5,6,7] +; SSE2-NEXT: pshufhw {{.*#+}} xmm3 = xmm3[0,1,2,3,6,5,6,7] +; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm3[0,2,2,3] +; SSE2-NEXT: pshuflw {{.*#+}} xmm3 = xmm3[1,0,3,2,4,5,6,7] +; SSE2-NEXT: movdqa %xmm3, %xmm4 +; SSE2-NEXT: pmulhw %xmm2, %xmm4 +; SSE2-NEXT: pmullw %xmm2, %xmm3 +; SSE2-NEXT: punpcklwd {{.*#+}} xmm3 = xmm3[0],xmm4[0],xmm3[1],xmm4[1],xmm3[2],xmm4[2],xmm3[3],xmm4[3] +; SSE2-NEXT: pshuflw {{.*#+}} xmm0 = xmm0[0,3,2,3,4,5,6,7] +; SSE2-NEXT: pshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,7,6,7] +; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] +; SSE2-NEXT: pshuflw {{.*#+}} xmm1 = xmm1[3,1,2,3,4,5,6,7] +; SSE2-NEXT: pshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,7,5,6,7] +; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] +; SSE2-NEXT: pshuflw {{.*#+}} xmm1 = xmm1[1,0,3,2,4,5,6,7] +; SSE2-NEXT: movdqa %xmm0, %xmm2 +; SSE2-NEXT: pmulhw %xmm1, %xmm2 +; SSE2-NEXT: pmullw %xmm1, %xmm0 +; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3] +; SSE2-NEXT: paddd %xmm3, %xmm0 +; SSE2-NEXT: retq +; +; AVX-LABEL: pmaddwd_bad_indices: +; AVX: # %bb.0: +; AVX-NEXT: vmovdqa (%rdi), %xmm0 +; AVX-NEXT: vmovdqa (%rsi), %xmm1 +; AVX-NEXT: vpshufb {{.*#+}} xmm2 = xmm0[2,3,4,5,10,11,12,13,12,13,10,11,12,13,14,15] +; AVX-NEXT: vpmovsxwd %xmm2, %xmm2 +; AVX-NEXT: vpshufb {{.*#+}} xmm3 = xmm1[0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15] +; AVX-NEXT: vpmovsxwd %xmm3, %xmm3 +; AVX-NEXT: vpmulld %xmm3, %xmm2, %xmm2 +; AVX-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[0,1,6,7,8,9,14,15,8,9,14,15,12,13,14,15] +; AVX-NEXT: vpmovsxwd %xmm0, %xmm0 +; AVX-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15] +; AVX-NEXT: vpmovsxwd %xmm1, %xmm1 +; AVX-NEXT: vpmulld %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; AVX-NEXT: retq + %A = load <8 x i16>, <8 x i16>* %Aptr + %B = load <8 x i16>, <8 x i16>* %Bptr + %A_even = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> + %A_odd = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> + %B_even = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> ; different indices than A + %B_odd = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> ; different indices than A + %A_even_ext = sext <4 x i16> %A_even to <4 x i32> + %B_even_ext = sext <4 x i16> %B_even to <4 x i32> + %A_odd_ext = sext <4 x i16> %A_odd to <4 x i32> + %B_odd_ext = sext <4 x i16> %B_odd to <4 x i32> + %even_mul = mul <4 x i32> %A_even_ext, %B_even_ext + %odd_mul = mul <4 x i32> %A_odd_ext, %B_odd_ext + %add = add <4 x i32> %even_mul, %odd_mul + ret <4 x i32> %add +}