Index: llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h =================================================================== --- llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -133,6 +133,7 @@ Value *optimizeCAbs(CallInst *CI, IRBuilder<> &B); Value *optimizeCos(CallInst *CI, IRBuilder<> &B); Value *optimizePow(CallInst *CI, IRBuilder<> &B); + Value *replacePowWithExp(CallInst *Pow, IRBuilder<> &B); Value *replacePowWithRoot(CallInst *Pow, IRBuilder<> &B); Value *optimizeExp2(CallInst *CI, IRBuilder<> &B); Value *optimizeFMinFMax(CallInst *CI, IRBuilder<> &B); Index: llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/SimplifyLibCalls.h" +#include "llvm/ADT/APSInt.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/Triple.h" @@ -1119,6 +1120,89 @@ return InnerChain[Exp]; } +/// Use exp2(n * x) for pow(2.0 ** n, x); exp10(n * x) for pow(10.0 ** n, x); +/// exp{,2,10}(x * y) for pow(exp{,2,10}(x), y). +Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { + Value *BaseV = Pow->getArgOperand(0), *ExpoV = Pow->getArgOperand(1); + AttributeList Attrs = Pow->getCalledFunction()->getAttributes(); + Module *Mod = Pow->getModule(); + Type *Ty = Pow->getType(); + bool Ignored; + + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(Pow->getFastMathFlags()); + + // Evaluate special cases related to a nested function as the base. + + // pow(exp(x), y) -> exp(x * y) + // pow(exp2(x), y) -> exp2(x * y) + // pow(exp10(x), y) -> exp10(x * y) + // We enable these only with fast-math. Besides rounding differences, the + // transformation changes overflow and underflow behavior quite dramatically. + // For example: + // pow(exp(1000), 0.001) = pow(inf, 0.001) = inf + // Whereas: + // exp(1000 * 0.001) = exp(1) + CallInst *BaseFn = dyn_cast(BaseV); + if (BaseFn && BaseFn->isFast() && Pow->isFast()) { + Function *CalledFn = BaseFn->getCalledFunction(); + if (CalledFn) { + StringRef NameFn = CalledFn->getName(); + LibFunc Fn; + + if (TLI->getLibFunc(NameFn, Fn) && TLI->has(Fn)) { + Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), ExpoV, "mul"); + Value *ExpFn; + + switch (Fn) { + default: + return nullptr; + case LibFunc_exp: + case LibFunc_expf: + case LibFunc_expl: + ExpFn = Intrinsic::getDeclaration(Mod, Intrinsic::exp, Ty); + return B.CreateCall(ExpFn, FMul, "exp"); + case LibFunc_exp2: + case LibFunc_exp2f: + case LibFunc_exp2l: + ExpFn = Intrinsic::getDeclaration(Mod, Intrinsic::exp2, Ty); + return B.CreateCall(ExpFn, FMul, "exp2"); + case LibFunc_exp10: + case LibFunc_exp10f: + case LibFunc_exp10l: + return emitUnaryFloatFnCall(FMul, NameFn, B, Attrs); + } + } + } + } + + // Evaluate special cases related to a constant base. + + const APFloat *BaseF; + if (!match(Pow->getArgOperand(0), m_APFloat(BaseF))) + return nullptr; + + // pow(2.0 ** n, x) -> exp2(n * x) + APSInt BaseI(64, false); + if (BaseF->isInteger() && + !BaseF->convertToInteger(BaseI, APFloat::rmTowardZero, &Ignored) && + BaseI > 1 && BaseI.isPowerOf2()) { + unsigned BaseN = BaseI.logBase2(); + Value *ExpoN = B.CreateFMul(ExpoV, ConstantFP::get(Ty, BaseN), "mul"); + Value *Exp2Fn = Intrinsic::getDeclaration(Mod, Intrinsic::exp2, Ty); + return B.CreateCall(Exp2Fn, ExpoN, "exp2"); + } + + // pow(10.0, x) -> exp10(x) + // TODO: There is no exp10() intrinsic yet, but some day there shall be one. + if (BaseF->isExactlyValue(10.0) && + hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) + // There's no exp10() intrinsic yet. + return emitUnaryFloatFnCall(ExpoV, TLI->getName(LibFunc_exp10), B, Attrs); + + return nullptr; +} + /// Use sqrt() for pow(x, +/-0.5) and cbrt() for pow(x, +/-0.333...). Value *LibCallSimplifier::replacePowWithRoot(CallInst *Pow, IRBuilder<> &B) { Module *Mod = Pow->getModule(); @@ -1195,7 +1279,6 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - AttributeList Attrs = Callee->getAttributes(); Value *Base = CI->getArgOperand(0), *Expo = CI->getArgOperand(1); Type *Ty = CI->getType(); Value *Shrunk = nullptr; @@ -1219,36 +1302,8 @@ if (match(Base, m_FPOne())) return Base; - // pow(2.0, x) -> exp2(x) - if (match(Base, m_SpecificFP(2.0))) { - Value *Exp2 = Intrinsic::getDeclaration(CI->getModule(), Intrinsic::exp2, - Ty); - return B.CreateCall(Exp2, Expo, "exp2"); - } - - // pow(10.0, x) -> exp10(x) - // TODO: There is no exp10() intrinsic yet, but some day there shall be one. - ConstantFP *Op1C = dyn_cast(Base); - if (Op1C && Op1C->isExactlyValue(10.0) && - hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) - return emitUnaryFloatFnCall(Expo, TLI->getName(LibFunc_exp10), B, Attrs); - - // pow(exp(x), y) -> exp(x * y) - // pow(exp2(x), y) -> exp2(x * y) - // We enable these only with fast-math. Besides rounding differences, the - // transformation changes overflow and underflow behavior quite dramatically. - // Example: x = 1000, y = 0.001. - // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1). - auto *BaseFn = dyn_cast(Base); - if (BaseFn && BaseFn->isFast() && CI->isFast()) { - LibFunc LibFn; - Function *CalleeFn = BaseFn->getCalledFunction(); - if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && - TLI->has(LibFn) && (LibFn == LibFunc_exp || LibFn == LibFunc_exp2)) { - Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul"); - return emitUnaryFloatFnCall(FMul, CalleeFn->getName(), B, Attrs); - } - } + if (Value *Exp = replacePowWithExp(CI, B)) + return Exp; // Evaluate special cases related to the exponent. Index: llvm/test/Transforms/InstCombine/pow-1.ll =================================================================== --- llvm/test/Transforms/InstCombine/pow-1.ll +++ llvm/test/Transforms/InstCombine/pow-1.ll @@ -32,7 +32,7 @@ ; CHECK-NEXT: ret double 1.000000e+00 } -; Check pow(2.0, x) -> exp2(x). +; Check pow(2.0 ** n, x) -> exp2(n * x). define float @test_simplify3(float %x) { ; CHECK-LABEL: @test_simplify3( @@ -44,8 +44,9 @@ define double @test_simplify4(double %x) { ; CHECK-LABEL: @test_simplify4( - %retval = call double @pow(double 2.0, double %x) -; CHECK-NEXT: [[EXP2:%[a-z0-9]+]] = call double @llvm.exp2.f64(double %x) + %retval = call double @pow(double 1024.0, double %x) +; CHECK-NEXT: [[TMP1:%.*]] = fmul double %x, 1.000000e+01 +; CHECK-NEXT: [[EXP2:%[a-z0-9]+]] = call double @llvm.exp2.f64(double [[TMP1]]) ret double %retval ; CHECK-NEXT: ret double [[EXP2]] } Index: llvm/test/Transforms/InstCombine/pow-exp.ll =================================================================== --- llvm/test/Transforms/InstCombine/pow-exp.ll +++ llvm/test/Transforms/InstCombine/pow-exp.ll @@ -1,15 +1,15 @@ ; RUN: opt < %s -instcombine -S | FileCheck %s -define double @pow_exp(double %x, double %y) { - %call = call fast double @exp(double %x) nounwind readnone - %pow = call fast double @llvm.pow.f64(double %call, double %y) - ret double %pow +define float @pow_exp(float %x, float %y) { + %call = call fast float @expf(float %x) nounwind readnone + %pow = call fast float @llvm.pow.f32(float %call, float %y) + ret float %pow } -; CHECK-LABEL: define double @pow_exp( -; CHECK-NEXT: %mul = fmul fast double %x, %y -; CHECK-NEXT: %exp = call fast double @exp(double %mul) -; CHECK-NEXT: ret double %exp +; CHECK-LABEL: define float @pow_exp +; CHECK-NEXT: %mul = fmul fast float %x, %y +; CHECK-NEXT: %exp = call fast float @llvm.exp.f32(float %mul) +; CHECK-NEXT: ret float %exp define double @pow_exp2(double %x, double %y) { %call = call fast double @exp2(double %x) nounwind readnone @@ -17,9 +17,9 @@ ret double %pow } -; CHECK-LABEL: define double @pow_exp2( +; CHECK-LABEL: define double @pow_exp2 ; CHECK-NEXT: %mul = fmul fast double %x, %y -; CHECK-NEXT: %exp2 = call fast double @exp2(double %mul) +; CHECK-NEXT: %exp2 = call fast double @llvm.exp2.f64(double %mul) ; CHECK-NEXT: ret double %exp2 define double @pow_exp_not_fast(double %x, double %y) { @@ -43,7 +43,8 @@ ; CHECK-NEXT: %call1 = call fast double %fptr() ; CHECK-NEXT: %pow = call fast double @llvm.pow.f64(double %call1, double %p1) +declare float @expf(float) declare double @exp(double) declare double @exp2(double) +declare float @llvm.pow.f32(float, float) declare double @llvm.pow.f64(double, double) -