diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -2258,14 +2258,26 @@ return II; } - // Try to simplify the underlying FMul. - if (Value *V = SimplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1), - II->getFastMathFlags(), - SQ.getWithInstruction(II))) { - auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2)); - FAdd->copyFastMathFlags(II); - return FAdd; - } + // Try to simplify the underlying FMul to a value that does not require + // additional rounding. SimplifyFMulInst will either simplify to an existing + // value or constant fold the multiply. If simplified to an existing value, + // no additional rounding is required. Constant folding could introduce + // additional rounding. We only try to simplify multiplications with 2 + // constants, if either is 1.0 or 0.0, which won't required rounding. + // Fmuladd intrinsics do not make any guarantees about rounding, so we can + // constant fold arbirary multiplies. + if (match(Src0, m_FPOne()) || match(Src1, m_FPOne()) || + match(Src0, m_Zero()) || match(Src1, m_Zero()) || + !isa(II->getArgOperand(0)) || + !isa(II->getArgOperand(1)) || IID == Intrinsic::fmuladd) + // Try to simplify the underlying FMul. + if (Value *V = SimplifyFMulInst( + II->getArgOperand(0), II->getArgOperand(1), + II->getFastMathFlags(), SQ.getWithInstruction(II))) { + auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2)); + FAdd->copyFastMathFlags(II); + return FAdd; + } break; } diff --git a/llvm/test/Transforms/InstCombine/fma.ll b/llvm/test/Transforms/InstCombine/fma.ll --- a/llvm/test/Transforms/InstCombine/fma.ll +++ b/llvm/test/Transforms/InstCombine/fma.ll @@ -445,6 +445,70 @@ ret <2 x double> %res } +; We do not fold constant multiplies in FMAs, as they could require rounding, unless either constant is 0.0 or 1.0. +define <2 x double> @fma_const_fmul(<2 x double> %b) { +; CHECK-LABEL: @fma_const_fmul( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> , <2 x double> , <2 x double> [[B:%.*]]) +; CHECK-NEXT: ret <2 x double> [[RES]] +; +entry: + %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> , <2 x double> , <2 x double> %b) + ret <2 x double> %res +} + +define <2 x double> @fma_const_fmul_zero(<2 x double> %b) { +; CHECK-LABEL: @fma_const_fmul_zero( +; CHECK-NEXT: entry: +; CHECK-NEXT: ret <2 x double> [[B:%.*]] +; +entry: + %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> , <2 x double> , <2 x double> %b) + ret <2 x double> %res +} + +define <2 x double> @fma_const_fmul_zero2(<2 x double> %b) { +; CHECK-LABEL: @fma_const_fmul_zero2( +; CHECK-NEXT: entry: +; CHECK-NEXT: ret <2 x double> [[B:%.*]] +; +entry: + %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> , <2 x double> , <2 x double> %b) + ret <2 x double> %res +} + +define <2 x double> @fma_const_fmul_one(<2 x double> %b) { +; CHECK-LABEL: @fma_const_fmul_one( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[RES:%.*]] = fadd nnan nsz <2 x double> [[B:%.*]], +; CHECK-NEXT: ret <2 x double> [[RES]] +; +entry: + %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> , <2 x double> , <2 x double> %b) + ret <2 x double> %res +} + +define <2 x double> @fma_const_fmul_one2(<2 x double> %b) { +; CHECK-LABEL: @fma_const_fmul_one2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[RES:%.*]] = fadd nnan nsz <2 x double> [[B:%.*]], +; CHECK-NEXT: ret <2 x double> [[RES]] +; +entry: + %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> , <2 x double> , <2 x double> %b) + ret <2 x double> %res +} + +define <2 x double> @fmuladd_const_fmul(<2 x double> %b) { +; CHECK-LABEL: @fmuladd_const_fmul( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[RES:%.*]] = fadd nnan nsz <2 x double> [[B:%.*]], +; CHECK-NEXT: ret <2 x double> [[RES]] +; +entry: + %res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> , <2 x double> , <2 x double> %b) + ret <2 x double> %res +} declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>) declare <2 x double> @llvm.sqrt.v2f64(<2 x double>)