diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -684,6 +684,49 @@ return None; } +static Optional instCombineSVEVectorBinOp(InstCombiner &IC, + IntrinsicInst &II) { + auto IsIntrinsic = [](auto *I, auto In) { + auto *IntrI = dyn_cast(I); + if (!IntrI || IntrI->getIntrinsicID() != In) + return false; + return true; + }; + + auto *OperandValue = II.getOperand(0); + if (IsIntrinsic(OperandValue, Intrinsic::aarch64_sve_convert_from_svbool)) { + OperandValue = dyn_cast(OperandValue)->getOperand(0); + } + + if (IsIntrinsic(OperandValue, Intrinsic::aarch64_sve_ptrue)) { + auto *PTrueIntrinsic = dyn_cast(OperandValue); + const uint64_t PTruePattern = + cast(PTrueIntrinsic->getOperand(0))->getZExtValue(); + if (PTruePattern != AArch64SVEPredPattern::all) { + return None; + } + + Instruction::BinaryOps BinOp = Instruction::BinaryOpsEnd; + switch (II.getIntrinsicID()) { + case Intrinsic::aarch64_sve_fmul: + BinOp = Instruction::BinaryOps::FMul; + break; + case Intrinsic::aarch64_sve_fadd: + BinOp = Instruction::BinaryOps::FAdd; + break; + default: + return None; + } + + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + return IC.replaceInstUsesWith( + II, Builder.CreateBinOp(BinOp, II.getOperand(1), II.getOperand(2))); + } + + return None; +} + static Optional instCombineSVEVectorMul(InstCombiner &IC, IntrinsicInst &II) { auto *OpPredicate = II.getOperand(0); @@ -736,7 +779,7 @@ } } - return None; + return instCombineSVEVectorBinOp(IC, II); } static Optional instCombineSVEUnpack(InstCombiner &IC, @@ -824,6 +867,8 @@ case Intrinsic::aarch64_sve_mul: case Intrinsic::aarch64_sve_fmul: return instCombineSVEVectorMul(IC, II); + case Intrinsic::aarch64_sve_fadd: + return instCombineSVEVectorBinOp(IC, II); case Intrinsic::aarch64_sve_tbl: return instCombineSVETBL(IC, II); case Intrinsic::aarch64_sve_uunpkhi: