diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -544,6 +544,21 @@ return replaceInstUsesWith(I, Sqrt); } + // The following transforms are done irrespective of the number of uses + // for the expression "1.0/sqrt(X)". + // 1) 1.0/sqrt(X) * X -> X/sqrt(X) + // 2) X * 1.0/sqrt(X) -> X/sqrt(X) + // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it + // has the necessary (reassoc) fast-math-flags. + if (I.hasNoSignedZeros() && + match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && + match(Y, m_Intrinsic(m_Value(X))) && Op1 == X) + return BinaryOperator::CreateFDivFMF(X, Y, &I); + if (I.hasNoSignedZeros() && + match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && + match(Y, m_Intrinsic(m_Value(X))) && Op0 == X) + return BinaryOperator::CreateFDivFMF(X, Y, &I); + // Like the similar transform in instsimplify, this requires 'nsz' because // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0. if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && diff --git a/llvm/test/Transforms/InstCombine/fmul-sqrt.ll b/llvm/test/Transforms/InstCombine/fmul-sqrt.ll --- a/llvm/test/Transforms/InstCombine/fmul-sqrt.ll +++ b/llvm/test/Transforms/InstCombine/fmul-sqrt.ll @@ -103,7 +103,7 @@ ; CHECK-LABEL: @rsqrt_x_reassociate_extra_use( ; CHECK-NEXT: [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[X:%.*]]) ; CHECK-NEXT: [[RSQRT:%.*]] = fdiv double 1.000000e+00, [[SQRT]] -; CHECK-NEXT: [[RES:%.*]] = fmul reassoc nsz double [[RSQRT]], [[X]] +; CHECK-NEXT: [[RES:%.*]] = fdiv reassoc nsz double [[X:%.*]], [[SQRT]] ; CHECK-NEXT: store double [[RSQRT]], double* [[P:%.*]], align 8 ; CHECK-NEXT: ret double [[RES]] ; @@ -119,7 +119,7 @@ ; CHECK-NEXT: [[ADD:%.*]] = fadd fast <2 x float> [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: [[SQRT:%.*]] = call fast <2 x float> @llvm.sqrt.v2f32(<2 x float> [[ADD]]) ; CHECK-NEXT: [[RSQRT:%.*]] = fdiv fast <2 x float> , [[SQRT]] -; CHECK-NEXT: [[RES:%.*]] = fmul fast <2 x float> [[ADD]], [[RSQRT]] +; CHECK-NEXT: [[RES:%.*]] = fdiv fast <2 x float> [[ADD]], [[SQRT]] ; CHECK-NEXT: store <2 x float> [[RSQRT]], <2 x float>* [[P:%.*]], align 8 ; CHECK-NEXT: ret <2 x float> [[RES]] ;