diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -19537,33 +19537,35 @@ } } - // Make sure all but the first op are undef. - auto ConcatWithUndef = [](SDValue Concat) { + // Make sure all but the first op are undef or constant. + auto ConcatWithConstantOrUndef = [](SDValue Concat) { return Concat.getOpcode() == ISD::CONCAT_VECTORS && std::all_of(std::next(Concat->op_begin()), Concat->op_end(), - [](const SDValue &Op) { - return Op.isUndef(); - }); + [](const SDValue &Op) { + return Op.isUndef() || + ISD::isBuildVectorOfConstantSDNodes(Op.getNode()); + }); }; // The following pattern is likely to emerge with vector reduction ops. Moving // the binary operation ahead of the concat may allow using a narrower vector // instruction that has better performance than the wide version of the op: - // VBinOp (concat X, undef), (concat Y, undef) --> concat (VBinOp X, Y), VecC - if (ConcatWithUndef(LHS) && ConcatWithUndef(RHS) && + // VBinOp (concat X, undef/constant), (concat Y, undef/constant) --> + // concat (VBinOp X, Y), VecC + if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) && (LHS.hasOneUse() || RHS.hasOneUse())) { - SDValue X = LHS.getOperand(0); - SDValue Y = RHS.getOperand(0); - EVT NarrowVT = X.getValueType(); - if (NarrowVT == Y.getValueType() && + EVT NarrowVT = LHS.getOperand(0).getValueType(); + if (NarrowVT == RHS.getOperand(0).getValueType() && TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) { - // (binop undef, undef) may not return undef, so compute that result. SDLoc DL(N); - SDValue VecC = - DAG.getNode(Opcode, DL, NarrowVT, DAG.getUNDEF(NarrowVT), - DAG.getUNDEF(NarrowVT)); - SmallVector Ops(LHS.getNumOperands(), VecC); - Ops[0] = DAG.getNode(Opcode, DL, NarrowVT, X, Y); + unsigned NumOperands = LHS.getNumOperands(); + SmallVector Ops; + for (unsigned i = 0; i != NumOperands; ++i) { + // This constant fold for operands 1 and up. + Ops.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i), + RHS.getOperand(i))); + } + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Ops); } } diff --git a/llvm/test/CodeGen/X86/madd.ll b/llvm/test/CodeGen/X86/madd.ll --- a/llvm/test/CodeGen/X86/madd.ll +++ b/llvm/test/CodeGen/X86/madd.ll @@ -2720,52 +2720,27 @@ ; SSE2-NEXT: movd %xmm1, %eax ; SSE2-NEXT: retq ; -; AVX1-LABEL: madd_quad_reduction: -; AVX1: # %bb.0: -; AVX1-NEXT: movq {{[0-9]+}}(%rsp), %r10 -; AVX1-NEXT: movq {{[0-9]+}}(%rsp), %rax -; AVX1-NEXT: vmovdqu (%rdi), %xmm0 -; AVX1-NEXT: vmovdqu (%rdx), %xmm1 -; AVX1-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1 -; AVX1-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 -; AVX1-NEXT: vmovdqu (%r8), %xmm2 -; AVX1-NEXT: vpmaddwd (%r9), %xmm2, %xmm2 -; AVX1-NEXT: vpaddd %xmm2, %xmm1, %xmm1 -; AVX1-NEXT: vmovdqu (%rax), %xmm2 -; AVX1-NEXT: vpmaddwd (%r10), %xmm2, %xmm2 -; AVX1-NEXT: vpaddd %xmm2, %xmm1, %xmm1 -; AVX1-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX1-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX1-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX1-NEXT: vmovd %xmm0, %eax -; AVX1-NEXT: retq -; -; AVX256-LABEL: madd_quad_reduction: -; AVX256: # %bb.0: -; AVX256-NEXT: movq {{[0-9]+}}(%rsp), %r10 -; AVX256-NEXT: movq {{[0-9]+}}(%rsp), %rax -; AVX256-NEXT: vmovdqu (%rdi), %xmm0 -; AVX256-NEXT: vmovdqu (%rdx), %xmm1 -; AVX256-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1 -; AVX256-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 -; AVX256-NEXT: vmovdqu (%r8), %xmm2 -; AVX256-NEXT: vpmaddwd (%r9), %xmm2, %xmm2 -; AVX256-NEXT: vpaddd %ymm2, %ymm1, %ymm1 -; AVX256-NEXT: vmovdqu (%rax), %xmm2 -; AVX256-NEXT: vpmaddwd (%r10), %xmm2, %xmm2 -; AVX256-NEXT: vpaddd %ymm2, %ymm1, %ymm1 -; AVX256-NEXT: vpaddd %ymm1, %ymm0, %ymm0 -; AVX256-NEXT: vextracti128 $1, %ymm0, %xmm1 -; AVX256-NEXT: vpor %xmm1, %xmm0, %xmm0 -; AVX256-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX256-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX256-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX256-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX256-NEXT: vmovd %xmm0, %eax -; AVX256-NEXT: vzeroupper -; AVX256-NEXT: retq +; AVX-LABEL: madd_quad_reduction: +; AVX: # %bb.0: +; AVX-NEXT: movq {{[0-9]+}}(%rsp), %r10 +; AVX-NEXT: movq {{[0-9]+}}(%rsp), %rax +; AVX-NEXT: vmovdqu (%rdi), %xmm0 +; AVX-NEXT: vmovdqu (%rdx), %xmm1 +; AVX-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1 +; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 +; AVX-NEXT: vmovdqu (%r8), %xmm2 +; AVX-NEXT: vpmaddwd (%r9), %xmm2, %xmm2 +; AVX-NEXT: vpaddd %xmm2, %xmm1, %xmm1 +; AVX-NEXT: vmovdqu (%rax), %xmm2 +; AVX-NEXT: vpmaddwd (%r10), %xmm2, %xmm2 +; AVX-NEXT: vpaddd %xmm2, %xmm1, %xmm1 +; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] +; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] +; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vmovd %xmm0, %eax +; AVX-NEXT: retq %tmp = load <8 x i16>, <8 x i16>* %arg, align 1 %tmp6 = load <8 x i16>, <8 x i16>* %arg1, align 1 %tmp7 = sext <8 x i16> %tmp to <8 x i32>