Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -1344,6 +1344,8 @@ setOperationAction(ISD::UINT_TO_FP, MVT::v8i64, Legal); setOperationAction(ISD::FP_TO_SINT, MVT::v8i64, Legal); setOperationAction(ISD::FP_TO_UINT, MVT::v8i64, Legal); + + setOperationAction(ISD::MUL, MVT::v8i64, Legal); } if (Subtarget.hasCDI()) { @@ -1432,6 +1434,8 @@ setOperationAction(ISD::UINT_TO_FP, VT, Legal); setOperationAction(ISD::FP_TO_SINT, VT, Legal); setOperationAction(ISD::FP_TO_UINT, VT, Legal); + + setOperationAction(ISD::MUL, VT, Legal); } } @@ -33050,6 +33054,27 @@ return SDValue(); } +static SDValue getNodeAVXSplit(unsigned Opcode, const SDLoc &DL, EVT VT, + ArrayRef Ops, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + assert(DAG.getTargetLoweringInfo().isTypeLegal(VT) && "Expected legal type"); + assert(VT.isVector() && VT.isInteger() && "Expected integer vector type"); + + if (VT.getSimpleVT().is256BitVector() && !Subtarget.hasAVX2()) { + SmallVector OpsLo(Ops.size()), OpsHi(Ops.size()); + for (unsigned i = 0; i != Ops.size(); ++i) + std::tie(OpsLo[i], OpsHi[i]) = DAG.SplitVector(Ops[i], DL); + + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT); + SDValue Lo = DAG.getNode(Opcode, DL, LoVT, OpsLo); + SDValue Hi = DAG.getNode(Opcode, DL, HiVT, OpsHi); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi); + } + + return DAG.getNode(Opcode, DL, VT, Ops); +} + // If the upper 17 bits of each element are zero then we can use PMADDWD, // which is always at least as quick as PMULLD, expect on KNL. static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG, @@ -33078,15 +33103,42 @@ !DAG.MaskedValueIsZero(N0, Mask17)) return SDValue(); - // Use SplitBinaryOpsAndApply to handle AVX splitting. - auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, - ArrayRef Ops) { - MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32); - return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Ops); - }; - return SplitBinaryOpsAndApply(DAG, Subtarget, SDLoc(N), VT, - DAG.getBitcast(WVT, N0), - DAG.getBitcast(WVT, N1), PMADDWDBuilder); + return getNodeAVXSplit(X86ISD::VPMADDWD, SDLoc(N), VT, + { DAG.getBitcast(WVT, N0), DAG.getBitcast(WVT, N1) }, + DAG, Subtarget); +} + +static SDValue combineMulToPMULDQ(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (!Subtarget.hasSSE2()) + return SDValue(); + + EVT VT = N->getValueType(0); + + // Only support vXi64 vectors. + if (!VT.isVector() || VT.getVectorElementType() != MVT::i64 || + !DAG.getTargetLoweringInfo().isTypeLegal(VT)) + return SDValue(); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // MULDQ returns the 64-bit result of the signed multiplication of the lower + // 32-bits. We can lower with this if the sign bits stretch that far. + if (Subtarget.hasSSE41() && DAG.ComputeNumSignBits(N0) > 32 && + DAG.ComputeNumSignBits(N1) > 32) { + return getNodeAVXSplit(X86ISD::PMULDQ, SDLoc(N), VT, { N0, N1 }, + DAG, Subtarget); + } + + // If the upper bits are zero we can use a single pmuludq. + APInt Mask = APInt::getHighBitsSet(64, 32); + if (DAG.MaskedValueIsZero(N0, Mask) && DAG.MaskedValueIsZero(N1, Mask)) { + return getNodeAVXSplit(X86ISD::PMULUDQ, SDLoc(N), VT, { N0, N1 }, + DAG, Subtarget); + } + + return SDValue(); } /// Optimize a single multiply with constant into two operations in order to @@ -33099,6 +33151,9 @@ if (SDValue V = combineMulToPMADDWD(N, DAG, Subtarget)) return V; + if (SDValue V = combineMulToPMULDQ(N, DAG, Subtarget)) + return V; + if (DCI.isBeforeLegalize() && VT.isVector()) return reduceVMULWidth(N, DAG, Subtarget); @@ -35701,7 +35756,7 @@ // X86 is rubbish at scalar and vector i64 multiplies (until AVX512DQ) - its // better to truncate if we have the chance. if (SrcVT.getScalarType() == MVT::i64 && TLI.isOperationLegal(Opcode, VT) && - !Subtarget.hasDQI()) + !TLI.isOperationLegal(Opcode, SrcVT)) return TruncateArithmetic(Src.getOperand(0), Src.getOperand(1)); LLVM_FALLTHROUGH; case ISD::ADD: { Index: test/CodeGen/X86/mulvi32.ll =================================================================== --- test/CodeGen/X86/mulvi32.ll +++ test/CodeGen/X86/mulvi32.ll @@ -179,25 +179,55 @@ ; ; AVX1-LABEL: _mul4xi32toi64a: ; AVX1: # %bb.0: -; AVX1-NEXT: vpxor %xmm2, %xmm2, %xmm2 -; AVX1-NEXT: vpunpckhdq {{.*#+}} xmm3 = xmm1[2],xmm2[2],xmm1[3],xmm2[3] -; AVX1-NEXT: vpunpckhdq {{.*#+}} xmm2 = xmm0[2],xmm2[2],xmm0[3],xmm2[3] -; AVX1-NEXT: vpmuludq %xmm3, %xmm2, %xmm2 -; AVX1-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero -; AVX1-NEXT: vpmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero -; AVX1-NEXT: vpmuludq %xmm1, %xmm0, %xmm0 -; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm0, %ymm0 +; AVX1-NEXT: vmovd %xmm0, %r8d +; AVX1-NEXT: vpextrd $1, %xmm0, %r9d +; AVX1-NEXT: vpextrd $2, %xmm0, %edx +; AVX1-NEXT: vpextrd $3, %xmm0, %esi +; AVX1-NEXT: vmovd %xmm1, %r10d +; AVX1-NEXT: vpextrd $1, %xmm1, %eax +; AVX1-NEXT: vpextrd $2, %xmm1, %ecx +; AVX1-NEXT: vpextrd $3, %xmm1, %edi +; AVX1-NEXT: vmovq %rdi, %xmm0 +; AVX1-NEXT: vmovq %rcx, %xmm1 +; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm0 = xmm1[0],xmm0[0] +; AVX1-NEXT: vmovq %rsi, %xmm1 +; AVX1-NEXT: vmovq %rdx, %xmm2 +; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0] +; AVX1-NEXT: vpmuludq %xmm0, %xmm1, %xmm0 +; AVX1-NEXT: vmovq %rax, %xmm1 +; AVX1-NEXT: vmovq %r10, %xmm2 +; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0] +; AVX1-NEXT: vmovq %r9, %xmm2 +; AVX1-NEXT: vmovq %r8, %xmm3 +; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm2 = xmm3[0],xmm2[0] +; AVX1-NEXT: vpmuludq %xmm1, %xmm2, %xmm1 +; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm1, %ymm0 ; AVX1-NEXT: retq ; ; AVX2-LABEL: _mul4xi32toi64a: ; AVX2: # %bb.0: -; AVX2-NEXT: vpxor %xmm2, %xmm2, %xmm2 -; AVX2-NEXT: vpunpckhdq {{.*#+}} xmm3 = xmm0[2],xmm2[2],xmm0[3],xmm2[3] -; AVX2-NEXT: vpmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero -; AVX2-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm0 -; AVX2-NEXT: vpunpckhdq {{.*#+}} xmm2 = xmm1[2],xmm2[2],xmm1[3],xmm2[3] -; AVX2-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero -; AVX2-NEXT: vinserti128 $1, %xmm2, %ymm1, %ymm1 +; AVX2-NEXT: vmovd %xmm0, %r10d +; AVX2-NEXT: vpextrd $1, %xmm0, %ecx +; AVX2-NEXT: vpextrd $2, %xmm0, %edx +; AVX2-NEXT: vpextrd $3, %xmm0, %esi +; AVX2-NEXT: vmovd %xmm1, %r8d +; AVX2-NEXT: vpextrd $1, %xmm1, %r9d +; AVX2-NEXT: vpextrd $2, %xmm1, %edi +; AVX2-NEXT: vpextrd $3, %xmm1, %eax +; AVX2-NEXT: vmovq %rsi, %xmm0 +; AVX2-NEXT: vmovq %rdx, %xmm1 +; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm0 = xmm1[0],xmm0[0] +; AVX2-NEXT: vmovq %rcx, %xmm1 +; AVX2-NEXT: vmovq %r10, %xmm2 +; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0] +; AVX2-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm0 +; AVX2-NEXT: vmovq %rax, %xmm1 +; AVX2-NEXT: vmovq %rdi, %xmm2 +; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0] +; AVX2-NEXT: vmovq %r9, %xmm2 +; AVX2-NEXT: vmovq %r8, %xmm3 +; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm2 = xmm3[0],xmm2[0] +; AVX2-NEXT: vinserti128 $1, %xmm1, %ymm2, %ymm1 ; AVX2-NEXT: vpmuludq %ymm1, %ymm0, %ymm0 ; AVX2-NEXT: retq %f00 = extractelement <4 x i32> %0, i32 0