Index: include/llvm/CodeGen/SelectionDAG.h =================================================================== --- include/llvm/CodeGen/SelectionDAG.h +++ include/llvm/CodeGen/SelectionDAG.h @@ -1588,9 +1588,12 @@ /// Extract. The reduction must use one of the opcodes listed in /p /// CandidateBinOps and on success /p BinOp will contain the matching opcode. /// Returns the vector that is being reduced on, or SDValue() if a reduction - /// was not matched. + /// was not matched. If \p AllowPartials is set then in the case of a + /// reduction pattern that only matches the first few stages, the extracted + /// subvector of the start of the reduction is returned. SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, - ArrayRef CandidateBinOps); + ArrayRef CandidateBinOps, + bool AllowPartials = false); /// Utility function used by legalize and lowering to /// "unroll" a vector operation by splitting out the scalars and operating Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -9005,7 +9005,8 @@ SDValue SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, - ArrayRef CandidateBinOps) { + ArrayRef CandidateBinOps, + bool AllowPartials) { // The pattern must end in an extract from index 0. if (Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT || !isNullConstant(Extract->getOperand(1))) @@ -9013,12 +9014,30 @@ SDValue Op = Extract->getOperand(0); unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements()); + assert(Stages < 31 && "Too many reduction stages"); // Match against one of the candidate binary ops. if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) { return Op.getOpcode() == unsigned(BinOp); })) return SDValue(); + unsigned CandidateBinOp = Op.getOpcode(); + + // Matching failed - attempt to see if we did enough stages that a partial + // reduction from a subvector is possible. + auto PartialReduction = [&](SDValue Op, unsigned NumSubElts) { + if (!AllowPartials || !Op) + return SDValue(); + EVT OpVT = Op.getValueType(); + EVT OpSVT = OpVT.getScalarType(); + EVT SubVT = EVT::getVectorVT(*getContext(), OpSVT, NumSubElts); + if (!TLI->isExtractSubvectorCheap(SubVT, OpVT, 0)) + return SDValue(); + BinOp = (ISD::NodeType)CandidateBinOp; + return getNode( + ISD::EXTRACT_SUBVECTOR, SDLoc(Op), SubVT, Op, + getConstant(0, SDLoc(Op), TLI->getVectorIdxTy(getDataLayout()))); + }; // At each stage, we're looking for something that looks like: // %s = shufflevector <8 x i32> %op, <8 x i32> undef, @@ -9030,10 +9049,15 @@ // <4,5,6,7,u,u,u,u> // <2,3,u,u,u,u,u,u> // <1,u,u,u,u,u,u,u> - unsigned CandidateBinOp = Op.getOpcode(); + // While a partial reduction match would be: + // <2,3,u,u,u,u,u,u> + // <1,u,u,u,u,u,u,u> + SDValue PrevOp; for (unsigned i = 0; i < Stages; ++i) { + unsigned MaskEnd = (1 << i); + if (Op.getOpcode() != CandidateBinOp) - return SDValue(); + return PartialReduction(PrevOp, MaskEnd); SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); @@ -9049,12 +9073,14 @@ // The first operand of the shuffle should be the same as the other operand // of the binop. if (!Shuffle || Shuffle->getOperand(0) != Op) - return SDValue(); + return PartialReduction(PrevOp, MaskEnd); // Verify the shuffle has the expected (at this stage of the pyramid) mask. - for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index) - if (Shuffle->getMaskElt(Index) != MaskEnd + Index) - return SDValue(); + for (int Index = 0; Index < (int)MaskEnd; ++Index) + if (Shuffle->getMaskElt(Index) != (MaskEnd + Index)) + return PartialReduction(PrevOp, MaskEnd); + + PrevOp = Op; } BinOp = (ISD::NodeType)CandidateBinOp; Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -35634,7 +35634,7 @@ // TODO: Allow FADD with reduction and/or reassociation and no-signed-zeros. ISD::NodeType Opc; - SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD}); + SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD}, true); if (!Rdx) return SDValue(); @@ -35643,7 +35643,7 @@ "Reduction doesn't end in an extract from index 0"); EVT VT = ExtElt->getValueType(0); - EVT VecVT = ExtElt->getOperand(0).getValueType(); + EVT VecVT = Rdx.getValueType(); if (VecVT.getScalarType() != VT) return SDValue(); @@ -35657,14 +35657,14 @@ // vXi8 reduction - sum lo/hi halves then use PSADBW. if (VT == MVT::i8) { while (Rdx.getValueSizeInBits() > 128) { - EVT RdxVT = Rdx.getValueType(); - unsigned HalfSize = RdxVT.getSizeInBits() / 2; - unsigned HalfElts = RdxVT.getVectorNumElements() / 2; + unsigned HalfSize = VecVT.getSizeInBits() / 2; + unsigned HalfElts = VecVT.getVectorNumElements() / 2; SDValue Lo = extractSubVector(Rdx, 0, DAG, DL, HalfSize); SDValue Hi = extractSubVector(Rdx, HalfElts, DAG, DL, HalfSize); Rdx = DAG.getNode(ISD::ADD, DL, Lo.getValueType(), Lo, Hi); + VecVT = Rdx.getValueType(); } - assert(Rdx.getValueType() == MVT::v16i8 && "v16i8 reduction expected"); + assert(VecVT == MVT::v16i8 && "v16i8 reduction expected"); SDValue Hi = DAG.getVectorShuffle( MVT::v16i8, DL, Rdx, Rdx, @@ -35692,15 +35692,14 @@ unsigned NumElts = VecVT.getVectorNumElements(); SDValue Hi = extract128BitVector(Rdx, NumElts / 2, DAG, DL); SDValue Lo = extract128BitVector(Rdx, 0, DAG, DL); - VecVT = EVT::getVectorVT(*DAG.getContext(), VT, NumElts / 2); - Rdx = DAG.getNode(HorizOpcode, DL, VecVT, Hi, Lo); + Rdx = DAG.getNode(HorizOpcode, DL, Lo.getValueType(), Hi, Lo); + VecVT = Rdx.getValueType(); } if (!((VecVT == MVT::v8i16 || VecVT == MVT::v4i32) && Subtarget.hasSSSE3()) && !((VecVT == MVT::v4f32 || VecVT == MVT::v2f64) && Subtarget.hasSSE3())) return SDValue(); // extract (add (shuf X), X), 0 --> extract (hadd X, X), 0 - assert(Rdx.getValueType() == VecVT && "Unexpected reduction match"); unsigned ReductionSteps = Log2_32(VecVT.getVectorNumElements()); for (unsigned i = 0; i != ReductionSteps; ++i) Rdx = DAG.getNode(HorizOpcode, DL, VecVT, Rdx, Rdx); Index: test/CodeGen/X86/phaddsub-extract.ll =================================================================== --- test/CodeGen/X86/phaddsub-extract.ll +++ test/CodeGen/X86/phaddsub-extract.ll @@ -1699,8 +1699,7 @@ ; ; AVX-FAST-LABEL: partial_reduction_add_v8i32: ; AVX-FAST: # %bb.0: -; AVX-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 ; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 ; AVX-FAST-NEXT: vmovd %xmm0, %eax ; AVX-FAST-NEXT: vzeroupper @@ -1741,34 +1740,13 @@ ; AVX-SLOW-NEXT: vzeroupper ; AVX-SLOW-NEXT: retq ; -; AVX1-FAST-LABEL: partial_reduction_add_v16i32: -; AVX1-FAST: # %bb.0: -; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX1-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX1-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 -; AVX1-FAST-NEXT: vmovd %xmm0, %eax -; AVX1-FAST-NEXT: vzeroupper -; AVX1-FAST-NEXT: retq -; -; AVX2-FAST-LABEL: partial_reduction_add_v16i32: -; AVX2-FAST: # %bb.0: -; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-FAST-NEXT: vmovd %xmm0, %eax -; AVX2-FAST-NEXT: vzeroupper -; AVX2-FAST-NEXT: retq -; -; AVX512-FAST-LABEL: partial_reduction_add_v16i32: -; AVX512-FAST: # %bb.0: -; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-FAST-NEXT: vmovd %xmm0, %eax -; AVX512-FAST-NEXT: vzeroupper -; AVX512-FAST-NEXT: retq +; AVX-FAST-LABEL: partial_reduction_add_v16i32: +; AVX-FAST: # %bb.0: +; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX-FAST-NEXT: vmovd %xmm0, %eax +; AVX-FAST-NEXT: vzeroupper +; AVX-FAST-NEXT: retq %x23 = shufflevector <16 x i32> %x, <16 x i32> undef, <16 x i32> %x0213 = add <16 x i32> %x, %x23 %x13 = shufflevector <16 x i32> %x0213, <16 x i32> undef, <16 x i32> @@ -2010,8 +1988,7 @@ ; ; AVX-FAST-LABEL: hadd32_8: ; AVX-FAST: # %bb.0: -; AVX-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 ; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 ; AVX-FAST-NEXT: vmovd %xmm0, %eax ; AVX-FAST-NEXT: vzeroupper @@ -2052,34 +2029,13 @@ ; AVX-SLOW-NEXT: vzeroupper ; AVX-SLOW-NEXT: retq ; -; AVX1-FAST-LABEL: hadd32_16: -; AVX1-FAST: # %bb.0: -; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX1-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX1-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 -; AVX1-FAST-NEXT: vmovd %xmm0, %eax -; AVX1-FAST-NEXT: vzeroupper -; AVX1-FAST-NEXT: retq -; -; AVX2-FAST-LABEL: hadd32_16: -; AVX2-FAST: # %bb.0: -; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-FAST-NEXT: vmovd %xmm0, %eax -; AVX2-FAST-NEXT: vzeroupper -; AVX2-FAST-NEXT: retq -; -; AVX512-FAST-LABEL: hadd32_16: -; AVX512-FAST: # %bb.0: -; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-FAST-NEXT: vmovd %xmm0, %eax -; AVX512-FAST-NEXT: vzeroupper -; AVX512-FAST-NEXT: retq +; AVX-FAST-LABEL: hadd32_16: +; AVX-FAST: # %bb.0: +; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX-FAST-NEXT: vmovd %xmm0, %eax +; AVX-FAST-NEXT: vzeroupper +; AVX-FAST-NEXT: retq %x226 = shufflevector <16 x i32> %x225, <16 x i32> undef, <16 x i32> %x227 = add <16 x i32> %x225, %x226 %x228 = shufflevector <16 x i32> %x227, <16 x i32> undef, <16 x i32> @@ -2149,8 +2105,7 @@ ; ; AVX-LABEL: hadd32_8_optsize: ; AVX: # %bb.0: -; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0 ; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0 ; AVX-NEXT: vmovd %xmm0, %eax ; AVX-NEXT: vzeroupper @@ -2172,63 +2127,13 @@ ; SSE3-NEXT: movd %xmm1, %eax ; SSE3-NEXT: retq ; -; AVX1-SLOW-LABEL: hadd32_16_optsize: -; AVX1-SLOW: # %bb.0: -; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX1-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX1-SLOW-NEXT: vphaddd %xmm0, %xmm0, %xmm0 -; AVX1-SLOW-NEXT: vmovd %xmm0, %eax -; AVX1-SLOW-NEXT: vzeroupper -; AVX1-SLOW-NEXT: retq -; -; AVX1-FAST-LABEL: hadd32_16_optsize: -; AVX1-FAST: # %bb.0: -; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX1-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX1-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 -; AVX1-FAST-NEXT: vmovd %xmm0, %eax -; AVX1-FAST-NEXT: vzeroupper -; AVX1-FAST-NEXT: retq -; -; AVX2-SLOW-LABEL: hadd32_16_optsize: -; AVX2-SLOW: # %bb.0: -; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX2-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX2-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-SLOW-NEXT: vmovd %xmm0, %eax -; AVX2-SLOW-NEXT: vzeroupper -; AVX2-SLOW-NEXT: retq -; -; AVX2-FAST-LABEL: hadd32_16_optsize: -; AVX2-FAST: # %bb.0: -; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-FAST-NEXT: vmovd %xmm0, %eax -; AVX2-FAST-NEXT: vzeroupper -; AVX2-FAST-NEXT: retq -; -; AVX512-SLOW-LABEL: hadd32_16_optsize: -; AVX512-SLOW: # %bb.0: -; AVX512-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX512-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX512-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-SLOW-NEXT: vmovd %xmm0, %eax -; AVX512-SLOW-NEXT: vzeroupper -; AVX512-SLOW-NEXT: retq -; -; AVX512-FAST-LABEL: hadd32_16_optsize: -; AVX512-FAST: # %bb.0: -; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-FAST-NEXT: vmovd %xmm0, %eax -; AVX512-FAST-NEXT: vzeroupper -; AVX512-FAST-NEXT: retq +; AVX-LABEL: hadd32_16_optsize: +; AVX: # %bb.0: +; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX-NEXT: vmovd %xmm0, %eax +; AVX-NEXT: vzeroupper +; AVX-NEXT: retq %x226 = shufflevector <16 x i32> %x225, <16 x i32> undef, <16 x i32> %x227 = add <16 x i32> %x225, %x226 %x228 = shufflevector <16 x i32> %x227, <16 x i32> undef, <16 x i32>