diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -10782,6 +10782,52 @@ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID, Zero, SplatVal); } +// Simplify patterns like +// INSERT_SUBVECTOR(BUILD_VECTOR(f32 a, f32 b, f32 a, f32 b)) +// to SPLAT_VECTOR(f64(a, b)) +SDValue useWideSplatForBuildVectorRepeatedComplexPattern(SDValue Op, + SelectionDAG &DAG) { + SDValue Insert = Op.getOperand(1); + if (Insert.getOpcode() != ISD::INSERT_SUBVECTOR) + return SDValue(); + + SDValue BuildVector = Insert.getOperand(1); + if (BuildVector.getOpcode() != ISD::BUILD_VECTOR) + return SDValue(); + + EVT VecTy = BuildVector.getValueType(); + if (!VecTy.is128BitVector()) + return SDValue(); + + unsigned NumOperands = BuildVector.getNumOperands(); + if (NumOperands < 4 || NumOperands % 2 != 0) + return SDValue(); + for (unsigned i = 0; i < NumOperands / 2; i++) { + if (BuildVector.getOperand(i) != BuildVector.getOperand(i + 2)) + return SDValue(); + } + + EVT VecElTy = VecTy.getScalarType(); + MVT WideElTy = MVT::getFloatingPointVT(VecElTy.getScalarSizeInBits() * 2); + MVT WideVecTy = + MVT::getVectorVT(WideElTy, VecTy.getVectorNumElements() / 2, false); + MVT ScalableWideVecTy = + MVT::getVectorVT(WideElTy, VecTy.getVectorNumElements() / 2, true); + + SDLoc DL(Op); + SDValue InsertVectorElt = + DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecTy, DAG.getUNDEF(VecTy), + BuildVector.getOperand(1), DAG.getConstant(1, DL, MVT::i64)); + SDValue WideVecBitcast = + DAG.getNode(ISD::BITCAST, DL, WideVecTy, InsertVectorElt); + SDValue ExtractWideVectorElt = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, WideElTy, WideVecBitcast, + DAG.getConstant(0, DL, MVT::i64)); + SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, DL, ScalableWideVecTy, + ExtractWideVectorElt); + return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Splat); +} + SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); @@ -10799,9 +10845,16 @@ // DUPQ can be used when idx is in range. auto *CIdx = dyn_cast(Idx128); - if (CIdx && (CIdx->getZExtValue() <= 3)) { - SDValue CI = DAG.getTargetConstant(CIdx->getZExtValue(), DL, MVT::i64); - return DAG.getNode(AArch64ISD::DUPLANE128, DL, VT, Op.getOperand(1), CI); + if (CIdx) { + uint64_t CIdxVal = CIdx->getZExtValue(); + if (CIdxVal <= 3) { + if (SDValue WideSplat = + useWideSplatForBuildVectorRepeatedComplexPattern(Op, DAG); + CIdxVal == 0 && WideSplat) + return WideSplat; + SDValue CI = DAG.getTargetConstant(CIdx->getZExtValue(), DL, MVT::i64); + return DAG.getNode(AArch64ISD::DUPLANE128, DL, VT, Op.getOperand(1), CI); + } } SDValue V = DAG.getNode(ISD::BITCAST, DL, MVT::nxv2i64, Op.getOperand(1)); diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll @@ -587,13 +587,9 @@ define dso_local @dupq_f32_repeat_complex(float %x, float %y) { ; CHECK-LABEL: dupq_f32_repeat_complex: ; CHECK: // %bb.0: -; CHECK-NEXT: // kill: def $s0 killed $s0 def $q0 -; CHECK-NEXT: mov v2.16b, v0.16b ; CHECK-NEXT: // kill: def $s1 killed $s1 def $q1 -; CHECK-NEXT: mov v2.s[1], v1.s[0] -; CHECK-NEXT: mov v2.s[2], v0.s[0] -; CHECK-NEXT: mov v2.s[3], v1.s[0] -; CHECK-NEXT: mov z0.q, q2 +; CHECK-NEXT: mov v0.s[1], v1.s[0] +; CHECK-NEXT: mov z0.d, d0 ; CHECK-NEXT: ret %1 = insertelement <4 x float> undef, float %x, i64 0 %2 = insertelement <4 x float> %1, float %y, i64 1 @@ -607,17 +603,9 @@ define dso_local @dupq_f16_repeat_complex(half %a, half %b) { ; CHECK-LABEL: dupq_f16_repeat_complex: ; CHECK: // %bb.0: -; CHECK-NEXT: // kill: def $h0 killed $h0 def $q0 -; CHECK-NEXT: mov v2.16b, v0.16b ; CHECK-NEXT: // kill: def $h1 killed $h1 def $q1 -; CHECK-NEXT: mov v2.h[1], v1.h[0] -; CHECK-NEXT: mov v2.h[2], v0.h[0] -; CHECK-NEXT: mov v2.h[3], v1.h[0] -; CHECK-NEXT: mov v2.h[4], v0.h[0] -; CHECK-NEXT: mov v2.h[5], v1.h[0] -; CHECK-NEXT: mov v2.h[6], v0.h[0] -; CHECK-NEXT: mov v2.h[7], v1.h[0] -; CHECK-NEXT: mov z0.q, q2 +; CHECK-NEXT: mov v0.h[1], v1.h[0] +; CHECK-NEXT: mov z0.s, s0 ; CHECK-NEXT: ret %1 = insertelement <8 x half> undef, half %a, i64 0 %2 = insertelement <8 x half> %1, half %b, i64 1