Index: llvm/include/llvm/Analysis/InstructionSimplify.h =================================================================== --- llvm/include/llvm/Analysis/InstructionSimplify.h +++ llvm/include/llvm/Analysis/InstructionSimplify.h @@ -142,6 +142,12 @@ Value *SimplifyFMulInst(Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); +/// Given operands for the multiplication of a FMA, fold the result or return +/// null. In contrast to SimplifyFMulInst, this function won't return values +/// requiring rounding. +Value *SimplifyFMAFMul(Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q); + /// Given operands for a Mul, fold the result or return null. Value *SimplifyMulInst(Value *LHS, Value *RHS, const SimplifyQuery &Q); Index: llvm/lib/Analysis/InstructionSimplify.cpp =================================================================== --- llvm/lib/Analysis/InstructionSimplify.cpp +++ llvm/lib/Analysis/InstructionSimplify.cpp @@ -4533,23 +4533,27 @@ return nullptr; } -/// Given the operands for an FMul, see if we can fold the result -static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q)) - return C; - - if (Constant *C = simplifyFPBinop(Op0, Op1)) - return C; - +/// Given operands for the multiplication of a FMA, fold the result or return +/// null. In contrast to SimplifyFMulInst, this function won't return values +/// requiring rounding. +static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { // fmul X, 1.0 ==> X if (match(Op1, m_FPOne())) return Op0; + // fmul 1.0, X ==> X + if (match(Op0, m_FPOne())) + return Op1; + // fmul nnan nsz X, 0 ==> 0 if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZeroFP())) return ConstantFP::getNullValue(Op0->getType()); + // fmul nnan nsz 0, X ==> 0 + if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZeroFP())) + return ConstantFP::getNullValue(Op1->getType()); + // sqrt(X) * sqrt(X) --> X, if we can: // 1. Remove the intermediate rounding (reassociate). // 2. Ignore non-zero negative numbers because sqrt would produce NAN. @@ -4562,6 +4566,19 @@ return nullptr; } +/// Given the operands for an FMul, see if we can fold the result +static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q)) + return C; + + if (Constant *C = simplifyFPBinop(Op0, Op1)) + return C; + + // Now apply simplifications that do not require rounding. + return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse); +} + Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q) { return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit); @@ -4578,6 +4595,11 @@ return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit); } +Value *llvm::SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q) { + return ::SimplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit); +} + static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, unsigned) { if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) Index: llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -2234,6 +2234,15 @@ return replaceInstUsesWith(*II, Add); } + // Try to simplify the underlying FMul. + if (Value *V = SimplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1), + II->getFastMathFlags(), + SQ.getWithInstruction(II))) { + auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2)); + FAdd->copyFastMathFlags(II); + return FAdd; + } + LLVM_FALLTHROUGH; } case Intrinsic::fma: { @@ -2258,9 +2267,12 @@ return II; } - // fma x, 1, z -> fadd x, z - if (match(Src1, m_FPOne())) { - auto *FAdd = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2)); + // Try to simplify the underlying FMul. We can only apply simplifications + // that do not require rounding. + if (Value *V = SimplifyFMAFMul(II->getArgOperand(0), II->getArgOperand(1), + II->getFastMathFlags(), + SQ.getWithInstruction(II))) { + auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2)); FAdd->copyFastMathFlags(II); return FAdd; } Index: llvm/test/Transforms/InstCombine/fma.ll =================================================================== --- llvm/test/Transforms/InstCombine/fma.ll +++ llvm/test/Transforms/InstCombine/fma.ll @@ -372,8 +372,7 @@ define <2 x double> @fmuladd_a_0_b(<2 x double> %a, <2 x double> %b) { ; CHECK-LABEL: @fmuladd_a_0_b( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]]) -; CHECK-NEXT: ret <2 x double> [[RES]] +; CHECK-NEXT: ret <2 x double> [[B:%.*]] ; entry: %res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> zeroinitializer, <2 x double> %b) @@ -383,8 +382,7 @@ define <2 x double> @fmuladd_0_a_b(<2 x double> %a, <2 x double> %b) { ; CHECK-LABEL: @fmuladd_0_a_b( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]]) -; CHECK-NEXT: ret <2 x double> [[RES]] +; CHECK-NEXT: ret <2 x double> [[B:%.*]] ; entry: %res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> zeroinitializer, <2 x double> %a, <2 x double> %b) @@ -407,8 +405,7 @@ define <2 x double> @fma_a_0_b(<2 x double> %a, <2 x double> %b) { ; CHECK-LABEL: @fma_a_0_b( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]]) -; CHECK-NEXT: ret <2 x double> [[RES]] +; CHECK-NEXT: ret <2 x double> [[B:%.*]] ; entry: %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> %a, <2 x double> zeroinitializer, <2 x double> %b) @@ -418,8 +415,7 @@ define <2 x double> @fma_0_a_b(<2 x double> %a, <2 x double> %b) { ; CHECK-LABEL: @fma_0_a_b( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]]) -; CHECK-NEXT: ret <2 x double> [[RES]] +; CHECK-NEXT: ret <2 x double> [[B:%.*]] ; entry: %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> zeroinitializer, <2 x double> %a, <2 x double> %b) @@ -440,8 +436,7 @@ define <2 x double> @fma_sqrt(<2 x double> %a, <2 x double> %b) { ; CHECK-LABEL: @fma_sqrt( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[SQRT:%.*]] = call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> [[A:%.*]]) -; CHECK-NEXT: [[RES:%.*]] = call fast <2 x double> @llvm.fma.v2f64(<2 x double> [[SQRT]], <2 x double> [[SQRT]], <2 x double> [[B:%.*]]) +; CHECK-NEXT: [[RES:%.*]] = fadd fast <2 x double> [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: ret <2 x double> [[RES]] ; entry: @@ -450,6 +445,70 @@ ret <2 x double> %res } +; We do not fold constant multiplies in FMAs, as they could require rounding, unless either constant is 0.0 or 1.0. +define <2 x double> @fma_const_fmul(<2 x double> %b) { +; CHECK-LABEL: @fma_const_fmul( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> , <2 x double> , <2 x double> [[B:%.*]]) +; CHECK-NEXT: ret <2 x double> [[RES]] +; +entry: + %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> , <2 x double> , <2 x double> %b) + ret <2 x double> %res +} + +define <2 x double> @fma_const_fmul_zero(<2 x double> %b) { +; CHECK-LABEL: @fma_const_fmul_zero( +; CHECK-NEXT: entry: +; CHECK-NEXT: ret <2 x double> [[B:%.*]] +; +entry: + %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> , <2 x double> , <2 x double> %b) + ret <2 x double> %res +} + +define <2 x double> @fma_const_fmul_zero2(<2 x double> %b) { +; CHECK-LABEL: @fma_const_fmul_zero2( +; CHECK-NEXT: entry: +; CHECK-NEXT: ret <2 x double> [[B:%.*]] +; +entry: + %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> , <2 x double> , <2 x double> %b) + ret <2 x double> %res +} + +define <2 x double> @fma_const_fmul_one(<2 x double> %b) { +; CHECK-LABEL: @fma_const_fmul_one( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[RES:%.*]] = fadd nnan nsz <2 x double> [[B:%.*]], +; CHECK-NEXT: ret <2 x double> [[RES]] +; +entry: + %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> , <2 x double> , <2 x double> %b) + ret <2 x double> %res +} + +define <2 x double> @fma_const_fmul_one2(<2 x double> %b) { +; CHECK-LABEL: @fma_const_fmul_one2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[RES:%.*]] = fadd nnan nsz <2 x double> [[B:%.*]], +; CHECK-NEXT: ret <2 x double> [[RES]] +; +entry: + %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> , <2 x double> , <2 x double> %b) + ret <2 x double> %res +} + +define <2 x double> @fmuladd_const_fmul(<2 x double> %b) { +; CHECK-LABEL: @fmuladd_const_fmul( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[RES:%.*]] = fadd nnan nsz <2 x double> [[B:%.*]], +; CHECK-NEXT: ret <2 x double> [[RES]] +; +entry: + %res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> , <2 x double> , <2 x double> %b) + ret <2 x double> %res +} declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>) declare <2 x double> @llvm.sqrt.v2f64(<2 x double>)