Index: llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h =================================================================== --- llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -133,6 +133,7 @@ Value *optimizeCAbs(CallInst *CI, IRBuilder<> &B); Value *optimizeCos(CallInst *CI, IRBuilder<> &B); Value *optimizePow(CallInst *CI, IRBuilder<> &B); + Value *replacePowWithExp(CallInst *Pow, IRBuilder<> &B); Value *replacePowWithRoot(CallInst *Pow, IRBuilder<> &B); Value *optimizeExp2(CallInst *CI, IRBuilder<> &B); Value *optimizeFMinFMax(CallInst *CI, IRBuilder<> &B); Index: llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/SimplifyLibCalls.h" +#include "llvm/ADT/APSInt.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/Triple.h" @@ -1104,6 +1105,86 @@ return InnerChain[Exp]; } +/// Use exp2(n * x) for pow(2.0 ** n, x); exp10(n * x) for pow(10.0 ** n, x); +/// exp{,2}(x * y) for pow(exp{,2}(x), y). +/// TODO: Handle exp10() when more targets have it available. +Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { + Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); + AttributeList Attrs = Pow->getCalledFunction()->getAttributes(); + Module *Mod = Pow->getModule(); + Type *Ty = Pow->getType(); + bool Ignored; + + // Evaluate special cases related to a nested function as the base. + + // pow(exp(x), y) -> exp(x * y) + // pow(exp2(x), y) -> exp2(x * y) + // We enable these only with fast-math. Besides rounding differences, the + // transformation changes overflow and underflow behavior quite dramatically. + // For example: + // pow(exp(1000), 0.001) = pow(inf, 0.001) = inf + // Whereas: + // exp(1000 * 0.001) = exp(1) + CallInst *BaseFn = dyn_cast(Base); + if (BaseFn && BaseFn->isFast() && Pow->isFast()) { + Function *CalledFn = BaseFn->getCalledFunction(); + if (CalledFn) { + StringRef NameFn = CalledFn->getName(); + LibFunc Fn; + if (TLI->getLibFunc(NameFn, Fn) && TLI->has(Fn)) { + Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul"); + Value *ExpFn; + + switch (Fn) { + default: + return nullptr; + case LibFunc_exp: + case LibFunc_expf: + case LibFunc_expl: + ExpFn = Intrinsic::getDeclaration(Mod, Intrinsic::exp, Ty); + return B.CreateCall(ExpFn, FMul, "exp"); + case LibFunc_exp2: + case LibFunc_exp2f: + case LibFunc_exp2l: + ExpFn = Intrinsic::getDeclaration(Mod, Intrinsic::exp2, Ty); + return B.CreateCall(ExpFn, FMul, "exp2"); + } + } + } + } + + // Evaluate special cases related to a constant base. + + const APFloat *BaseF; + if (!match(Pow->getArgOperand(0), m_APFloat(BaseF))) + return nullptr; + + // pow(2.0 ** n, x) -> exp2(n * x) + APFloat BaseR = APFloat(1.0); + BaseR.convert(BaseF->getSemantics(), APFloat::rmTowardZero, &Ignored); + BaseR = BaseR / *BaseF; + bool isInteger = BaseF->isInteger(), + isReciprocal = BaseR.isInteger(); + const APFloat *NF = isReciprocal ? &BaseR : BaseF; + APSInt NI(64, false); + if ((isInteger || isReciprocal) && + !NF->convertToInteger(NI, APFloat::rmTowardZero, &Ignored) && + NI > 1 && NI.isPowerOf2()) { + double N = NI.logBase2() * (isReciprocal ? -1.0 : 1.0); + Value *ExpoN = B.CreateFMul(Expo, ConstantFP::get(Ty, N)); + Value *Exp2Fn = Intrinsic::getDeclaration(Mod, Intrinsic::exp2, Ty); + return B.CreateCall(Exp2Fn, ExpoN, "exp2"); + } + + // pow(10.0, x) -> exp10(x) + // TODO: There is no exp10() intrinsic yet, but some day there shall be one. + if (BaseF->isExactlyValue(10.0) && + hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) + return emitUnaryFloatFnCall(Expo, TLI->getName(LibFunc_exp10), B, Attrs); + + return nullptr; +} + /// Use sqrt() for pow(x, +/-0.5) and cbrt() for pow(x, +/-0.333...). Value *LibCallSimplifier::replacePowWithRoot(CallInst *Pow, IRBuilder<> &B) { Value *Root, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); @@ -1176,9 +1257,7 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); Function *Callee = Pow->getCalledFunction(); - AttributeList Attrs = Callee->getAttributes(); StringRef Name = Callee->getName(); - Module *Module = Pow->getModule(); Type *Ty = Pow->getType(); Value *Shrunk = nullptr; bool Ignored; @@ -1203,36 +1282,8 @@ if (match(Base, m_FPOne())) return Base; - // pow(2.0, x) -> exp2(x) - if (match(Base, m_SpecificFP(2.0))) { - Value *Exp2 = Intrinsic::getDeclaration(Module, Intrinsic::exp2, Ty); - return B.CreateCall(Exp2, Expo, "exp2"); - } - - // pow(10.0, x) -> exp10(x) - if (ConstantFP *BaseC = dyn_cast(Base)) - // There's no exp10 intrinsic yet, but, maybe, some day there shall be one. - if (BaseC->isExactlyValue(10.0) && - hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) - return emitUnaryFloatFnCall(Expo, TLI->getName(LibFunc_exp10), B, Attrs); - - // pow(exp(x), y) -> exp(x * y) - // pow(exp2(x), y) -> exp2(x * y) - // We enable these only with fast-math. Besides rounding differences, the - // transformation changes overflow and underflow behavior quite dramatically. - // Example: x = 1000, y = 0.001. - // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x * y) = exp(1). - auto *BaseFn = dyn_cast(Base); - if (BaseFn && BaseFn->isFast() && Pow->isFast()) { - LibFunc LibFn; - Function *CalleeFn = BaseFn->getCalledFunction(); - if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && - (LibFn == LibFunc_exp || LibFn == LibFunc_exp2) && TLI->has(LibFn)) { - Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul"); - return emitUnaryFloatFnCall(FMul, CalleeFn->getName(), B, - CalleeFn->getAttributes()); - } - } + if (Value *Exp = replacePowWithExp(Pow, B)) + return Exp; // Evaluate special cases related to the exponent. Index: llvm/test/Transforms/InstCombine/pow-1.ll =================================================================== --- llvm/test/Transforms/InstCombine/pow-1.ll +++ llvm/test/Transforms/InstCombine/pow-1.ll @@ -48,12 +48,13 @@ ; CHECK-NEXT: ret <2 x double> } -; Check pow(2.0, x) -> exp2(x). +; Check pow(2.0 ** n, x) -> exp2(n * x). define float @test_simplify3(float %x) { ; CHECK-LABEL: @test_simplify3( - %retval = call float @powf(float 2.0, float %x) -; CHECK-NEXT: [[EXP2F:%[a-z0-9]+]] = call float @llvm.exp2.f32(float %x) + %retval = call float @powf(float 0.25, float %x) +; CHECK-NEXT: [[TMP1:%.*]] = fmul float %x, -2.000000e+00 +; CHECK-NEXT: [[EXP2F:%[a-z0-9]+]] = call float @llvm.exp2.f32(float [[TMP1]]) ret float %retval ; CHECK-NEXT: ret float [[EXP2F]] } @@ -68,16 +69,18 @@ define double @test_simplify4(double %x) { ; CHECK-LABEL: @test_simplify4( - %retval = call double @pow(double 2.0, double %x) -; CHECK-NEXT: [[EXP2:%[a-z0-9]+]] = call double @llvm.exp2.f64(double %x) + %retval = call double @pow(double 4.0, double %x) +; CHECK-NEXT: [[TMP1:%.*]] = fmul double %x, 2.000000e+00 +; CHECK-NEXT: [[EXP2:%[a-z0-9]+]] = call double @llvm.exp2.f64(double [[TMP1]]) ret double %retval ; CHECK-NEXT: ret double [[EXP2]] } define <2 x double> @test_simplify4v(<2 x double> %x) { ; CHECK-LABEL: @test_simplify4v( - %retval = call <2 x double> @llvm.pow.v2f64(<2 x double> , <2 x double> %x) -; CHECK-NEXT: [[EXP2:%[a-z0-9]+]] = call <2 x double> @llvm.exp2.v2f64(<2 x double> %x) + %retval = call <2 x double> @llvm.pow.v2f64(<2 x double> , <2 x double> %x) +; CHECK-NEXT: [[TMP1:%.*]] = fsub <2 x double> , %x +; CHECK-NEXT: [[EXP2:%[a-z0-9]+]] = call <2 x double> @llvm.exp2.v2f64(<2 x double> [[TMP1]]) ret <2 x double> %retval ; CHECK-NEXT: ret <2 x double> [[EXP2]] } Index: llvm/test/Transforms/InstCombine/pow-exp.ll =================================================================== --- llvm/test/Transforms/InstCombine/pow-exp.ll +++ llvm/test/Transforms/InstCombine/pow-exp.ll @@ -1,15 +1,15 @@ ; RUN: opt < %s -instcombine -S | FileCheck %s -define double @pow_exp(double %x, double %y) { - %call = call fast double @exp(double %x) nounwind readnone - %pow = call fast double @llvm.pow.f64(double %call, double %y) - ret double %pow +define float @pow_exp(float %x, float %y) { + %call = call fast float @expf(float %x) nounwind readnone + %pow = call fast float @llvm.pow.f32(float %call, float %y) + ret float %pow } -; CHECK-LABEL: define double @pow_exp( -; CHECK-NEXT: %mul = fmul fast double %x, %y -; CHECK-NEXT: %exp = call fast double @exp(double %mul) -; CHECK-NEXT: ret double %exp +; CHECK-LABEL: define float @pow_exp( +; CHECK-NEXT: %mul = fmul fast float %x, %y +; CHECK-NEXT: %exp = call fast float @llvm.exp.f32(float %mul) +; CHECK-NEXT: ret float %exp define double @pow_exp2(double %x, double %y) { %call = call fast double @exp2(double %x) nounwind readnone @@ -19,9 +19,21 @@ ; CHECK-LABEL: define double @pow_exp2( ; CHECK-NEXT: %mul = fmul fast double %x, %y -; CHECK-NEXT: %exp2 = call fast double @exp2(double %mul) +; CHECK-NEXT: %exp2 = call fast double @llvm.exp2.f64(double %mul) ; CHECK-NEXT: ret double %exp2 +; TODO: exp10() is not widely enabled on many targets. +define float @pow_exp10(float %x, float %y) { + %call = call fast float @exp10f(float %x) nounwind readnone + %pow = call fast float @llvm.pow.f32(float %call, float %y) + ret float %pow +} + +; CHECK-LABEL: define float @pow_exp10( +; CHECK-NEXT: %call = call fast float @exp10f(float %x) +; CHECK-NEXT: %pow = call fast float @llvm.pow.f32(float %call, float %y) +; CHECK-NEXT: ret float %pow + define double @pow_exp_not_fast(double %x, double %y) { %call = call double @exp(double %x) %pow = call fast double @llvm.pow.f64(double %call, double %y) @@ -43,7 +55,9 @@ ; CHECK-NEXT: %call1 = call fast double %fptr() ; CHECK-NEXT: %pow = call fast double @llvm.pow.f64(double %call1, double %p1) +declare float @expf(float) declare double @exp(double) declare double @exp2(double) +declare float @exp10f(float) +declare float @llvm.pow.f32(float, float) declare double @llvm.pow.f64(double, double) - Index: llvm/test/Transforms/InstCombine/pow-sqrt.ll =================================================================== --- llvm/test/Transforms/InstCombine/pow-sqrt.ll +++ llvm/test/Transforms/InstCombine/pow-sqrt.ll @@ -1,3 +1,4 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt < %s -instcombine -S | FileCheck %s define float @powf_intrinsic_half_fast(float %x) { @@ -33,9 +34,9 @@ define <2 x double> @pow_intrinsic_neghalf_fast(<2 x double> %x) { ; CHECK-LABEL: @pow_intrinsic_neghalf_fast( -; CHECK-NEXT: [[TMP1:%.*]] = call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> %x) -; CHECK-NEXT: [[TMP2:%.*]] = fdiv fast <2 x double> , [[TMP1]] -; CHECK-NEXT: ret <2 x double> [[TMP2]] +; CHECK-NEXT: [[SQRT:%.*]] = call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> [[X:%.*]]) +; CHECK-NEXT: [[RECP:%.*]] = fdiv fast <2 x double> , [[SQRT]] +; CHECK-NEXT: ret <2 x double> [[RECP]] ; %pow = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> ) ret <2 x double> %pow @@ -84,5 +85,3 @@ declare float @powf(float, float) attributes #0 = { nounwind readnone speculatable } -attributes #1 = { nounwind readnone } -