Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -7851,13 +7851,14 @@ return DstVec; } -/// Return true if \p N implements a horizontal binop and return the -/// operands for the horizontal binop into V0 and V1. -/// /// This is a helper function of LowerToHorizontalOp(). /// This function checks that the build_vector \p N in input implements a -/// horizontal operation. Parameter \p Opcode defines the kind of horizontal -/// operation to match. +/// 128-bit partial horizontal operation on a 256-bit vector, but that operation +/// may not match the layout of an x86 256-bit horizontal instruction. +/// In other words, if this returns true, then some extraction/insertion will +/// be required to produce a valid horizontal instruction. +/// +/// Parameter \p Opcode defines the kind of horizontal operation to match. /// For example, if \p Opcode is equal to ISD::ADD, then this function /// checks if \p N implements a horizontal arithmetic add; if instead \p Opcode /// is equal to ISD::SUB, then this function checks if this is a horizontal @@ -7865,12 +7866,17 @@ /// /// This function only analyzes elements of \p N whose indices are /// in range [BaseIdx, LastIdx). -static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode, - SelectionDAG &DAG, - unsigned BaseIdx, unsigned LastIdx, - SDValue &V0, SDValue &V1) { +/// +/// TODO: This function was originally used to match both real and fake partial +/// horizontal operations, but the index-matching logic is incorrect for that. +/// See the corrected implementation in isHopBuildVector(). Can we reduce this +/// code because it is only used for partial h-op matching now? +static bool isHorizontalBinOpPart(const BuildVectorSDNode *N, unsigned Opcode, + SelectionDAG &DAG, + unsigned BaseIdx, unsigned LastIdx, + SDValue &V0, SDValue &V1) { EVT VT = N->getValueType(0); - + assert(VT.is256BitVector() && "Only use for matching partial 256-bit h-ops"); assert(BaseIdx * 2 <= LastIdx && "Invalid Indices in input!"); assert(VT.isVector() && VT.getVectorNumElements() >= LastIdx && "Invalid Vector in input!"); @@ -8211,17 +8217,148 @@ return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1); } +static bool isHopBuildVector(const BuildVectorSDNode *BV, SelectionDAG &DAG, + unsigned &HOpcode, SDValue &V0, SDValue &V1) { + // Initialize outputs to known values. + MVT VT = BV->getSimpleValueType(0); + HOpcode = ISD::DELETED_NODE; + V0 = DAG.getUNDEF(VT); + V1 = DAG.getUNDEF(VT); + + // x86 256-bit horizontal ops are defined in a non-obvious way. Each 128-bit + // half of the result is calculated independently from the 128-bit halves of + // the inputs, so that makes the index-checking logic below more complicated. + unsigned NumElts = VT.getVectorNumElements(); + unsigned GenericOpcode = ISD::DELETED_NODE; + unsigned Num128BitChunks = VT.is256BitVector() ? 2 : 1; + unsigned NumEltsIn128Bits = NumElts / Num128BitChunks; + unsigned NumEltsIn64Bits = NumEltsIn128Bits / 2; + for (unsigned i = 0; i != Num128BitChunks; ++i) { + for (unsigned j = 0; j != NumEltsIn128Bits; ++j) { + // Ignore undef elements. + SDValue Op = BV->getOperand(i * NumEltsIn128Bits + j); + if (Op.isUndef()) + continue; + + // If there's an opcode mismatch, we're done. + if (HOpcode != ISD::DELETED_NODE && Op.getOpcode() != GenericOpcode) + return false; + + // Initialize horizontal opcode. + if (HOpcode == ISD::DELETED_NODE) { + GenericOpcode = Op.getOpcode(); + switch (GenericOpcode) { + case ISD::ADD: HOpcode = X86ISD::HADD; break; + case ISD::SUB: HOpcode = X86ISD::HSUB; break; + case ISD::FADD: HOpcode = X86ISD::FHADD; break; + case ISD::FSUB: HOpcode = X86ISD::FHSUB; break; + default: return false; + } + } + + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + if (Op0.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + Op1.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + Op0.getOperand(0) != Op1.getOperand(0) || + !isa(Op0.getOperand(1)) || + !isa(Op1.getOperand(1)) || !Op.hasOneUse()) + return false; + + // The source vector is chosen based on which 64-bit half of the + // destination vector is being calculated. + if (j < NumEltsIn64Bits) { + if (V0.isUndef()) + V0 = Op0.getOperand(0); + } else { + if (V1.isUndef()) + V1 = Op0.getOperand(0); + } + + SDValue SourceVec = (j < NumEltsIn64Bits) ? V0 : V1; + if (SourceVec != Op0.getOperand(0)) + return false; + + // op (extract_vector_elt A, I), (extract_vector_elt A, I+1) + unsigned ExtIndex0 = Op0.getConstantOperandVal(1); + unsigned ExtIndex1 = Op1.getConstantOperandVal(1); + unsigned ExpectedIndex = i * NumEltsIn128Bits + + (j % NumEltsIn64Bits) * 2; + if (ExpectedIndex == ExtIndex0 && ExtIndex1 == ExtIndex0 + 1) + continue; + + // If this is not a commutative op, this does not match. + if (GenericOpcode != ISD::ADD && GenericOpcode != ISD::FADD) + return false; + + // Addition is commutative, so try swapping the extract indexes. + // op (extract_vector_elt A, I+1), (extract_vector_elt A, I) + if (ExpectedIndex == ExtIndex1 && ExtIndex0 == ExtIndex1 + 1) + continue; + + // Extract indexes do not match horizontal requirement. + return false; + } + } + // We matched. Opcode and operands are returned by reference as arguments. + return true; +} + +static SDValue getHopForBuildVector(const BuildVectorSDNode *BV, + SelectionDAG &DAG, unsigned HOpcode, + SDValue V0, SDValue V1) { + // TODO: We should extract/insert to match the size of the build vector. + MVT VT = BV->getSimpleValueType(0); + if (V0.getValueType() != VT || V1.getValueType() != VT) + return SDValue(); + + return DAG.getNode(HOpcode, SDLoc(BV), VT, V0, V1); +} + /// Lower BUILD_VECTOR to a horizontal add/sub operation if possible. static SDValue LowerToHorizontalOp(const BuildVectorSDNode *BV, const X86Subtarget &Subtarget, SelectionDAG &DAG) { + // We need at least 2 non-undef elements to make this worthwhile by default. + unsigned NumNonUndefs = 0; + for (const SDValue &V : BV->op_values()) + if (!V.isUndef()) + ++NumNonUndefs; + + if (NumNonUndefs < 2) + return SDValue(); + + // There are 4 sets of horizontal math operations distinguished by type: + // int/FP at 128-bit/256-bit. Each type was introduced with a different + // subtarget feature. Try to match those "native" patterns first. MVT VT = BV->getSimpleValueType(0); + unsigned HOpcode; + SDValue V0, V1; + if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget.hasSSE3()) + if (isHopBuildVector(BV, DAG, HOpcode, V0, V1)) + return getHopForBuildVector(BV, DAG, HOpcode, V0, V1); + + if ((VT == MVT::v8i16 || VT == MVT::v4i32) && Subtarget.hasSSSE3()) + if (isHopBuildVector(BV, DAG, HOpcode, V0, V1)) + return getHopForBuildVector(BV, DAG, HOpcode, V0, V1); + + if ((VT == MVT::v8f32 || VT == MVT::v4f64) && Subtarget.hasAVX()) + if (isHopBuildVector(BV, DAG, HOpcode, V0, V1)) + return getHopForBuildVector(BV, DAG, HOpcode, V0, V1); + + if ((VT == MVT::v16i16 || VT == MVT::v8i32) && Subtarget.hasAVX2()) + if (isHopBuildVector(BV, DAG, HOpcode, V0, V1)) + return getHopForBuildVector(BV, DAG, HOpcode, V0, V1); + + // Try harder to match 256-bit ops by using extract/concat. + if (!Subtarget.hasAVX() || !VT.is256BitVector()) + return SDValue(); + + // Count the number of UNDEF operands in the build_vector in input. unsigned NumElts = VT.getVectorNumElements(); + unsigned Half = NumElts / 2; unsigned NumUndefsLO = 0; unsigned NumUndefsHI = 0; - unsigned Half = NumElts/2; - - // Count the number of UNDEF operands in the build_vector in input. for (unsigned i = 0, e = Half; i != e; ++i) if (BV->getOperand(i)->isUndef()) NumUndefsLO++; @@ -8230,72 +8367,31 @@ if (BV->getOperand(i)->isUndef()) NumUndefsHI++; - // Early exit if this is either a build_vector of all UNDEFs or all the - // operands but one are UNDEF. - if (NumUndefsLO + NumUndefsHI + 1 >= NumElts) - return SDValue(); - SDLoc DL(BV); SDValue InVec0, InVec1; - if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget.hasSSE3()) { - // Try to match an SSE3 float HADD/HSUB. - if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, NumElts, InVec0, InVec1)) - return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1); - - if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, NumElts, InVec0, InVec1)) - return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1); - } else if ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget.hasSSSE3()) { - // Try to match an SSSE3 integer HADD/HSUB. - if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1)) - return DAG.getNode(X86ISD::HADD, DL, VT, InVec0, InVec1); - - if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, NumElts, InVec0, InVec1)) - return DAG.getNode(X86ISD::HSUB, DL, VT, InVec0, InVec1); - } - - if (!Subtarget.hasAVX()) - return SDValue(); - - if ((VT == MVT::v8f32 || VT == MVT::v4f64)) { - // Try to match an AVX horizontal add/sub of packed single/double - // precision floating point values from 256-bit vectors. - SDValue InVec2, InVec3; - if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, Half, InVec0, InVec1) && - isHorizontalBinOp(BV, ISD::FADD, DAG, Half, NumElts, InVec2, InVec3) && - ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) && - ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3)) - return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1); - - if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, Half, InVec0, InVec1) && - isHorizontalBinOp(BV, ISD::FSUB, DAG, Half, NumElts, InVec2, InVec3) && - ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) && - ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3)) - return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1); - } else if (VT == MVT::v8i32 || VT == MVT::v16i16) { + if (VT == MVT::v8i32 || VT == MVT::v16i16) { // Try to match an AVX2 horizontal add/sub of signed integers. SDValue InVec2, InVec3; unsigned X86Opcode; bool CanFold = true; - if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, Half, InVec0, InVec1) && - isHorizontalBinOp(BV, ISD::ADD, DAG, Half, NumElts, InVec2, InVec3) && + if (isHorizontalBinOpPart(BV, ISD::ADD, DAG, 0, Half, InVec0, InVec1) && + isHorizontalBinOpPart(BV, ISD::ADD, DAG, Half, NumElts, InVec2, + InVec3) && ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) && ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3)) X86Opcode = X86ISD::HADD; - else if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, Half, InVec0, InVec1) && - isHorizontalBinOp(BV, ISD::SUB, DAG, Half, NumElts, InVec2, InVec3) && - ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) && - ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3)) + else if (isHorizontalBinOpPart(BV, ISD::SUB, DAG, 0, Half, InVec0, + InVec1) && + isHorizontalBinOpPart(BV, ISD::SUB, DAG, Half, NumElts, InVec2, + InVec3) && + ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) && + ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3)) X86Opcode = X86ISD::HSUB; else CanFold = false; if (CanFold) { - // Fold this build_vector into a single horizontal add/sub. - // Do this only if the target has AVX2. - if (Subtarget.hasAVX2()) - return DAG.getNode(X86Opcode, DL, VT, InVec0, InVec1); - // Do not try to expand this build_vector into a pair of horizontal // add/sub if we can emit a pair of scalar add/sub. if (NumUndefsLO + 1 == Half || NumUndefsHI + 1 == Half) @@ -8310,16 +8406,19 @@ } } - if ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 || - VT == MVT::v16i16) && Subtarget.hasAVX()) { + if (VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 || + VT == MVT::v16i16) { unsigned X86Opcode; - if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1)) + if (isHorizontalBinOpPart(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1)) X86Opcode = X86ISD::HADD; - else if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, NumElts, InVec0, InVec1)) + else if (isHorizontalBinOpPart(BV, ISD::SUB, DAG, 0, NumElts, InVec0, + InVec1)) X86Opcode = X86ISD::HSUB; - else if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, NumElts, InVec0, InVec1)) + else if (isHorizontalBinOpPart(BV, ISD::FADD, DAG, 0, NumElts, InVec0, + InVec1)) X86Opcode = X86ISD::FHADD; - else if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, NumElts, InVec0, InVec1)) + else if (isHorizontalBinOpPart(BV, ISD::FSUB, DAG, 0, NumElts, InVec0, + InVec1)) X86Opcode = X86ISD::FHSUB; else return SDValue(); Index: llvm/trunk/test/CodeGen/X86/haddsub-undef.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/haddsub-undef.ll +++ llvm/trunk/test/CodeGen/X86/haddsub-undef.ll @@ -300,7 +300,7 @@ ; ; AVX-LABEL: test11_undef: ; AVX: # %bb.0: -; AVX-NEXT: vhaddps %ymm0, %ymm0, %ymm0 +; AVX-NEXT: vhaddps %ymm1, %ymm0, %ymm0 ; AVX-NEXT: retq %vecext = extractelement <8 x float> %a, i32 0 %vecext1 = extractelement <8 x float> %a, i32 1 @@ -934,12 +934,12 @@ ; ; AVX1-SLOW-LABEL: v16f32_inputs_v8f32_output_4567: ; AVX1-SLOW: # %bb.0: -; AVX1-SLOW-NEXT: vhaddps %ymm0, %ymm0, %ymm0 +; AVX1-SLOW-NEXT: vhaddps %ymm2, %ymm0, %ymm0 ; AVX1-SLOW-NEXT: retq ; ; AVX1-FAST-LABEL: v16f32_inputs_v8f32_output_4567: ; AVX1-FAST: # %bb.0: -; AVX1-FAST-NEXT: vhaddps %ymm0, %ymm0, %ymm0 +; AVX1-FAST-NEXT: vhaddps %ymm2, %ymm0, %ymm0 ; AVX1-FAST-NEXT: retq ; ; AVX512-LABEL: v16f32_inputs_v8f32_output_4567: @@ -973,7 +973,7 @@ ; ; AVX-LABEL: PR40243: ; AVX: # %bb.0: -; AVX-NEXT: vhaddps %ymm0, %ymm0, %ymm0 +; AVX-NEXT: vhaddps %ymm1, %ymm0, %ymm0 ; AVX-NEXT: retq %a4 = extractelement <8 x float> %a, i32 4 %a5 = extractelement <8 x float> %a, i32 5 Index: llvm/trunk/test/CodeGen/X86/phaddsub-undef.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/phaddsub-undef.ll +++ llvm/trunk/test/CodeGen/X86/phaddsub-undef.ll @@ -75,12 +75,12 @@ ; ; AVX2-LABEL: test15_undef: ; AVX2: # %bb.0: -; AVX2-NEXT: vphaddd %ymm0, %ymm0, %ymm0 +; AVX2-NEXT: vphaddd %ymm1, %ymm0, %ymm0 ; AVX2-NEXT: retq ; ; AVX512-LABEL: test15_undef: ; AVX512: # %bb.0: -; AVX512-NEXT: vphaddd %ymm0, %ymm0, %ymm0 +; AVX512-NEXT: vphaddd %ymm1, %ymm0, %ymm0 ; AVX512-NEXT: retq %vecext = extractelement <8 x i32> %a, i32 0 %vecext1 = extractelement <8 x i32> %a, i32 1 @@ -105,12 +105,12 @@ ; ; AVX2-LABEL: PR40243_alt: ; AVX2: # %bb.0: -; AVX2-NEXT: vphaddd %ymm0, %ymm0, %ymm0 +; AVX2-NEXT: vphaddd %ymm1, %ymm0, %ymm0 ; AVX2-NEXT: retq ; ; AVX512-LABEL: PR40243_alt: ; AVX512: # %bb.0: -; AVX512-NEXT: vphaddd %ymm0, %ymm0, %ymm0 +; AVX512-NEXT: vphaddd %ymm1, %ymm0, %ymm0 ; AVX512-NEXT: retq %a4 = extractelement <8 x i32> %a, i32 4 %a5 = extractelement <8 x i32> %a, i32 5