diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h --- a/llvm/include/llvm/IR/Operator.h +++ b/llvm/include/llvm/IR/Operator.h @@ -247,6 +247,9 @@ void operator|=(const FastMathFlags &OtherFlags) { Flags |= OtherFlags.Flags; } + bool operator!=(const FastMathFlags &OtherFlags) const { + return Flags != OtherFlags.Flags; + } }; /// Utility class for floating point operations which can have diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -695,6 +695,37 @@ return None; } +static Optional instCombineSVEVectorFMLA(InstCombiner &IC, + IntrinsicInst &II) { + // fold (fadd p a (fmul p b c)) -> (fma p a b c) + Value *p = II.getOperand(0); + Value *a = II.getOperand(1); + auto FMul = II.getOperand(2); + Value *b, *c; + if (!match(FMul, m_Intrinsic( + m_Deferred(p), m_Value(b), m_Value(c)))) + return None; + + if (!FMul->hasOneUse()) + return None; + + llvm::FastMathFlags FAddFlags = II.getFastMathFlags(); + llvm::FastMathFlags FMulFlags = cast(FMul)->getFastMathFlags(); + // Don't combine when FMul & Fadd flags differ to prevent the loss of any + // additional important flags + if (FAddFlags != FMulFlags) + return None; + if (!FAddFlags.allowContract() || !FMulFlags.allowContract()) + return None; + + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla, + {II.getType()}, {p, a, b, c}, &II); + FMLA->setFastMathFlags(FAddFlags); + return IC.replaceInstUsesWith(II, FMLA); +} + static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) { switch (Intrinsic) { case Intrinsic::aarch64_sve_fmul: @@ -724,6 +755,14 @@ return IC.replaceInstUsesWith(II, BinOp); } +static Optional instCombineSVEVectorFAdd(InstCombiner &IC, + IntrinsicInst &II) { + auto FMLA = instCombineSVEVectorFMLA(IC, II); + if (FMLA) + return FMLA; + return instCombineSVEVectorBinOp(IC, II); +} + static Optional instCombineSVEVectorMul(InstCombiner &IC, IntrinsicInst &II) { auto *OpPredicate = II.getOperand(0); @@ -901,6 +940,7 @@ case Intrinsic::aarch64_sve_fmul: return instCombineSVEVectorMul(IC, II); case Intrinsic::aarch64_sve_fadd: + return instCombineSVEVectorFAdd(IC, II); case Intrinsic::aarch64_sve_fsub: return instCombineSVEVectorBinOp(IC, II); case Intrinsic::aarch64_sve_tbl: