diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1045,34 +1045,51 @@ return std::nullopt; } +template static std::optional -instCombineSVEVectorFMLA(InstCombiner &IC, IntrinsicInst &II) { - // fold (fadd p a (fmul p b c)) -> (fma p a b c) +instCombineSVEVectorFuseMulAddSub(InstCombiner &IC, IntrinsicInst &II, + bool MergeIntoAddendOp) { Value *P = II.getOperand(0); - Value *A = II.getOperand(1); - auto FMul = II.getOperand(2); - Value *B, *C; - if (!match(FMul, m_Intrinsic( - m_Specific(P), m_Value(B), m_Value(C)))) - return std::nullopt; + Value *MulOp0, *MulOp1, *AddendOp, *Mul; + if (MergeIntoAddendOp) { + AddendOp = II.getOperand(1); + Mul = II.getOperand(2); + } else { + AddendOp = II.getOperand(2); + Mul = II.getOperand(1); + } - if (!FMul->hasOneUse()) + if (!match(Mul, m_Intrinsic(m_Specific(P), m_Value(MulOp0), + m_Value(MulOp1)))) return std::nullopt; - llvm::FastMathFlags FAddFlags = II.getFastMathFlags(); - // Stop the combine when the flags on the inputs differ in case dropping flags - // would lead to us missing out on more beneficial optimizations. - if (FAddFlags != cast(FMul)->getFastMathFlags()) - return std::nullopt; - if (!FAddFlags.allowContract()) + if (!Mul->hasOneUse()) return std::nullopt; + Instruction *FMFSource = nullptr; + if (II.getType()->isFPOrFPVectorTy()) { + llvm::FastMathFlags FAddFlags = II.getFastMathFlags(); + // Stop the combine when the flags on the inputs differ in case dropping + // flags would lead to us missing out on more beneficial optimizations. + if (FAddFlags != cast(Mul)->getFastMathFlags()) + return std::nullopt; + if (!FAddFlags.allowContract()) + return std::nullopt; + FMFSource = &II; + } + IRBuilder<> Builder(II.getContext()); Builder.SetInsertPoint(&II); - auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla, - {II.getType()}, {P, A, B, C}, &II); - FMLA->setFastMathFlags(FAddFlags); - return IC.replaceInstUsesWith(II, FMLA); + + CallInst *Res; + if (MergeIntoAddendOp) + Res = Builder.CreateIntrinsic(FuseOpc, {II.getType()}, + {P, AddendOp, MulOp0, MulOp1}, FMFSource); + else + Res = Builder.CreateIntrinsic(FuseOpc, {II.getType()}, + {P, MulOp0, MulOp1, AddendOp}, FMFSource); + + return IC.replaceInstUsesWith(II, Res); } static bool isAllActivePredicate(Value *Pred) { @@ -1166,10 +1183,45 @@ return IC.replaceInstUsesWith(II, BinOp); } -static std::optional -instCombineSVEVectorFAdd(InstCombiner &IC, IntrinsicInst &II) { - if (auto FMLA = instCombineSVEVectorFMLA(IC, II)) +static std::optional instCombineSVEVectorAdd(InstCombiner &IC, + IntrinsicInst &II) { + if (auto FMLA = + instCombineSVEVectorFuseMulAddSub(IC, II, + true)) return FMLA; + if (auto MLA = instCombineSVEVectorFuseMulAddSub( + IC, II, true)) + return MLA; + if (auto FMAD = + instCombineSVEVectorFuseMulAddSub(IC, II, + false)) + return FMAD; + if (auto MAD = instCombineSVEVectorFuseMulAddSub( + IC, II, false)) + return MAD; + return instCombineSVEVectorBinOp(IC, II); +} + +static std::optional instCombineSVEVectorSub(InstCombiner &IC, + IntrinsicInst &II) { + if (auto FMLS = + instCombineSVEVectorFuseMulAddSub(IC, II, + true)) + return FMLS; + if (auto MLS = instCombineSVEVectorFuseMulAddSub( + IC, II, true)) + return MLS; + if (auto FMSB = + instCombineSVEVectorFuseMulAddSub( + IC, II, false)) + return FMSB; return instCombineSVEVectorBinOp(IC, II); } @@ -1470,9 +1522,11 @@ case Intrinsic::aarch64_sve_fmul: return instCombineSVEVectorMul(IC, II); case Intrinsic::aarch64_sve_fadd: - return instCombineSVEVectorFAdd(IC, II); + case Intrinsic::aarch64_sve_add: + return instCombineSVEVectorAdd(IC, II); case Intrinsic::aarch64_sve_fsub: - return instCombineSVEVectorBinOp(IC, II); + case Intrinsic::aarch64_sve_sub: + return instCombineSVEVectorSub(IC, II); case Intrinsic::aarch64_sve_tbl: return instCombineSVETBL(IC, II); case Intrinsic::aarch64_sve_uunpkhi: diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladdsub.ll rename from llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll rename to llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladdsub.ll --- a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladdsub.ll @@ -6,50 +6,115 @@ define dso_local @combine_fmla( %p, %a, %b, %c) local_unnamed_addr #0 { ; CHECK-LABEL: @combine_fmla( ; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) -; CHECK-NEXT: [[TMP2:%.*]] = call fast @llvm.aarch64.sve.fmla.nxv8f16( [[TMP1]], [[A:%.*]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call fast @llvm.aarch64.sve.fmla.nxv8f16( [[TMP1]], [[C:%.*]], [[A:%.*]], [[B:%.*]]) ; CHECK-NEXT: ret [[TMP2]] ; %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) - %2 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) - %3 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %1, %a, %2) + %2 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %a, %b) + %3 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %1, %c, %2) ret %3 } -define dso_local @neg_combine_fmla_mul_first_operand( %p, %a, %b, %c) local_unnamed_addr #0 { -; CHECK-LABEL: @neg_combine_fmla_mul_first_operand( +define dso_local @combine_mla_i8( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_mla_i8( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.mla.nxv16i8( [[P:%.*]], [[C:%.*]], [[A:%.*]], [[B:%.*]]) +; CHECK-NEXT: ret [[TMP1]] +; + %1 = tail call @llvm.aarch64.sve.mul.nxv16i8( %p, %a, %b) + %2 = tail call @llvm.aarch64.sve.add.nxv16i8( %p, %c, %1) + ret %2 +} + +define dso_local @combine_fmad( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fmad( ; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) -; CHECK-NEXT: [[TMP2:%.*]] = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( [[TMP1]], [[B:%.*]], [[C:%.*]]) -; CHECK-NEXT: [[TMP3:%.*]] = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( [[TMP1]], [[TMP2]], [[A:%.*]]) -; CHECK-NEXT: ret [[TMP3]] +; CHECK-NEXT: [[TMP2:%.*]] = call fast @llvm.aarch64.sve.fmad.nxv8f16( [[TMP1]], [[A:%.*]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: ret [[TMP2]] ; %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) - %2 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) - %3 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %1, %2, %a) + %2 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %a, %b) + %3 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %1, %2, %c) ret %3 } -define dso_local @neg_combine_fmla_contract_flag_only( %p, %a, %b, %c) local_unnamed_addr #0 { -; CHECK-LABEL: @neg_combine_fmla_contract_flag_only( +define dso_local @combine_mad_i8( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_mad_i8( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.mad.nxv16i8( [[P:%.*]], [[A:%.*]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: ret [[TMP1]] +; + %1 = tail call @llvm.aarch64.sve.mul.nxv16i8( %p, %a, %b) + %2 = tail call @llvm.aarch64.sve.add.nxv16i8( %p, %1, %c) + ret %2 +} + +define dso_local @combine_fmls( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fmls( ; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) -; CHECK-NEXT: [[TMP2:%.*]] = call contract @llvm.aarch64.sve.fmla.nxv8f16( [[TMP1]], [[A:%.*]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call fast @llvm.aarch64.sve.fmls.nxv8f16( [[TMP1]], [[C:%.*]], [[A:%.*]], [[B:%.*]]) ; CHECK-NEXT: ret [[TMP2]] ; %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) - %2 = tail call contract @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) - %3 = tail call contract @llvm.aarch64.sve.fadd.nxv8f16( %1, %a, %2) + %2 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %a, %b) + %3 = tail call fast @llvm.aarch64.sve.fsub.nxv8f16( %1, %c, %2) + ret %3 +} + +define dso_local @combine_mls_i8( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_mls_i8( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.mls.nxv16i8( [[P:%.*]], [[C:%.*]], [[A:%.*]], [[B:%.*]]) +; CHECK-NEXT: ret [[TMP1]] +; + %1 = tail call @llvm.aarch64.sve.mul.nxv16i8( %p, %a, %b) + %2 = tail call @llvm.aarch64.sve.sub.nxv16i8( %p, %c, %1) + ret %2 +} + +define dso_local @combine_fnmsb( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fnmsb( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call fast @llvm.aarch64.sve.fnmsb.nxv8f16( [[TMP1]], [[A:%.*]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: ret [[TMP2]] +; + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %a, %b) + %3 = tail call fast @llvm.aarch64.sve.fsub.nxv8f16( %1, %2, %c) + ret %3 +} + +; No integer variant of fnmsb exists; Do not combine +define dso_local @neg_combine_nmsb_i8( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_nmsb_i8( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.mul.nxv16i8( [[P:%.*]], [[A:%.*]], [[B:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.sub.nxv16i8( [[P]], [[TMP1]], [[C:%.*]]) +; CHECK-NEXT: ret [[TMP2]] +; + %1 = tail call @llvm.aarch64.sve.mul.nxv16i8( %p, %a, %b) + %2 = tail call @llvm.aarch64.sve.sub.nxv16i8( %p, %1, %c) + ret %2 +} + +define dso_local @combine_fmla_contract_flag_only( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_fmla_contract_flag_only( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call contract @llvm.aarch64.sve.fmla.nxv8f16( [[TMP1]], [[C:%.*]], [[A:%.*]], [[B:%.*]]) +; CHECK-NEXT: ret [[TMP2]] +; + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call contract @llvm.aarch64.sve.fmul.nxv8f16( %1, %a, %b) + %3 = tail call contract @llvm.aarch64.sve.fadd.nxv8f16( %1, %c, %2) ret %3 } define dso_local @neg_combine_fmla_no_flags( %p, %a, %b, %c) local_unnamed_addr #0 { ; CHECK-LABEL: @neg_combine_fmla_no_flags( ; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) -; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.fmul.nxv8f16( [[TMP1]], [[B:%.*]], [[C:%.*]]) -; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.aarch64.sve.fadd.nxv8f16( [[TMP1]], [[A:%.*]], [[TMP2]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.fmul.nxv8f16( [[TMP1]], [[A:%.*]], [[B:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.aarch64.sve.fadd.nxv8f16( [[TMP1]], [[C:%.*]], [[TMP2]]) ; CHECK-NEXT: ret [[TMP3]] ; %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) - %2 = tail call @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) - %3 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %1, %a, %2) + %2 = tail call @llvm.aarch64.sve.fmul.nxv8f16( %1, %a, %b) + %3 = tail call @llvm.aarch64.sve.fadd.nxv8f16( %1, %c, %2) ret %3 } @@ -58,31 +123,31 @@ ; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) ; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.ptrue.nxv16i1(i32 5) ; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[TMP2]]) -; CHECK-NEXT: [[TMP4:%.*]] = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( [[TMP1]], [[B:%.*]], [[C:%.*]]) -; CHECK-NEXT: [[TMP5:%.*]] = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( [[TMP3]], [[A:%.*]], [[TMP4]]) +; CHECK-NEXT: [[TMP4:%.*]] = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( [[TMP1]], [[A:%.*]], [[B:%.*]]) +; CHECK-NEXT: [[TMP5:%.*]] = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( [[TMP3]], [[C:%.*]], [[TMP4]]) ; CHECK-NEXT: ret [[TMP5]] ; ; ret %9 %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) %2 = tail call @llvm.aarch64.sve.ptrue.nxv16i1(i32 5) %3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %2) - %4 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) - %5 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %3, %a, %4) + %4 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %a, %b) + %5 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %3, %c, %4) ret %5 } define dso_local @neg_combine_fmla_two_fmul_uses( %p, %a, %b, %c) local_unnamed_addr #0 { ; CHECK-LABEL: @neg_combine_fmla_two_fmul_uses( ; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) -; CHECK-NEXT: [[TMP2:%.*]] = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( [[TMP1]], [[B:%.*]], [[C:%.*]]) -; CHECK-NEXT: [[TMP3:%.*]] = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( [[TMP1]], [[A:%.*]], [[TMP2]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( [[TMP1]], [[A:%.*]], [[B:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( [[TMP1]], [[C:%.*]], [[TMP2]]) ; CHECK-NEXT: [[TMP4:%.*]] = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( [[TMP1]], [[TMP3]], [[TMP2]]) ; CHECK-NEXT: ret [[TMP4]] ; ; ret %8 %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) - %2 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %b, %c) - %3 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %1, %a, %2) + %2 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %a, %b) + %3 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %1, %c, %2) %4 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %1, %3, %2) ret %4 } @@ -101,8 +166,15 @@ ret %3 } +declare @llvm.aarch64.sve.ptrue.nxv16i1(i32) declare @llvm.aarch64.sve.convert.from.svbool.nxv8i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv4i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv2i1() declare @llvm.aarch64.sve.fmul.nxv8f16(, , ) declare @llvm.aarch64.sve.fadd.nxv8f16(, , ) -declare @llvm.aarch64.sve.ptrue.nxv16i1(i32) +declare @llvm.aarch64.sve.fsub.nxv8f16(, , ) +declare @llvm.aarch64.sve.mul.nxv16i8(, , ) +declare @llvm.aarch64.sve.add.nxv16i8(, , ) +declare @llvm.aarch64.sve.sub.nxv16i8(, , ) + attributes #0 = { "target-features"="+sve" }