Index: llvm/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/lib/Target/X86/X86ISelLowering.cpp +++ llvm/lib/Target/X86/X86ISelLowering.cpp @@ -49045,10 +49045,10 @@ In0 = N00In; In1 = N01In; - // The input vector sizes must match the output. - // TODO: Insert cast ops to allow different types. - if (In0.getValueSizeInBits() != VT.getSizeInBits() || - In1.getValueSizeInBits() != VT.getSizeInBits()) + // The input vectors must be at least as wide as the output. + // If they are larger than the output, we extract subvector below. + if (In0.getValueSizeInBits() < VT.getSizeInBits() || + In1.getValueSizeInBits() < VT.getSizeInBits()) return SDValue(); } // Mul is commutative so the input vectors can be in any order. @@ -49073,6 +49073,18 @@ OpVT.getVectorNumElements() / 2); return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]); }; + + // If either input vector is wider than the output, extract the low part. + EVT Vec16VT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, + VT.getVectorNumElements() * 2); + if (Vec16VT.bitsLT(In0.getValueType())) { + In0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Vec16VT, In0, + DAG.getIntPtrConstant(0, DL)); + } + if (Vec16VT.bitsLT(In1.getValueType())) { + In1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Vec16VT, In1, + DAG.getIntPtrConstant(0, DL)); + } return SplitOpsAndApply(DAG, Subtarget, DL, VT, { In0, In1 }, PMADDBuilder); } Index: llvm/test/CodeGen/X86/madd.ll =================================================================== --- llvm/test/CodeGen/X86/madd.ll +++ llvm/test/CodeGen/X86/madd.ll @@ -3052,48 +3052,12 @@ define <4 x i32> @input_size_mismatch(<16 x i16> %x, <16 x i16>* %p) { ; SSE2-LABEL: input_size_mismatch: ; SSE2: # %bb.0: -; SSE2-NEXT: movdqa (%rdi), %xmm1 -; SSE2-NEXT: pshuflw {{.*#+}} xmm2 = xmm0[0,2,2,3,4,5,6,7] -; SSE2-NEXT: pshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,4,6,6,7] -; SSE2-NEXT: pshufd {{.*#+}} xmm2 = xmm2[0,2,2,3] -; SSE2-NEXT: pshuflw {{.*#+}} xmm0 = xmm0[3,1,2,3,4,5,6,7] -; SSE2-NEXT: pshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,7,5,6,7] -; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] -; SSE2-NEXT: pshuflw {{.*#+}} xmm0 = xmm0[1,0,3,2,4,5,6,7] -; SSE2-NEXT: pshuflw {{.*#+}} xmm3 = xmm1[0,2,2,3,4,5,6,7] -; SSE2-NEXT: pshufhw {{.*#+}} xmm3 = xmm3[0,1,2,3,4,6,6,7] -; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm3[0,2,2,3] -; SSE2-NEXT: pshuflw {{.*#+}} xmm1 = xmm1[3,1,2,3,4,5,6,7] -; SSE2-NEXT: pshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,7,5,6,7] -; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] -; SSE2-NEXT: pshuflw {{.*#+}} xmm1 = xmm1[1,0,3,2,4,5,6,7] -; SSE2-NEXT: movdqa %xmm2, %xmm4 -; SSE2-NEXT: pmulhw %xmm3, %xmm4 -; SSE2-NEXT: pmullw %xmm3, %xmm2 -; SSE2-NEXT: punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm4[0],xmm2[1],xmm4[1],xmm2[2],xmm4[2],xmm2[3],xmm4[3] -; SSE2-NEXT: movdqa %xmm0, %xmm3 -; SSE2-NEXT: pmulhw %xmm1, %xmm3 -; SSE2-NEXT: pmullw %xmm1, %xmm0 -; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm3[0],xmm0[1],xmm3[1],xmm0[2],xmm3[2],xmm0[3],xmm3[3] -; SSE2-NEXT: paddd %xmm2, %xmm0 +; SSE2-NEXT: pmaddwd (%rdi), %xmm0 ; SSE2-NEXT: retq ; ; AVX-LABEL: input_size_mismatch: ; AVX: # %bb.0: -; AVX-NEXT: vmovdqa {{.*#+}} xmm1 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15] -; AVX-NEXT: vpshufb %xmm1, %xmm0, %xmm2 -; AVX-NEXT: vmovdqa {{.*#+}} xmm3 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15] -; AVX-NEXT: vpshufb %xmm3, %xmm0, %xmm0 -; AVX-NEXT: vmovdqa (%rdi), %xmm4 -; AVX-NEXT: vpshufb %xmm1, %xmm4, %xmm1 -; AVX-NEXT: vpshufb %xmm3, %xmm4, %xmm3 -; AVX-NEXT: vpmovsxwd %xmm2, %xmm2 -; AVX-NEXT: vpmovsxwd %xmm0, %xmm0 -; AVX-NEXT: vpmovsxwd %xmm1, %xmm1 -; AVX-NEXT: vpmulld %xmm1, %xmm2, %xmm1 -; AVX-NEXT: vpmovsxwd %xmm3, %xmm2 -; AVX-NEXT: vpmulld %xmm2, %xmm0, %xmm0 -; AVX-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; AVX-NEXT: vpmaddwd (%rdi), %xmm0, %xmm0 ; AVX-NEXT: vzeroupper ; AVX-NEXT: retq %y = load <16 x i16>, <16 x i16>* %p, align 32 @@ -3119,19 +3083,7 @@ ; ; AVX-LABEL: output_size_mismatch: ; AVX: # %bb.0: -; AVX-NEXT: vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15] -; AVX-NEXT: vpshufb %xmm2, %xmm0, %xmm3 -; AVX-NEXT: vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15] -; AVX-NEXT: vpshufb %xmm4, %xmm0, %xmm0 -; AVX-NEXT: vpshufb %xmm2, %xmm1, %xmm2 -; AVX-NEXT: vpshufb %xmm4, %xmm1, %xmm1 -; AVX-NEXT: vpmovsxwd %xmm3, %xmm3 -; AVX-NEXT: vpmovsxwd %xmm0, %xmm0 -; AVX-NEXT: vpmovsxwd %xmm2, %xmm2 -; AVX-NEXT: vpmulld %xmm2, %xmm3, %xmm2 -; AVX-NEXT: vpmovsxwd %xmm1, %xmm1 -; AVX-NEXT: vpmulld %xmm1, %xmm0, %xmm0 -; AVX-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; AVX-NEXT: vpmaddwd %xmm1, %xmm0, %xmm0 ; AVX-NEXT: vzeroupper ; AVX-NEXT: retq %x0 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32>