Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -34583,6 +34583,75 @@ DAG.getConstant(0, DL, VT), NewCmp); } +static SDValue combineMAddPattern(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + + SDValue MulOp, Phi; + if (Op0.getOpcode() == ISD::MUL) { + MulOp = Op0; + Phi = Op1; + } else if (Op1.getOpcode() == ISD::MUL) { + MulOp = Op1; + Phi = Op0; + } else + return SDValue(); + + ShrinkMode Mode; + if (!canReduceVMulWidth(MulOp.getNode(), DAG, Mode)) + return SDValue(); + + // SSSE3 has 8bit PMADDUBSW support, otherwise use 16bit PMADDWD + if (!Subtarget.hasSSSE3() && (Mode == MULS8 || Mode == MULU8)) + Mode = (Mode == MULS8) ? MULS16 : MULU16; + + MVT::SimpleValueType ReducedType = + (Mode == MULS8 || Mode == MULU8) ? MVT::i8 : MVT::i16; + EVT VT = N->getValueType(0); + + unsigned RegSize = 128; + if (Subtarget.hasBWI()) + RegSize = 512; + else if (Subtarget.hasAVX2()) + RegSize = 256; + unsigned VectorSize = + VT.getVectorNumElements() * (ReducedType == MVT::i8 ? 8 : 16); + // If the vector size is less than 128, or greater than the supported RegSize, + // do not use PMADD. + if (VectorSize < 128 || VectorSize > RegSize) + return SDValue(); + + SDLoc DL(N); + SDValue N0 = MulOp->getOperand(0); + SDValue N1 = MulOp->getOperand(1); + EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), ReducedType, + VT.getVectorNumElements()); + EVT MAddVT = EVT::getVectorVT(*DAG.getContext(), + ReducedType == MVT::i8 ? MVT::i16 : MVT::i32, + VT.getVectorNumElements() / 2); + + // Shrink the operands of mul. + SDValue NewN0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N0); + SDValue NewN1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N1); + + // Madd vector size is half of the original vector size + SDValue Madd = DAG.getNode(ReducedType == MVT::i8 ? X86ISD::VPMADDUBSW + : X86ISD::VPMADDWD, + DL, MAddVT, NewN0, NewN1); + if (ReducedType == MVT::i8) { + // MAddVT is i16, need to sign/zero extend to i32 + MAddVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, + VT.getVectorNumElements() / 2); + Madd = DAG.getNode(Mode == MULS8 ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, + MAddVT, Madd); + } + // Fill the rest of the output with 0 + SDValue Zero = getZeroVector(Madd.getSimpleValueType(), Subtarget, DAG, DL); + SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, Zero); + return DAG.getNode(ISD::ADD, DL, VT, Concat, Phi); +} + static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { SDLoc DL(N); @@ -34660,6 +34729,8 @@ if (Flags->hasVectorReduction()) { if (SDValue Sad = combineLoopSADPattern(N, DAG, Subtarget)) return Sad; + if (SDValue Sad = combineMAddPattern(N, DAG, Subtarget)) + return Sad; } EVT VT = N->getValueType(0); SDValue Op0 = N->getOperand(0); Index: test/CodeGen/X86/madd.ll =================================================================== --- /dev/null +++ test/CodeGen/X86/madd.ll @@ -0,0 +1,205 @@ +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+sse2 | FileCheck %s --check-prefix=SSE2 +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+ssse3 | FileCheck %s --check-prefix=SSSE3 +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx2 | FileCheck %s --check-prefix=AVX2 +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512f | FileCheck %s --check-prefix=AVX512 +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512bw | FileCheck %s --check-prefix=AVX512 + +;SSE2-label: @_Z10test_shortPsS_i +;SSE2: movdqu +;SSE2-NEXT: movdqu +;SSE2-NEXT: pmaddwd +;SSE2-NEXT: paddd + +;SSSE3-label: @_Z10test_shortPsS_i +;SSSE3: movdqu +;SSSE3-NEXT: movdqu +;SSSE3-NEXT: pmaddwd +;SSSE3-NEXT: paddd + +;AVX2-label: @_Z10test_shortPsS_i +;AVX2: vmovdqu +;AVX2-NEXT: vpmaddwd +;AVX2-NEXT: vinserti128 +;AVX2-NEXT: vpaddd + +;AVX512-label: @_Z10test_shortPsS_i +;AVX512: vmovdqu +;AVX512-NEXT: vpmaddwd +;AVX512-NEXT: vinserti128 +;AVX512-NEXT: vpaddd + +define i32 @_Z10test_shortPsS_i(i16* nocapture readonly, i16* nocapture readonly, i32) local_unnamed_addr #0 { + %4 = icmp sgt i32 %2, 0 + br i1 %4, label %.lr.ph.preheader, label %._crit_edge + +.lr.ph.preheader: ; preds = %3 + %wide.trip.count = zext i32 %2 to i64 + %min.iters.check = icmp ult i32 %2, 8 + br i1 %min.iters.check, label %.lr.ph.preheader19, label %min.iters.checked + +.lr.ph.preheader19: ; preds = %middle.block, %min.iters.checked, %.lr.ph.preheader + %indvars.iv.ph = phi i64 [ 0, %min.iters.checked ], [ 0, %.lr.ph.preheader ], [ %n.vec, %middle.block ] + %.01112.ph = phi i32 [ 0, %min.iters.checked ], [ 0, %.lr.ph.preheader ], [ %15, %middle.block ] + br label %.lr.ph + +min.iters.checked: ; preds = %.lr.ph.preheader + %5 = and i32 %2, 7 + %n.mod.vf = zext i32 %5 to i64 + %n.vec = sub nsw i64 %wide.trip.count, %n.mod.vf + %cmp.zero = icmp eq i64 %n.vec, 0 + br i1 %cmp.zero, label %.lr.ph.preheader19, label %vector.body.preheader + +vector.body.preheader: ; preds = %min.iters.checked + br label %vector.body + +vector.body: ; preds = %vector.body.preheader, %vector.body + %index = phi i64 [ %index.next, %vector.body ], [ 0, %vector.body.preheader ] + %vec.phi = phi <8 x i32> [ %13, %vector.body ], [ zeroinitializer, %vector.body.preheader ] + %6 = getelementptr inbounds i16, i16* %0, i64 %index + %7 = bitcast i16* %6 to <8 x i16>* + %wide.load = load <8 x i16>, <8 x i16>* %7, align 2 + %8 = sext <8 x i16> %wide.load to <8 x i32> + %9 = getelementptr inbounds i16, i16* %1, i64 %index + %10 = bitcast i16* %9 to <8 x i16>* + %wide.load14 = load <8 x i16>, <8 x i16>* %10, align 2 + %11 = sext <8 x i16> %wide.load14 to <8 x i32> + %12 = mul nsw <8 x i32> %11, %8 + %13 = add nsw <8 x i32> %12, %vec.phi + %index.next = add i64 %index, 8 + %14 = icmp eq i64 %index.next, %n.vec + br i1 %14, label %middle.block, label %vector.body + +middle.block: ; preds = %vector.body + %rdx.shuf = shufflevector <8 x i32> %13, <8 x i32> undef, <8 x i32> + %bin.rdx = add <8 x i32> %13, %rdx.shuf + %rdx.shuf15 = shufflevector <8 x i32> %bin.rdx, <8 x i32> undef, <8 x i32> + %bin.rdx16 = add <8 x i32> %bin.rdx, %rdx.shuf15 + %rdx.shuf17 = shufflevector <8 x i32> %bin.rdx16, <8 x i32> undef, <8 x i32> + %bin.rdx18 = add <8 x i32> %bin.rdx16, %rdx.shuf17 + %15 = extractelement <8 x i32> %bin.rdx18, i32 0 + %cmp.n = icmp eq i32 %5, 0 + br i1 %cmp.n, label %._crit_edge, label %.lr.ph.preheader19 + +._crit_edge.loopexit: ; preds = %.lr.ph + br label %._crit_edge + +._crit_edge: ; preds = %._crit_edge.loopexit, %middle.block, %3 + %.011.lcssa = phi i32 [ 0, %3 ], [ %15, %middle.block ], [ %23, %._crit_edge.loopexit ] + ret i32 %.011.lcssa + +.lr.ph: ; preds = %.lr.ph.preheader19, %.lr.ph + %indvars.iv = phi i64 [ %indvars.iv.next, %.lr.ph ], [ %indvars.iv.ph, %.lr.ph.preheader19 ] + %.01112 = phi i32 [ %23, %.lr.ph ], [ %.01112.ph, %.lr.ph.preheader19 ] + %16 = getelementptr inbounds i16, i16* %0, i64 %indvars.iv + %17 = load i16, i16* %16, align 2 + %18 = sext i16 %17 to i32 + %19 = getelementptr inbounds i16, i16* %1, i64 %indvars.iv + %20 = load i16, i16* %19, align 2 + %21 = sext i16 %20 to i32 + %22 = mul nsw i32 %21, %18 + %23 = add nsw i32 %22, %.01112 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count + br i1 %exitcond, label %._crit_edge.loopexit, label %.lr.ph +} + +;SSSE3-label: @_Z9test_charPcS_i +;SSSE3: movdqu +;SSSE3-NEXT: movdqu +;SSSE3-NEXT: pmaddubsw +;SSSE3-NEXT: punpckhwd +;SSSE3-NEXT: psrad +;SSSE3-NEXT: paddd +;SSSE3-NEXT: punpcklwd +;SSSE3-NEXT: psrad +;SSSE3-NEXT: paddd + +;AVX2-label: @_Z9test_charPcS_i +;AVX2: vmovdqu +;AVX2-NEXT: vpmaddubsw +;AVX2-NEXT: vpmovsxwd +;AVX2-NEXT: vpaddd + +;AVX512-label: @_Z9test_charPcS_i +;AVX512: vmovdqu +;AVX512-NEXT: vpmaddubsw +;AVX512-NEXT: vpmovsxwd +;AVX512-NEXT: vinserti64x4 +;AVX512-NEXT: vpaddd + +define i32 @_Z9test_charPcS_i(i8* nocapture readonly, i8* nocapture readonly, i32) local_unnamed_addr #0 { + %4 = icmp sgt i32 %2, 0 + br i1 %4, label %.lr.ph.preheader, label %._crit_edge + +.lr.ph.preheader: ; preds = %3 + %wide.trip.count = zext i32 %2 to i64 + %min.iters.check = icmp ult i32 %2, 16 + br i1 %min.iters.check, label %.lr.ph.preheader21, label %min.iters.checked + +.lr.ph.preheader21: ; preds = %middle.block, %min.iters.checked, %.lr.ph.preheader + %indvars.iv.ph = phi i64 [ 0, %min.iters.checked ], [ 0, %.lr.ph.preheader ], [ %n.vec, %middle.block ] + %.01112.ph = phi i32 [ 0, %min.iters.checked ], [ 0, %.lr.ph.preheader ], [ %15, %middle.block ] + br label %.lr.ph + +min.iters.checked: ; preds = %.lr.ph.preheader + %5 = and i32 %2, 15 + %n.mod.vf = zext i32 %5 to i64 + %n.vec = sub nsw i64 %wide.trip.count, %n.mod.vf + %cmp.zero = icmp eq i64 %n.vec, 0 + br i1 %cmp.zero, label %.lr.ph.preheader21, label %vector.body.preheader + +vector.body.preheader: ; preds = %min.iters.checked + br label %vector.body + +vector.body: ; preds = %vector.body.preheader, %vector.body + %index = phi i64 [ %index.next, %vector.body ], [ 0, %vector.body.preheader ] + %vec.phi = phi <16 x i32> [ %13, %vector.body ], [ zeroinitializer, %vector.body.preheader ] + %6 = getelementptr inbounds i8, i8* %0, i64 %index + %7 = bitcast i8* %6 to <16 x i8>* + %wide.load = load <16 x i8>, <16 x i8>* %7, align 1 + %8 = sext <16 x i8> %wide.load to <16 x i32> + %9 = getelementptr inbounds i8, i8* %1, i64 %index + %10 = bitcast i8* %9 to <16 x i8>* + %wide.load14 = load <16 x i8>, <16 x i8>* %10, align 1 + %11 = sext <16 x i8> %wide.load14 to <16 x i32> + %12 = mul nsw <16 x i32> %11, %8 + %13 = add nsw <16 x i32> %12, %vec.phi + %index.next = add i64 %index, 16 + %14 = icmp eq i64 %index.next, %n.vec + br i1 %14, label %middle.block, label %vector.body + +middle.block: ; preds = %vector.body + %rdx.shuf = shufflevector <16 x i32> %13, <16 x i32> undef, <16 x i32> + %bin.rdx = add <16 x i32> %13, %rdx.shuf + %rdx.shuf15 = shufflevector <16 x i32> %bin.rdx, <16 x i32> undef, <16 x i32> + %bin.rdx16 = add <16 x i32> %bin.rdx, %rdx.shuf15 + %rdx.shuf17 = shufflevector <16 x i32> %bin.rdx16, <16 x i32> undef, <16 x i32> + %bin.rdx18 = add <16 x i32> %bin.rdx16, %rdx.shuf17 + %rdx.shuf19 = shufflevector <16 x i32> %bin.rdx18, <16 x i32> undef, <16 x i32> + %bin.rdx20 = add <16 x i32> %bin.rdx18, %rdx.shuf19 + %15 = extractelement <16 x i32> %bin.rdx20, i32 0 + %cmp.n = icmp eq i32 %5, 0 + br i1 %cmp.n, label %._crit_edge, label %.lr.ph.preheader21 + +._crit_edge.loopexit: ; preds = %.lr.ph + br label %._crit_edge + +._crit_edge: ; preds = %._crit_edge.loopexit, %middle.block, %3 + %.011.lcssa = phi i32 [ 0, %3 ], [ %15, %middle.block ], [ %23, %._crit_edge.loopexit ] + ret i32 %.011.lcssa + +.lr.ph: ; preds = %.lr.ph.preheader21, %.lr.ph + %indvars.iv = phi i64 [ %indvars.iv.next, %.lr.ph ], [ %indvars.iv.ph, %.lr.ph.preheader21 ] + %.01112 = phi i32 [ %23, %.lr.ph ], [ %.01112.ph, %.lr.ph.preheader21 ] + %16 = getelementptr inbounds i8, i8* %0, i64 %indvars.iv + %17 = load i8, i8* %16, align 1 + %18 = sext i8 %17 to i32 + %19 = getelementptr inbounds i8, i8* %1, i64 %indvars.iv + %20 = load i8, i8* %19, align 1 + %21 = sext i8 %20 to i32 + %22 = mul nsw i32 %21, %18 + %23 = add nsw i32 %22, %.01112 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count + br i1 %exitcond, label %._crit_edge.loopexit, label %.lr.ph +}