Index: llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h =================================================================== --- llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -131,6 +131,7 @@ // Math Library Optimizations Value *optimizeCos(CallInst *CI, IRBuilder<> &B); Value *optimizePow(CallInst *CI, IRBuilder<> &B); + Value *replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B); Value *optimizeExp2(CallInst *CI, IRBuilder<> &B); Value *optimizeFMinFMax(CallInst *CI, IRBuilder<> &B); Value *optimizeLog(CallInst *CI, IRBuilder<> &B); Index: llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1074,6 +1074,52 @@ return InnerChain[Exp]; } +/// Use square root in place of pow(x, +/-0.5). +Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { + // TODO: There is some subset of 'fast' under which these transforms should + // be allowed. + if (!Pow->isFast()) + return nullptr; + + // TODO: This should use m_APFloat to allow vector splats. + ConstantFP *Op2C = dyn_cast(Pow->getArgOperand(1)); + if (!Op2C) + return nullptr; + if (!Op2C->isExactlyValue(0.5) && !Op2C->isExactlyValue(-0.5)) + return nullptr; + + // Fast-math flags from the pow() are propagated to all replacement ops. + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(Pow->getFastMathFlags()); + Type *Ty = Pow->getType(); + Value *Sqrt; + if (Pow->hasFnAttr(Attribute::ReadNone)) { + // We know that errno is never set, so replace with an intrinsic: + // pow(x, 0.5) --> llvm.sqrt(x) + // llvm.pow(x, 0.5) --> llvm.sqrt(x) + auto *F = Intrinsic::getDeclaration(Pow->getModule(), Intrinsic::sqrt, Ty); + Sqrt = B.CreateCall(F, Pow->getArgOperand(0)); + } else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl)) { + // Errno could be set, so we must use a sqrt libcall. + // TODO: We also should check that the target can in fact lower the sqrt + // libcall. We currently have no way to ask this question, so we ask + // whether the target has a sqrt libcall which is not exactly the same. + Sqrt = emitUnaryFloatFnCall(Pow->getArgOperand(0), + TLI->getName(LibFunc_sqrt), B, + Pow->getCalledFunction()->getAttributes()); + } else { + // We can't replace with an intrinsic or a libcall. + return nullptr; + } + + // If this is pow(x, -0.5), get the reciprocal. + if (Op2C->isExactlyValue(-0.5)) + Sqrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Sqrt); + + return Sqrt; +} + Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); Value *Ret = nullptr; @@ -1131,42 +1177,13 @@ if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0 return ConstantFP::get(CI->getType(), 1.0); - if (Op2C->isExactlyValue(-0.5) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf, - LibFunc_sqrtl)) { - // If -ffast-math: - // pow(x, -0.5) -> 1.0 / sqrt(x) - if (CI->isFast()) { - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - // TODO: If the pow call is an intrinsic, we should lower to the sqrt - // intrinsic, so we match errno semantics. We also should check that the - // target can in fact lower the sqrt intrinsic -- we currently have no way - // to ask this question other than asking whether the target has a sqrt - // libcall, which is a sufficient but not necessary condition. - Value *Sqrt = emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc_sqrt), B, - Callee->getAttributes()); - - return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Sqrt, "sqrtrecip"); - } - } + if (Value *Sqrt = replacePowWithSqrt(CI, B)) + return Sqrt; + // FIXME: Correct the transforms and pull this into replacePowWithSqrt(). if (Op2C->isExactlyValue(0.5) && hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) { - - // In -ffast-math, pow(x, 0.5) -> sqrt(x). - if (CI->isFast()) { - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - // TODO: As above, we should lower to the sqrt intrinsic if the pow is an - // intrinsic, to match errno semantics. - return emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc_sqrt), B, - Callee->getAttributes()); - } - // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). // This is faster than calling pow, and still handles negative zero // and negative infinity correctly. Index: llvm/trunk/test/Transforms/InstCombine/pow-sqrt.ll =================================================================== --- llvm/trunk/test/Transforms/InstCombine/pow-sqrt.ll +++ llvm/trunk/test/Transforms/InstCombine/pow-sqrt.ll @@ -2,8 +2,8 @@ define double @pow_intrinsic_half_fast(double %x) { ; CHECK-LABEL: @pow_intrinsic_half_fast( -; CHECK-NEXT: [[SQRT:%.*]] = call fast double @sqrt(double %x) #1 -; CHECK-NEXT: ret double [[SQRT]] +; CHECK-NEXT: [[TMP1:%.*]] = call fast double @llvm.sqrt.f64(double %x) +; CHECK-NEXT: ret double [[TMP1]] ; %pow = call fast double @llvm.pow.f64(double %x, double 5.000000e-01) ret double %pow @@ -51,8 +51,8 @@ define float @pow_libcall_neghalf_fast(float %x) { ; CHECK-LABEL: @pow_libcall_neghalf_fast( ; CHECK-NEXT: [[SQRTF:%.*]] = call fast float @sqrtf(float %x) -; CHECK-NEXT: [[SQRTRECIP:%.*]] = fdiv fast float 1.000000e+00, [[SQRTF]] -; CHECK-NEXT: ret float [[SQRTRECIP]] +; CHECK-NEXT: [[TMP1:%.*]] = fdiv fast float 1.000000e+00, [[SQRTF]] +; CHECK-NEXT: ret float [[TMP1]] ; %pow = call fast float @powf(float %x, float -5.0e-01) ret float %pow