Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -30093,7 +30093,8 @@ // the elements of a vector. // Returns the vector that is being reduced on, or SDValue() if a reduction // was not matched. -static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) { +static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp, + ArrayRef CandidateBinOps) { // The pattern must end in an extract from index 0. if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) || !isNullConstant(Extract->getOperand(1))) @@ -30113,8 +30114,16 @@ // <4,5,6,7,u,u,u,u> // <2,3,u,u,u,u,u,u> // <1,u,u,u,u,u,u,u> + unsigned CandidateBinOp = 0; for (unsigned i = 0; i < Stages; ++i) { - if (Op.getOpcode() != BinOp) + // Match against one of the candidate binary ops. + if (i == 0) { + if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) { + return Op.getOpcode() == BinOp; + })) + return SDValue(); + CandidateBinOp = Op.getOpcode(); + } else if (Op.getOpcode() != CandidateBinOp) return SDValue(); ShuffleVectorSDNode *Shuffle = @@ -30137,6 +30146,7 @@ return SDValue(); } + BinOp = CandidateBinOp; return Op; } @@ -30250,66 +30260,63 @@ return SDValue(); // Check for OR(any_of) and AND(all_of) horizontal reduction patterns. - for (ISD::NodeType Op : {ISD::OR, ISD::AND}) { - SDValue Match = matchBinOpReduction(Extract, Op); - if (!Match) - continue; + unsigned BinOp = 0; + SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND}); + if (!Match) + return SDValue(); - // EXTRACT_VECTOR_ELT can require implicit extension of the vector element - // which we can't support here for now. - if (Match.getScalarValueSizeInBits() != BitWidth) - continue; + // EXTRACT_VECTOR_ELT can require implicit extension of the vector element + // which we can't support here for now. + if (Match.getScalarValueSizeInBits() != BitWidth) + return SDValue(); - // We require AVX2 for PMOVMSKB for v16i16/v32i8; - unsigned MatchSizeInBits = Match.getValueSizeInBits(); - if (!(MatchSizeInBits == 128 || - (MatchSizeInBits == 256 && - ((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2())))) - return SDValue(); + // We require AVX2 for PMOVMSKB for v16i16/v32i8; + unsigned MatchSizeInBits = Match.getValueSizeInBits(); + if (!(MatchSizeInBits == 128 || + (MatchSizeInBits == 256 && + ((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2())))) + return SDValue(); - // Don't bother performing this for 2-element vectors. - if (Match.getValueType().getVectorNumElements() <= 2) - return SDValue(); + // Don't bother performing this for 2-element vectors. + if (Match.getValueType().getVectorNumElements() <= 2) + return SDValue(); - // Check that we are extracting a reduction of all sign bits. - if (DAG.ComputeNumSignBits(Match) != BitWidth) - return SDValue(); + // Check that we are extracting a reduction of all sign bits. + if (DAG.ComputeNumSignBits(Match) != BitWidth) + return SDValue(); - // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB. - MVT MaskVT; - if (64 == BitWidth || 32 == BitWidth) - MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth), - MatchSizeInBits / BitWidth); - else - MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8); + // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB. + MVT MaskVT; + if (64 == BitWidth || 32 == BitWidth) + MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth), + MatchSizeInBits / BitWidth); + else + MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8); - APInt CompareBits; - ISD::CondCode CondCode; - if (Op == ISD::OR) { - // any_of -> MOVMSK != 0 - CompareBits = APInt::getNullValue(32); - CondCode = ISD::CondCode::SETNE; - } else { - // all_of -> MOVMSK == ((1 << NumElts) - 1) - CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements()); - CondCode = ISD::CondCode::SETEQ; - } - - // Perform the select as i32/i64 and then truncate to avoid partial register - // stalls. - unsigned ResWidth = std::max(BitWidth, 32u); - EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth); - SDLoc DL(Extract); - SDValue Zero = DAG.getConstant(0, DL, ResVT); - SDValue Ones = DAG.getAllOnesConstant(DL, ResVT); - SDValue Res = DAG.getBitcast(MaskVT, Match); - Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res); - Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32), - Ones, Zero, CondCode); - return DAG.getSExtOrTrunc(Res, DL, ExtractVT); + APInt CompareBits; + ISD::CondCode CondCode; + if (BinOp == ISD::OR) { + // any_of -> MOVMSK != 0 + CompareBits = APInt::getNullValue(32); + CondCode = ISD::CondCode::SETNE; + } else { + // all_of -> MOVMSK == ((1 << NumElts) - 1) + CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements()); + CondCode = ISD::CondCode::SETEQ; } - return SDValue(); + // Perform the select as i32/i64 and then truncate to avoid partial register + // stalls. + unsigned ResWidth = std::max(BitWidth, 32u); + EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth); + SDLoc DL(Extract); + SDValue Zero = DAG.getConstant(0, DL, ResVT); + SDValue Ones = DAG.getAllOnesConstant(DL, ResVT); + SDValue Res = DAG.getBitcast(MaskVT, Match); + Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res); + Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32), + Ones, Zero, CondCode); + return DAG.getSExtOrTrunc(Res, DL, ExtractVT); } static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, @@ -30336,7 +30343,8 @@ return SDValue(); // Match shuffle + add pyramid. - SDValue Root = matchBinOpReduction(Extract, ISD::ADD); + unsigned BinOp = 0; + SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD}); // The operand is expected to be zero extended from i8 // (verified in detectZextAbsDiff).