diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1075,6 +1075,28 @@ return IC.replaceInstUsesWith(II, FMLA); } +// Fold (ADD p c (MUL p a b)) -> (MAD p a b c) +static std::optional instCombineSVEVectorMAD(InstCombiner &IC, + IntrinsicInst &II) { + Value *P = II.getOperand(0); + Value *Mul = II.getOperand(2); + Value *A, *B; + if (!match(Mul, m_Intrinsic( + m_Specific(P), m_Value(A), m_Value(B)))) + return std::nullopt; + + if (!Mul->hasOneUse()) + return std::nullopt; + + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + Value *C = II.getOperand(1); + + CallInst *MAD = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_mad, + {II.getType()}, {P, A, B, C}); + return IC.replaceInstUsesWith(II, MAD); +} + static bool isAllActivePredicate(Value *Pred) { // Look through convert.from.svbool(convert.to.svbool(...) chain. Value *UncastedPred; @@ -1166,10 +1188,12 @@ return IC.replaceInstUsesWith(II, BinOp); } -static std::optional -instCombineSVEVectorFAdd(InstCombiner &IC, IntrinsicInst &II) { +static std::optional instCombineSVEVectorAdd(InstCombiner &IC, + IntrinsicInst &II) { if (auto FMLA = instCombineSVEVectorFMLA(IC, II)) return FMLA; + if (auto MAD = instCombineSVEVectorMAD(IC, II)) + return MAD; return instCombineSVEVectorBinOp(IC, II); } @@ -1470,7 +1494,8 @@ case Intrinsic::aarch64_sve_fmul: return instCombineSVEVectorMul(IC, II); case Intrinsic::aarch64_sve_fadd: - return instCombineSVEVectorFAdd(IC, II); + case Intrinsic::aarch64_sve_add: + return instCombineSVEVectorAdd(IC, II); case Intrinsic::aarch64_sve_fsub: return instCombineSVEVectorBinOp(IC, II); case Intrinsic::aarch64_sve_tbl: diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll --- a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll @@ -101,8 +101,103 @@ ret %3 } +define dso_local @combine_mad_i8( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_mad_i8( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.mad.nxv16i8( [[P:%.*]], [[A:%.*]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: ret [[TMP1]] +; + %1 = tail call @llvm.aarch64.sve.mul.nxv16i8( %p, %a, %b) + %2 = tail call @llvm.aarch64.sve.add.nxv16i8( %p, %c, %1) + ret %2 +} + +define dso_local @combine_mad_i16( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_mad_i16( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.mad.nxv8i16( [[TMP1]], [[A:%.*]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: ret [[TMP2]] +; + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call @llvm.aarch64.sve.mul.nxv8i16( %1, %a, %b) + %3 = tail call @llvm.aarch64.sve.add.nxv8i16( %1, %c, %2) + ret %3 +} + +define dso_local @combine_mad_i32( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_mad_i32( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.mad.nxv4i32( [[TMP1]], [[A:%.*]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: ret [[TMP2]] +; + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( %p) + %2 = tail call @llvm.aarch64.sve.mul.nxv4i32( %1, %a, %b) + %3 = tail call @llvm.aarch64.sve.add.nxv4i32( %1, %c, %2) + ret %3 +} + +define dso_local @combine_mad_i64( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_mad_i64( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv2i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.mad.nxv2i64( [[TMP1]], [[A:%.*]], [[B:%.*]], [[C:%.*]]) +; CHECK-NEXT: ret [[TMP2]] +; + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv2i1( %p) + %2 = tail call @llvm.aarch64.sve.mul.nxv2i64( %1, %a, %b) + %3 = tail call @llvm.aarch64.sve.add.nxv2i64( %1, %c, %2) + ret %3 +} + +define dso_local @combine_mad_i8_a_dest( %a, %b, %p) local_unnamed_addr #0 { +; CHECK-LABEL: @combine_mad_i8_a_dest( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.mad.nxv16i8( [[P:%.*]], [[A:%.*]], [[B:%.*]], [[A]]) +; CHECK-NEXT: ret [[TMP1]] +; + %1 = tail call @llvm.aarch64.sve.mul.nxv16i8( %p, %a, %b) + %2 = tail call @llvm.aarch64.sve.add.nxv16i8( %p, %a, %1) + ret %2 +} + +define dso_local @neg_combine_mad_mul_first_operand( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_mad_mul_first_operand( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.mul.nxv4i32( [[TMP1]], [[A:%.*]], [[B:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.aarch64.sve.add.nxv4i32( [[TMP1]], [[TMP2]], [[C:%.*]]) +; CHECK-NEXT: ret [[TMP3]] +; + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( %p) + %2 = tail call @llvm.aarch64.sve.mul.nxv4i32( %1, %a, %b) + %3 = tail call @llvm.aarch64.sve.add.nxv4i32( %1, %2, %c) + ret %3 +} + +define dso_local @neg_combine_mad_two_mul_uses( %p, %a, %b, %c) local_unnamed_addr #0 { +; CHECK-LABEL: @neg_combine_mad_two_mul_uses( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[P:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.mul.nxv8i16( [[TMP1]], [[A:%.*]], [[B:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.aarch64.sve.add.nxv8i16( [[TMP1]], [[C:%.*]], [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = tail call @llvm.aarch64.sve.add.nxv8i16( [[TMP1]], [[TMP3]], [[TMP2]]) +; CHECK-NEXT: ret [[TMP4]] +; +; ret %8 + %1 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %p) + %2 = tail call @llvm.aarch64.sve.mul.nxv8i16( %1, %a, %b) + %3 = tail call @llvm.aarch64.sve.add.nxv8i16( %1, %c, %2) + %4 = tail call @llvm.aarch64.sve.add.nxv8i16( %1, %3, %2) + ret %4 +} + declare @llvm.aarch64.sve.convert.from.svbool.nxv8i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv4i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv2i1() declare @llvm.aarch64.sve.fmul.nxv8f16(, , ) declare @llvm.aarch64.sve.fadd.nxv8f16(, , ) +declare @llvm.aarch64.sve.mul.nxv16i8(, , ) +declare @llvm.aarch64.sve.add.nxv16i8(, , ) +declare @llvm.aarch64.sve.mul.nxv8i16(, , ) +declare @llvm.aarch64.sve.add.nxv8i16(, , ) +declare @llvm.aarch64.sve.mul.nxv4i32(, , ) +declare @llvm.aarch64.sve.add.nxv4i32(, , ) +declare @llvm.aarch64.sve.mul.nxv2i64(, , ) +declare @llvm.aarch64.sve.add.nxv2i64(, , ) declare @llvm.aarch64.sve.ptrue.nxv16i1(i32) attributes #0 = { "target-features"="+sve" }