diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -20027,9 +20027,71 @@ return DAG.getNode(ISD::OR, DL, VT, Sel, SelInv); } +// Simplify a repeating complex pattern from vector insert inside DUPLANE128 +// to a splat of one complex number, i.e.: +// nxv4f32 DUPLANE128(f32 a, f32 b, f32 a, f32 b) -> nx2f64 SPLAT_VECTOR(a << +// 32 | b)) +SDValue simplifyDupLane128RepeatedComplexPattern(SDNode *Op, + SelectionDAG &DAG) { + SDValue InsertSubvector = Op->getOperand(0); + if (InsertSubvector.getOpcode() != ISD::INSERT_SUBVECTOR) + return SDValue(); + if (!InsertSubvector.getOperand(0).isUndef()) + return SDValue(); + + SDValue InsertVecElt = InsertSubvector.getOperand(1); + if (InsertVecElt.getOpcode() != ISD::INSERT_VECTOR_ELT) + return SDValue(); + EVT VecTy = InsertVecElt.getValueType(); + if (!VecTy.is128BitVector()) + return SDValue(); + + SmallVector Sequence; + do { + if (InsertVecElt.getOpcode() == ISD::INSERT_VECTOR_ELT) { + Sequence.push_back(InsertVecElt.getOperand(1)); + InsertVecElt = InsertVecElt.getOperand(0); + } else if (InsertVecElt.getOpcode() == ISD::SCALAR_TO_VECTOR) { + Sequence.push_back(InsertVecElt.getOperand(0)); + break; + } else + return SDValue(); + } while (true); + + if (Sequence.size() < 4 || Sequence.size() % 2 != 0) + return SDValue(); + for (unsigned i = 0; i < Sequence.size() / 2; i++) { + if (Sequence[i] != Sequence[i + 2]) + return SDValue(); + } + + EVT VecElTy = VecTy.getScalarType(); + MVT WideElTy = MVT::getFloatingPointVT(VecElTy.getScalarSizeInBits() * 2); + MVT WideVecTy = + MVT::getVectorVT(WideElTy, VecTy.getVectorNumElements() / 2, false); + MVT ScalableWideVecTy = + MVT::getVectorVT(WideElTy, VecTy.getVectorNumElements() / 2, true); + + SDLoc DL(Op); + InsertVecElt = + DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecTy, DAG.getUNDEF(VecTy), + Sequence[0], DAG.getConstant(1, DL, MVT::i64)); + SDValue WideVecBitcast = + DAG.getNode(ISD::BITCAST, DL, WideVecTy, InsertVecElt); + SDValue ExtractWideVectorElt = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, WideElTy, WideVecBitcast, + DAG.getConstant(0, DL, MVT::i64)); + SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, DL, ScalableWideVecTy, + ExtractWideVectorElt); + return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), Splat); +} + static SDValue performDupLane128Combine(SDNode *N, SelectionDAG &DAG) { - EVT VT = N->getValueType(0); + SDValue WideSplat = simplifyDupLane128RepeatedComplexPattern(N, DAG); + if (WideSplat) + return WideSplat; + EVT VT = N->getValueType(0); SDValue Insert = N->getOperand(0); if (Insert.getOpcode() != ISD::INSERT_SUBVECTOR) return SDValue(); diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll @@ -587,13 +587,9 @@ define dso_local @dupq_f32_repeat_complex(float %x, float %y) { ; CHECK-LABEL: dupq_f32_repeat_complex: ; CHECK: // %bb.0: -; CHECK-NEXT: // kill: def $s0 killed $s0 def $q0 -; CHECK-NEXT: mov v2.16b, v0.16b ; CHECK-NEXT: // kill: def $s1 killed $s1 def $q1 -; CHECK-NEXT: mov v2.s[1], v1.s[0] -; CHECK-NEXT: mov v2.s[2], v0.s[0] -; CHECK-NEXT: mov v2.s[3], v1.s[0] -; CHECK-NEXT: mov z0.q, q2 +; CHECK-NEXT: mov v0.s[1], v1.s[0] +; CHECK-NEXT: mov z0.d, d0 ; CHECK-NEXT: ret %1 = insertelement <4 x float> undef, float %x, i64 0 %2 = insertelement <4 x float> %1, float %y, i64 1 @@ -607,17 +603,9 @@ define dso_local @dupq_f16_repeat_complex(half %a, half %b) { ; CHECK-LABEL: dupq_f16_repeat_complex: ; CHECK: // %bb.0: -; CHECK-NEXT: // kill: def $h0 killed $h0 def $q0 -; CHECK-NEXT: mov v2.16b, v0.16b ; CHECK-NEXT: // kill: def $h1 killed $h1 def $q1 -; CHECK-NEXT: mov v2.h[1], v1.h[0] -; CHECK-NEXT: mov v2.h[2], v0.h[0] -; CHECK-NEXT: mov v2.h[3], v1.h[0] -; CHECK-NEXT: mov v2.h[4], v0.h[0] -; CHECK-NEXT: mov v2.h[5], v1.h[0] -; CHECK-NEXT: mov v2.h[6], v0.h[0] -; CHECK-NEXT: mov v2.h[7], v1.h[0] -; CHECK-NEXT: mov z0.q, q2 +; CHECK-NEXT: mov v0.h[1], v1.h[0] +; CHECK-NEXT: mov z0.s, s0 ; CHECK-NEXT: ret %1 = insertelement <8 x half> undef, half %a, i64 0 %2 = insertelement <8 x half> %1, half %b, i64 1