diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h --- a/llvm/include/llvm/IR/Operator.h +++ b/llvm/include/llvm/IR/Operator.h @@ -243,6 +243,9 @@ void operator|=(const FastMathFlags &OtherFlags) { Flags |= OtherFlags.Flags; } + bool operator!=(const FastMathFlags &OtherFlags) const { + return Flags != OtherFlags.Flags; + } }; /// Utility class for floating point operations which can have diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -695,6 +695,43 @@ return None; } +static Optional instCombineSVEVectorFMLA(InstCombiner &IC, + IntrinsicInst &II) { + // fold (fadd a (fmul b c)) -> (fma a b c) + Value *p, *FMul, *a, *b, *c; + auto m_SVEFAdd = [](auto p, auto w, auto x) { + return m_CombineOr(m_Intrinsic(p, w, x), + m_Intrinsic(p, x, w)); + }; + auto m_SVEFMul = [](auto p, auto y, auto z) { + return m_Intrinsic(p, y, z); + }; + if (!match(&II, m_SVEFAdd(m_Value(p), m_Value(a), + m_CombineAnd(m_Value(FMul), + m_SVEFMul(m_Deferred(p), m_Value(b), + m_Value(c)))))) + return None; + + if (!FMul->hasOneUse()) + return None; + + llvm::FastMathFlags FAddFlags = II.getFastMathFlags(); + llvm::FastMathFlags FMulFlags = cast(FMul)->getFastMathFlags(); + if (!FAddFlags.allowReassoc() || !FMulFlags.allowReassoc()) + return None; + if (!FAddFlags.allowContract() || !FMulFlags.allowContract()) + return None; + if (FAddFlags != FMulFlags) + return None; + + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla, + {II.getType()}, {p, a, b, c}, &II); + FMLA->setFastMathFlags(FAddFlags); + return IC.replaceInstUsesWith(II, FMLA); +} + static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) { switch (Intrinsic) { case Intrinsic::aarch64_sve_fmul: @@ -724,6 +761,19 @@ return IC.replaceInstUsesWith(II, BinOp); } +static Optional instCombineSVEVectorFAdd(InstCombiner &IC, + IntrinsicInst &II, + const TargetOptions &Options) { + bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast || + Options.UnsafeFPMath); + if(AllowFusionGlobally){ + auto FMLA = instCombineSVEVectorFMLA(IC, II); + if(FMLA) + return FMLA; + } + return instCombineSVEVectorBinOp(IC, II); +} + static Optional instCombineSVEVectorMul(InstCombiner &IC, IntrinsicInst &II) { auto *OpPredicate = II.getOperand(0); @@ -900,7 +950,10 @@ case Intrinsic::aarch64_sve_mul: case Intrinsic::aarch64_sve_fmul: return instCombineSVEVectorMul(IC, II); - case Intrinsic::aarch64_sve_fadd: + case Intrinsic::aarch64_sve_fadd: { + const TargetOptions &Options = ST->getTargetLowering()->getTargetMachine().Options; + return instCombineSVEVectorFAdd(IC, II, Options); + } case Intrinsic::aarch64_sve_fsub: return instCombineSVEVectorBinOp(IC, II); case Intrinsic::aarch64_sve_tbl: diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fmla.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fmla.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fmla.ll @@ -0,0 +1,109 @@ +; RUN: opt -S -instcombine < %s | FileCheck %s + +target triple = "aarch64-unknown-linux-gnu" + +define dso_local @combine_fmla( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fmla +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = call fast @llvm.aarch64.sve.fmla.nxv8f16( %5, %1, %2, %3) +; CHECK-NEXT: ret %6 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) + %7 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) + ret %7 +} + +define dso_local @combine_fmla_contract_flag_only( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fmla +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = call contract @llvm.aarch64.sve.fmla.nxv8f16( %5, %1, %2, %3) +; CHECK-NEXT: ret %6 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call contract @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) + %7 = tail call contract @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) + ret %7 +} + +define dso_local @neg_combine_fmla_no_fast_flag( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_fmla_no_fast_flag +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = tail call @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) +; CHECK-NEXT: %7 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) +; CHECK-NEXT: ret %7 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) + %7 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) + ret %7 +} + +define dso_local @neg_combine_fmla_no_fmul( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_fmla_no_fmul +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %5, %2, %3) +; CHECK-NEXT: %7 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) +; CHECK-NEXT: ret %7 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %5, %2, %3) + %7 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) + ret %7 +} + +define dso_local @neg_combine_fmla_neq_pred( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_fmla_neq_pred +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = tail call @llvm.aarch64.sve.ptrue.nxv16i1(i32 5) +; CHECK-NEXT: %7 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %6) +; CHECK-NEXT: %8 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) +; CHECK-NEXT: %9 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %7, %1, %8) +; ret %9 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call @llvm.aarch64.sve.ptrue.nxv16i1(i32 5) + %7 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %6) + %8 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) + %9 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %7, %1, %8) + ret %9 +} + +define dso_local @neg_combine_fmla_two_fmul_uses( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_fmla_two_fmul_uses +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) +; CHECK-NEXT: %7 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) +; CHECK-NEXT: %8 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %5, %7, %6) +; ret %8 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) + %7 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) + %8 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %5, %7, %6) + ret %8 +} + +define dso_local @neg_combine_fmla_and_flags_1( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_fmla_and_flags_1 +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) +; CHECK-NEXT: %7 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) +; ret %7 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) + %7 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) + ret %7 +} + +define dso_local @neg_combine_fmla_and_flags_2( %0, %1, %2, %3) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_fmla_and_flags_2 +; CHECK-NEXT: %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) +; CHECK-NEXT: %6 = tail call @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) +; CHECK-NEXT: %7 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) +; CHECK-NEXT: ret %7 + %5 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %0) + %6 = tail call @llvm.aarch64.sve.fmul.nxv8f16( %5, %2, %3) + %7 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %5, %1, %6) + ret %7 +} + +declare @llvm.aarch64.sve.convert.from.svbool.nxv8i1() +declare @llvm.aarch64.sve.fmul.nxv8f16(, , ) +declare @llvm.aarch64.sve.fadd.nxv8f16(, , ) +declare @llvm.aarch64.sve.ptrue.nxv16i1(i32) +attributes #0 = { "target-features"="+sve" }