diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h --- a/llvm/include/llvm/IR/Operator.h +++ b/llvm/include/llvm/IR/Operator.h @@ -247,6 +247,9 @@ void operator|=(const FastMathFlags &OtherFlags) { Flags |= OtherFlags.Flags; } + bool operator!=(const FastMathFlags &OtherFlags) const { + return Flags != OtherFlags.Flags; + } }; /// Utility class for floating point operations which can have diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -695,6 +695,36 @@ return None; } +static Optional instCombineSVEVectorFMLA(InstCombiner &IC, + IntrinsicInst &II) { + // fold (fadd p a (fmul p b c)) -> (fma p a b c) + Value *P = II.getOperand(0); + Value *A = II.getOperand(1); + auto FMul = II.getOperand(2); + Value *B, *C; + if (!match(FMul, m_Intrinsic( + m_Specific(P), m_Value(B), m_Value(C)))) + return None; + + if (!FMul->hasOneUse()) + return None; + + llvm::FastMathFlags FAddFlags = II.getFastMathFlags(); + // Stop the combine when the flags on the inputs differ in case dropping flags + // would lead to us missing out on more beneficial optimizations. + if (FAddFlags != cast(FMul)->getFastMathFlags()) + return None; + if (!FAddFlags.allowContract()) + return None; + + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla, + {II.getType()}, {P, A, B, C}, &II); + FMLA->setFastMathFlags(FAddFlags); + return IC.replaceInstUsesWith(II, FMLA); +} + static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) { switch (Intrinsic) { case Intrinsic::aarch64_sve_fmul: @@ -724,6 +754,13 @@ return IC.replaceInstUsesWith(II, BinOp); } +static Optional instCombineSVEVectorFAdd(InstCombiner &IC, + IntrinsicInst &II) { + if (auto FMLA = instCombineSVEVectorFMLA(IC, II)) + return FMLA; + return instCombineSVEVectorBinOp(IC, II); +} + static Optional instCombineSVEVectorMul(InstCombiner &IC, IntrinsicInst &II) { auto *OpPredicate = II.getOperand(0); @@ -969,6 +1006,7 @@ case Intrinsic::aarch64_sve_fmul: return instCombineSVEVectorMul(IC, II); case Intrinsic::aarch64_sve_fadd: + return instCombineSVEVectorFAdd(IC, II); case Intrinsic::aarch64_sve_fsub: return instCombineSVEVectorBinOp(IC, II); case Intrinsic::aarch64_sve_tbl: diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll @@ -0,0 +1,108 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -instcombine < %s | FileCheck %s + +target triple = "aarch64-unknown-linux-gnu" + +define dso_local @combine_fmla( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fmla( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call fast @llvm.aarch64.sve.fmla.nxv8f16( [[TMP1]], [[A:%.*]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: ret [[TMP2]] +; + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) + %3 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %1, %a, %2) + ret %3 +} + +define dso_local @neg_combine_fmla_mul_first_operand( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_fmla_mul_first_operand( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( [[TMP1]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( [[TMP1]], [[TMP2]], [[A:%.*]]) +; CHECK-NEXT: ret [[TMP3]] +; + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) + %3 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %1, %2, %a) + ret %3 +} + +define dso_local @neg_combine_fmla_contract_flag_only( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_fmla_contract_flag_only( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call contract @llvm.aarch64.sve.fmla.nxv8f16( [[TMP1]], [[A:%.*]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: ret [[TMP2]] +; + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call contract @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) + %3 = tail call contract @llvm.aarch64.sve.fadd.nxv8f16( %1, %a, %2) + ret %3 +} + +define dso_local @neg_combine_fmla_no_flags( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_fmla_no_flags( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.fmul.nxv8f16( [[TMP1]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.aarch64.sve.fadd.nxv8f16( [[TMP1]], [[A:%.*]], [[TMP2]]) +; CHECK-NEXT: ret [[TMP3]] +; + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) + %3 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %1, %a, %2) + ret %3 +} + +define dso_local @neg_combine_fmla_neq_pred( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_fmla_neq_pred( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.ptrue.nxv16i1(i32 5) +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( [[TMP1]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: [[TMP5:%.*]] = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( [[TMP3]], [[A:%.*]], [[TMP4]]) +; CHECK-NEXT: ret [[TMP5]] +; +; ret %9 + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call @llvm.aarch64.sve.ptrue.nxv16i1(i32 5) + %3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %2) + %4 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) + %5 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %3, %a, %4) + ret %5 +} + +define dso_local @neg_combine_fmla_two_fmul_uses( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_fmla_two_fmul_uses( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( [[TMP1]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( [[TMP1]], [[A:%.*]], [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( [[TMP1]], [[TMP3]], [[TMP2]]) +; CHECK-NEXT: ret [[TMP4]] +; +; ret %8 + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) + %3 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %1, %a, %2) + %4 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %1, %3, %2) + ret %4 +} + +define dso_local @neg_combine_fmla_neq_flags( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_fmla_neq_flags( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call reassoc nnan contract @llvm.aarch64.sve.fmul.nxv8f16( [[TMP1]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call reassoc contract @llvm.aarch64.sve.fadd.nxv8f16( [[TMP1]], [[A:%.*]], [[TMP2]]) +; CHECK-NEXT: ret [[TMP3]] +; +; ret %7 + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call reassoc nnan contract @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) + %3 = tail call reassoc contract @llvm.aarch64.sve.fadd.nxv8f16( %1, %a, %2) + ret %3 +} + +declare @llvm.aarch64.sve.convert.from.svbool.nxv8i1() +declare @llvm.aarch64.sve.fmul.nxv8f16(, , ) +declare @llvm.aarch64.sve.fadd.nxv8f16(, , ) +declare @llvm.aarch64.sve.ptrue.nxv16i1(i32) +attributes #0 = { "target-features"="+sve" }