Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -1344,6 +1344,8 @@ setOperationAction(ISD::UINT_TO_FP, MVT::v8i64, Legal); setOperationAction(ISD::FP_TO_SINT, MVT::v8i64, Legal); setOperationAction(ISD::FP_TO_UINT, MVT::v8i64, Legal); + + setOperationAction(ISD::MUL, MVT::v8i64, Legal); } if (Subtarget.hasCDI()) { @@ -1432,6 +1434,8 @@ setOperationAction(ISD::UINT_TO_FP, VT, Legal); setOperationAction(ISD::FP_TO_SINT, VT, Legal); setOperationAction(ISD::FP_TO_UINT, VT, Legal); + + setOperationAction(ISD::MUL, VT, Legal); } } @@ -5105,10 +5109,11 @@ template SDValue SplitOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget, const SDLoc &DL, EVT VT, ArrayRef Ops, - F Builder) { + F Builder, bool CheckBWI = true) { assert(Subtarget.hasSSE2() && "Target assumed to support at least SSE2"); unsigned NumSubs = 1; - if (Subtarget.useBWIRegs()) { + if ((CheckBWI && Subtarget.useBWIRegs()) || + (!CheckBWI && Subtarget.useAVX512Regs())) { if (VT.getSizeInBits() > 512) { NumSubs = VT.getSizeInBits() / 512; assert((VT.getSizeInBits() % 512) == 0 && "Illegal vector size"); @@ -5147,9 +5152,10 @@ template SDValue SplitBinaryOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget, const SDLoc &DL, EVT VT, SDValue Op0, - SDValue Op1, F Builder) { + SDValue Op1, F Builder, bool CheckBWI = true) { SDValue Ops[] = {Op0, Op1}; - return SplitOpsAndApply(DAG, Subtarget, DL, VT, makeArrayRef(Ops), Builder); + return SplitOpsAndApply(DAG, Subtarget, DL, VT, makeArrayRef(Ops), Builder, + CheckBWI); } // Return true if the instruction zeroes the unused upper part of the @@ -22515,13 +22521,7 @@ assert((VT == MVT::v2i64 || VT == MVT::v4i64 || VT == MVT::v8i64) && "Only know how to lower V2I64/V4I64/V8I64 multiply"); - - // MULDQ returns the 64-bit result of the signed multiplication of the lower - // 32-bits. We can lower with this if the sign bits stretch that far. - if (Subtarget.hasSSE41() && DAG.ComputeNumSignBits(A) > 32 && - DAG.ComputeNumSignBits(B) > 32) { - return DAG.getNode(X86ISD::PMULDQ, dl, VT, A, B); - } + assert(!Subtarget.hasDQI() && "DQI should use MULLQ"); // Ahi = psrlqi(a, 32); // Bhi = psrlqi(b, 32); @@ -22540,11 +22540,6 @@ bool AHiIsZero = DAG.MaskedValueIsZero(A, UpperBitsMask); bool BHiIsZero = DAG.MaskedValueIsZero(B, UpperBitsMask); - // If DQI is supported we can use MULLQ, but MULUDQ is still better if the - // the high bits are known to be zero. - if (Subtarget.hasDQI() && (!AHiIsZero || !BHiIsZero)) - return Op; - SDValue Zero = getZeroVector(VT, Subtarget, DAG, dl); // Only multiply lo/hi halves that aren't known to be zero. @@ -33089,6 +33084,47 @@ DAG.getBitcast(WVT, N1), PMADDWDBuilder); } +static SDValue combineMulToPMULDQ(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (!Subtarget.hasSSE2()) + return SDValue(); + + EVT VT = N->getValueType(0); + + // Only support vXi64 vectors. + if (!VT.isVector() || VT.getVectorElementType() != MVT::i64 || + !DAG.getTargetLoweringInfo().isTypeLegal(VT)) + return SDValue(); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // MULDQ returns the 64-bit result of the signed multiplication of the lower + // 32-bits. We can lower with this if the sign bits stretch that far. + if (Subtarget.hasSSE41() && DAG.ComputeNumSignBits(N0) > 32 && + DAG.ComputeNumSignBits(N1) > 32) { + auto PMULDQBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef Ops) { + return DAG.getNode(X86ISD::PMULDQ, DL, Ops[0].getValueType(), Ops); + }; + return SplitBinaryOpsAndApply(DAG, Subtarget, SDLoc(N), VT, + N0, N1, PMULDQBuilder, /*CheckBWI*/false); + } + + // If the upper bits are zero we can use a single pmuludq. + APInt Mask = APInt::getHighBitsSet(64, 32); + if (DAG.MaskedValueIsZero(N0, Mask) && DAG.MaskedValueIsZero(N1, Mask)) { + auto PMULDQBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef Ops) { + return DAG.getNode(X86ISD::PMULUDQ, DL, Ops[0].getValueType(), Ops); + }; + return SplitBinaryOpsAndApply(DAG, Subtarget, SDLoc(N), VT, + N0, N1, PMULDQBuilder, /*CheckBWI*/false); + } + + return SDValue(); +} + /// Optimize a single multiply with constant into two operations in order to /// implement it with two cheaper instructions, e.g. LEA + SHL, LEA + LEA. static SDValue combineMul(SDNode *N, SelectionDAG &DAG, @@ -33099,6 +33135,9 @@ if (SDValue V = combineMulToPMADDWD(N, DAG, Subtarget)) return V; + if (SDValue V = combineMulToPMULDQ(N, DAG, Subtarget)) + return V; + if (DCI.isBeforeLegalize() && VT.isVector()) return reduceVMULWidth(N, DAG, Subtarget); @@ -35701,7 +35740,7 @@ // X86 is rubbish at scalar and vector i64 multiplies (until AVX512DQ) - its // better to truncate if we have the chance. if (SrcVT.getScalarType() == MVT::i64 && TLI.isOperationLegal(Opcode, VT) && - !Subtarget.hasDQI()) + !TLI.isOperationLegal(Opcode, SrcVT)) return TruncateArithmetic(Src.getOperand(0), Src.getOperand(1)); LLVM_FALLTHROUGH; case ISD::ADD: { Index: test/CodeGen/X86/mulvi32.ll =================================================================== --- test/CodeGen/X86/mulvi32.ll +++ test/CodeGen/X86/mulvi32.ll @@ -179,25 +179,55 @@ ; ; AVX1-LABEL: _mul4xi32toi64a: ; AVX1: # %bb.0: -; AVX1-NEXT: vpxor %xmm2, %xmm2, %xmm2 -; AVX1-NEXT: vpunpckhdq {{.*#+}} xmm3 = xmm1[2],xmm2[2],xmm1[3],xmm2[3] -; AVX1-NEXT: vpunpckhdq {{.*#+}} xmm2 = xmm0[2],xmm2[2],xmm0[3],xmm2[3] -; AVX1-NEXT: vpmuludq %xmm3, %xmm2, %xmm2 -; AVX1-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero -; AVX1-NEXT: vpmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero -; AVX1-NEXT: vpmuludq %xmm1, %xmm0, %xmm0 -; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm0, %ymm0 +; AVX1-NEXT: vmovd %xmm0, %r8d +; AVX1-NEXT: vpextrd $1, %xmm0, %r9d +; AVX1-NEXT: vpextrd $2, %xmm0, %edx +; AVX1-NEXT: vpextrd $3, %xmm0, %esi +; AVX1-NEXT: vmovd %xmm1, %r10d +; AVX1-NEXT: vpextrd $1, %xmm1, %eax +; AVX1-NEXT: vpextrd $2, %xmm1, %ecx +; AVX1-NEXT: vpextrd $3, %xmm1, %edi +; AVX1-NEXT: vmovq %rdi, %xmm0 +; AVX1-NEXT: vmovq %rcx, %xmm1 +; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm0 = xmm1[0],xmm0[0] +; AVX1-NEXT: vmovq %rsi, %xmm1 +; AVX1-NEXT: vmovq %rdx, %xmm2 +; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0] +; AVX1-NEXT: vpmuludq %xmm0, %xmm1, %xmm0 +; AVX1-NEXT: vmovq %rax, %xmm1 +; AVX1-NEXT: vmovq %r10, %xmm2 +; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0] +; AVX1-NEXT: vmovq %r9, %xmm2 +; AVX1-NEXT: vmovq %r8, %xmm3 +; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm2 = xmm3[0],xmm2[0] +; AVX1-NEXT: vpmuludq %xmm1, %xmm2, %xmm1 +; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm1, %ymm0 ; AVX1-NEXT: retq ; ; AVX2-LABEL: _mul4xi32toi64a: ; AVX2: # %bb.0: -; AVX2-NEXT: vpxor %xmm2, %xmm2, %xmm2 -; AVX2-NEXT: vpunpckhdq {{.*#+}} xmm3 = xmm0[2],xmm2[2],xmm0[3],xmm2[3] -; AVX2-NEXT: vpmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero -; AVX2-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm0 -; AVX2-NEXT: vpunpckhdq {{.*#+}} xmm2 = xmm1[2],xmm2[2],xmm1[3],xmm2[3] -; AVX2-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero -; AVX2-NEXT: vinserti128 $1, %xmm2, %ymm1, %ymm1 +; AVX2-NEXT: vmovd %xmm0, %r10d +; AVX2-NEXT: vpextrd $1, %xmm0, %ecx +; AVX2-NEXT: vpextrd $2, %xmm0, %edx +; AVX2-NEXT: vpextrd $3, %xmm0, %esi +; AVX2-NEXT: vmovd %xmm1, %r8d +; AVX2-NEXT: vpextrd $1, %xmm1, %r9d +; AVX2-NEXT: vpextrd $2, %xmm1, %edi +; AVX2-NEXT: vpextrd $3, %xmm1, %eax +; AVX2-NEXT: vmovq %rsi, %xmm0 +; AVX2-NEXT: vmovq %rdx, %xmm1 +; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm0 = xmm1[0],xmm0[0] +; AVX2-NEXT: vmovq %rcx, %xmm1 +; AVX2-NEXT: vmovq %r10, %xmm2 +; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0] +; AVX2-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm0 +; AVX2-NEXT: vmovq %rax, %xmm1 +; AVX2-NEXT: vmovq %rdi, %xmm2 +; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0] +; AVX2-NEXT: vmovq %r9, %xmm2 +; AVX2-NEXT: vmovq %r8, %xmm3 +; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm2 = xmm3[0],xmm2[0] +; AVX2-NEXT: vinserti128 $1, %xmm1, %ymm2, %ymm1 ; AVX2-NEXT: vpmuludq %ymm1, %ymm0, %ymm0 ; AVX2-NEXT: retq %f00 = extractelement <4 x i32> %0, i32 0