Index: include/llvm/CodeGen/SelectionDAG.h =================================================================== --- include/llvm/CodeGen/SelectionDAG.h +++ include/llvm/CodeGen/SelectionDAG.h @@ -1503,6 +1503,15 @@ /// allow an 'add' to be transformed into an 'or'. bool haveNoCommonBitsSet(SDValue A, SDValue B) const; + /// Match a binop + shuffle pyramid that represents a horizontal reduction + /// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p + /// Extract. The reduction must use one of the opcodes listed in /p + /// CandidateBinOps and on success /p BinOp will contain the matching opcode. + /// Returns the vector that is being reduced on, or SDValue() if a reduction + /// was not matched. + SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, + ArrayRef CandidateBinOps); + /// Utility function used by legalize and lowering to /// "unroll" a vector operation by splitting out the scalars and operating /// on each element individually. If the ResNE is 0, fully unroll the vector Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -8318,6 +8318,64 @@ this->Flags.intersectWith(Flags); } +SDValue +SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, + ArrayRef CandidateBinOps) { + // The pattern must end in an extract from index 0. + if (Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT || + !isNullConstant(Extract->getOperand(1))) + return SDValue(); + + SDValue Op = Extract->getOperand(0); + unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements()); + + // Match against one of the candidate binary ops. + if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) { + return Op.getOpcode() == unsigned(BinOp); + })) + return SDValue(); + + // At each stage, we're looking for something that looks like: + // %s = shufflevector <8 x i32> %op, <8 x i32> undef, + // <8 x i32> + // %a = binop <8 x i32> %op, %s + // Where the mask changes according to the stage. E.g. for a 3-stage pyramid, + // we expect something like: + // <4,5,6,7,u,u,u,u> + // <2,3,u,u,u,u,u,u> + // <1,u,u,u,u,u,u,u> + unsigned CandidateBinOp = Op.getOpcode(); + for (unsigned i = 0; i < Stages; ++i) { + if (Op.getOpcode() != CandidateBinOp) + return SDValue(); + + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + + ShuffleVectorSDNode *Shuffle = dyn_cast(Op0); + if (Shuffle) { + Op = Op1; + } else { + Shuffle = dyn_cast(Op1); + Op = Op0; + } + + // The first operand of the shuffle should be the same as the other operand + // of the binop. + if (!Shuffle || Shuffle->getOperand(0) != Op) + return SDValue(); + + // Verify the shuffle has the expected (at this stage of the pyramid) mask. + for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index) + if (Shuffle->getMaskElt(Index) != MaskEnd + Index) + return SDValue(); + } + + BinOp = (ISD::NodeType)CandidateBinOp; + return Op; +} + SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) { assert(N->getNumValues() == 1 && "Can't unroll a vector with multiple results!"); Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -31807,65 +31807,6 @@ return SDValue(); } -// Match a binop + shuffle pyramid that represents a horizontal reduction over -// the elements of a vector. -// Returns the vector that is being reduced on, or SDValue() if a reduction -// was not matched. -static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp, - ArrayRef CandidateBinOps) { - // The pattern must end in an extract from index 0. - if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) || - !isNullConstant(Extract->getOperand(1))) - return SDValue(); - - SDValue Op = Extract->getOperand(0); - unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements()); - - // Match against one of the candidate binary ops. - if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) { - return Op.getOpcode() == unsigned(BinOp); - })) - return SDValue(); - - // At each stage, we're looking for something that looks like: - // %s = shufflevector <8 x i32> %op, <8 x i32> undef, - // <8 x i32> - // %a = binop <8 x i32> %op, %s - // Where the mask changes according to the stage. E.g. for a 3-stage pyramid, - // we expect something like: - // <4,5,6,7,u,u,u,u> - // <2,3,u,u,u,u,u,u> - // <1,u,u,u,u,u,u,u> - unsigned CandidateBinOp = Op.getOpcode(); - for (unsigned i = 0; i < Stages; ++i) { - if (Op.getOpcode() != CandidateBinOp) - return SDValue(); - - ShuffleVectorSDNode *Shuffle = - dyn_cast(Op.getOperand(0).getNode()); - if (Shuffle) { - Op = Op.getOperand(1); - } else { - Shuffle = dyn_cast(Op.getOperand(1).getNode()); - Op = Op.getOperand(0); - } - - // The first operand of the shuffle should be the same as the other operand - // of the binop. - if (!Shuffle || Shuffle->getOperand(0) != Op) - return SDValue(); - - // Verify the shuffle has the expected (at this stage of the pyramid) mask. - for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index) - if (Shuffle->getMaskElt(Index) != MaskEnd + Index) - return SDValue(); - } - - BinOp = CandidateBinOp; - return Op; -} - // Given a select, detect the following pattern: // 1: %2 = zext %0 to // 2: %3 = zext %1 to @@ -31980,8 +31921,8 @@ return SDValue(); // Check for SMAX/SMIN/UMAX/UMIN horizontal reduction patterns. - unsigned BinOp; - SDValue Src = matchBinOpReduction( + ISD::NodeType BinOp; + SDValue Src = DAG.matchBinOpReduction( Extract, BinOp, {ISD::SMAX, ISD::SMIN, ISD::UMAX, ISD::UMIN}); if (!Src) return SDValue(); @@ -32060,8 +32001,8 @@ return SDValue(); // Check for OR(any_of) and AND(all_of) horizontal reduction patterns. - unsigned BinOp = 0; - SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND}); + ISD::NodeType BinOp; + SDValue Match = DAG.matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND}); if (!Match) return SDValue(); @@ -32143,8 +32084,8 @@ return SDValue(); // Match shuffle + add pyramid. - unsigned BinOp = 0; - SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD}); + ISD::NodeType BinOp; + SDValue Root = DAG.matchBinOpReduction(Extract, BinOp, {ISD::ADD}); // The operand is expected to be zero extended from i8 // (verified in detectZextAbsDiff).