Index: llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h =================================================================== --- llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -133,7 +133,7 @@ Value *optimizeCAbs(CallInst *CI, IRBuilder<> &B); Value *optimizeCos(CallInst *CI, IRBuilder<> &B); Value *optimizePow(CallInst *CI, IRBuilder<> &B); - Value *replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B); + Value *replacePowWithRoot(CallInst *Pow, IRBuilder<> &B); Value *optimizeExp2(CallInst *CI, IRBuilder<> &B); Value *optimizeFMinFMax(CallInst *CI, IRBuilder<> &B); Value *optimizeLog(CallInst *CI, IRBuilder<> &B); Index: llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1119,53 +1119,73 @@ return InnerChain[Exp]; } -/// Use square root in place of pow(x, +/-0.5). -Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { - Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); +/// Use sqrt() for pow(x, +/-0.5) and cbrt() for pow(x, +/-0.333...). +Value *LibCallSimplifier::replacePowWithRoot(CallInst *Pow, IRBuilder<> &B) { + Value *Root, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); + AttributeList Attrs = Pow->getCalledFunction()->getAttributes(); Module *Mod = Pow->getModule(); Type *Ty = Pow->getType(); const APFloat *ExpoF; - if (!match(Expo, m_APFloat(ExpoF)) || - (!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5))) - return nullptr; - - // If errno is never set, then use the intrinsic for sqrt(). - if (Pow->hasFnAttr(Attribute::ReadNone)) { - Function *SqrtFn = Intrinsic::getDeclaration(Pow->getModule(), - Intrinsic::sqrt, Ty); - Sqrt = B.CreateCall(SqrtFn, Base, "sqrt"); - } - // Otherwise, use the libcall for sqrt(). - else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) - // TODO: We also should check that the target can in fact lower the sqrt() - // libcall. We currently have no way to ask this question, so we ask if - // the target has a sqrt() libcall, which is not exactly the same. - Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), B, - Pow->getCalledFunction()->getAttributes()); - else + if (!match(Expo, m_APFloat(ExpoF))) return nullptr; - // Handle signed zero base by expanding to fabs(sqrt(x)). - if (!Pow->hasNoSignedZeros()) { - Function *FAbsFn = Intrinsic::getDeclaration(Mod, Intrinsic::fabs, Ty); - Sqrt = B.CreateCall(FAbsFn, Sqrt, "abs"); - } + const double OneThird = (Ty->getTypeID() == Type::FloatTyID) + ? (1.0f / 3.0f) : (1.0 / 3.0); + bool isHalf (ExpoF->isExactlyValue(0.5) || ExpoF->isExactlyValue(-0.5)), + isThird (ExpoF->isExactlyValue(OneThird) || + ExpoF->isExactlyValue(-OneThird)); + if (!isHalf && !isThird) + return nullptr; + + // Expand pow(x, +/-0.5) to sqrt(). + if (isHalf) { + // If errno is never set, then use the intrinsic for sqrt(). + if (Pow->hasFnAttr(Attribute::ReadNone)) { + Function *SqrtFn = Intrinsic::getDeclaration(Mod, Intrinsic::sqrt, Ty); + Root = B.CreateCall(SqrtFn, Base, "sqrt"); + } + // Otherwise, use the libcall for sqrt(). + else if (hasUnaryFloatFn(TLI, Ty, + LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) + // TODO: We also should check that the target can in fact lower the sqrt() + // libcall. We currently have no way to ask this question, so we ask if + // the target has a sqrt() libcall, which is not exactly the same. + Root = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), B, Attrs); + else + return nullptr; - // Handle non finite base by expanding to - // (x == -infinity ? +infinity : sqrt(x)). - if (!Pow->hasNoInfs()) { - Value *PosInf = ConstantFP::getInfinity(Ty), - *NegInf = ConstantFP::getInfinity(Ty, true); - Value *FCmp = B.CreateFCmpOEQ(Base, NegInf, "iseq"); - Sqrt = B.CreateSelect(FCmp, PosInf, Sqrt); + // Handle signed zero base by expanding to fabs(sqrt(x)). + if (!Pow->hasNoSignedZeros()) { + Function *FAbsFn = Intrinsic::getDeclaration(Mod, Intrinsic::fabs, Ty); + Root = B.CreateCall(FAbsFn, Root, "abs"); + } + + // Handle non finite base by expanding to + // (x == -infinity ? +infinity : sqrt(x)). + if (!Pow->hasNoInfs()) { + Value *PosInf = ConstantFP::getInfinity(Ty), + *NegInf = ConstantFP::getInfinity(Ty, true); + Value *FCmp = B.CreateFCmpOEQ(Base, NegInf, "iseq"); + Root = B.CreateSelect(FCmp, PosInf, Root); + } } + // Expand pow(x, +/-0.333...) to cbrt(), but only for a regular base. + else if (isThird && + Pow->hasNoInfs() && Pow->hasNoNaNs() && Pow->hasNoSignedZeros()) { + if (hasUnaryFloatFn(TLI, Ty, LibFunc_cbrt, LibFunc_cbrtf, LibFunc_cbrtl)) + Root = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_cbrt), B, Attrs); + else + return nullptr; + } + else + return nullptr; // If the exponent is negative, then get the reciprocal. if (ExpoF->isNegative()) - Sqrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Sqrt, "reciprocal"); + Root = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Root, "reciprocal"); - return Sqrt; + return Root; } Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { @@ -1246,8 +1266,8 @@ if (match(Expo, m_SpecificFP(2.0))) return B.CreateFMul(Base, Base, "square"); - if (Value *Sqrt = replacePowWithSqrt(Pow, B)) - return Sqrt; + if (Value *Root = replacePowWithRoot(Pow, B)) + return Root; // pow(x, n) -> x * x * x * ... const APFloat *ExpoF; Index: llvm/test/Transforms/InstCombine/pow-cbrt.ll =================================================================== --- llvm/test/Transforms/InstCombine/pow-cbrt.ll +++ llvm/test/Transforms/InstCombine/pow-cbrt.ll @@ -1,10 +1,9 @@ -; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt < %s -instcombine -S | FileCheck %s define double @pow_intrinsic_third_fast(double %x) { ; CHECK-LABEL: @pow_intrinsic_third_fast( -; CHECK-NEXT: [[POW:%.*]] = call fast double @llvm.pow.f64(double [[X:%.*]], double 0x3FD5555555555555) -; CHECK-NEXT: ret double [[POW]] +; CHECK-NEXT: [[CBRT:%.*]] = call fast double @cbrt(double %x) #1 +; CHECK-NEXT: ret double [[CBRT]] ; %pow = call fast double @llvm.pow.f64(double %x, double 0x3fd5555555555555) ret double %pow @@ -12,8 +11,8 @@ define float @powf_intrinsic_third_fast(float %x) { ; CHECK-LABEL: @powf_intrinsic_third_fast( -; CHECK-NEXT: [[POW:%.*]] = call fast float @llvm.pow.f32(float [[X:%.*]], float 0x3FD5555560000000) -; CHECK-NEXT: ret float [[POW]] +; CHECK-NEXT: [[CBRTF:%.*]] = call fast float @cbrtf(float %x) #1 +; CHECK-NEXT: ret float [[CBRTF]] ; %pow = call fast float @llvm.pow.f32(float %x, float 0x3fd5555560000000) ret float %pow @@ -21,7 +20,7 @@ define double @pow_intrinsic_third_approx(double %x) { ; CHECK-LABEL: @pow_intrinsic_third_approx( -; CHECK-NEXT: [[POW:%.*]] = call afn double @llvm.pow.f64(double [[X:%.*]], double 0x3FD5555555555555) +; CHECK-NEXT: [[POW:%.*]] = call afn double @llvm.pow.f64(double %x, double 0x3FD5555555555555) ; CHECK-NEXT: ret double [[POW]] ; %pow = call afn double @llvm.pow.f64(double %x, double 0x3fd5555555555555) @@ -30,7 +29,7 @@ define float @powf_intrinsic_third_approx(float %x) { ; CHECK-LABEL: @powf_intrinsic_third_approx( -; CHECK-NEXT: [[POW:%.*]] = call afn float @llvm.pow.f32(float [[X:%.*]], float 0x3FD5555560000000) +; CHECK-NEXT: [[POW:%.*]] = call afn float @llvm.pow.f32(float %x, float 0x3FD5555560000000) ; CHECK-NEXT: ret float [[POW]] ; %pow = call afn float @llvm.pow.f32(float %x, float 0x3fd5555560000000) @@ -39,8 +38,8 @@ define double @pow_libcall_third_fast(double %x) { ; CHECK-LABEL: @pow_libcall_third_fast( -; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[X:%.*]], double 0x3FD5555555555555) -; CHECK-NEXT: ret double [[POW]] +; CHECK-NEXT: [[CBRT:%.*]] = call fast double @cbrt(double %x) +; CHECK-NEXT: ret double [[CBRT]] ; %pow = call fast double @pow(double %x, double 0x3fd5555555555555) ret double %pow @@ -48,8 +47,8 @@ define float @powf_libcall_third_fast(float %x) { ; CHECK-LABEL: @powf_libcall_third_fast( -; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[X:%.*]], float 0x3FD5555560000000) -; CHECK-NEXT: ret float [[POW]] +; CHECK-NEXT: [[CBRTF:%.*]] = call fast float @cbrtf(float %x) +; CHECK-NEXT: ret float [[CBRTF]] ; %pow = call fast float @powf(float %x, float 0x3fd5555560000000) ret float %pow @@ -57,8 +56,9 @@ define double @pow_intrinsic_negthird_fast(double %x) { ; CHECK-LABEL: @pow_intrinsic_negthird_fast( -; CHECK-NEXT: [[POW:%.*]] = call fast double @llvm.pow.f64(double [[X:%.*]], double 0xBFD5555555555555) -; CHECK-NEXT: ret double [[POW]] +; CHECK-NEXT: [[CBRT:%.*]] = call fast double @cbrt(double %x) #1 +; CHECK-NEXT: [[RECP:%.*]] = fdiv fast double 1.000000e+00, [[CBRT]] +; CHECK-NEXT: ret double [[RECP]] ; %pow = call fast double @llvm.pow.f64(double %x, double 0xbfd5555555555555) ret double %pow @@ -66,8 +66,9 @@ define float @powf_intrinsic_negthird_fast(float %x) { ; CHECK-LABEL: @powf_intrinsic_negthird_fast( -; CHECK-NEXT: [[POW:%.*]] = call fast float @llvm.pow.f32(float [[X:%.*]], float 0xBFD5555560000000) -; CHECK-NEXT: ret float [[POW]] +; CHECK-NEXT: [[CBRTF:%.*]] = call fast float @cbrtf(float %x) #1 +; CHECK-NEXT: [[RECP:%.*]] = fdiv fast float 1.000000e+00, [[CBRTF]] +; CHECK-NEXT: ret float [[RECP]] ; %pow = call fast float @llvm.pow.f32(float %x, float 0xbfd5555560000000) ret float %pow @@ -75,7 +76,7 @@ define double @pow_intrinsic_negthird_approx(double %x) { ; CHECK-LABEL: @pow_intrinsic_negthird_approx( -; CHECK-NEXT: [[POW:%.*]] = call afn double @llvm.pow.f64(double [[X:%.*]], double 0xBFD5555555555555) +; CHECK-NEXT: [[POW:%.*]] = call afn double @llvm.pow.f64(double %x, double 0xBFD5555555555555) ; CHECK-NEXT: ret double [[POW]] ; %pow = call afn double @llvm.pow.f64(double %x, double 0xbfd5555555555555) @@ -84,7 +85,7 @@ define float @powf_intrinsic_negthird_approx(float %x) { ; CHECK-LABEL: @powf_intrinsic_negthird_approx( -; CHECK-NEXT: [[POW:%.*]] = call afn float @llvm.pow.f32(float [[X:%.*]], float 0xBFD5555560000000) +; CHECK-NEXT: [[POW:%.*]] = call afn float @llvm.pow.f32(float %x, float 0xBFD5555560000000) ; CHECK-NEXT: ret float [[POW]] ; %pow = call afn float @llvm.pow.f32(float %x, float 0xbfd5555560000000) @@ -93,8 +94,9 @@ define double @pow_libcall_negthird_fast(double %x) { ; CHECK-LABEL: @pow_libcall_negthird_fast( -; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[X:%.*]], double 0xBFD5555555555555) -; CHECK-NEXT: ret double [[POW]] +; CHECK-NEXT: [[CBRT:%.*]] = call fast double @cbrt(double %x) +; CHECK-NEXT: [[RECP:%.*]] = fdiv fast double 1.000000e+00, [[CBRT]] +; CHECK-NEXT: ret double [[RECP]] ; %pow = call fast double @pow(double %x, double 0xbfd5555555555555) ret double %pow @@ -102,8 +104,9 @@ define float @powf_libcall_negthird_fast(float %x) { ; CHECK-LABEL: @powf_libcall_negthird_fast( -; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[X:%.*]], float 0xBFD5555560000000) -; CHECK-NEXT: ret float [[POW]] +; CHECK-NEXT: [[CBRTF:%.*]] = call fast float @cbrtf(float %x) +; CHECK-NEXT: [[RECP:%.*]] = fdiv fast float 1.000000e+00, [[CBRTF]] +; CHECK-NEXT: ret float [[RECP]] ; %pow = call fast float @powf(float %x, float 0xbfd5555560000000) ret float %pow