Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -6023,14 +6023,11 @@ return SDValue(); } -/// Attempt to use the vbroadcast instruction to generate a splat value for the -/// following cases: -/// 1. A splat BUILD_VECTOR which uses a single scalar load, or a constant. -/// 2. A splat shuffle which uses a scalar_to_vector node which comes from -/// a scalar load, or a constant. +/// Attempt to use the vbroadcast instruction to generate a splat value for a +/// splat BUILD_VECTOR which uses a single scalar load, or a constant. /// The VBROADCAST node is returned when a pattern is found, /// or SDValue() otherwise. -static SDValue LowerVectorBroadcast(SDValue Op, const X86Subtarget &Subtarget, +static SDValue LowerVectorBroadcast(BuildVectorSDNode *BVOp, const X86Subtarget &Subtarget, SelectionDAG &DAG) { // VBROADCAST requires AVX. // TODO: Splats could be generated for non-AVX CPUs using SSE @@ -6038,79 +6035,27 @@ if (!Subtarget.hasAVX()) return SDValue(); - MVT VT = Op.getSimpleValueType(); - SDLoc dl(Op); + MVT VT = BVOp->getSimpleValueType(0); + SDLoc dl(BVOp); assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) && "Unsupported vector type for broadcast."); - SDValue Ld; - bool ConstSplatVal; - - switch (Op.getOpcode()) { - default: - // Unknown pattern found. - return SDValue(); - - case ISD::BUILD_VECTOR: { - auto *BVOp = cast(Op.getNode()); - BitVector UndefElements; - SDValue Splat = BVOp->getSplatValue(&UndefElements); - - // We need a splat of a single value to use broadcast, and it doesn't - // make any sense if the value is only in one element of the vector. - if (!Splat || (VT.getVectorNumElements() - UndefElements.count()) <= 1) - return SDValue(); - - Ld = Splat; - ConstSplatVal = (Ld.getOpcode() == ISD::Constant || - Ld.getOpcode() == ISD::ConstantFP); - - // Make sure that all of the users of a non-constant load are from the - // BUILD_VECTOR node. - if (!ConstSplatVal && !BVOp->isOnlyUserOf(Ld.getNode())) - return SDValue(); - break; - } - - case ISD::VECTOR_SHUFFLE: { - ShuffleVectorSDNode *SVOp = cast(Op); - - // Shuffles must have a splat mask where the first element is - // broadcasted. - if ((!SVOp->isSplat()) || SVOp->getMaskElt(0) != 0) - return SDValue(); - - SDValue Sc = Op.getOperand(0); - if (Sc.getOpcode() != ISD::SCALAR_TO_VECTOR && - Sc.getOpcode() != ISD::BUILD_VECTOR) { - - if (!Subtarget.hasInt256()) - return SDValue(); - - // Use the register form of the broadcast instruction available on AVX2. - if (VT.getSizeInBits() >= 256) - Sc = extract128BitVector(Sc, 0, DAG, dl); - return DAG.getNode(X86ISD::VBROADCAST, dl, VT, Sc); - } + BitVector UndefElements; + SDValue Ld = BVOp->getSplatValue(&UndefElements); - Ld = Sc.getOperand(0); - ConstSplatVal = (Ld.getOpcode() == ISD::Constant || - Ld.getOpcode() == ISD::ConstantFP); + // We need a splat of a single value to use broadcast, and it doesn't + // make any sense if the value is only in one element of the vector. + if (!Ld || (VT.getVectorNumElements() - UndefElements.count()) <= 1) + return SDValue(); - // The scalar_to_vector node and the suspected - // load node must have exactly one user. - // Constants may have multiple users. + bool ConstSplatVal = + (Ld.getOpcode() == ISD::Constant || Ld.getOpcode() == ISD::ConstantFP); - // AVX-512 has register version of the broadcast - bool hasRegVer = Subtarget.hasAVX512() && VT.is512BitVector() && - Ld.getValueSizeInBits() >= 32; - if (!ConstSplatVal && ((!Sc.hasOneUse() || !Ld.hasOneUse()) && - !hasRegVer)) - return SDValue(); - break; - } - } + // Make sure that all of the users of a non-constant load are from the + // BUILD_VECTOR node. + if (!ConstSplatVal && !BVOp->isOnlyUserOf(Ld.getNode())) + return SDValue(); unsigned ScalarSize = Ld.getValueSizeInBits(); bool IsGE256 = (VT.getSizeInBits() >= 256); @@ -6881,7 +6826,7 @@ return AddSub; if (SDValue HorizontalOp = LowerToHorizontalOp(BV, Subtarget, DAG)) return HorizontalOp; - if (SDValue Broadcast = LowerVectorBroadcast(Op, Subtarget, DAG)) + if (SDValue Broadcast = LowerVectorBroadcast(BV, Subtarget, DAG)) return Broadcast; if (SDValue BitOp = lowerBuildVectorToBitOp(Op, DAG)) return BitOp;