Changeset View
Changeset View
Standalone View
Standalone View
lib/Target/X86/X86ISelLowering.cpp
- This file is larger than 256 KB, so syntax highlighting is disabled by default.
Show First 20 Lines • Show All 34,255 Lines • ▼ Show 20 Lines | static SDValue combineHorizontalPredicateResult(SDNode *Extract, | ||||
const X86Subtarget &Subtarget) { | const X86Subtarget &Subtarget) { | ||||
// Bail without SSE2 or with AVX512VL (which uses predicate registers). | // Bail without SSE2 or with AVX512VL (which uses predicate registers). | ||||
if (!Subtarget.hasSSE2() || Subtarget.hasVLX()) | if (!Subtarget.hasSSE2() || Subtarget.hasVLX()) | ||||
return SDValue(); | return SDValue(); | ||||
EVT ExtractVT = Extract->getValueType(0); | EVT ExtractVT = Extract->getValueType(0); | ||||
unsigned BitWidth = ExtractVT.getSizeInBits(); | unsigned BitWidth = ExtractVT.getSizeInBits(); | ||||
if (ExtractVT != MVT::i64 && ExtractVT != MVT::i32 && ExtractVT != MVT::i16 && | if (ExtractVT != MVT::i64 && ExtractVT != MVT::i32 && ExtractVT != MVT::i16 && | ||||
ExtractVT != MVT::i8) | ExtractVT != MVT::i8 && ExtractVT != MVT::i1) | ||||
return SDValue(); | return SDValue(); | ||||
// Check for OR(any_of) and AND(all_of) horizontal reduction patterns. | // Check for OR(any_of) and AND(all_of) horizontal reduction patterns. | ||||
ISD::NodeType BinOp; | ISD::NodeType BinOp; | ||||
SDValue Match = DAG.matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND}); | SDValue Match = DAG.matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND}); | ||||
if (!Match) | if (!Match) | ||||
return SDValue(); | return SDValue(); | ||||
// EXTRACT_VECTOR_ELT can require implicit extension of the vector element | // EXTRACT_VECTOR_ELT can require implicit extension of the vector element | ||||
// which we can't support here for now. | // which we can't support here for now. | ||||
if (Match.getScalarValueSizeInBits() != BitWidth) | if (Match.getScalarValueSizeInBits() != BitWidth) | ||||
return SDValue(); | return SDValue(); | ||||
SDValue Movmsk; | |||||
SDLoc DL(Extract); | |||||
unsigned NumElts = Match.getValueType().getVectorNumElements(); | |||||
if (ExtractVT == MVT::i1) { | |||||
// Special case for (pre-legalization) vXi1 reductions. | |||||
// Use combineBitcastvxi1 to create the MOVMSK. | |||||
if (NumElts > 32) | |||||
return SDValue(); | |||||
if (NumElts == 32 && !Subtarget.hasInt256()) { | |||||
SDValue Lo, Hi; | |||||
std::tie(Lo, Hi) = DAG.SplitVector(Match, DL); | |||||
Match = DAG.getNode(BinOp, DL, Lo.getValueType(), Lo, Hi); | |||||
NumElts = 16; | |||||
} | |||||
EVT MovmskVT = EVT::getIntegerVT(*DAG.getContext(), NumElts); | |||||
Movmsk = combineBitcastvxi1(DAG, MovmskVT, Match, DL, Subtarget); | |||||
if (!Movmsk) | |||||
return SDValue(); | |||||
Movmsk = DAG.getZExtOrTrunc(Movmsk, DL, MVT::i32); | |||||
} else { | |||||
unsigned MatchSizeInBits = Match.getValueSizeInBits(); | unsigned MatchSizeInBits = Match.getValueSizeInBits(); | ||||
if (!(MatchSizeInBits == 128 || (MatchSizeInBits == 256 && Subtarget.hasAVX()))) | if (!(MatchSizeInBits == 128 || | ||||
(MatchSizeInBits == 256 && Subtarget.hasAVX()))) | |||||
return SDValue(); | return SDValue(); | ||||
// Make sure this isn't a vector of 1 element. The perf win from using MOVMSK | // Make sure this isn't a vector of 1 element. The perf win from using | ||||
// diminishes with less elements in the reduction, but it is generally better | // MOVMSK diminishes with less elements in the reduction, but it is | ||||
// to get the comparison over to the GPRs as soon as possible to reduce the | // generally better to get the comparison over to the GPRs as soon as | ||||
// number of vector ops. | // possible to reduce the number of vector ops. | ||||
if (Match.getValueType().getVectorNumElements() < 2) | if (Match.getValueType().getVectorNumElements() < 2) | ||||
return SDValue(); | return SDValue(); | ||||
// Check that we are extracting a reduction of all sign bits. | // Check that we are extracting a reduction of all sign bits. | ||||
if (DAG.ComputeNumSignBits(Match) != BitWidth) | if (DAG.ComputeNumSignBits(Match) != BitWidth) | ||||
return SDValue(); | return SDValue(); | ||||
SDLoc DL(Extract); | |||||
if (MatchSizeInBits == 256 && BitWidth < 32 && !Subtarget.hasInt256()) { | if (MatchSizeInBits == 256 && BitWidth < 32 && !Subtarget.hasInt256()) { | ||||
SDValue Lo, Hi; | SDValue Lo, Hi; | ||||
std::tie(Lo, Hi) = DAG.SplitVector(Match, DL); | std::tie(Lo, Hi) = DAG.SplitVector(Match, DL); | ||||
Match = DAG.getNode(BinOp, DL, Lo.getValueType(), Lo, Hi); | Match = DAG.getNode(BinOp, DL, Lo.getValueType(), Lo, Hi); | ||||
MatchSizeInBits = Match.getValueSizeInBits(); | MatchSizeInBits = Match.getValueSizeInBits(); | ||||
} | } | ||||
// For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB. | // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB. | ||||
MVT MaskSrcVT; | MVT MaskSrcVT; | ||||
if (64 == BitWidth || 32 == BitWidth) | if (64 == BitWidth || 32 == BitWidth) | ||||
MaskSrcVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth), | MaskSrcVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth), | ||||
MatchSizeInBits / BitWidth); | MatchSizeInBits / BitWidth); | ||||
else | else | ||||
MaskSrcVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8); | MaskSrcVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8); | ||||
SDValue BitcastLogicOp = DAG.getBitcast(MaskSrcVT, Match); | |||||
Movmsk = getPMOVMSKB(DL, BitcastLogicOp, DAG, Subtarget); | |||||
NumElts = MaskSrcVT.getVectorNumElements(); | |||||
} | |||||
assert(NumElts <= 32 && "Not expecting more than 32 elements"); | |||||
SDValue CmpC; | SDValue CmpC; | ||||
ISD::CondCode CondCode; | ISD::CondCode CondCode; | ||||
if (BinOp == ISD::OR) { | if (BinOp == ISD::OR) { | ||||
// any_of -> MOVMSK != 0 | // any_of -> MOVMSK != 0 | ||||
CmpC = DAG.getConstant(0, DL, MVT::i32); | CmpC = DAG.getConstant(0, DL, MVT::i32); | ||||
CondCode = ISD::CondCode::SETNE; | CondCode = ISD::CondCode::SETNE; | ||||
} else { | } else { | ||||
// all_of -> MOVMSK == ((1 << NumElts) - 1) | // all_of -> MOVMSK == ((1 << NumElts) - 1) | ||||
uint64_t NumElts = MaskSrcVT.getVectorNumElements(); | |||||
assert(NumElts <= 32 && "Not expecting more than 32 elements"); | |||||
CmpC = DAG.getConstant((1ULL << NumElts) - 1, DL, MVT::i32); | CmpC = DAG.getConstant((1ULL << NumElts) - 1, DL, MVT::i32); | ||||
CondCode = ISD::CondCode::SETEQ; | CondCode = ISD::CondCode::SETEQ; | ||||
} | } | ||||
// The setcc produces an i8 of 0/1, so extend that to the result width and | // The setcc produces an i8 of 0/1, so extend that to the result width and | ||||
// negate to get the final 0/-1 mask value. | // negate to get the final 0/-1 mask value. | ||||
const TargetLowering &TLI = DAG.getTargetLoweringInfo(); | const TargetLowering &TLI = DAG.getTargetLoweringInfo(); | ||||
SDValue BitcastLogicOp = DAG.getBitcast(MaskSrcVT, Match); | EVT SetccVT = | ||||
SDValue Movmsk = getPMOVMSKB(DL, BitcastLogicOp, DAG, Subtarget); | TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i32); | ||||
EVT SetccVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), | |||||
MVT::i32); | |||||
SDValue Setcc = DAG.getSetCC(DL, SetccVT, Movmsk, CmpC, CondCode); | SDValue Setcc = DAG.getSetCC(DL, SetccVT, Movmsk, CmpC, CondCode); | ||||
SDValue Zext = DAG.getZExtOrTrunc(Setcc, DL, ExtractVT); | SDValue Zext = DAG.getZExtOrTrunc(Setcc, DL, ExtractVT); | ||||
SDValue Zero = DAG.getConstant(0, DL, ExtractVT); | SDValue Zero = DAG.getConstant(0, DL, ExtractVT); | ||||
return DAG.getNode(ISD::SUB, DL, ExtractVT, Zero, Zext); | return DAG.getNode(ISD::SUB, DL, ExtractVT, Zero, Zext); | ||||
} | } | ||||
static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, | static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, | ||||
const X86Subtarget &Subtarget) { | const X86Subtarget &Subtarget) { | ||||
▲ Show 20 Lines • Show All 9,667 Lines • Show Last 20 Lines |