diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -5802,6 +5802,71 @@ return false; } +static std::pair splitVector(SDValue Op, SelectionDAG &DAG, + const SDLoc &dl) { + MVT VT = Op.getSimpleValueType(); + unsigned NumElems = VT.getVectorNumElements(); + unsigned SizeInBits = VT.getSizeInBits(); + + SDValue Lo = extractSubVector(Op, 0, DAG, dl, SizeInBits / 2); + SDValue Hi = extractSubVector(Op, NumElems / 2, DAG, dl, SizeInBits / 2); + + return std::make_pair(Lo, Hi); +} + +// Split an unary integer op into 2 half sized ops. +static SDValue splitVectorIntUnary(SDValue Op, SelectionDAG &DAG) { + EVT VT = Op.getValueType(); + + // Make sure we only try to split 256/512-bit types to avoid creating + // narrow vectors. + assert((Op.getOperand(0).getValueType().is256BitVector() || + Op.getOperand(0).getValueType().is512BitVector()) && + (VT.is256BitVector() || VT.is512BitVector()) && "Unsupported VT!"); + assert(Op.getOperand(0).getValueType().getVectorNumElements() == + VT.getVectorNumElements() && + "Unexpected VTs!"); + + SDLoc dl(Op); + + // Extract the Lo/Hi vectors + SDValue Lo, Hi; + std::tie(Lo, Hi) = splitVector(Op.getOperand(0), DAG, dl); + + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT); + return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, + DAG.getNode(Op.getOpcode(), dl, LoVT, Lo), + DAG.getNode(Op.getOpcode(), dl, HiVT, Hi)); +} + +/// Break a binary integer operation into 2 half sized ops and then +/// concatenate the result back. +static SDValue splitVectorIntBinary(SDValue Op, SelectionDAG &DAG) { + EVT VT = Op.getValueType(); + + // Sanity check that all the types match. + assert(Op.getOperand(0).getValueType() == VT && + Op.getOperand(1).getValueType() == VT && "Unexpected VTs!"); + assert((VT.is256BitVector() || VT.is512BitVector()) && "Unsupported VT!"); + + SDLoc dl(Op); + + // Extract the LHS Lo/Hi vectors + SDValue LHS1, LHS2; + std::tie(LHS1, LHS2) = splitVector(Op.getOperand(0), DAG, dl); + + // Extract the RHS Lo/Hi vectors + SDValue RHS1, RHS2; + std::tie(RHS1, RHS2) = splitVector(Op.getOperand(1), DAG, dl); + + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT); + return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, + DAG.getNode(Op.getOpcode(), dl, LoVT, LHS1, RHS1), + DAG.getNode(Op.getOpcode(), dl, HiVT, LHS2, RHS2)); +} + // Helper for splitting operands of an operation to legal target size and // apply a function on each part. // Useful for operations that are available on SSE2 in 128-bit, on AVX2 in @@ -21820,32 +21885,30 @@ /// Break a VSETCC 256-bit integer VSETCC into two new 128 ones and then /// concatenate the result back. -static SDValue Lower256IntVSETCC(SDValue Op, SelectionDAG &DAG) { - MVT VT = Op.getSimpleValueType(); +static SDValue splitIntVSETCC(SDValue Op, SelectionDAG &DAG) { + EVT VT = Op.getValueType(); - assert(VT.is256BitVector() && Op.getOpcode() == ISD::SETCC && - "Unsupported value type for operation"); + assert(Op.getOpcode() == ISD::SETCC && "Unsupported operation"); + assert(Op.getOperand(0).getValueType().isInteger() && + VT == Op.getOperand(0).getValueType() && "Unsupported VTs!"); - unsigned NumElems = VT.getVectorNumElements(); SDLoc dl(Op); SDValue CC = Op.getOperand(2); - // Extract the LHS vectors - SDValue LHS = Op.getOperand(0); - SDValue LHS1 = extract128BitVector(LHS, 0, DAG, dl); - SDValue LHS2 = extract128BitVector(LHS, NumElems / 2, DAG, dl); + // Extract the LHS Lo/Hi vectors + SDValue LHS1, LHS2; + std::tie(LHS1, LHS2) = splitVector(Op.getOperand(0), DAG, dl); - // Extract the RHS vectors - SDValue RHS = Op.getOperand(1); - SDValue RHS1 = extract128BitVector(RHS, 0, DAG, dl); - SDValue RHS2 = extract128BitVector(RHS, NumElems / 2, DAG, dl); + // Extract the RHS Lo/Hi vectors + SDValue RHS1, RHS2; + std::tie(RHS1, RHS2) = splitVector(Op.getOperand(1), DAG, dl); // Issue the operation on the smaller types and concatenate the result back - MVT EltVT = VT.getVectorElementType(); - MVT NewVT = MVT::getVectorVT(EltVT, NumElems/2); + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT); return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, - DAG.getNode(Op.getOpcode(), dl, NewVT, LHS1, RHS1, CC), - DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, RHS2, CC)); + DAG.getNode(ISD::SETCC, dl, LoVT, LHS1, RHS1, CC), + DAG.getNode(ISD::SETCC, dl, HiVT, LHS2, RHS2, CC)); } static SDValue LowerIntVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) { @@ -22187,7 +22250,7 @@ // Break 256-bit integer vector compare into smaller ones. if (VT.is256BitVector() && !Subtarget.hasInt256()) - return Lower256IntVSETCC(Op, DAG); + return splitIntVSETCC(Op, DAG); // If this is a SETNE against the signed minimum value, change it to SETGT. // If this is a SETNE against the signed maximum value, change it to SETLT. @@ -25922,43 +25985,6 @@ return DAG.getMergeValues({RetVal, Chain}, DL); } -// Split an unary integer op into 2 half sized ops. -static SDValue LowerVectorIntUnary(SDValue Op, SelectionDAG &DAG) { - MVT VT = Op.getSimpleValueType(); - unsigned NumElems = VT.getVectorNumElements(); - unsigned SizeInBits = VT.getSizeInBits(); - MVT EltVT = VT.getVectorElementType(); - SDValue Src = Op.getOperand(0); - assert(EltVT == Src.getSimpleValueType().getVectorElementType() && - "Src and Op should have the same element type!"); - - // Extract the Lo/Hi vectors - SDLoc dl(Op); - SDValue Lo = extractSubVector(Src, 0, DAG, dl, SizeInBits / 2); - SDValue Hi = extractSubVector(Src, NumElems / 2, DAG, dl, SizeInBits / 2); - - MVT NewVT = MVT::getVectorVT(EltVT, NumElems / 2); - return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, - DAG.getNode(Op.getOpcode(), dl, NewVT, Lo), - DAG.getNode(Op.getOpcode(), dl, NewVT, Hi)); -} - -// Decompose 256-bit ops into smaller 128-bit ops. -static SDValue Lower256IntUnary(SDValue Op, SelectionDAG &DAG) { - assert(Op.getSimpleValueType().is256BitVector() && - Op.getSimpleValueType().isInteger() && - "Only handle AVX 256-bit vector integer operation"); - return LowerVectorIntUnary(Op, DAG); -} - -// Decompose 512-bit ops into smaller 256-bit ops. -static SDValue Lower512IntUnary(SDValue Op, SelectionDAG &DAG) { - assert(Op.getSimpleValueType().is512BitVector() && - Op.getSimpleValueType().isInteger() && - "Only handle AVX 512-bit vector integer operation"); - return LowerVectorIntUnary(Op, DAG); -} - /// Lower a vector CTLZ using native supported vector CTLZ instruction. // // i8/i16 vector implemented using dword LZCNT vector instruction @@ -25979,7 +26005,7 @@ // Split vector, it's Lo and Hi parts will be handled in next iteration. if (NumElems > 16 || (NumElems == 16 && !Subtarget.canExtendTo512DQ())) - return LowerVectorIntUnary(Op, DAG); + return splitVectorIntUnary(Op, DAG); MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems); assert((NewVT.is256BitVector() || NewVT.is512BitVector()) && @@ -26089,11 +26115,11 @@ // Decompose 256-bit ops into smaller 128-bit ops. if (VT.is256BitVector() && !Subtarget.hasInt256()) - return Lower256IntUnary(Op, DAG); + return splitVectorIntUnary(Op, DAG); // Decompose 512-bit ops into smaller 256-bit ops. if (VT.is512BitVector() && !Subtarget.hasBWI()) - return Lower512IntUnary(Op, DAG); + return splitVectorIntUnary(Op, DAG); assert(Subtarget.hasSSSE3() && "Expected SSSE3 support for PSHUFB"); return LowerVectorCTLZInRegLUT(Op, DL, Subtarget, DAG); @@ -26159,48 +26185,6 @@ return DAG.getNode(X86ISD::CMOV, dl, VT, Ops); } -/// Break a binary integer operation into 2 half sized ops and then -/// concatenate the result back. -static SDValue splitVectorIntBinary(SDValue Op, SelectionDAG &DAG) { - MVT VT = Op.getSimpleValueType(); - unsigned NumElems = VT.getVectorNumElements(); - unsigned SizeInBits = VT.getSizeInBits(); - SDLoc dl(Op); - - // Extract the LHS Lo/Hi vectors - SDValue LHS = Op.getOperand(0); - SDValue LHS1 = extractSubVector(LHS, 0, DAG, dl, SizeInBits / 2); - SDValue LHS2 = extractSubVector(LHS, NumElems / 2, DAG, dl, SizeInBits / 2); - - // Extract the RHS Lo/Hi vectors - SDValue RHS = Op.getOperand(1); - SDValue RHS1 = extractSubVector(RHS, 0, DAG, dl, SizeInBits / 2); - SDValue RHS2 = extractSubVector(RHS, NumElems / 2, DAG, dl, SizeInBits / 2); - - MVT NewVT = MVT::getVectorVT(VT.getVectorElementType(), NumElems / 2); - return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, - DAG.getNode(Op.getOpcode(), dl, NewVT, LHS1, RHS1), - DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, RHS2)); -} - -/// Break a 256-bit integer operation into two new 128-bit ones and then -/// concatenate the result back. -static SDValue split256IntArith(SDValue Op, SelectionDAG &DAG) { - assert(Op.getSimpleValueType().is256BitVector() && - Op.getSimpleValueType().isInteger() && - "Unsupported value type for operation"); - return splitVectorIntBinary(Op, DAG); -} - -/// Break a 512-bit integer operation into two new 256-bit ones and then -/// concatenate the result back. -static SDValue split512IntArith(SDValue Op, SelectionDAG &DAG) { - assert(Op.getSimpleValueType().is512BitVector() && - Op.getSimpleValueType().isInteger() && - "Unsupported value type for operation"); - return splitVectorIntBinary(Op, DAG); -} - static SDValue lowerAddSub(SDValue Op, SelectionDAG &DAG, const X86Subtarget &Subtarget) { MVT VT = Op.getSimpleValueType(); @@ -26214,7 +26198,7 @@ assert(Op.getSimpleValueType().is256BitVector() && Op.getSimpleValueType().isInteger() && "Only handle AVX 256-bit vector integer operation"); - return split256IntArith(Op, DAG); + return splitVectorIntBinary(Op, DAG); } static SDValue LowerADDSAT_SUBSAT(SDValue Op, SelectionDAG &DAG, @@ -26262,7 +26246,7 @@ assert(Op.getSimpleValueType().is256BitVector() && Op.getSimpleValueType().isInteger() && "Only handle AVX 256-bit vector integer operation"); - return split256IntArith(Op, DAG); + return splitVectorIntBinary(Op, DAG); } static SDValue LowerABS(SDValue Op, const X86Subtarget &Subtarget, @@ -26292,7 +26276,7 @@ if (VT.is256BitVector() && !Subtarget.hasInt256()) { assert(VT.isInteger() && "Only handle AVX 256-bit vector integer operation"); - return Lower256IntUnary(Op, DAG); + return splitVectorIntUnary(Op, DAG); } // Default to expand. @@ -26304,7 +26288,7 @@ // For AVX1 cases, split to use legal ops (everything but v4i64). if (VT.getScalarType() != MVT::i64 && VT.is256BitVector()) - return split256IntArith(Op, DAG); + return splitVectorIntBinary(Op, DAG); SDLoc DL(Op); unsigned Opcode = Op.getOpcode(); @@ -26348,7 +26332,7 @@ // Decompose 256-bit ops into 128-bit ops. if (VT.is256BitVector() && !Subtarget.hasInt256()) - return split256IntArith(Op, DAG); + return splitVectorIntBinary(Op, DAG); SDValue A = Op.getOperand(0); SDValue B = Op.getOperand(1); @@ -26494,7 +26478,7 @@ // Decompose 256-bit ops into 128-bit ops. if (VT.is256BitVector() && !Subtarget.hasInt256()) - return split256IntArith(Op, DAG); + return splitVectorIntBinary(Op, DAG); if (VT == MVT::v4i32 || VT == MVT::v8i32 || VT == MVT::v16i32) { assert((VT == MVT::v4i32 && Subtarget.hasSSE2()) || @@ -26586,7 +26570,7 @@ // For signed 512-bit vectors, split into 256-bit vectors to allow the // sign-extension to occur. if (VT == MVT::v64i8 && IsSigned) - return split512IntArith(Op, DAG); + return splitVectorIntBinary(Op, DAG); // Signed AVX2 implementation - extend xmm subvectors to ymm. if (VT == MVT::v32i8 && IsSigned) { @@ -27560,7 +27544,7 @@ // Decompose 256-bit shifts into 128-bit shifts. if (VT.is256BitVector()) - return split256IntArith(Op, DAG); + return splitVectorIntBinary(Op, DAG); return SDValue(); } @@ -27606,7 +27590,7 @@ // XOP implicitly uses modulo rotation amounts. if (Subtarget.hasXOP()) { if (VT.is256BitVector()) - return split256IntArith(Op, DAG); + return splitVectorIntBinary(Op, DAG); assert(VT.is128BitVector() && "Only rotate 128-bit vectors!"); // Attempt to rotate by immediate. @@ -27622,7 +27606,7 @@ // Split 256-bit integers on pre-AVX2 targets. if (VT.is256BitVector() && !Subtarget.hasAVX2()) - return split256IntArith(Op, DAG); + return splitVectorIntBinary(Op, DAG); assert((VT == MVT::v4i32 || VT == MVT::v8i16 || VT == MVT::v16i8 || ((VT == MVT::v8i32 || VT == MVT::v16i16 || VT == MVT::v32i8) && @@ -28287,11 +28271,11 @@ // Decompose 256-bit ops into smaller 128-bit ops. if (VT.is256BitVector() && !Subtarget.hasInt256()) - return Lower256IntUnary(Op, DAG); + return splitVectorIntUnary(Op, DAG); // Decompose 512-bit ops into smaller 256-bit ops. if (VT.is512BitVector() && !Subtarget.hasBWI()) - return Lower512IntUnary(Op, DAG); + return splitVectorIntUnary(Op, DAG); // For element types greater than i8, do vXi8 pop counts and a bytesum. if (VT.getScalarType() != MVT::i8) { @@ -28335,7 +28319,7 @@ // Decompose 256-bit ops into smaller 128-bit ops. if (VT.is256BitVector()) - return Lower256IntUnary(Op, DAG); + return splitVectorIntUnary(Op, DAG); assert(VT.is128BitVector() && "Only 128-bit vector bitreverse lowering supported."); @@ -28376,7 +28360,7 @@ // lowering. if (VT == MVT::v8i64 || VT == MVT::v16i32) { assert(!Subtarget.hasBWI() && "BWI should Expand BITREVERSE"); - return Lower512IntUnary(Op, DAG); + return splitVectorIntUnary(Op, DAG); } unsigned NumElts = VT.getVectorNumElements(); @@ -28385,7 +28369,7 @@ // Decompose 256-bit ops into smaller 128-bit ops on pre-AVX2. if (VT.is256BitVector() && !Subtarget.hasInt256()) - return Lower256IntUnary(Op, DAG); + return splitVectorIntUnary(Op, DAG); // Perform BITREVERSE using PSHUFB lookups. Each byte is split into // two nibbles and a PSHUFB lookup to find the bitreverse of each @@ -47137,7 +47121,7 @@ if (isConcatenatedNot(InVecBC.getOperand(0)) || isConcatenatedNot(InVecBC.getOperand(1))) { // extract (and v4i64 X, (not (concat Y1, Y2))), n -> andnp v2i64 X(n), Y1 - SDValue Concat = split256IntArith(InVecBC, DAG); + SDValue Concat = splitVectorIntBinary(InVecBC, DAG); return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, DAG.getBitcast(InVecVT, Concat), N->getOperand(1)); }