Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -20678,9 +20678,96 @@ return DAG.getNode(ISD::OR, DL, VT, Sel, SelInv); } +// Simplify a repeating complex pattern from vector insert inside DUPLANE128 +// to a splat of one complex number, i.e.: +// nxv4f32 DUPLANE128(f32 x, f32 y, f32 x, f32 y) -> nx2f64 SPLAT_VECTOR(x << +// 32 | y)) +SDValue simplifyDupLane128RepeatedComplexPattern(SDNode *Op, + SelectionDAG &DAG) { + SDValue InsertSubvector = Op->getOperand(0); + if (InsertSubvector.getOpcode() != ISD::INSERT_SUBVECTOR) + return SDValue(); + if (!InsertSubvector.getOperand(0).isUndef()) + return SDValue(); + + SDValue InsertVecElt = InsertSubvector.getOperand(1); + if (InsertVecElt.getOpcode() != ISD::INSERT_VECTOR_ELT) + return SDValue(); + EVT VecTy = InsertVecElt.getValueType(); + if (!VecTy.is128BitVector()) + return SDValue(); + EVT VecElTy = VecTy.getScalarType(); + if (VecElTy == MVT::f64 || VecElTy == MVT::i64 || VecElTy == MVT::i8) + return SDValue(); + + // Walk the chain of INSERT_VECTOR_ELT's until a SCALAR_TO_VECTOR is hit, + // Capturing the scalar registers being inserted. This will result in a + // reversed sequence (e.g. y, x, y, x) or (3, 2, 1, 0) + unsigned int NumElements = VecTy.getVectorNumElements() - 1; + SmallVector RSequence; + do { + if (InsertVecElt.getOpcode() == ISD::INSERT_VECTOR_ELT) { + RSequence.push_back(InsertVecElt.getOperand(1)); + uint64_t Index = InsertVecElt.getConstantOperandVal(2); + if (Index != NumElements) + return SDValue(); + NumElements--; + InsertVecElt = InsertVecElt.getOperand(0); + } else if (InsertVecElt.getOpcode() == ISD::SCALAR_TO_VECTOR) { + RSequence.push_back(InsertVecElt.getOperand(0)); + break; + } else + return SDValue(); + } while (true); + + // We can't simplify a repeat complex pattern if there's less than 4 elements + if (RSequence.size() < 4) + return SDValue(); + // We can't simplify a repeat complex pattern if there's an odd number of + // elements + if (RSequence.size() % 2 != 0) + return SDValue(); + + // Check the "real" and "imaginary" components for equality across each + // complex number + for (unsigned i = 0; i < RSequence.size() - 2; i++) + if (RSequence[i] != RSequence[i + 2]) + return SDValue(); + + MVT WideElTy = MVT::getFloatingPointVT(VecElTy.getScalarSizeInBits() * 2); + MVT WideVecTy = + MVT::getVectorVT(WideElTy, VecTy.getVectorNumElements() / 2, false); + MVT ScalableWideVecTy = + MVT::getVectorVT(WideElTy, VecTy.getVectorNumElements() / 2, true); + + int SRIdx = 0; + if (VecElTy == MVT::f32 || VecElTy == MVT::i32) + SRIdx = AArch64::ssub; + else if (VecElTy == MVT::f16 || VecElTy == MVT::i16) + SRIdx = AArch64::hsub; + SDLoc DL(Op); + + SDValue X = RSequence[1], Y = RSequence[0]; + SDValue InsertSubreg = + DAG.getTargetInsertSubreg(SRIdx, DL, VecTy, DAG.getUNDEF(VecTy), X); + InsertVecElt = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecTy, InsertSubreg, Y, + DAG.getConstant(1, DL, MVT::i64)); + SDValue WideVecBitcast = + DAG.getNode(ISD::BITCAST, DL, WideVecTy, InsertVecElt); + SDValue ExtractWideVectorElt = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, WideElTy, WideVecBitcast, + DAG.getConstant(0, DL, MVT::i64)); + SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, DL, ScalableWideVecTy, + ExtractWideVectorElt); + return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), Splat); +} + static SDValue performDupLane128Combine(SDNode *N, SelectionDAG &DAG) { - EVT VT = N->getValueType(0); + SDValue WideSplat = simplifyDupLane128RepeatedComplexPattern(N, DAG); + if (WideSplat) + return WideSplat; + EVT VT = N->getValueType(0); SDValue Insert = N->getOperand(0); if (Insert.getOpcode() != ISD::INSERT_SUBVECTOR) return SDValue(); Index: llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll +++ llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll @@ -584,16 +584,13 @@ ; EXT ; -define dso_local @dupq_f32_repeat_complex(float %x, float %y) { +define @dupq_f32_repeat_complex(float %x, float %y) { ; CHECK-LABEL: dupq_f32_repeat_complex: ; CHECK: // %bb.0: -; CHECK-NEXT: // kill: def $s0 killed $s0 def $q0 -; CHECK-NEXT: mov v2.16b, v0.16b +; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0 ; CHECK-NEXT: // kill: def $s1 killed $s1 def $q1 -; CHECK-NEXT: mov v2.s[1], v1.s[0] -; CHECK-NEXT: mov v2.s[2], v0.s[0] -; CHECK-NEXT: mov v2.s[3], v1.s[0] -; CHECK-NEXT: mov z0.q, q2 +; CHECK-NEXT: mov v0.s[1], v1.s[0] +; CHECK-NEXT: mov z0.d, d0 ; CHECK-NEXT: ret %1 = insertelement <4 x float> undef, float %x, i64 0 %2 = insertelement <4 x float> %1, float %y, i64 1 @@ -604,20 +601,149 @@ ret %6 } -define dso_local @dupq_f16_repeat_complex(half %a, half %b) { +define @dupq_f32_repeat_complex_unordered(float %x, float %y) { +; CHECK-LABEL: dupq_f32_repeat_complex_unordered: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0 +; CHECK-NEXT: // kill: def $s1 killed $s1 def $q1 +; CHECK-NEXT: mov v0.s[1], v0.s[0] +; CHECK-NEXT: mov v0.s[2], v1.s[0] +; CHECK-NEXT: mov v0.s[3], v1.s[0] +; CHECK-NEXT: mov z0.q, q0 +; CHECK-NEXT: ret + %1 = insertelement <4 x float> undef, float %x, i64 0 + %2 = insertelement <4 x float> %1, float %y, i64 3 + %3 = insertelement <4 x float> %2, float %x, i64 1 + %4 = insertelement <4 x float> %3, float %y, i64 2 + %5 = tail call @llvm.vector.insert.nxv4f32.v4f32( undef, <4 x float> %4, i64 0) + %6 = tail call @llvm.aarch64.sve.dupq.lane.nxv4f32( %5, i64 0) + ret %6 +} + +define @dupq_f32_repeat_complex_rev(float %x, float %y) { +; CHECK-LABEL: dupq_f32_repeat_complex_rev: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $s1 killed $s1 def $z1 +; CHECK-NEXT: // kill: def $s0 killed $s0 def $q0 +; CHECK-NEXT: mov v1.s[1], v0.s[0] +; CHECK-NEXT: mov z0.d, d1 +; CHECK-NEXT: ret + %1 = insertelement <4 x float> undef, float %x, i64 3 + %2 = insertelement <4 x float> %1, float %y, i64 2 + %3 = insertelement <4 x float> %2, float %x, i64 1 + %4 = insertelement <4 x float> %3, float %y, i64 0 + %5 = tail call @llvm.vector.insert.nxv4f32.v4f32( undef, <4 x float> %4, i64 0) + %6 = tail call @llvm.aarch64.sve.dupq.lane.nxv4f32( %5, i64 0) + ret %6 +} + +define @dupq_f32_repeat_complex_rev_unordered(float %x, float %y) { +; CHECK-LABEL: dupq_f32_repeat_complex_rev_unordered: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $s1 killed $s1 def $z1 +; CHECK-NEXT: // kill: def $s0 killed $s0 def $q0 +; CHECK-NEXT: mov v1.s[1], v1.s[0] +; CHECK-NEXT: mov v1.s[2], v0.s[0] +; CHECK-NEXT: mov v1.s[3], v0.s[0] +; CHECK-NEXT: mov z0.q, q1 +; CHECK-NEXT: ret + %1 = insertelement <4 x float> undef, float %x, i64 3 + %2 = insertelement <4 x float> %1, float %y, i64 0 + %3 = insertelement <4 x float> %2, float %x, i64 2 + %4 = insertelement <4 x float> %3, float %y, i64 1 + %5 = tail call @llvm.vector.insert.nxv4f32.v4f32( undef, <4 x float> %4, i64 0) + %6 = tail call @llvm.aarch64.sve.dupq.lane.nxv4f32( %5, i64 0) + ret %6 +} + +define @dupq_f16_repeat_complex(half %a, half %b) { ; CHECK-LABEL: dupq_f16_repeat_complex: ; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0 +; CHECK-NEXT: // kill: def $h1 killed $h1 def $q1 +; CHECK-NEXT: mov v0.h[1], v1.h[0] +; CHECK-NEXT: mov z0.s, s0 +; CHECK-NEXT: ret + %1 = insertelement <8 x half> undef, half %a, i64 0 + %2 = insertelement <8 x half> %1, half %b, i64 1 + %3 = insertelement <8 x half> %2, half %a, i64 2 + %4 = insertelement <8 x half> %3, half %b, i64 3 + %5 = insertelement <8 x half> %4, half %a, i64 4 + %6 = insertelement <8 x half> %5, half %b, i64 5 + %7 = insertelement <8 x half> %6, half %a, i64 6 + %8 = insertelement <8 x half> %7, half %b, i64 7 + %9 = tail call @llvm.vector.insert.nxv8f16.v8f16( undef, <8 x half> %8, i64 0) + %10 = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( %9, i64 0) + ret %10 +} + +define @dupq_f16_repeat_complex_omit_pair(half %a, half %b) { +; CHECK-LABEL: dupq_f16_repeat_complex_omit_pair: +; CHECK: // %bb.0: ; CHECK-NEXT: // kill: def $h0 killed $h0 def $q0 ; CHECK-NEXT: mov v2.16b, v0.16b ; CHECK-NEXT: // kill: def $h1 killed $h1 def $q1 ; CHECK-NEXT: mov v2.h[1], v1.h[0] ; CHECK-NEXT: mov v2.h[2], v0.h[0] ; CHECK-NEXT: mov v2.h[3], v1.h[0] +; CHECK-NEXT: mov v2.h[6], v0.h[0] +; CHECK-NEXT: mov v2.h[7], v1.h[0] +; CHECK-NEXT: mov z0.q, q2 +; CHECK-NEXT: ret + %1 = insertelement <8 x half> undef, half %a, i64 0 + %2 = insertelement <8 x half> %1, half %b, i64 1 + %3 = insertelement <8 x half> %2, half %a, i64 2 + %4 = insertelement <8 x half> %3, half %b, i64 3 + %5 = insertelement <8 x half> %4, half %a, i64 6 + %6 = insertelement <8 x half> %5, half %b, i64 7 + %7 = tail call @llvm.vector.insert.nxv8f16.v8f16( undef, <8 x half> %6, i64 0) + %8 = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( %7, i64 0) + ret %8 +} + +define @dupq_f16_repeat_complex_mismatched_front(half %a, half %b, half %c) { +; CHECK-LABEL: dupq_f16_repeat_complex_mismatched_front: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $h2 killed $h2 def $z2 +; CHECK-NEXT: // kill: def $h0 killed $h0 def $q0 +; CHECK-NEXT: // kill: def $h1 killed $h1 def $q1 +; CHECK-NEXT: mov v2.h[1], v2.h[0] +; CHECK-NEXT: mov v2.h[2], v0.h[0] +; CHECK-NEXT: mov v2.h[3], v1.h[0] ; CHECK-NEXT: mov v2.h[4], v0.h[0] ; CHECK-NEXT: mov v2.h[5], v1.h[0] ; CHECK-NEXT: mov v2.h[6], v0.h[0] ; CHECK-NEXT: mov v2.h[7], v1.h[0] ; CHECK-NEXT: mov z0.q, q2 +; CHECK-NEXT: ret + %1 = insertelement <8 x half> undef, half %c, i64 0 + %2 = insertelement <8 x half> %1, half %c, i64 1 + %3 = insertelement <8 x half> %2, half %a, i64 2 + %4 = insertelement <8 x half> %3, half %b, i64 3 + %5 = insertelement <8 x half> %4, half %a, i64 4 + %6 = insertelement <8 x half> %5, half %b, i64 5 + %7 = insertelement <8 x half> %6, half %a, i64 6 + %8 = insertelement <8 x half> %7, half %b, i64 7 + %9 = tail call @llvm.vector.insert.nxv8f16.v8f16( undef, <8 x half> %8, i64 0) + %10 = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( %9, i64 0) + ret %10 +} + +define @dupq_f16_repeat_complex_mismatched_end(half %a, half %b, half %c) { +; CHECK-LABEL: dupq_f16_repeat_complex_mismatched_end: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $h0 killed $h0 def $q0 +; CHECK-NEXT: mov v3.16b, v0.16b +; CHECK-NEXT: // kill: def $h1 killed $h1 def $q1 +; CHECK-NEXT: // kill: def $h2 killed $h2 def $q2 +; CHECK-NEXT: mov v3.h[1], v1.h[0] +; CHECK-NEXT: mov v3.h[2], v0.h[0] +; CHECK-NEXT: mov v3.h[3], v1.h[0] +; CHECK-NEXT: mov v3.h[4], v0.h[0] +; CHECK-NEXT: mov v3.h[5], v1.h[0] +; CHECK-NEXT: mov v3.h[6], v2.h[0] +; CHECK-NEXT: mov v3.h[7], v2.h[0] +; CHECK-NEXT: mov z0.q, q3 ; CHECK-NEXT: ret %1 = insertelement <8 x half> undef, half %a, i64 0 %2 = insertelement <8 x half> %1, half %b, i64 1 @@ -625,6 +751,35 @@ %4 = insertelement <8 x half> %3, half %b, i64 3 %5 = insertelement <8 x half> %4, half %a, i64 4 %6 = insertelement <8 x half> %5, half %b, i64 5 + %7 = insertelement <8 x half> %6, half %c, i64 6 + %8 = insertelement <8 x half> %7, half %c, i64 7 + %9 = tail call @llvm.vector.insert.nxv8f16.v8f16( undef, <8 x half> %8, i64 0) + %10 = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( %9, i64 0) + ret %10 +} + +define @dupq_f16_repeat_complex_mismatched_middle(half %a, half %b, half %c) { +; CHECK-LABEL: dupq_f16_repeat_complex_mismatched_middle: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $h0 killed $h0 def $q0 +; CHECK-NEXT: mov v3.16b, v0.16b +; CHECK-NEXT: // kill: def $h1 killed $h1 def $q1 +; CHECK-NEXT: // kill: def $h2 killed $h2 def $q2 +; CHECK-NEXT: mov v3.h[1], v1.h[0] +; CHECK-NEXT: mov v3.h[2], v0.h[0] +; CHECK-NEXT: mov v3.h[3], v1.h[0] +; CHECK-NEXT: mov v3.h[4], v2.h[0] +; CHECK-NEXT: mov v3.h[5], v2.h[0] +; CHECK-NEXT: mov v3.h[6], v0.h[0] +; CHECK-NEXT: mov v3.h[7], v1.h[0] +; CHECK-NEXT: mov z0.q, q3 +; CHECK-NEXT: ret + %1 = insertelement <8 x half> undef, half %a, i64 0 + %2 = insertelement <8 x half> %1, half %b, i64 1 + %3 = insertelement <8 x half> %2, half %a, i64 2 + %4 = insertelement <8 x half> %3, half %b, i64 3 + %5 = insertelement <8 x half> %4, half %c, i64 4 + %6 = insertelement <8 x half> %5, half %c, i64 5 %7 = insertelement <8 x half> %6, half %a, i64 6 %8 = insertelement <8 x half> %7, half %b, i64 7 %9 = tail call @llvm.vector.insert.nxv8f16.v8f16( undef, <8 x half> %8, i64 0)