diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -684,6 +684,45 @@ return None; } +static Optional instCombineSVEVectorBinOp(InstCombiner &IC, + IntrinsicInst &II) { + auto IsIntrinsic = [](auto *I, auto In) { + auto *IntrI = dyn_cast(I); + if (!IntrI || IntrI->getIntrinsicID() != In) + return false; + return true; + }; + + auto *OpPredicate = II.getOperand(0); + if (IsIntrinsic(OpPredicate, Intrinsic::aarch64_sve_ptrue)) { + auto *PTrueIntrinsic = dyn_cast(OpPredicate); + const uint64_t PTruePattern = + cast(PTrueIntrinsic->getOperand(0))->getZExtValue(); + if (PTruePattern != AArch64SVEPredPattern::all) { + return None; + } + + Instruction::BinaryOps BinOp = Instruction::BinaryOpsEnd; + switch (II.getIntrinsicID()) { + case Intrinsic::aarch64_sve_fmul: + BinOp = Instruction::BinaryOps::FMul; + break; + case Intrinsic::aarch64_sve_fadd: + BinOp = Instruction::BinaryOps::FAdd; + break; + default: + return None; + } + + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + return IC.replaceInstUsesWith( + II, Builder.CreateBinOp(BinOp, II.getOperand(1), II.getOperand(2))); + } + + return None; +} + static Optional instCombineSVEVectorMul(InstCombiner &IC, IntrinsicInst &II) { auto *OpPredicate = II.getOperand(0); @@ -736,7 +775,7 @@ } } - return None; + return instCombineSVEVectorBinOp(IC, II); } static Optional instCombineSVEUnpack(InstCombiner &IC, @@ -824,6 +863,8 @@ case Intrinsic::aarch64_sve_mul: case Intrinsic::aarch64_sve_fmul: return instCombineSVEVectorMul(IC, II); + case Intrinsic::aarch64_sve_fadd: + return instCombineSVEVectorBinOp(IC, II); case Intrinsic::aarch64_sve_tbl: return instCombineSVETBL(IC, II); case Intrinsic::aarch64_sve_uunpkhi: diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fma-binops.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fma-binops.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fma-binops.ll @@ -0,0 +1,51 @@ +; RUN: opt -S -instcombine < %s | FileCheck %s + +target triple = "aarch64-unknown-linux-gnu" + +; SVE intrinsics fmul and fadd should be replaced with regular fmul and fadd +define @replace_fmul_intrinsic_half( %a, %b) #0 { +; CHECK-LABEL: @replace_fmul_intrinsic_half +; CHECK-NEXT: %1 = fmul %a, %b +; CHECK-NEXT: %2 = fadd %1, %a +; CHECK-NEXT: ret %2 + %1 = tail call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) + %2 = tail call fast @llvm.aarch64.sve.fmul.nxv8f16( %1, %a, %b) + %3 = tail call fast @llvm.aarch64.sve.fadd.nxv8f16( %1, %a, %2) + ret %3 +} + +define @replace_fmul_intrinsic_float( %a, %b) #0 { +; CHECK-LABEL: @replace_fmul_intrinsic_float +; CHECK-NEXT: %1 = fmul %a, %b +; CHECK-NEXT: %2 = fadd %1, %a +; CHECK-NEXT: ret %2 + %1 = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %2 = tail call fast @llvm.aarch64.sve.fmul.nxv4f32( %1, %a, %b) + %3 = tail call fast @llvm.aarch64.sve.fadd.nxv4f32( %1, %a, %2) + ret %3 +} + +define @replace_fmul_intrinsic_double( %a, %b) #0 { +; CHECK-LABEL: @replace_fmul_intrinsic_double +; CHECK-NEXT: %1 = fmul %a, %b +; CHECK-NEXT: %2 = fadd %1, %a +; CHECK-NEXT: ret %2 + %1 = tail call @llvm.aarch64.sve.ptrue.nxv2i1(i32 31) + %2 = tail call fast @llvm.aarch64.sve.fmul.nxv2f64( %1, %a, %b) + %3 = tail call fast @llvm.aarch64.sve.fadd.nxv2f64( %1, %a, %2) + ret %3 +} + +attributes #0 = { "target-features"="+sve" } + +declare @llvm.aarch64.sve.ptrue.nxv8i1(i32) +declare @llvm.aarch64.sve.fmul.nxv8f16(, , ) +declare @llvm.aarch64.sve.fadd.nxv8f16(, , ) + +declare @llvm.aarch64.sve.ptrue.nxv4i1(i32) +declare @llvm.aarch64.sve.fmul.nxv4f32(, , ) +declare @llvm.aarch64.sve.fadd.nxv4f32(, , ) + +declare @llvm.aarch64.sve.ptrue.nxv2i1(i32) +declare @llvm.aarch64.sve.fmul.nxv2f64(, , ) +declare @llvm.aarch64.sve.fadd.nxv2f64(, , ) \ No newline at end of file