diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1619,9 +1619,17 @@ case Intrinsic::aarch64_sve_fadd: case Intrinsic::aarch64_sve_add: return instCombineSVEVectorAdd(IC, II); + case Intrinsic::aarch64_sve_fadd_u: + return instCombineSVEVectorFuseMulAddSub( + IC, II, true); case Intrinsic::aarch64_sve_fsub: case Intrinsic::aarch64_sve_sub: return instCombineSVEVectorSub(IC, II); + case Intrinsic::aarch64_sve_fsub_u: + return instCombineSVEVectorFuseMulAddSub( + IC, II, true); case Intrinsic::aarch64_sve_tbl: return instCombineSVETBL(IC, II); case Intrinsic::aarch64_sve_uunpkhi: diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladdsub.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladdsub.ll --- a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladdsub.ll +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladdsub.ll @@ -15,6 +15,18 @@ ret %3 } +define @combine_fmla_u( %p, %a, %b, %c) #0 { +; CHECK-LABEL: @combine_fmla_u( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call fast @llvm.aarch64.sve.fmla.u.nxv8f16( [[TMP1]], [[C:%.*]], [[A:%.*]], [[B:%.*]]) +; CHECK-NEXT: ret [[TMP2]] +; + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call fast @llvm.aarch64.sve.fmul.u.nxv8f16( %1, %a, %b) + %3 = tail call fast @llvm.aarch64.sve.fadd.u.nxv8f16( %1, %c, %2) + ret %3 +} + define @combine_mla_i8( %p, %a, %b, %c) #0 { ; CHECK-LABEL: @combine_mla_i8( ; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.mla.nxv16i8( [[P:%.*]], [[C:%.*]], [[A:%.*]], [[B:%.*]]) @@ -59,6 +71,18 @@ ret %3 } +define @combine_fmls_u( %p, %a, %b, %c) #0 { +; CHECK-LABEL: @combine_fmls_u( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call fast @llvm.aarch64.sve.fmls.u.nxv8f16( [[TMP1]], [[C:%.*]], [[A:%.*]], [[B:%.*]]) +; CHECK-NEXT: ret [[TMP2]] +; + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call fast @llvm.aarch64.sve.fmul.u.nxv8f16( %1, %a, %b) + %3 = tail call fast @llvm.aarch64.sve.fsub.u.nxv8f16( %1, %c, %2) + ret %3 +} + define @combine_mls_i8( %p, %a, %b, %c) #0 { ; CHECK-LABEL: @combine_mls_i8( ; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.mls.nxv16i8( [[P:%.*]], [[C:%.*]], [[A:%.*]], [[B:%.*]]) @@ -173,6 +197,9 @@ declare @llvm.aarch64.sve.fmul.nxv8f16(, , ) declare @llvm.aarch64.sve.fadd.nxv8f16(, , ) declare @llvm.aarch64.sve.fsub.nxv8f16(, , ) +declare @llvm.aarch64.sve.fmul.u.nxv8f16(, , ) +declare @llvm.aarch64.sve.fadd.u.nxv8f16(, , ) +declare @llvm.aarch64.sve.fsub.u.nxv8f16(, , ) declare @llvm.aarch64.sve.mul.nxv16i8(, , ) declare @llvm.aarch64.sve.add.nxv16i8(, , ) declare @llvm.aarch64.sve.sub.nxv16i8(, , )