diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -695,6 +695,42 @@ return None; } +static Optional instCombineSVEVectorFmla(InstCombiner &IC, + IntrinsicInst &II) { + Value *p, *a, *b, *c; + auto m_SVEFAdd = [](auto p, auto w, auto x) { + return m_CombineOr(m_Intrinsic(p, w, x), + m_Intrinsic(p, x, w)); + }; + auto m_SVEFMul = [](auto p, auto y, auto z) { + return m_Intrinsic(p, y, z); + }; + if (!match(&II, m_SVEFAdd(m_Value(p), + m_SVEFMul(m_Deferred(p), m_Value(a), m_Value(b)), + m_Value(c)))) + return None; + + Value *AddOp1 = II.getOperand(1); + Value *AddOp2 = II.getOperand(2); + if (!match(AddOp1, m_Intrinsic())) + std::swap(AddOp1, AddOp2); + if (!match(AddOp1, m_Intrinsic())) + return None; + + auto FMulInst = dyn_cast(AddOp1); + if (!FMulInst->hasOneUse()) + return None; + llvm::FastMathFlags flags = II.getFastMathFlags(); + flags &= FMulInst->getFastMathFlags(); + + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + auto FmlaInst = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla, + {II.getType()}, {p, a, b, c}, &II); + FmlaInst->setFastMathFlags(flags); + return IC.replaceInstUsesWith(II, FmlaInst); +} + static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) { switch (Intrinsic) { case Intrinsic::aarch64_sve_fmul: @@ -710,10 +746,17 @@ static Optional instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) { - auto *OpPredicate = II.getOperand(0); auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID()); - if (BinOpCode == Instruction::BinaryOpsEnd || - !match(OpPredicate, m_Intrinsic( + if (BinOpCode == Instruction::BinaryOpsEnd) + return None; + else if (BinOpCode == Instruction::FAdd) { + auto FmlaResult = instCombineSVEVectorFmla(IC, II); + if (FmlaResult.hasValue()) + return FmlaResult; + } + + auto *OpPredicate = II.getOperand(0); + if (!match(OpPredicate, m_Intrinsic( m_ConstantInt()))) return None; IRBuilder<> Builder(II.getContext()); diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fmla.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fmla.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fmla.ll @@ -0,0 +1,84 @@ +; RUN: opt -S -instcombine < %s | FileCheck %s + +target triple = "aarch64-unknown-linux-gnu" + +define dso_local @combine_fmla( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fmla +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = call @llvm.aarch64.sve.fmla.nxv8f16( %5, %2, %3, %1) +; CHECK-NEXT: ret %6 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) + %7 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) + ret %7 +} + +define dso_local @combine_fmla_neg( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fmla_neg +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %2, %3) +; CHECK-NEXT: %7 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) +; CHECK-NEXT: ret %7 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %2, %3) + %7 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) + ret %7 +} + +define dso_local @combine_fmla_neg_nequal_pred( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fmla +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = tail call @llvm.aarch64.sve.ptrue.nxv16i1(i32 5) +; CHECK-NEXT: %7 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %6) +; CHECK-NEXT: %8 = tail call @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) +; CHECK-NEXT: %9 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %7, %1, %8) +; ret %9 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call @llvm.aarch64.sve.ptrue.nxv16i1(i32 5) + %7 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %6) + %8 = tail call @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) + %9 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %7, %1, %8) + ret %9 +} + +define dso_local @combine_fmla_two_uses( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fmla_two_uses +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = tail call @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) +; CHECK-NEXT: %7 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) +; CHECK-NEXT: %8 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %7, %6) +; ret %8 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) + %7 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) + %8 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %7, %6) + ret %8 +} + +define dso_local @combine_fmla_copy_fast_flag( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fmla_copy_fast_flag +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = call fast @llvm.aarch64.sve.fmla.nxv8f16( %5, %2, %3, %1) +; CHECK-NEXT: ret %6 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) + %7 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) + ret %7 +} + +define dso_local @combine_fmla_dont_copy_fast_flag_and( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fmla_dont_copy_fast_flag_and +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = call @llvm.aarch64.sve.fmla.nxv8f16( %5, %2, %3, %1) +; CHECK-NEXT: ret %6 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) + %7 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) + ret %7 +} + +declare @llvm.aarch64.sve.convert.from.svbool.nxv8i1() +declare @llvm.aarch64.sve.fmul.nxv8f16(, , ) +declare @llvm.aarch64.sve.fadd.nxv8f16(, , ) +declare @llvm.aarch64.sve.ptrue.nxv16i1(i32) +attributes #0 = { "target-features"="+sve" }