Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -16310,18 +16310,42 @@ return SDValue(); } -static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) { +static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG, + unsigned LegalOperations) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); EVT OpVT = N->getOperand(0).getValueType(); + SDLoc DL(N); + EVT VT = N->getValueType(0); + + // concat_vectors( bitcast (scalar_to_vector %A), UNDEF) --> + // bitcast (scalar_to_vector %A) + if (!LegalOperations && N->getNumOperands() > 1) { + SDValue Op0 = N->getOperand(0); + if (Op0.hasOneUse() && Op0.getOpcode() == ISD::BITCAST && + Op0.getOperand(0).hasOneUse() && + Op0.getOperand(0).getOpcode() == ISD::SCALAR_TO_VECTOR) { + bool AllUndefs = + std::all_of(N->op_begin() + 1, N->op_end(), + [](const SDValue &U) { return U.isUndef(); }); + + if (AllUndefs) { + SDValue Scalar = Op0.getOperand(0).getOperand(0); + EVT SVT = Scalar.getValueType(); + + EVT NewVT = EVT::getVectorVT(*DAG.getContext(), SVT, + VT.getSizeInBits() / SVT.getSizeInBits()); + SDValue STV = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, NewVT, Scalar); + return DAG.getBitcast(VT, STV); + } + } + } + // If the operands are legal vectors, leave them alone. if (TLI.isTypeLegal(OpVT)) return SDValue(); - SDLoc DL(N); - EVT VT = N->getValueType(0); SmallVector Ops; - EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits()); SDValue ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT); @@ -16551,7 +16575,7 @@ } // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR. - if (SDValue V = combineConcatVectorOfScalars(N, DAG)) + if (SDValue V = combineConcatVectorOfScalars(N, DAG, LegalOperations)) return V; // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE. Index: test/CodeGen/X86/simplify_concat_vectors.ll =================================================================== --- test/CodeGen/X86/simplify_concat_vectors.ll +++ test/CodeGen/X86/simplify_concat_vectors.ll @@ -0,0 +1,18 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=x86_64-unknown-unknown -mattr=+avx < %s | FileCheck %s + +define void @PR32957(<2 x float>* %in, <8 x float>* %out) { +; CHECK-LABEL: PR32957: +; CHECK: # %bb.0: +; CHECK-NEXT: vmovsd {{.*#+}} xmm0 = mem[0],zero +; CHECK-NEXT: vmovaps %ymm0, (%rsi) +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + %ld = load <2 x float>, <2 x float>* %in, align 8 + %ext = extractelement <2 x float> %ld, i64 0 + %ext2 = extractelement <2 x float> %ld, i64 1 + %ins = insertelement <8 x float> , float %ext, i64 0 + %ins2 = insertelement <8 x float> %ins, float %ext2, i64 1 + store <8 x float> %ins2, <8 x float>* %out, align 32 + ret void +}