Index: lib/Transforms/InstCombine/InstCombineMulDivRem.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -728,6 +728,23 @@ } } + // sqrt(a) * sqrt(b) -> sqrt(a * b) + if (AllowReassociate) { + IntrinsicInst *I0 = dyn_cast(Op0); + IntrinsicInst *I1 = dyn_cast(Op1); + if (I0 && I1 && + I0->getIntrinsicID() == Intrinsic::sqrt && + I1->getIntrinsicID() == Intrinsic::sqrt) { + Value *Opnd0 = I0->getOperand(0); + Value *Opnd1 = I1->getOperand(0); + Builder.setFastMathFlags(I.getFastMathFlags()); + Value *FMulVal = Builder.CreateFMul(Opnd0, Opnd1); + Value *Sqrt = Intrinsic::getDeclaration(I.getModule(), Intrinsic::sqrt, I.getType()); + Value *SqrtCall = Builder.CreateCall(Sqrt, FMulVal); + return replaceInstUsesWith(I, SqrtCall); + } + } + // Handle symmetric situation in a 2-iteration loop Value *Opnd0 = Op0; Value *Opnd1 = Op1; Index: test/Transforms/InstCombine/fmul.ll =================================================================== --- test/Transforms/InstCombine/fmul.ll +++ test/Transforms/InstCombine/fmul.ll @@ -181,3 +181,29 @@ %mul = fmul float %x.fabs, %y.fabs ret float %mul } + +; CHECK-LABEL @sqrt_a_sqrt_b( +; CHECK: fmul fast double %a, %b +; CHECK: call fast double @llvm.sqrt.f64(double %1) +define double @sqrt_a_sqrt_b(double %a, double %b) { + %1 = call fast double @llvm.sqrt.f64(double %a) + %2 = call fast double @llvm.sqrt.f64(double %b) + %mul = fmul fast double %1, %2 + ret double %mul +} + +; CHECK-LABEL @sqrt_a_sqrt_b_sqrt_c_sqrt_d( +; CHECK: fmul fast double %a, %b +; CHECK: fmul fast double %1, %c +; CHECK: fmul fast double %2, %d +; CHECK: call fast double @llvm.sqrt.f64(double %3) +define double @sqrt_a_sqrt_b_sqrt_c_sqrt_d(double %a, double %b, double %c, double %d) { + %1 = call fast double @llvm.sqrt.f64(double %a) + %2 = call fast double @llvm.sqrt.f64(double %b) + %mul = fmul fast double %1, %2 + %3 = call fast double @llvm.sqrt.f64(double %c) + %mul1 = fmul fast double %mul, %3 + %4 = call fast double @llvm.sqrt.f64(double %d) + %mul2 = fmul fast double %mul1, %4 + ret double %mul2 +}