Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -39059,17 +39059,8 @@ if (!Subtarget.hasSSE2()) return SDValue(); - SDValue MulOp = N->getOperand(0); - SDValue Phi = N->getOperand(1); - - if (MulOp.getOpcode() != ISD::MUL) - std::swap(MulOp, Phi); - if (MulOp.getOpcode() != ISD::MUL) - return SDValue(); - - ShrinkMode Mode; - if (!canReduceVMulWidth(MulOp.getNode(), DAG, Mode) || Mode == MULU16) - return SDValue(); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); EVT VT = N->getValueType(0); @@ -39078,28 +39069,49 @@ if (VT.getVectorNumElements() < 8) return SDValue(); + if (Op0.getOpcode() != ISD::MUL) + std::swap(Op0, Op1); + if (Op0.getOpcode() != ISD::MUL) + return SDValue(); + + ShrinkMode Mode; + if (!canReduceVMulWidth(Op0.getNode(), DAG, Mode) || Mode == MULU16) + return SDValue(); + SDLoc DL(N); EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorNumElements()); EVT MAddVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, VT.getVectorNumElements() / 2); - // Shrink the operands of mul. - SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(0)); - SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(1)); - // Madd vector size is half of the original vector size auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, ArrayRef Ops) { MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32); return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Ops); }; - SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 }, - PMADDWDBuilder); - // Fill the rest of the output with 0 - SDValue Zero = getZeroVector(Madd.getSimpleValueType(), Subtarget, DAG, DL); - SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, Zero); - return DAG.getNode(ISD::ADD, DL, VT, Concat, Phi); + + auto BuildPMADDWD = [&](SDValue Mul) { + // Shrink the operands of mul. + SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(0)); + SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(1)); + + SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 }, + PMADDWDBuilder); + // Fill the rest of the output with 0 + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, + DAG.getConstant(0, DL, MAddVT)); + }; + + Op0 = BuildPMADDWD(Op0); + + // It's possible that Op1 is also a mul we can reduce. + if (Op1.getOpcode() == ISD::MUL && + canReduceVMulWidth(Op1.getNode(), DAG, Mode) && Mode != MULU16) { + Op1 = BuildPMADDWD(Op1); + } + + return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1); } static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, Index: llvm/trunk/test/CodeGen/X86/madd.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/madd.ll +++ llvm/trunk/test/CodeGen/X86/madd.ll @@ -2671,17 +2671,11 @@ ; SSE2: # %bb.0: ; SSE2-NEXT: movdqu (%rdi), %xmm0 ; SSE2-NEXT: movdqu (%rsi), %xmm1 -; SSE2-NEXT: movdqa %xmm0, %xmm2 -; SSE2-NEXT: pmulhw %xmm1, %xmm2 -; SSE2-NEXT: pmullw %xmm1, %xmm0 -; SSE2-NEXT: movdqa %xmm0, %xmm1 -; SSE2-NEXT: punpckhwd {{.*#+}} xmm1 = xmm1[4],xmm2[4],xmm1[5],xmm2[5],xmm1[6],xmm2[6],xmm1[7],xmm2[7] -; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3] -; SSE2-NEXT: paddd %xmm1, %xmm0 -; SSE2-NEXT: movdqu (%rdx), %xmm1 +; SSE2-NEXT: pmaddwd %xmm0, %xmm1 +; SSE2-NEXT: movdqu (%rdx), %xmm0 ; SSE2-NEXT: movdqu (%rcx), %xmm2 -; SSE2-NEXT: pmaddwd %xmm1, %xmm2 -; SSE2-NEXT: paddd %xmm0, %xmm2 +; SSE2-NEXT: pmaddwd %xmm0, %xmm2 +; SSE2-NEXT: paddd %xmm1, %xmm2 ; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm2[2,3,0,1] ; SSE2-NEXT: paddd %xmm2, %xmm0 ; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] @@ -2691,14 +2685,9 @@ ; ; AVX1-LABEL: madd_double_reduction: ; AVX1: # %bb.0: -; AVX1-NEXT: vpmovsxwd (%rdi), %xmm0 -; AVX1-NEXT: vpmovsxwd 8(%rdi), %xmm1 -; AVX1-NEXT: vpmovsxwd (%rsi), %xmm2 -; AVX1-NEXT: vpmulld %xmm2, %xmm0, %xmm0 -; AVX1-NEXT: vpmovsxwd 8(%rsi), %xmm2 -; AVX1-NEXT: vpmulld %xmm2, %xmm1, %xmm1 -; AVX1-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX1-NEXT: vmovdqu (%rdi), %xmm0 ; AVX1-NEXT: vmovdqu (%rdx), %xmm1 +; AVX1-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 ; AVX1-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1 ; AVX1-NEXT: vpaddd %xmm0, %xmm1, %xmm0 ; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] @@ -2709,10 +2698,9 @@ ; ; AVX256-LABEL: madd_double_reduction: ; AVX256: # %bb.0: -; AVX256-NEXT: vpmovsxwd (%rdi), %ymm0 -; AVX256-NEXT: vpmovsxwd (%rsi), %ymm1 -; AVX256-NEXT: vpmulld %ymm1, %ymm0, %ymm0 +; AVX256-NEXT: vmovdqu (%rdi), %xmm0 ; AVX256-NEXT: vmovdqu (%rdx), %xmm1 +; AVX256-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 ; AVX256-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1 ; AVX256-NEXT: vpaddd %ymm0, %ymm1, %ymm0 ; AVX256-NEXT: vextracti128 $1, %ymm0, %xmm1