Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -874,6 +874,12 @@ } setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom); + + for (MVT VT : MVT::fp_scalable_vector_valuetypes()) { + if (isTypeLegal(VT)) { + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + } + } } PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive(); @@ -7443,14 +7449,6 @@ // Extend input splat value where needed to fit into a GPR (32b or 64b only) // FPRs don't have this restriction. switch (ElemVT.getSimpleVT().SimpleTy) { - case MVT::i8: - case MVT::i16: - case MVT::i32: - SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i32); - return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); - case MVT::i64: - SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); - return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); case MVT::i1: { // The general case of i1. There isn't any natural way to do this, // so we use some trickery with whilelo. @@ -7463,13 +7461,24 @@ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, DAG.getConstant(0, dl, MVT::i64), SplatVal); } - // TODO: we can support float types, but haven't added patterns yet. + case MVT::i8: + case MVT::i16: + case MVT::i32: + SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i32); + break; + case MVT::i64: + SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); + break; case MVT::f16: case MVT::f32: case MVT::f64: + // Fine as is + break; default: report_fatal_error("Unsupported SPLAT_VECTOR input operand type"); } + + return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); } static bool resolveBuildVector(BuildVectorSDNode *BVN, APInt &CnstBits, Index: llvm/lib/Target/AArch64/AArch64InstrFormats.td =================================================================== --- llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -66,6 +66,7 @@ dag InOperandList = iops; let Pattern = pattern; let isCodeGenOnly = 1; + let isPseudo = 1; } // Real instructions (have encoding information) Index: llvm/lib/Target/AArch64/AArch64InstrInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -1474,14 +1474,48 @@ } bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const { + const TargetRegisterInfo *TRI = &getRegisterInfo(); + MachineBasicBlock &MBB = *MI.getParent(); + DebugLoc DL = MI.getDebugLoc(); + unsigned Opcode = MI.getOpcode(); + + if ((Opcode == AArch64::DUP_ZV_H) || + (Opcode == AArch64::DUP_ZV_S) || + (Opcode == AArch64::DUP_ZV_D)) { + auto RC = &AArch64::ZPRRegClass; + unsigned Dst = MI.getOperand(0).getReg(); + unsigned Src = MI.getOperand(1).getReg(); + bool SrcIsKill = MI.getOperand(1).isKill(); + + unsigned NewOpcode; + unsigned NewSrc; + switch (Opcode) { + case AArch64::DUP_ZV_H: + NewOpcode = AArch64::DUP_ZZI_H; + NewSrc = TRI->getMatchingSuperReg(Src, AArch64::hsub, RC); + break; + case AArch64::DUP_ZV_S: + NewOpcode = AArch64::DUP_ZZI_S; + NewSrc = TRI->getMatchingSuperReg(Src, AArch64::ssub, RC); + break; + case AArch64::DUP_ZV_D: + NewOpcode = AArch64::DUP_ZZI_D; + NewSrc = TRI->getMatchingSuperReg(Src, AArch64::dsub, RC); + break; + } + + BuildMI(MBB, MI, DL, get(NewOpcode), Dst) + .addReg(NewSrc, getKillRegState(SrcIsKill)) + .addImm(0); + MBB.erase(MI); + return true; + } + if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD && MI.getOpcode() != AArch64::CATCHRET) return false; - MachineBasicBlock &MBB = *MI.getParent(); auto &Subtarget = MBB.getParent()->getSubtarget(); - auto TRI = Subtarget.getRegisterInfo(); - DebugLoc DL = MI.getDebugLoc(); if (MI.getOpcode() == AArch64::CATCHRET) { // Skip to the first instruction before the epilog. Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -293,6 +293,27 @@ defm CPY_ZPmR : sve_int_perm_cpy_r<"cpy", AArch64dup_pred>; defm CPY_ZPmV : sve_int_perm_cpy_v<"cpy", AArch64dup_pred>; + // Create variants of DUP_ZZI (I=0) that can be used without INSERT_SUBREG. + def DUP_ZV_H : Pseudo<(outs ZPR16:$Zd), (ins FPR16:$Vn), []>, Sched<[]>; + def DUP_ZV_S : Pseudo<(outs ZPR32:$Zd), (ins FPR32:$Vn), []>, Sched<[]>; + def DUP_ZV_D : Pseudo<(outs ZPR64:$Zd), (ins FPR64:$Vn), []>, Sched<[]>; + + // Duplicate FP scalar into all vector elements + def : Pat<(nxv8f16 (AArch64dup (f16 FPR16:$src))), (DUP_ZV_H $src)>; + def : Pat<(nxv4f16 (AArch64dup (f16 FPR16:$src))), (DUP_ZV_H $src)>; + def : Pat<(nxv2f16 (AArch64dup (f16 FPR16:$src))), (DUP_ZV_H $src)>; + def : Pat<(nxv4f32 (AArch64dup (f32 FPR32:$src))), (DUP_ZV_S $src)>; + def : Pat<(nxv2f32 (AArch64dup (f32 FPR32:$src))), (DUP_ZV_S $src)>; + def : Pat<(nxv2f64 (AArch64dup (f64 FPR64:$src))), (DUP_ZV_D $src)>; + + // Duplicate +0.0 into all vector elements + def : Pat<(nxv8f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>; + def : Pat<(nxv4f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>; + def : Pat<(nxv2f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>; + def : Pat<(nxv4f32 (AArch64dup (f32 fpimm0))), (DUP_ZI_S 0, 0)>; + def : Pat<(nxv2f32 (AArch64dup (f32 fpimm0))), (DUP_ZI_S 0, 0)>; + def : Pat<(nxv2f64 (AArch64dup (f64 fpimm0))), (DUP_ZI_D 0, 0)>; + // Select elements from either vector (predicated) defm SEL_ZPZZ : sve_int_sel_vvv<"sel", vselect>; Index: llvm/test/CodeGen/AArch64/vector-splat-sve.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/vector-splat-sve.ll @@ -0,0 +1,100 @@ +; RUN: llc < %s -mtriple=aarch64--linux-gnu -mattr=+sve | FileCheck %s + +define @splat_nxv8f16(half %val) { +; CHECK-LABEL: splat_nxv8f16: +; CHECK: mov z0.h, h0 +; CHECK-NEXT: ret + %1 = insertelement undef, half %val, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + ret %2 +} + +define @splat_nxv4f16(half %val) { +; CHECK-LABEL: splat_nxv4f16: +; CHECK: mov z0.h, h0 +; CHECK-NEXT: ret + %1 = insertelement undef, half %val, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + ret %2 +} + +define @splat_nxv2f16(half %val) { +; CHECK-LABEL: splat_nxv2f16: +; CHECK: mov z0.h, h0 +; CHECK-NEXT: ret + %1 = insertelement undef, half %val, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + ret %2 +} + +define @splat_nxv4f32(float %val) { +; CHECK-LABEL: splat_nxv4f32: +; CHECK: mov z0.s, s0 +; CHECK-NEXT: ret + %1 = insertelement undef, float %val, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + ret %2 +} + +define @splat_nxv2f32(float %val) { +; CHECK-LABEL: splat_nxv2f32: +; CHECK: mov z0.s, s0 +; CHECK-NEXT: ret + %1 = insertelement undef, float %val, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + ret %2 +} + +define @splat_nxv2f64(double %val) { +; CHECK-LABEL: splat_nxv2f64: +; CHECK: mov z0.d, d0 +; CHECK-NEXT: ret + %1 = insertelement undef, double %val, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + ret %2 +} + +; TODO: The f16 constant load isn't folded correctly. +define @splat_nxv8f16_zero() { +; CHECK-LABEL: splat_nxv8f16_zero: +; CHECK: mov z0.h, h0 +; CHECK-NEXT: ret + ret zeroinitializer +} + +; TODO: The f16 constant load isn't folded correctly. +define @splat_nxv4f16_zero() { +; CHECK-LABEL: splat_nxv4f16_zero: +; CHECK: mov z0.h, h0 +; CHECK-NEXT: ret + ret zeroinitializer +} + +; TODO: The f16 constant load isn't folded correctly. +define @splat_nxv2f16_zero() { +; CHECK-LABEL: splat_nxv2f16_zero: +; CHECK: mov z0.h, h0 +; CHECK-NEXT: ret + ret zeroinitializer +} + +define @splat_nxv4f32_zero() { +; CHECK-LABEL: splat_nxv4f32_zero: +; CHECK: mov z0.s, #0 +; CHECK-NEXT: ret + ret zeroinitializer +} + +define @splat_nxv2f32_zero() { +; CHECK-LABEL: splat_nxv2f32_zero: +; CHECK: mov z0.s, #0 +; CHECK-NEXT: ret + ret zeroinitializer +} + +define @splat_nxv2f64_zero() { +; CHECK-LABEL: splat_nxv2f64_zero: +; CHECK: mov z0.d, #0 +; CHECK-NEXT: ret + ret zeroinitializer +}