diff --git a/llvm/include/llvm/CodeGen/SelectionDAGTargetInfo.h b/llvm/include/llvm/CodeGen/SelectionDAGTargetInfo.h --- a/llvm/include/llvm/CodeGen/SelectionDAGTargetInfo.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGTargetInfo.h @@ -157,7 +157,8 @@ // Return true when the decision to generate FMA's (or FMS, FMLA etc) rather // than FMUL and ADD is delegated to the machine combiner. - virtual bool generateFMAsInMachineCombiner(CodeGenOpt::Level OptLevel) const { + virtual bool generateFMAsInMachineCombiner(SelectionDAG &DAG, + CodeGenOpt::Level OptLevel) const { return false; } diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12539,7 +12539,7 @@ if (!AllowFusionGlobally && !isContractable(N)) return SDValue(); - if (STI && STI->generateFMAsInMachineCombiner(OptLevel)) + if (STI && STI->generateFMAsInMachineCombiner(DAG, OptLevel)) return SDValue(); // Always prefer FMAD to FMA for precision. @@ -12748,7 +12748,7 @@ if (!AllowFusionGlobally && !isContractable(N)) return SDValue(); - if (STI && STI->generateFMAsInMachineCombiner(OptLevel)) + if (STI && STI->generateFMAsInMachineCombiner(DAG, OptLevel)) return SDValue(); // Always prefer FMAD to FMA for precision. diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp --- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp +++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp @@ -406,7 +406,7 @@ assert(DstReg != MI.getOperand(3).getReg()); bool UseRev = false; - unsigned PredIdx, DOPIdx, SrcIdx; + unsigned PredIdx, DOPIdx, SrcIdx, Src2Idx; switch (DType) { case AArch64::DestructiveBinaryComm: case AArch64::DestructiveBinaryCommWithRev: @@ -420,7 +420,19 @@ case AArch64::DestructiveBinary: case AArch64::DestructiveBinaryImm: std::tie(PredIdx, DOPIdx, SrcIdx) = std::make_tuple(1, 2, 3); - break; + break; + case AArch64::DestructiveTernaryCommWithRev: + std::tie(PredIdx, DOPIdx, SrcIdx, Src2Idx) = std::make_tuple(1, 2, 3, 4); + if (DstReg == MI.getOperand(3).getReg()) { + // FMLA Zd, Pg, Za, Zd, Zm ==> FMAD Zdn, Pg, Zm, Za + std::tie(PredIdx, DOPIdx, SrcIdx, Src2Idx) = std::make_tuple(1, 3, 4, 2); + UseRev = true; + } else if (DstReg == MI.getOperand(4).getReg()) { + // FMLA Zd, Pg, Za, Zm, Zd ==> FMAD Zdn, Pg, Zm, Za + std::tie(PredIdx, DOPIdx, SrcIdx, Src2Idx) = std::make_tuple(1, 4, 3, 2); + UseRev = true; + } + break; default: llvm_unreachable("Unsupported Destructive Operand type"); } @@ -440,6 +452,12 @@ case AArch64::DestructiveBinaryImm: DOPRegIsUnique = true; break; + case AArch64::DestructiveTernaryCommWithRev: + DOPRegIsUnique = + DstReg != MI.getOperand(DOPIdx).getReg() || + (MI.getOperand(DOPIdx).getReg() != MI.getOperand(SrcIdx).getReg() && + MI.getOperand(DOPIdx).getReg() != MI.getOperand(Src2Idx).getReg()); + break; } #endif @@ -522,6 +540,12 @@ .addReg(MI.getOperand(DOPIdx).getReg(), RegState::Kill) .add(MI.getOperand(SrcIdx)); break; + case AArch64::DestructiveTernaryCommWithRev: + DOP.add(MI.getOperand(PredIdx)) + .addReg(MI.getOperand(DOPIdx).getReg(), RegState::Kill) + .add(MI.getOperand(SrcIdx)) + .add(MI.getOperand(Src2Idx)); + break; } if (PRFX) { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -11544,12 +11544,14 @@ bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd( const MachineFunction &MF, EVT VT) const { - VT = VT.getScalarType(); + EVT ScalarVT = VT.getScalarType(); - if (!VT.isSimple()) + if (!ScalarVT.isSimple()) return false; - switch (VT.getSimpleVT().SimpleTy) { + switch (ScalarVT.getSimpleVT().SimpleTy) { + case MVT::f16: + return Subtarget->hasFullFP16() && VT.isScalableVector(); case MVT::f32: case MVT::f64: return true; diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -461,31 +461,34 @@ defm FCADD_ZPmZ : sve_fp_fcadd<"fcadd", int_aarch64_sve_fcadd>; defm FCMLA_ZPmZZ : sve_fp_fcmla<"fcmla", int_aarch64_sve_fcmla>; - defm FMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b00, "fmla", int_aarch64_sve_fmla>; - defm FMLS_ZPmZZ : sve_fp_3op_p_zds_a<0b01, "fmls", int_aarch64_sve_fmls>; - defm FNMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b10, "fnmla", int_aarch64_sve_fnmla>; - defm FNMLS_ZPmZZ : sve_fp_3op_p_zds_a<0b11, "fnmls", int_aarch64_sve_fnmls>; + defm FMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b00, "fmla", "FMLA_ZPZZZ", int_aarch64_sve_fmla, "FMAD_ZPmZZ">; + defm FMLS_ZPmZZ : sve_fp_3op_p_zds_a<0b01, "fmls", "FMLS_ZPZZZ", int_aarch64_sve_fmls, "FMSB_ZPmZZ">; + defm FNMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b10, "fnmla", "FNMLA_ZPZZZ", int_aarch64_sve_fnmla, "FNMAD_ZPmZZ">; + defm FNMLS_ZPmZZ : sve_fp_3op_p_zds_a<0b11, "fnmls", "FNMLS_ZPZZZ", int_aarch64_sve_fnmls, "FNMSB_ZPmZZ">; - defm FMAD_ZPmZZ : sve_fp_3op_p_zds_b<0b00, "fmad", int_aarch64_sve_fmad>; - defm FMSB_ZPmZZ : sve_fp_3op_p_zds_b<0b01, "fmsb", int_aarch64_sve_fmsb>; - defm FNMAD_ZPmZZ : sve_fp_3op_p_zds_b<0b10, "fnmad", int_aarch64_sve_fnmad>; - defm FNMSB_ZPmZZ : sve_fp_3op_p_zds_b<0b11, "fnmsb", int_aarch64_sve_fnmsb>; + defm FMAD_ZPmZZ : sve_fp_3op_p_zds_b<0b00, "fmad", int_aarch64_sve_fmad, "FMLA_ZPmZZ", /*isReverseInstr*/ 1>; + defm FMSB_ZPmZZ : sve_fp_3op_p_zds_b<0b01, "fmsb", int_aarch64_sve_fmsb, "FMLS_ZPmZZ", /*isReverseInstr*/ 1>; + defm FNMAD_ZPmZZ : sve_fp_3op_p_zds_b<0b10, "fnmad", int_aarch64_sve_fnmad, "FNMLA_ZPmZZ", /*isReverseInstr*/ 1>; + defm FNMSB_ZPmZZ : sve_fp_3op_p_zds_b<0b11, "fnmsb", int_aarch64_sve_fnmsb, "FNMLS_ZPmZZ", /*isReverseInstr*/ 1>; + + defm FMLA_ZPZZZ : sve_fp_3op_p_zds_zx; + defm FMLS_ZPZZZ : sve_fp_3op_p_zds_zx; + defm FNMLA_ZPZZZ : sve_fp_3op_p_zds_zx; + defm FNMLS_ZPZZZ : sve_fp_3op_p_zds_zx; // Add patterns for FMA where disabled lanes are undef. - // FIXME: Implement a pseudo so we can choose a better instruction after - // regalloc. def : Pat<(nxv8f16 (AArch64fma_p nxv8i1:$P, nxv8f16:$Op1, nxv8f16:$Op2, nxv8f16:$Op3)), - (FMLA_ZPmZZ_H $P, $Op3, $Op1, $Op2)>; + (FMLA_ZPZZZ_UNDEF_H $P, $Op3, $Op1, $Op2)>; def : Pat<(nxv4f16 (AArch64fma_p nxv4i1:$P, nxv4f16:$Op1, nxv4f16:$Op2, nxv4f16:$Op3)), - (FMLA_ZPmZZ_H $P, $Op3, $Op1, $Op2)>; + (FMLA_ZPZZZ_UNDEF_H $P, $Op3, $Op1, $Op2)>; def : Pat<(nxv2f16 (AArch64fma_p nxv2i1:$P, nxv2f16:$Op1, nxv2f16:$Op2, nxv2f16:$Op3)), - (FMLA_ZPmZZ_H $P, $Op3, $Op1, $Op2)>; + (FMLA_ZPZZZ_UNDEF_H $P, $Op3, $Op1, $Op2)>; def : Pat<(nxv4f32 (AArch64fma_p nxv4i1:$P, nxv4f32:$Op1, nxv4f32:$Op2, nxv4f32:$Op3)), - (FMLA_ZPmZZ_S $P, $Op3, $Op1, $Op2)>; + (FMLA_ZPZZZ_UNDEF_S $P, $Op3, $Op1, $Op2)>; def : Pat<(nxv2f32 (AArch64fma_p nxv2i1:$P, nxv2f32:$Op1, nxv2f32:$Op2, nxv2f32:$Op3)), - (FMLA_ZPmZZ_S $P, $Op3, $Op1, $Op2)>; + (FMLA_ZPZZZ_UNDEF_S $P, $Op3, $Op1, $Op2)>; def : Pat<(nxv2f64 (AArch64fma_p nxv2i1:$P, nxv2f64:$Op1, nxv2f64:$Op2, nxv2f64:$Op3)), - (FMLA_ZPmZZ_D $P, $Op3, $Op1, $Op2)>; + (FMLA_ZPZZZ_UNDEF_D $P, $Op3, $Op1, $Op2)>; defm FTMAD_ZZI : sve_fp_ftmad<"ftmad", int_aarch64_sve_ftmad_x>; @@ -2126,6 +2129,66 @@ // 16-element contiguous store defm : st1; + // Zd = Za + Zn * Zm + def : Pat<(nxv8f16 (AArch64fma_p nxv8i1:$P, nxv8f16:$Zn, nxv8f16:$Zm, nxv8f16:$Za)), + (FMLA_ZPZZZ_UNDEF_H $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv4f32 (AArch64fma_p nxv4i1:$P, nxv4f32:$Zn, nxv4f32:$Zm, nxv4f32:$Za)), + (FMLA_ZPZZZ_UNDEF_S $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f32 (AArch64fma_p nxv2i1:$P, nxv2f32:$Zn, nxv2f32:$Zm, nxv2f32:$Za)), + (FMLA_ZPZZZ_UNDEF_S $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f64 (AArch64fma_p nxv2i1:$P, nxv2f64:$Zn, nxv2f64:$Zm, nxv2f64:$Za)), + (FMLA_ZPZZZ_UNDEF_D $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + + // Zd = Za + -Zn * Zm + def : Pat<(nxv8f16 (AArch64fma_p nxv8i1:$P, (AArch64fneg_mt nxv8i1:$P, nxv8f16:$Zn, (nxv8f16 (undef))), nxv8f16:$Zm, nxv8f16:$Za)), + (FMLS_ZPZZZ_UNDEF_H $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv4f32 (AArch64fma_p nxv4i1:$P, (AArch64fneg_mt nxv4i1:$P, nxv4f32:$Zn, (nxv4f32 (undef))), nxv4f32:$Zm, nxv4f32:$Za)), + (FMLS_ZPZZZ_UNDEF_S $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f32 (AArch64fma_p nxv2i1:$P, (AArch64fneg_mt nxv2i1:$P, nxv2f32:$Zn, (nxv2f32 (undef))), nxv2f32:$Zm, nxv2f32:$Za)), + (FMLS_ZPZZZ_UNDEF_S $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f64 (AArch64fma_p nxv2i1:$P, (AArch64fneg_mt nxv2i1:$P, nxv2f64:$Zn, (nxv2f64 (undef))), nxv2f64:$Zm, nxv2f64:$Za)), + (FMLS_ZPZZZ_UNDEF_D $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + + // Zd = -Za + Zn * Zm + def : Pat<(nxv8f16 (AArch64fma_p nxv8i1:$P, nxv8f16:$Zn, nxv8f16:$Zm, (AArch64fneg_mt nxv8i1:$P, nxv8f16:$Za, (nxv8f16 (undef))))), + (FNMLS_ZPZZZ_UNDEF_H $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv4f32 (AArch64fma_p nxv4i1:$P, nxv4f32:$Zn, nxv4f32:$Zm, (AArch64fneg_mt nxv4i1:$P, nxv4f32:$Za, (nxv4f32 (undef))))), + (FNMLS_ZPZZZ_UNDEF_S $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f32 (AArch64fma_p nxv2i1:$P, nxv2f32:$Zn, nxv2f32:$Zm, (AArch64fneg_mt nxv2i1:$P, nxv2f32:$Za, (nxv2f32 (undef))))), + (FNMLS_ZPZZZ_UNDEF_S $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f64 (AArch64fma_p nxv2i1:$P, nxv2f64:$Zn, nxv2f64:$Zm, (AArch64fneg_mt nxv2i1:$P, nxv2f64:$Za, (nxv2f64 (undef))))), + (FNMLS_ZPZZZ_UNDEF_D $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + + // Zd = -Za + -Zn * Zm + def : Pat<(nxv8f16 (AArch64fma_p nxv8i1:$P, (AArch64fneg_mt nxv8i1:$P, nxv8f16:$Zn, (nxv8f16 (undef))), nxv8f16:$Zm, (AArch64fneg_mt nxv8i1:$P, nxv8f16:$Za, (nxv8f16 (undef))))), + (FNMLA_ZPZZZ_UNDEF_H $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv4f32 (AArch64fma_p nxv4i1:$P, (AArch64fneg_mt nxv4i1:$P, nxv4f32:$Zn, (nxv4f32 (undef))), nxv4f32:$Zm, (AArch64fneg_mt nxv4i1:$P, nxv4f32:$Za, (nxv4f32 (undef))))), + (FNMLA_ZPZZZ_UNDEF_S $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f32 (AArch64fma_p nxv2i1:$P, (AArch64fneg_mt nxv2i1:$P, nxv2f32:$Zn, (nxv2f32 (undef))), nxv2f32:$Zm, (AArch64fneg_mt nxv2i1:$P, nxv2f32:$Za, (nxv2f32 (undef))))), + (FNMLA_ZPZZZ_UNDEF_S $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f64 (AArch64fma_p nxv2i1:$P, (AArch64fneg_mt nxv2i1:$P, nxv2f64:$Zn, (nxv2f64 (undef))), nxv2f64:$Zm, (AArch64fneg_mt nxv2i1:$P, nxv2f64:$Za, (nxv2f64 (undef))))), + (FNMLA_ZPZZZ_UNDEF_D $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + + // Zda = Zda + Zn * Zm + def : Pat<(vselect (nxv8i1 PPR:$Pg), (nxv8f16 (AArch64fma_p (nxv8i1 (AArch64ptrue 31)), ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLA_ZPmZZ_H PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(vselect (nxv4i1 PPR:$Pg), (nxv4f32 (AArch64fma_p (nxv4i1 (AArch64ptrue 31)), ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLA_ZPmZZ_S PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(vselect (nxv2i1 PPR:$Pg), (nxv2f32 (AArch64fma_p (nxv2i1 (AArch64ptrue 31)), ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLA_ZPmZZ_S PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(vselect (nxv2i1 PPR:$Pg), (nxv2f64 (AArch64fma_p (nxv2i1 (AArch64ptrue 31)), ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLA_ZPmZZ_D PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + + // Zda = Zda + -Zn * Zm + def : Pat<(vselect (nxv8i1 PPR:$Pg), (nxv8f16 (AArch64fma_p (nxv8i1 (AArch64ptrue 31)), (AArch64fneg_mt (nxv8i1 (AArch64ptrue 31)), nxv8f16:$Zn, (nxv8f16 (undef))), ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLS_ZPmZZ_H PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(vselect (nxv4i1 PPR:$Pg), (nxv4f32 (AArch64fma_p (nxv4i1 (AArch64ptrue 31)), (AArch64fneg_mt (nxv4i1 (AArch64ptrue 31)), nxv4f32:$Zn, (nxv4f32 (undef))), ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLS_ZPmZZ_S PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(vselect (nxv2i1 PPR:$Pg), (nxv2f32 (AArch64fma_p (nxv2i1 (AArch64ptrue 31)), (AArch64fneg_mt (nxv2i1 (AArch64ptrue 31)), nxv2f32:$Zn, (nxv2f32 (undef))), ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLS_ZPmZZ_S PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(vselect (nxv2i1 PPR:$Pg), (nxv2f64 (AArch64fma_p (nxv2i1 (AArch64ptrue 31)), (AArch64fneg_mt (nxv2i1 (AArch64ptrue 31)), nxv2f64:$Zn, (nxv2f64 (undef))), ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLS_ZPmZZ_D PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv16i8 (vector_insert (nxv16i8 (undef)), (i32 FPR32:$src), 0)), (INSERT_SUBREG (nxv16i8 (IMPLICIT_DEF)), FPR32:$src, ssub)>; def : Pat<(nxv8i16 (vector_insert (nxv8i16 (undef)), (i32 FPR32:$src), 0)), diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h --- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h +++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h @@ -28,7 +28,8 @@ SDValue Chain, SDValue Op1, SDValue Op2, MachinePointerInfo DstPtrInfo, bool ZeroData) const override; - bool generateFMAsInMachineCombiner(CodeGenOpt::Level OptLevel) const override; + bool generateFMAsInMachineCombiner(SelectionDAG &DAG, + CodeGenOpt::Level OptLevel) const override; }; } diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp @@ -53,8 +53,9 @@ return SDValue(); } bool AArch64SelectionDAGInfo::generateFMAsInMachineCombiner( - CodeGenOpt::Level OptLevel) const { - return OptLevel >= CodeGenOpt::Aggressive; + SelectionDAG &DAG, CodeGenOpt::Level OptLevel) const { + const auto &STI = DAG.getMachineFunction().getSubtarget(); + return (OptLevel >= CodeGenOpt::Aggressive) && !STI.hasSVE(); } static const int kSetTagLoopThreshold = 176; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -491,6 +491,13 @@ Pseudo<(outs zprty:$Zd), (ins PPR3bAny:$Pg, zprty:$Zs1, immty:$imm), []> { let FalseLanes = flags; } + + class PredThreeOpPseudo + : SVEPseudo2Instr, + Pseudo<(outs zprty:$Zd), (ins PPR3bAny:$Pg, zprty:$Zs1, zprty:$Zs2, zprty:$Zs3), []> { + let FalseLanes = flags; + } } //===----------------------------------------------------------------------===// @@ -1762,14 +1769,20 @@ let Inst{4-0} = Zda; let Constraints = "$Zda = $_Zda"; - let DestructiveInstType = DestructiveOther; let ElementSize = zprty.ElementSize; } -multiclass sve_fp_3op_p_zds_a opc, string asm, SDPatternOperator op> { - def _H : sve_fp_3op_p_zds_a<0b01, opc, asm, ZPR16>; - def _S : sve_fp_3op_p_zds_a<0b10, opc, asm, ZPR32>; - def _D : sve_fp_3op_p_zds_a<0b11, opc, asm, ZPR64>; +multiclass sve_fp_3op_p_zds_a opc, string asm, string Ps, + SDPatternOperator op, string revname, + bit isReverseInstr=0> { + let DestructiveInstType = DestructiveTernaryCommWithRev in { + def _H : sve_fp_3op_p_zds_a<0b01, opc, asm, ZPR16>, + SVEPseudo2Instr, SVEInstr2Rev; + def _S : sve_fp_3op_p_zds_a<0b10, opc, asm, ZPR32>, + SVEPseudo2Instr, SVEInstr2Rev; + def _D : sve_fp_3op_p_zds_a<0b11, opc, asm, ZPR64>, + SVEPseudo2Instr, SVEInstr2Rev; + } def : SVE_4_Op_Pat(NAME # _H)>; def : SVE_4_Op_Pat(NAME # _S)>; @@ -1801,16 +1814,26 @@ let ElementSize = zprty.ElementSize; } -multiclass sve_fp_3op_p_zds_b opc, string asm, SDPatternOperator op> { - def _H : sve_fp_3op_p_zds_b<0b01, opc, asm, ZPR16>; - def _S : sve_fp_3op_p_zds_b<0b10, opc, asm, ZPR32>; - def _D : sve_fp_3op_p_zds_b<0b11, opc, asm, ZPR64>; +multiclass sve_fp_3op_p_zds_b opc, string asm, SDPatternOperator op, + string revname, bit isReverseInstr> { + def _H : sve_fp_3op_p_zds_b<0b01, opc, asm, ZPR16>, + SVEInstr2Rev; + def _S : sve_fp_3op_p_zds_b<0b10, opc, asm, ZPR32>, + SVEInstr2Rev; + def _D : sve_fp_3op_p_zds_b<0b11, opc, asm, ZPR64>, + SVEInstr2Rev; def : SVE_4_Op_Pat(NAME # _H)>; def : SVE_4_Op_Pat(NAME # _S)>; def : SVE_4_Op_Pat(NAME # _D)>; } +multiclass sve_fp_3op_p_zds_zx { + def _UNDEF_H : PredThreeOpPseudo; + def _UNDEF_S : PredThreeOpPseudo; + def _UNDEF_D : PredThreeOpPseudo; +} + //===----------------------------------------------------------------------===// // SVE Floating Point Multiply-Add - Indexed Group //===----------------------------------------------------------------------===// diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll @@ -572,8 +572,8 @@ ; CHECK-DAG: ld1h { [[OP1:z[0-9]+]].h }, [[PG]]/z, [x0] ; CHECK-DAG: ld1h { [[OP2:z[0-9]+]].h }, [[PG]]/z, [x1] ; CHECK-DAG: ld1h { [[OP3:z[0-9]+]].h }, [[PG]]/z, [x2] -; CHECK: fmla [[OP3]].h, [[PG]]/m, [[OP1]].h, [[OP2]].h -; CHECK: st1h { [[OP3]].h }, [[PG]], [x0] +; CHECK: fmad [[OP1]].h, [[PG]]/m, [[OP2]].h, [[OP3]].h +; CHECK: st1h { [[OP1]].h }, [[PG]], [x0] ; CHECK: ret %op1 = load <16 x half>, <16 x half>* %a %op2 = load <16 x half>, <16 x half>* %b @@ -589,8 +589,8 @@ ; CHECK-DAG: ld1h { [[OP1:z[0-9]+]].h }, [[PG]]/z, [x0] ; CHECK-DAG: ld1h { [[OP2:z[0-9]+]].h }, [[PG]]/z, [x1] ; CHECK-DAG: ld1h { [[OP3:z[0-9]+]].h }, [[PG]]/z, [x2] -; CHECK: fmla [[OP3]].h, [[PG]]/m, [[OP1]].h, [[OP2]].h -; CHECK: st1h { [[OP3]].h }, [[PG]], [x0] +; CHECK: fmad [[OP1]].h, [[PG]]/m, [[OP2]].h, [[OP3]].h +; CHECK: st1h { [[OP1]].h }, [[PG]], [x0] ; CHECK: ret %op1 = load <32 x half>, <32 x half>* %a %op2 = load <32 x half>, <32 x half>* %b @@ -606,8 +606,8 @@ ; CHECK-DAG: ld1h { [[OP1:z[0-9]+]].h }, [[PG]]/z, [x0] ; CHECK-DAG: ld1h { [[OP2:z[0-9]+]].h }, [[PG]]/z, [x1] ; CHECK-DAG: ld1h { [[OP3:z[0-9]+]].h }, [[PG]]/z, [x2] -; CHECK: fmla [[OP3]].h, [[PG]]/m, [[OP1]].h, [[OP2]].h -; CHECK: st1h { [[OP3]].h }, [[PG]], [x0] +; CHECK: fmad [[OP1]].h, [[PG]]/m, [[OP2]].h, [[OP3]].h +; CHECK: st1h { [[OP1]].h }, [[PG]], [x0] ; CHECK: ret %op1 = load <64 x half>, <64 x half>* %a %op2 = load <64 x half>, <64 x half>* %b @@ -623,8 +623,8 @@ ; CHECK-DAG: ld1h { [[OP1:z[0-9]+]].h }, [[PG]]/z, [x0] ; CHECK-DAG: ld1h { [[OP2:z[0-9]+]].h }, [[PG]]/z, [x1] ; CHECK-DAG: ld1h { [[OP3:z[0-9]+]].h }, [[PG]]/z, [x2] -; CHECK: fmla [[OP3]].h, [[PG]]/m, [[OP1]].h, [[OP2]].h -; CHECK: st1h { [[OP3]].h }, [[PG]], [x0] +; CHECK: fmad [[OP1]].h, [[PG]]/m, [[OP2]].h, [[OP3]].h +; CHECK: st1h { [[OP1]].h }, [[PG]], [x0] ; CHECK: ret %op1 = load <128 x half>, <128 x half>* %a %op2 = load <128 x half>, <128 x half>* %b @@ -658,8 +658,8 @@ ; CHECK-DAG: ld1w { [[OP1:z[0-9]+]].s }, [[PG]]/z, [x0] ; CHECK-DAG: ld1w { [[OP2:z[0-9]+]].s }, [[PG]]/z, [x1] ; CHECK-DAG: ld1w { [[OP3:z[0-9]+]].s }, [[PG]]/z, [x2] -; CHECK: fmla [[OP3]].s, [[PG]]/m, [[OP1]].s, [[OP2]].s -; CHECK: st1w { [[OP3]].s }, [[PG]], [x0] +; CHECK: fmad [[OP1]].s, [[PG]]/m, [[OP2]].s, [[OP3]].s +; CHECK: st1w { [[OP1]].s }, [[PG]], [x0] ; CHECK: ret %op1 = load <8 x float>, <8 x float>* %a %op2 = load <8 x float>, <8 x float>* %b @@ -675,8 +675,8 @@ ; CHECK-DAG: ld1w { [[OP1:z[0-9]+]].s }, [[PG]]/z, [x0] ; CHECK-DAG: ld1w { [[OP2:z[0-9]+]].s }, [[PG]]/z, [x1] ; CHECK-DAG: ld1w { [[OP3:z[0-9]+]].s }, [[PG]]/z, [x2] -; CHECK: fmla [[OP3]].s, [[PG]]/m, [[OP1]].s, [[OP2]].s -; CHECK: st1w { [[OP3]].s }, [[PG]], [x0] +; CHECK: fmad [[OP1]].s, [[PG]]/m, [[OP2]].s, [[OP3]].s +; CHECK: st1w { [[OP1]].s }, [[PG]], [x0] ; CHECK: ret %op1 = load <16 x float>, <16 x float>* %a %op2 = load <16 x float>, <16 x float>* %b @@ -692,8 +692,8 @@ ; CHECK-DAG: ld1w { [[OP1:z[0-9]+]].s }, [[PG]]/z, [x0] ; CHECK-DAG: ld1w { [[OP2:z[0-9]+]].s }, [[PG]]/z, [x1] ; CHECK-DAG: ld1w { [[OP3:z[0-9]+]].s }, [[PG]]/z, [x2] -; CHECK: fmla [[OP3]].s, [[PG]]/m, [[OP1]].s, [[OP2]].s -; CHECK: st1w { [[OP3]].s }, [[PG]], [x0] +; CHECK: fmad [[OP1]].s, [[PG]]/m, [[OP2]].s, [[OP3]].s +; CHECK: st1w { [[OP1]].s }, [[PG]], [x0] ; CHECK: ret %op1 = load <32 x float>, <32 x float>* %a %op2 = load <32 x float>, <32 x float>* %b @@ -709,8 +709,8 @@ ; CHECK-DAG: ld1w { [[OP1:z[0-9]+]].s }, [[PG]]/z, [x0] ; CHECK-DAG: ld1w { [[OP2:z[0-9]+]].s }, [[PG]]/z, [x1] ; CHECK-DAG: ld1w { [[OP3:z[0-9]+]].s }, [[PG]]/z, [x2] -; CHECK: fmla [[OP3]].s, [[PG]]/m, [[OP1]].s, [[OP2]].s -; CHECK: st1w { [[OP3]].s }, [[PG]], [x0] +; CHECK: fmad [[OP1]].s, [[PG]]/m, [[OP2]].s, [[OP3]].s +; CHECK: st1w { [[OP1]].s }, [[PG]], [x0] ; CHECK: ret %op1 = load <64 x float>, <64 x float>* %a %op2 = load <64 x float>, <64 x float>* %b @@ -744,8 +744,8 @@ ; CHECK-DAG: ld1d { [[OP1:z[0-9]+]].d }, [[PG]]/z, [x0] ; CHECK-DAG: ld1d { [[OP2:z[0-9]+]].d }, [[PG]]/z, [x1] ; CHECK-DAG: ld1d { [[OP3:z[0-9]+]].d }, [[PG]]/z, [x2] -; CHECK: fmla [[OP3]].d, [[PG]]/m, [[OP1]].d, [[OP2]].d -; CHECK: st1d { [[OP3]].d }, [[PG]], [x0] +; CHECK: fmad [[OP1]].d, [[PG]]/m, [[OP2]].d, [[OP3]].d +; CHECK: st1d { [[OP1]].d }, [[PG]], [x0] ; CHECK: ret %op1 = load <4 x double>, <4 x double>* %a %op2 = load <4 x double>, <4 x double>* %b @@ -761,8 +761,8 @@ ; CHECK-DAG: ld1d { [[OP1:z[0-9]+]].d }, [[PG]]/z, [x0] ; CHECK-DAG: ld1d { [[OP2:z[0-9]+]].d }, [[PG]]/z, [x1] ; CHECK-DAG: ld1d { [[OP3:z[0-9]+]].d }, [[PG]]/z, [x2] -; CHECK: fmla [[OP3]].d, [[PG]]/m, [[OP1]].d, [[OP2]].d -; CHECK: st1d { [[OP3]].d }, [[PG]], [x0] +; CHECK: fmad [[OP1]].d, [[PG]]/m, [[OP2]].d, [[OP3]].d +; CHECK: st1d { [[OP1]].d }, [[PG]], [x0] ; CHECK: ret %op1 = load <8 x double>, <8 x double>* %a %op2 = load <8 x double>, <8 x double>* %b @@ -778,8 +778,8 @@ ; CHECK-DAG: ld1d { [[OP1:z[0-9]+]].d }, [[PG]]/z, [x0] ; CHECK-DAG: ld1d { [[OP2:z[0-9]+]].d }, [[PG]]/z, [x1] ; CHECK-DAG: ld1d { [[OP3:z[0-9]+]].d }, [[PG]]/z, [x2] -; CHECK: fmla [[OP3]].d, [[PG]]/m, [[OP1]].d, [[OP2]].d -; CHECK: st1d { [[OP3]].d }, [[PG]], [x0] +; CHECK: fmad [[OP1]].d, [[PG]]/m, [[OP2]].d, [[OP3]].d +; CHECK: st1d { [[OP1]].d }, [[PG]], [x0] ; CHECK: ret %op1 = load <16 x double>, <16 x double>* %a %op2 = load <16 x double>, <16 x double>* %b @@ -795,8 +795,8 @@ ; CHECK-DAG: ld1d { [[OP1:z[0-9]+]].d }, [[PG]]/z, [x0] ; CHECK-DAG: ld1d { [[OP2:z[0-9]+]].d }, [[PG]]/z, [x1] ; CHECK-DAG: ld1d { [[OP3:z[0-9]+]].d }, [[PG]]/z, [x2] -; CHECK: fmla [[OP3]].d, [[PG]]/m, [[OP1]].d, [[OP2]].d -; CHECK: st1d { [[OP3]].d }, [[PG]], [x0] +; CHECK: fmad [[OP1]].d, [[PG]]/m, [[OP2]].d, [[OP3]].d +; CHECK: st1d { [[OP1]].d }, [[PG]], [x0] ; CHECK: ret %op1 = load <32 x double>, <32 x double>* %a %op2 = load <32 x double>, <32 x double>* %b diff --git a/llvm/test/CodeGen/AArch64/sve-fma-dagcombine.ll b/llvm/test/CodeGen/AArch64/sve-fma-dagcombine.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-fma-dagcombine.ll @@ -0,0 +1,12 @@ +; RUN: llc -march=aarch64 -mattr=+sve -fp-contract=fast < %s | FileCheck %s + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64-none--elf" + +define @use_fma( %a, %b, %c) { +; CHECK-LABEL: use_fma +; CHECK: fmad + %mul = fmul fast %a, %b + %res = fadd fast %mul, %c + ret %res +} diff --git a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll @@ -0,0 +1,420 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve -fp-contract=fast < %s | FileCheck %s + +; NOTE: -fp-contract=fast required for fmla + +define @fmla_h_sel( %pred, %acc, + %m1, %m2) { +; CHECK-LABEL: fmla_h_sel: +; CHECK: fmla z0.h, p0/m, z1.h, z2.h +; CHECK: ret + %mul = fmul %m1, %m2 + %add = fadd %acc, %mul + %res = select %pred, %add, %acc + ret %res +} + +define @fmla_s_sel( %pred, %acc, + %m1, %m2) { +; CHECK-LABEL: fmla_s_sel: +; CHECK: fmla z0.s, p0/m, z1.s, z2.s +; CHECK: ret + %mul = fmul %m1, %m2 + %add = fadd %acc, %mul + %res = select %pred, %add, %acc + ret %res +} + +define @fmla_sx2_sel( %pred, %acc, + %m1, %m2) { +; CHECK-LABEL: fmla_sx2_sel: +; CHECK: fmla z0.s, p0/m, z1.s, z2.s +; CHECK: ret + %mul = fmul %m1, %m2 + %add = fadd %acc, %mul + %res = select %pred, %add, %acc + ret %res +} + +define @fmla_d_sel( %pred, %acc, + %m1, %m2) { +; CHECK-LABEL: fmla_d_sel: +; CHECK: fmla z0.d, p0/m, z1.d, z2.d +; CHECK: ret + %mul = fmul %m1, %m2 + %add = fadd %acc, %mul + %res = select %pred, %add, %acc + ret %res +} + +define @fmls_h_sel( %pred, %acc, + %m1, %m2) { +; CHECK-LABEL: fmls_h_sel: +; CHECK: fmls z0.h, p0/m, z1.h, z2.h +; CHECK: ret + %mul = fmul %m1, %m2 + %sub = fsub %acc, %mul + %res = select %pred, %sub, %acc + ret %res +} + +define @fmls_s_sel( %pred, %acc, + %m1, %m2) { +; CHECK-LABEL: fmls_s_sel: +; CHECK: fmls z0.s, p0/m, z1.s, z2.s +; CHECK: ret + %mul = fmul %m1, %m2 + %sub = fsub %acc, %mul + %res = select %pred, %sub, %acc + ret %res +} + +define @fmls_sx2_sel( %pred, %acc, + %m1, %m2) { +; CHECK-LABEL: fmls_sx2_sel: +; CHECK: fmls z0.s, p0/m, z1.s, z2.s +; CHECK: ret + %mul = fmul %m1, %m2 + %sub = fsub %acc, %mul + %res = select %pred, %sub, %acc + ret %res +} + +define @fmls_d_sel( %pred, %acc, + %m1, %m2) { +; CHECK-LABEL: fmls_d_sel: +; CHECK: fmls z0.d, p0/m, z1.d, z2.d +; CHECK: ret + %mul = fmul %m1, %m2 + %sub = fsub %acc, %mul + %res = select %pred, %sub, %acc + ret %res +} + +define @fmad_h( %m1, %m2, %acc) { +; CHECK-LABEL: fmad_h: +; CHECK: ptrue [[PG:p[0-9]+]].h +; CHECK-NEXT: fmad z0.h, [[PG]]/m, z1.h, z2.h +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fadd %acc, %mul + ret %res +} + +define @fmad_s( %m1, %m2, %acc) { +; CHECK-LABEL: fmad_s: +; CHECK: ptrue [[PG:p[0-9]+]].s +; CHECK-NEXT: fmad z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fadd %acc, %mul + ret %res +} + +define @fmad_sx2( %m1, %m2, %acc) { +; CHECK-LABEL: fmad_sx2: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fmad z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fadd %acc, %mul + ret %res +} + +define @fmad_d( %m1, %m2, %acc) { +; CHECK-LABEL: fmad_d: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fmad z0.d, [[PG]]/m, z1.d, z2.d +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fadd %acc, %mul + ret %res +} + +define @fmla_s( %acc, %m1, %m2) { +; CHECK-LABEL: fmla_s: +; CHECK: ptrue [[PG:p[0-9]+]].s +; CHECK-NEXT: fmla z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fadd %acc, %mul + ret %res +} + +define @fmla_sx2( %acc, %m1, %m2) { +; CHECK-LABEL: fmla_sx2: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fmla z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fadd %acc, %mul + ret %res +} + +define @fmla_d( %acc, %m1, %m2) { +; CHECK-LABEL: fmla_d: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fmla z0.d, [[PG]]/m, z1.d, z2.d +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fadd %acc, %mul + ret %res +} + +define @fmls_h( %acc, %m1, %m2) { +; CHECK-LABEL: fmls_h: +; CHECK: ptrue [[PG:p[0-9]+]].h +; CHECK-NEXT: fmls z0.h, [[PG]]/m, z1.h, z2.h +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %acc, %mul + ret %res +} + +define @fmls_s( %acc, %m1, %m2) { +; CHECK-LABEL: fmls_s: +; CHECK: ptrue [[PG:p[0-9]+]].s +; CHECK-NEXT: fmls z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %acc, %mul + ret %res +} + +define @fmls_sx2( %acc, %m1, %m2) { +; CHECK-LABEL: fmls_sx2: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fmls z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %acc, %mul + ret %res +} + +define @fmls_d( %acc, %m1, %m2) { +; CHECK-LABEL: fmls_d: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fmls z0.d, [[PG]]/m, z1.d, z2.d +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %acc, %mul + ret %res +} + +define @fmsb_s( %m1, %m2, %acc) { +; CHECK-LABEL: fmsb_s: +; CHECK: ptrue [[PG:p[0-9]+]].s +; CHECK-NEXT: fmsb z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %acc, %mul + ret %res +} + +define @fmsb_sx2( %m1, %m2, %acc) { +; CHECK-LABEL: fmsb_sx2: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fmsb z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %acc, %mul + ret %res +} + +define @fmsb_d( %m1, %m2, %acc) { +; CHECK-LABEL: fmsb_d: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fmsb z0.d, [[PG]]/m, z1.d, z2.d +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %acc, %mul + ret %res +} + +define @fnmad_s( %m1, %m2, %acc) { +; CHECK-LABEL: fnmad_s: +; CHECK: ptrue [[PG:p[0-9]+]].s +; CHECK-NEXT: fnmad z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %neg_m1 = fsub + shufflevector ( + insertelement ( undef, float -0.000000e+00, i32 0), + undef, + zeroinitializer), + %m1 + + %mul = fmul %neg_m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmad_sx2( %m1, %m2, %acc) { +; CHECK-LABEL: fnmad_sx2: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fnmad z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %neg_m1 = fsub + shufflevector ( + insertelement ( undef, float -0.000000e+00, i32 0), + undef, + zeroinitializer), + %m1 + + %mul = fmul %neg_m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmad_d( %m1, %m2, %acc) { +; CHECK-LABEL: fnmad_d: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fnmad z0.d, [[PG]]/m, z1.d, z2.d +; CHECK-NEXT: ret + %neg_m1 = fsub + shufflevector ( + insertelement ( undef, double -0.000000e+00, i32 0), + undef, + zeroinitializer), + %m1 + + %mul = fmul %neg_m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmla_h( %acc, %m1, %m2) { +; CHECK-LABEL: fnmla_h: +; CHECK: ptrue [[PG:p[0-9]+]].h +; CHECK-NEXT: fnmla z0.h, [[PG]]/m, z1.h, z2.h +; CHECK-NEXT: ret + %neg_m1 = fsub + shufflevector ( + insertelement ( undef, half -0.000000e+00, i32 0), + undef, + zeroinitializer), + %m1 + + %mul = fmul %neg_m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmla_s( %acc, %m1, %m2) { +; CHECK-LABEL: fnmla_s: +; CHECK: ptrue [[PG:p[0-9]+]].s +; CHECK-NEXT: fnmla z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %neg_m1 = fsub + shufflevector ( + insertelement ( undef, float -0.000000e+00, i32 0), + undef, + zeroinitializer), + %m1 + + %mul = fmul %neg_m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmla_sx2( %acc, %m1, %m2) { +; CHECK-LABEL: fnmla_sx2: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fnmla z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %neg_m1 = fsub + shufflevector ( + insertelement ( undef, float -0.000000e+00, i32 0), + undef, + zeroinitializer), + %m1 + + %mul = fmul %neg_m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmla_d( %acc, %m1, %m2) { +; CHECK-LABEL: fnmla_d: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fnmla z0.d, [[PG]]/m, z1.d, z2.d +; CHECK-NEXT: ret + %neg_m1 = fsub + shufflevector ( + insertelement ( undef, double -0.000000e+00, i32 0), + undef, + zeroinitializer), + %m1 + + %mul = fmul %neg_m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmls_h( %acc, %m1, %m2) { +; CHECK-LABEL: fnmls_h: +; CHECK: ptrue [[PG:p[0-9]+]].h +; CHECK-NEXT: fnmls z0.h, [[PG]]/m, z1.h, z2.h +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmls_s( %acc, %m1, %m2) { +; CHECK-LABEL: fnmls_s: +; CHECK: ptrue [[PG:p[0-9]+]].s +; CHECK-NEXT: fnmls z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmls_sx2( %acc, %m1, %m2) { +; CHECK-LABEL: fnmls_sx2: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fnmls z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmls_d( %acc, %m1, %m2) { +; CHECK-LABEL: fnmls_d: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fnmls z0.d, [[PG]]/m, z1.d, z2.d +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmsb_s( %m1, %m2, %acc) { +; CHECK-LABEL: fnmsb_s: +; CHECK: ptrue [[PG:p[0-9]+]].s +; CHECK-NEXT: fnmsb z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmsb_sx2( %m1, %m2, %acc) { +; CHECK-LABEL: fnmsb_sx2: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fnmsb z0.s, [[PG]]/m, z1.s, z2.s +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %mul, %acc + ret %res +} + +define @fnmsb_d( %m1, %m2, %acc) { +; CHECK-LABEL: fnmsb_d: +; CHECK: ptrue [[PG:p[0-9]+]].d +; CHECK-NEXT: fnmsb z0.d, [[PG]]/m, z1.d, z2.d +; CHECK-NEXT: ret + %mul = fmul %m1, %m2 + %res = fsub %mul, %acc + ret %res +} diff --git a/llvm/test/CodeGen/AArch64/sve-fp.ll b/llvm/test/CodeGen/AArch64/sve-fp.ll --- a/llvm/test/CodeGen/AArch64/sve-fp.ll +++ b/llvm/test/CodeGen/AArch64/sve-fp.ll @@ -240,8 +240,7 @@ ; CHECK-LABEL: fma_nxv8f16: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.h -; CHECK-NEXT: fmla z2.h, p0/m, z0.h, z1.h -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: fmad z0.h, p0/m, z1.h, z2.h ; CHECK-NEXT: ret %r = call @llvm.fma.nxv8f16( %a, %b, %c) ret %r @@ -251,8 +250,7 @@ ; CHECK-LABEL: fma_nxv4f16: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.s -; CHECK-NEXT: fmla z2.h, p0/m, z0.h, z1.h -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: fmad z0.h, p0/m, z1.h, z2.h ; CHECK-NEXT: ret %r = call @llvm.fma.nxv4f16( %a, %b, %c) ret %r @@ -262,8 +260,7 @@ ; CHECK-LABEL: fma_nxv2f16: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.d -; CHECK-NEXT: fmla z2.h, p0/m, z0.h, z1.h -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: fmad z0.h, p0/m, z1.h, z2.h ; CHECK-NEXT: ret %r = call @llvm.fma.nxv2f16( %a, %b, %c) ret %r @@ -273,8 +270,7 @@ ; CHECK-LABEL: fma_nxv4f32: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.s -; CHECK-NEXT: fmla z2.s, p0/m, z0.s, z1.s -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: fmad z0.s, p0/m, z1.s, z2.s ; CHECK-NEXT: ret %r = call @llvm.fma.nxv4f32( %a, %b, %c) ret %r @@ -284,8 +280,7 @@ ; CHECK-LABEL: fma_nxv2f32: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.d -; CHECK-NEXT: fmla z2.s, p0/m, z0.s, z1.s -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: fmad z0.s, p0/m, z1.s, z2.s ; CHECK-NEXT: ret %r = call @llvm.fma.nxv2f32( %a, %b, %c) ret %r @@ -295,8 +290,7 @@ ; CHECK-LABEL: fma_nxv2f64_1: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.d -; CHECK-NEXT: fmla z2.d, p0/m, z0.d, z1.d -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: fmad z0.d, p0/m, z1.d, z2.d ; CHECK-NEXT: ret %r = call @llvm.fma.nxv2f64( %a, %b, %c) ret %r @@ -306,8 +300,7 @@ ; CHECK-LABEL: fma_nxv2f64_2: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.d -; CHECK-NEXT: fmla z2.d, p0/m, z1.d, z0.d -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: fmad z0.d, p0/m, z1.d, z2.d ; CHECK-NEXT: ret %r = call @llvm.fma.nxv2f64( %b, %a, %c) ret %r