Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -16404,7 +16404,15 @@ EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT, VT.getSizeInBits() / SVT.getSizeInBits()); - return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops)); + // If all operands excluding the first operand are undefs, then use a + // SCALAR_TO_VECTOR instead of a BUILD_VECTOR. + bool AllUndefExcludingFirstOp = + std::all_of(std::next(Ops.begin()), Ops.end(), + [](const SDValue &Op) { return Op.isUndef(); }); + SDValue BV = AllUndefExcludingFirstOp + ? DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, Ops[0]) + : DAG.getBuildVector(VecVT, DL, Ops); + return DAG.getBitcast(VT, BV); } // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR @@ -16485,6 +16493,62 @@ DAG.getBitcast(VT, SV1), Mask); } +static SDValue simplifyConcatVectors(SDNode *N, SelectionDAG &DAG, + bool LegalOperations, + const TargetLowering &TLI) { + // Bail out immediately if operands of this concat_vectors (excluding the + // first operand) are not undef. + if (!std::all_of(std::next(N->op_begin()), N->op_end(), + [](const SDValue &Op) { return Op.isUndef(); })) + return SDValue(); + + EVT VT = N->getValueType(0); + SDValue In = N->getOperand(0); + assert(In.getValueType().isVector() && "Must concat vectors"); + + // Transform: concat_vectors(scalar, undef) -> scalar_to_vector(scalar). + SDValue Scalar = peekThroughOneUseBitcasts(In); + if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR && + Scalar.hasOneUse()) { + // concat_vector( bitcast (scalar_to_vector %A), UNDEF) --> + // bitcast (scalar_to_vector %A) + Scalar = Scalar.getOperand(0); + } else if (!Scalar.getValueType().isVector()) { + // If the bitcast type isn't legal, it might be a trunc of a legal type; + // look through the trunc so we can still do the transform: + // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar) + if (Scalar->getOpcode() == ISD::TRUNCATE && + !TLI.isTypeLegal(Scalar.getValueType()) && + TLI.isTypeLegal(Scalar->getOperand(0).getValueType())) + Scalar = Scalar->getOperand(0); + } else { + // Not a scalar. + return SDValue(); + } + + EVT SclTy = Scalar->getValueType(0); + + if (!SclTy.isFloatingPoint() && !SclTy.isInteger()) + return SDValue(); + + // Bail out if the vector size is not a multiple of the scalar size. + if (VT.getSizeInBits() % SclTy.getSizeInBits()) + return SDValue(); + + unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits(); + if (VNTNumElms < 2) + return SDValue(); + + EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms); + if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType())) + return SDValue(); + + SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar); + if (VT != NVT) + return DAG.getBitcast(VT, Res); + return Res; +} + SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { // If we only have one input vector, we don't need to do any concatenation. if (N->getNumOperands() == 1) @@ -16496,46 +16560,8 @@ return DAG.getUNDEF(VT); // Optimize concat_vectors where all but the first of the vectors are undef. - if (std::all_of(std::next(N->op_begin()), N->op_end(), [](const SDValue &Op) { - return Op.isUndef(); - })) { - SDValue In = N->getOperand(0); - assert(In.getValueType().isVector() && "Must concat vectors"); - - // Transform: concat_vectors(scalar, undef) -> scalar_to_vector(sclr). - if (In->getOpcode() == ISD::BITCAST && - !In->getOperand(0).getValueType().isVector()) { - SDValue Scalar = In->getOperand(0); - - // If the bitcast type isn't legal, it might be a trunc of a legal type; - // look through the trunc so we can still do the transform: - // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar) - if (Scalar->getOpcode() == ISD::TRUNCATE && - !TLI.isTypeLegal(Scalar.getValueType()) && - TLI.isTypeLegal(Scalar->getOperand(0).getValueType())) - Scalar = Scalar->getOperand(0); - - EVT SclTy = Scalar->getValueType(0); - - if (!SclTy.isFloatingPoint() && !SclTy.isInteger()) - return SDValue(); - - // Bail out if the vector size is not a multiple of the scalar size. - if (VT.getSizeInBits() % SclTy.getSizeInBits()) - return SDValue(); - - unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits(); - if (VNTNumElms < 2) - return SDValue(); - - EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms); - if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType())) - return SDValue(); - - SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar); - return DAG.getBitcast(VT, Res); - } - } + if (SDValue V = simplifyConcatVectors(N, DAG, LegalOperations, TLI)) + return V; // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR. // We have already tested above for an UNDEF only concatenation. Index: test/CodeGen/X86/combine-concatvectors.ll =================================================================== --- test/CodeGen/X86/combine-concatvectors.ll +++ test/CodeGen/X86/combine-concatvectors.ll @@ -0,0 +1,18 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=x86_64-unknown-unknown -mattr=+avx < %s | FileCheck %s + +define void @PR32957(<2 x float>* %in, <8 x float>* %out) { +; CHECK-LABEL: PR32957: +; CHECK: # %bb.0: +; CHECK-NEXT: vmovsd {{.*#+}} xmm0 = mem[0],zero +; CHECK-NEXT: vmovaps %ymm0, (%rsi) +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + %ld = load <2 x float>, <2 x float>* %in, align 8 + %ext = extractelement <2 x float> %ld, i64 0 + %ext2 = extractelement <2 x float> %ld, i64 1 + %ins = insertelement <8 x float> , float %ext, i64 0 + %ins2 = insertelement <8 x float> %ins, float %ext2, i64 1 + store <8 x float> %ins2, <8 x float>* %out, align 32 + ret void +}