diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -874,6 +874,12 @@ } setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom); + + for (MVT VT : MVT::fp_scalable_vector_valuetypes()) { + if (isTypeLegal(VT)) { + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + } + } } PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive(); @@ -7483,14 +7489,6 @@ // Extend input splat value where needed to fit into a GPR (32b or 64b only) // FPRs don't have this restriction. switch (ElemVT.getSimpleVT().SimpleTy) { - case MVT::i8: - case MVT::i16: - case MVT::i32: - SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i32); - return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); - case MVT::i64: - SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); - return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); case MVT::i1: { // The general case of i1. There isn't any natural way to do this, // so we use some trickery with whilelo. @@ -7503,13 +7501,24 @@ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, DAG.getConstant(0, dl, MVT::i64), SplatVal); } - // TODO: we can support float types, but haven't added patterns yet. + case MVT::i8: + case MVT::i16: + case MVT::i32: + SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i32); + break; + case MVT::i64: + SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); + break; case MVT::f16: case MVT::f32: case MVT::f64: + // Fine as is + break; default: report_fatal_error("Unsupported SPLAT_VECTOR input operand type"); } + + return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); } static bool resolveBuildVector(BuildVectorSDNode *BVN, APInt &CnstBits, diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -66,6 +66,7 @@ dag InOperandList = iops; let Pattern = pattern; let isCodeGenOnly = 1; + let isPseudo = 1; } // Real instructions (have encoding information) diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -296,6 +296,28 @@ defm CPY_ZPmR : sve_int_perm_cpy_r<"cpy", AArch64dup_pred>; defm CPY_ZPmV : sve_int_perm_cpy_v<"cpy", AArch64dup_pred>; + // Duplicate FP scalar into all vector elements + def : Pat<(nxv8f16 (AArch64dup (f16 FPR16:$src))), + (DUP_ZZI_H (INSERT_SUBREG (IMPLICIT_DEF), FPR16:$src, hsub), 0)>; + def : Pat<(nxv4f16 (AArch64dup (f16 FPR16:$src))), + (DUP_ZZI_H (INSERT_SUBREG (IMPLICIT_DEF), FPR16:$src, hsub), 0)>; + def : Pat<(nxv2f16 (AArch64dup (f16 FPR16:$src))), + (DUP_ZZI_H (INSERT_SUBREG (IMPLICIT_DEF), FPR16:$src, hsub), 0)>; + def : Pat<(nxv4f32 (AArch64dup (f32 FPR32:$src))), + (DUP_ZZI_S (INSERT_SUBREG (IMPLICIT_DEF), FPR32:$src, ssub), 0)>; + def : Pat<(nxv2f32 (AArch64dup (f32 FPR32:$src))), + (DUP_ZZI_S (INSERT_SUBREG (IMPLICIT_DEF), FPR32:$src, ssub), 0)>; + def : Pat<(nxv2f64 (AArch64dup (f64 FPR64:$src))), + (DUP_ZZI_D (INSERT_SUBREG (IMPLICIT_DEF), FPR64:$src, dsub), 0)>; + + // Duplicate +0.0 into all vector elements + def : Pat<(nxv8f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>; + def : Pat<(nxv4f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>; + def : Pat<(nxv2f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>; + def : Pat<(nxv4f32 (AArch64dup (f32 fpimm0))), (DUP_ZI_S 0, 0)>; + def : Pat<(nxv2f32 (AArch64dup (f32 fpimm0))), (DUP_ZI_S 0, 0)>; + def : Pat<(nxv2f64 (AArch64dup (f64 fpimm0))), (DUP_ZI_D 0, 0)>; + // Select elements from either vector (predicated) defm SEL_ZPZZ : sve_int_sel_vvv<"sel", vselect>; diff --git a/llvm/test/CodeGen/AArch64/sve-vector-splat.ll b/llvm/test/CodeGen/AArch64/sve-vector-splat.ll --- a/llvm/test/CodeGen/AArch64/sve-vector-splat.ll +++ b/llvm/test/CodeGen/AArch64/sve-vector-splat.ll @@ -133,3 +133,104 @@ %splat = shufflevector %ins, undef, zeroinitializer ret %splat } + +;; Splats of legal floating point vector types + +define @splat_nxv8f16(half %val) { +; CHECK-LABEL: splat_nxv8f16: +; CHECK: mov z0.h, h0 +; CHECK-NEXT: ret + %1 = insertelement undef, half %val, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + ret %2 +} + +define @splat_nxv4f16(half %val) { +; CHECK-LABEL: splat_nxv4f16: +; CHECK: mov z0.h, h0 +; CHECK-NEXT: ret + %1 = insertelement undef, half %val, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + ret %2 +} + +define @splat_nxv2f16(half %val) { +; CHECK-LABEL: splat_nxv2f16: +; CHECK: mov z0.h, h0 +; CHECK-NEXT: ret + %1 = insertelement undef, half %val, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + ret %2 +} + +define @splat_nxv4f32(float %val) { +; CHECK-LABEL: splat_nxv4f32: +; CHECK: mov z0.s, s0 +; CHECK-NEXT: ret + %1 = insertelement undef, float %val, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + ret %2 +} + +define @splat_nxv2f32(float %val) { +; CHECK-LABEL: splat_nxv2f32: +; CHECK: mov z0.s, s0 +; CHECK-NEXT: ret + %1 = insertelement undef, float %val, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + ret %2 +} + +define @splat_nxv2f64(double %val) { +; CHECK-LABEL: splat_nxv2f64: +; CHECK: mov z0.d, d0 +; CHECK-NEXT: ret + %1 = insertelement undef, double %val, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + ret %2 +} + +; TODO: The f16 constant should be folded into the move. +define @splat_nxv8f16_zero() { +; CHECK-LABEL: splat_nxv8f16_zero: +; CHECK: mov z0.h, h0 +; CHECK-NEXT: ret + ret zeroinitializer +} + +; TODO: The f16 constant should be folded into the move. +define @splat_nxv4f16_zero() { +; CHECK-LABEL: splat_nxv4f16_zero: +; CHECK: mov z0.h, h0 +; CHECK-NEXT: ret + ret zeroinitializer +} + +; TODO: The f16 constant should be folded into the move. +define @splat_nxv2f16_zero() { +; CHECK-LABEL: splat_nxv2f16_zero: +; CHECK: mov z0.h, h0 +; CHECK-NEXT: ret + ret zeroinitializer +} + +define @splat_nxv4f32_zero() { +; CHECK-LABEL: splat_nxv4f32_zero: +; CHECK: mov z0.s, #0 +; CHECK-NEXT: ret + ret zeroinitializer +} + +define @splat_nxv2f32_zero() { +; CHECK-LABEL: splat_nxv2f32_zero: +; CHECK: mov z0.s, #0 +; CHECK-NEXT: ret + ret zeroinitializer +} + +define @splat_nxv2f64_zero() { +; CHECK-LABEL: splat_nxv2f64_zero: +; CHECK: mov z0.d, #0 +; CHECK-NEXT: ret + ret zeroinitializer +}