Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -39108,17 +39108,8 @@ if (!Subtarget.hasSSE2()) return SDValue(); - SDValue MulOp = N->getOperand(0); - SDValue Phi = N->getOperand(1); - - if (MulOp.getOpcode() != ISD::MUL) - std::swap(MulOp, Phi); - if (MulOp.getOpcode() != ISD::MUL) - return SDValue(); - - ShrinkMode Mode; - if (!canReduceVMulWidth(MulOp.getNode(), DAG, Mode) || Mode == MULU16) - return SDValue(); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); EVT VT = N->getValueType(0); @@ -39127,28 +39118,49 @@ if (VT.getVectorNumElements() < 8) return SDValue(); + if (Op0.getOpcode() != ISD::MUL) + std::swap(Op0, Op1); + if (Op0.getOpcode() != ISD::MUL) + return SDValue(); + + ShrinkMode Mode; + if (!canReduceVMulWidth(Op0.getNode(), DAG, Mode) || Mode == MULU16) + return SDValue(); + SDLoc DL(N); EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorNumElements()); EVT MAddVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, VT.getVectorNumElements() / 2); - // Shrink the operands of mul. - SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(0)); - SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(1)); - // Madd vector size is half of the original vector size auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, ArrayRef Ops) { MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32); return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Ops); }; - SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 }, - PMADDWDBuilder); - // Fill the rest of the output with 0 - SDValue Zero = getZeroVector(Madd.getSimpleValueType(), Subtarget, DAG, DL); - SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, Zero); - return DAG.getNode(ISD::ADD, DL, VT, Concat, Phi); + + auto BuildPMADDWD = [&](SDValue Mul) { + // Shrink the operands of mul. + SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(0)); + SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(1)); + + SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 }, + PMADDWDBuilder); + // Fill the rest of the output with 0 + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, + DAG.getConstant(0, DL, MAddVT)); + }; + + Op0 = BuildPMADDWD(Op0); + + // It's possible that Op1 is also a mul we can reduce. + if (Op1.getOpcode() == ISD::MUL && + canReduceVMulWidth(Op1.getNode(), DAG, Mode) && Mode != MULU16) { + Op1 = BuildPMADDWD(Op1); + } + + return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1); } static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, Index: test/CodeGen/X86/madd.ll =================================================================== --- test/CodeGen/X86/madd.ll +++ test/CodeGen/X86/madd.ll @@ -2663,3 +2663,71 @@ %add = add <4 x i32> %even_mul, %odd_mul ret <4 x i32> %add } + +; This test contains two multiplies joined by an add. The result of that add is then reduced to a single element. +; SelectionDAGBuilder should tag the joining add as a vector reduction. We need to recognize that both sides can use pmaddwd +define i32 @madd_double_reduction(<8 x i16>* %arg, <8 x i16>* %arg1, <8 x i16>* %arg2, <8 x i16>* %arg3) { +; SSE2-LABEL: madd_double_reduction: +; SSE2: # %bb.0: +; SSE2-NEXT: movdqu (%rdi), %xmm0 +; SSE2-NEXT: movdqu (%rsi), %xmm1 +; SSE2-NEXT: pmaddwd %xmm0, %xmm1 +; SSE2-NEXT: movdqu (%rdx), %xmm0 +; SSE2-NEXT: movdqu (%rcx), %xmm2 +; SSE2-NEXT: pmaddwd %xmm0, %xmm2 +; SSE2-NEXT: paddd %xmm1, %xmm2 +; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm2[2,3,0,1] +; SSE2-NEXT: paddd %xmm2, %xmm0 +; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] +; SSE2-NEXT: paddd %xmm0, %xmm1 +; SSE2-NEXT: movd %xmm1, %eax +; SSE2-NEXT: retq +; +; AVX1-LABEL: madd_double_reduction: +; AVX1: # %bb.0: +; AVX1-NEXT: vmovdqu (%rdi), %xmm0 +; AVX1-NEXT: vmovdqu (%rdx), %xmm1 +; AVX1-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 +; AVX1-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1 +; AVX1-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] +; AVX1-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX1-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX1-NEXT: vmovd %xmm0, %eax +; AVX1-NEXT: retq +; +; AVX256-LABEL: madd_double_reduction: +; AVX256: # %bb.0: +; AVX256-NEXT: vmovdqu (%rdi), %xmm0 +; AVX256-NEXT: vmovdqu (%rdx), %xmm1 +; AVX256-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 +; AVX256-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1 +; AVX256-NEXT: vpaddd %ymm0, %ymm1, %ymm0 +; AVX256-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVX256-NEXT: vpaddd %ymm1, %ymm0, %ymm0 +; AVX256-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] +; AVX256-NEXT: vpaddd %ymm1, %ymm0, %ymm0 +; AVX256-NEXT: vphaddd %ymm0, %ymm0, %ymm0 +; AVX256-NEXT: vmovd %xmm0, %eax +; AVX256-NEXT: vzeroupper +; AVX256-NEXT: retq + %tmp = load <8 x i16>, <8 x i16>* %arg, align 1 + %tmp6 = load <8 x i16>, <8 x i16>* %arg1, align 1 + %tmp7 = sext <8 x i16> %tmp to <8 x i32> + %tmp17 = sext <8 x i16> %tmp6 to <8 x i32> + %tmp19 = mul nsw <8 x i32> %tmp7, %tmp17 + %tmp20 = load <8 x i16>, <8 x i16>* %arg2, align 1 + %tmp21 = load <8 x i16>, <8 x i16>* %arg3, align 1 + %tmp22 = sext <8 x i16> %tmp20 to <8 x i32> + %tmp23 = sext <8 x i16> %tmp21 to <8 x i32> + %tmp25 = mul nsw <8 x i32> %tmp22, %tmp23 + %tmp26 = add nuw nsw <8 x i32> %tmp25, %tmp19 + %tmp29 = shufflevector <8 x i32> %tmp26, <8 x i32> undef, <8 x i32> + %tmp30 = add <8 x i32> %tmp26, %tmp29 + %tmp31 = shufflevector <8 x i32> %tmp30, <8 x i32> undef, <8 x i32> + %tmp32 = add <8 x i32> %tmp30, %tmp31 + %tmp33 = shufflevector <8 x i32> %tmp32, <8 x i32> undef, <8 x i32> + %tmp34 = add <8 x i32> %tmp32, %tmp33 + %tmp35 = extractelement <8 x i32> %tmp34, i64 0 + ret i32 %tmp35 +}