Index: llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h =================================================================== --- llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -133,7 +133,7 @@ Value *optimizeCAbs(CallInst *CI, IRBuilder<> &B); Value *optimizeCos(CallInst *CI, IRBuilder<> &B); Value *optimizePow(CallInst *CI, IRBuilder<> &B); - Value *replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B); + Value *replacePowWithRoot(CallInst *Pow, IRBuilder<> &B); Value *optimizeExp2(CallInst *CI, IRBuilder<> &B); Value *optimizeFMinFMax(CallInst *CI, IRBuilder<> &B); Value *optimizeLog(CallInst *CI, IRBuilder<> &B); Index: llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1119,80 +1119,114 @@ return InnerChain[Exp]; } -/// Use square root in place of pow(x, +/-0.5). -Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { - // TODO: There is some subset of 'fast' under which these transforms should - // be allowed. - if (!Pow->isFast()) +/// Use sqrt() for pow(x, +/-0.5) and cbrt() for pow(x, +/-0.333...). +Value *LibCallSimplifier::replacePowWithRoot(CallInst *Pow, IRBuilder<> &B) { + const APFloat *Exp; + if (!match(Pow->getArgOperand(1), m_APFloat(Exp))) return nullptr; - const APFloat *Arg1C; - if (!match(Pow->getArgOperand(1), m_APFloat(Arg1C))) - return nullptr; - if (!Arg1C->isExactlyValue(0.5) && !Arg1C->isExactlyValue(-0.5)) + Type *Ty = Pow->getType(); + const double OneHalf = 0.5, + OneThird = (Ty->getTypeID() == Type::FloatTyID) + ? (1.0f / 3.0f) : (1.0 / 3.0); + bool isHalf (Exp->isExactlyValue(OneHalf) || Exp->isExactlyValue(-OneHalf)), + isThird (Exp->isExactlyValue(OneThird) || Exp->isExactlyValue(-OneThird)); + if (!isHalf && !isThird) return nullptr; // Fast-math flags from the pow() are propagated to all replacement ops. IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(Pow->getFastMathFlags()); - Type *Ty = Pow->getType(); - Value *Sqrt; - if (Pow->hasFnAttr(Attribute::ReadNone)) { - // We know that errno is never set, so replace with an intrinsic: - // pow(x, 0.5) --> llvm.sqrt(x) - // llvm.pow(x, 0.5) --> llvm.sqrt(x) - auto *F = Intrinsic::getDeclaration(Pow->getModule(), Intrinsic::sqrt, Ty); - Sqrt = B.CreateCall(F, Pow->getArgOperand(0)); - } else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, - LibFunc_sqrtl)) { - // Errno could be set, so we must use a sqrt libcall. - // TODO: We also should check that the target can in fact lower the sqrt - // libcall. We currently have no way to ask this question, so we ask - // whether the target has a sqrt libcall which is not exactly the same. - Sqrt = emitUnaryFloatFnCall(Pow->getArgOperand(0), - TLI->getName(LibFunc_sqrt), B, - Pow->getCalledFunction()->getAttributes()); - } else { - // We can't replace with an intrinsic or a libcall. - return nullptr; + + Value *Root, *Base = Pow->getArgOperand(0); + AttributeList Attributes = Pow->getCalledFunction()->getAttributes(); + + if (isHalf) { + // Expand pow(x, +/-0.5) to sqrt(). + if (Pow->hasFnAttr(Attribute::ReadNone)) { + // We know that errno is never set, so replace with an intrinsic. + Function *SqrtF = Intrinsic::getDeclaration(Pow->getModule(), + Intrinsic::sqrt, Ty); + Root = B.CreateCall(SqrtF, Base); + } else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl)) + // TODO: We also should check that the target can in fact lower the sqrt() + // libcall. We currently have no way to ask this question, so we ask if + // the target has a sqrt() libcall, which is not exactly the same. + Root = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), + B, Attributes); + else + return nullptr; + + // Handle signed zero base by expanding to fabs(sqrt(x)). + if (!Pow->hasNoSignedZeros()) { + Function *FAbsF = Intrinsic::getDeclaration(Pow->getModule(), + Intrinsic::fabs, Ty); + Root = B.CreateCall(FAbsF, Root); + } + + // Handle non finite base by expanding to + // (x == -infinity ? +infinity : sqrt(x)). + if (!Pow->hasNoInfs()) { + Value *PosInf = ConstantFP::getInfinity(Ty), + *NegInf = ConstantFP::getInfinity(Ty, true); + Value *FCmp = B.CreateFCmpOEQ(Base, NegInf); + Root = B.CreateSelect(FCmp, PosInf, Root); + } } + else if (isThird && + Pow->hasNoInfs() && Pow->hasNoNaNs() && Pow->hasNoSignedZeros()) { + // Expand pow(x, +/-0.333...) to cbrt(), but only for regular base. + if (hasUnaryFloatFn(TLI, Ty, LibFunc_cbrt, LibFunc_cbrtf, LibFunc_cbrtl)) + Root = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_cbrt), + B, Attributes); + else + return nullptr; + } + else + return nullptr; - // If this is pow(x, -0.5), get the reciprocal. - if (Arg1C->isExactlyValue(-0.5)) - Sqrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Sqrt); + // If the exponent is negative, then get the reciprocal. + if (Exp->isNegative()) + Root = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Root); - return Sqrt; + return Root; } Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); Value *Ret = nullptr; + Type *Ty = CI->getType(); + Value *Op1 = CI->getArgOperand(0), *Op2 = CI->getArgOperand(1); + + // Bail out if simplifying libcalls to pow() is disabled. + if (!hasUnaryFloatFn(TLI, Ty, LibFunc_pow, LibFunc_powf, LibFunc_powl)) + return nullptr; + StringRef Name = Callee->getName(); - if (UnsafeFPShrink && Name == "pow" && hasFloatVersion(Name)) + if (UnsafeFPShrink && + Name == TLI->getName(LibFunc_pow) && hasFloatVersion(Name)) Ret = optimizeUnaryDoubleFP(CI, B, true); - Value *Op1 = CI->getArgOperand(0), *Op2 = CI->getArgOperand(1); - // pow(1.0, x) -> 1.0 if (match(Op1, m_SpecificFP(1.0))) return Op1; + // pow(2.0, x) -> llvm.exp2(x) if (match(Op1, m_SpecificFP(2.0))) { Value *Exp2 = Intrinsic::getDeclaration(CI->getModule(), Intrinsic::exp2, - CI->getType()); + Ty); return B.CreateCall(Exp2, Op2, "exp2"); } - // There's no llvm.exp10 intrinsic yet, but, maybe, some day there will - // be one. - if (ConstantFP *Op1C = dyn_cast(Op1)) { - // pow(10.0, x) -> exp10(x) - if (Op1C->isExactlyValue(10.0) && - hasUnaryFloatFn(TLI, Op1->getType(), LibFunc_exp10, LibFunc_exp10f, - LibFunc_exp10l)) - return emitUnaryFloatFnCall(Op2, TLI->getName(LibFunc_exp10), B, - Callee->getAttributes()); - } + // pow(10.0, x) -> exp10(x) + // TODO: There is no exp10() intrinsic yet, but some day there shall be one. + ConstantFP *Op1C = dyn_cast(Op1); + if (Op1C && Op1C->isExactlyValue(10.0) && + hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) + // There's no llvm.exp10 intrinsic yet. + return emitUnaryFloatFnCall(Op2, TLI->getName(LibFunc_exp10), B, + Callee->getAttributes()); // pow(exp(x), y) -> exp(x * y) // pow(exp2(x), y) -> exp2(x * y) @@ -1214,44 +1248,20 @@ } } - if (Value *Sqrt = replacePowWithSqrt(CI, B)) - return Sqrt; + if (Value *Root = replacePowWithRoot(CI, B)) + return Root; ConstantFP *Op2C = dyn_cast(Op2); if (!Op2C) return Ret; if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0 - return ConstantFP::get(CI->getType(), 1.0); - - // FIXME: Correct the transforms and pull this into replacePowWithSqrt(). - if (Op2C->isExactlyValue(0.5) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf, - LibFunc_sqrtl)) { - // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). - // This is faster than calling pow, and still handles negative zero - // and negative infinity correctly. - // TODO: In finite-only mode, this could be just fabs(sqrt(x)). - Value *Inf = ConstantFP::getInfinity(CI->getType()); - Value *NegInf = ConstantFP::getInfinity(CI->getType(), true); - - // TODO: As above, we should lower to the sqrt intrinsic if the pow is an - // intrinsic, to match errno semantics. - Value *Sqrt = emitUnaryFloatFnCall(Op1, "sqrt", B, Callee->getAttributes()); - - Module *M = Callee->getParent(); - Function *FabsF = Intrinsic::getDeclaration(M, Intrinsic::fabs, - CI->getType()); - Value *FAbs = B.CreateCall(FabsF, Sqrt); - - Value *FCmp = B.CreateFCmpOEQ(Op1, NegInf); - Value *Sel = B.CreateSelect(FCmp, Inf, FAbs); - return Sel; - } + return ConstantFP::get(Ty, 1.0); // Propagate fast-math-flags from the call to any created instructions. IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(CI->getFastMathFlags()); + // pow(x, 1.0) --> x if (Op2C->isExactlyValue(1.0)) return Op1; @@ -1260,7 +1270,7 @@ return B.CreateFMul(Op1, Op1, "pow2"); // pow(x, -1.0) --> 1.0 / x if (Op2C->isExactlyValue(-1.0)) - return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Op1, "powrecip"); + return B.CreateFDiv(ConstantFP::get(Ty, 1.0), Op1, "powrecip"); // In -ffast-math, generate repeated fmul instead of generating pow(x, n). if (CI->isFast()) { @@ -1284,7 +1294,7 @@ Value *FMul = getPow(InnerChain, V.convertToDouble(), B); // For negative exponents simply compute the reciprocal. if (Op2C->isNegative()) - FMul = B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), FMul); + FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul); return FMul; } Index: llvm/test/Transforms/InstCombine/pow-1.ll =================================================================== --- llvm/test/Transforms/InstCombine/pow-1.ll +++ llvm/test/Transforms/InstCombine/pow-1.ll @@ -177,7 +177,7 @@ define double @test_simplify17(double %x) { ; CHECK-LABEL: @test_simplify17( %retval = call double @llvm.pow.f64(double %x, double 0.5) -; CHECK-NEXT: [[SQRT:%[a-z0-9]+]] = call double @sqrt(double %x) +; CHECK-NEXT: [[SQRT:%[a-z0-9]+]] = call double @llvm.sqrt.f64(double %x) ; CHECK-NEXT: [[FABS:%[a-z0-9]+]] = call double @llvm.fabs.f64(double [[SQRT]]) ; CHECK-NEXT: [[FCMP:%[a-z0-9]+]] = fcmp oeq double %x, 0xFFF0000000000000 ; CHECK-NEXT: [[SELECT:%[a-z0-9]+]] = select i1 [[FCMP]], double 0x7FF0000000000000, double [[FABS]] Index: llvm/test/Transforms/InstCombine/pow-cbrt.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/InstCombine/pow-cbrt.ll @@ -0,0 +1,120 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +define double @pow_intrinsic_third_fast(double %x) { +; CHECK-LABEL: @pow_intrinsic_third_fast( +; CHECK-NEXT: [[CBRT:%.*]] = call fast double @cbrt(double %x) #1 +; CHECK-NEXT: ret double [[CBRT]] +; + %pow = call fast double @llvm.pow.f64(double %x, double 0x3fd5555555555555) + ret double %pow +} + +define float @powf_intrinsic_third_fast(float %x) { +; CHECK-LABEL: @powf_intrinsic_third_fast( +; CHECK-NEXT: [[CBRTF:%.*]] = call fast float @cbrtf(float %x) #1 +; CHECK-NEXT: ret float [[CBRTF]] +; + %pow = call fast float @llvm.pow.f32(float %x, float 0x3fd5555560000000) + ret float %pow +} + +define double @pow_intrinsic_third_approx(double %x) { +; CHECK-LABEL: @pow_intrinsic_third_approx( +; CHECK-NEXT: [[POW:%.*]] = call afn double @llvm.pow.f64(double %x, double 0x3FD5555555555555) +; CHECK-NEXT: ret double [[POW]] +; + %pow = call afn double @llvm.pow.f64(double %x, double 0x3fd5555555555555) + ret double %pow +} + +define float @powf_intrinsic_third_approx(float %x) { +; CHECK-LABEL: @powf_intrinsic_third_approx( +; CHECK-NEXT: [[POW:%.*]] = call afn float @llvm.pow.f32(float %x, float 0x3FD5555560000000) +; CHECK-NEXT: ret float [[POW]] +; + %pow = call afn float @llvm.pow.f32(float %x, float 0x3fd5555560000000) + ret float %pow +} + +define double @pow_libcall_third_fast(double %x) { +; CHECK-LABEL: @pow_libcall_third_fast( +; CHECK-NEXT: [[CBRT:%.*]] = call fast double @cbrt(double %x) +; CHECK-NEXT: ret double [[CBRT]] +; + %pow = call fast double @pow(double %x, double 0x3fd5555555555555) + ret double %pow +} + +define float @powf_libcall_third_fast(float %x) { +; CHECK-LABEL: @powf_libcall_third_fast( +; CHECK-NEXT: [[CBRTF:%.*]] = call fast float @cbrtf(float %x) +; CHECK-NEXT: ret float [[CBRTF]] +; + %pow = call fast float @powf(float %x, float 0x3fd5555560000000) + ret float %pow +} + +define double @pow_intrinsic_negthird_fast(double %x) { +; CHECK-LABEL: @pow_intrinsic_negthird_fast( +; CHECK-NEXT: [[CBRT:%.*]] = call fast double @cbrt(double %x) #1 +; CHECK-NEXT: [[TMP1:%.*]] = fdiv fast double 1.000000e+00, [[CBRT]] +; CHECK-NEXT: ret double [[TMP1]] +; + %pow = call fast double @llvm.pow.f64(double %x, double 0xbfd5555555555555) + ret double %pow +} + +define float @powf_intrinsic_negthird_fast(float %x) { +; CHECK-LABEL: @powf_intrinsic_negthird_fast( +; CHECK-NEXT: [[CBRTF:%.*]] = call fast float @cbrtf(float %x) #1 +; CHECK-NEXT: [[TMP1:%.*]] = fdiv fast float 1.000000e+00, [[CBRTF]] +; CHECK-NEXT: ret float [[TMP1]] +; + %pow = call fast float @llvm.pow.f32(float %x, float 0xbfd5555560000000) + ret float %pow +} + +define double @pow_intrinsic_negthird_approx(double %x) { +; CHECK-LABEL: @pow_intrinsic_negthird_approx( +; CHECK-NEXT: [[POW:%.*]] = call afn double @llvm.pow.f64(double %x, double 0xBFD5555555555555) +; CHECK-NEXT: ret double [[POW]] +; + %pow = call afn double @llvm.pow.f64(double %x, double 0xbfd5555555555555) + ret double %pow +} + +define float @powf_intrinsic_negthird_approx(float %x) { +; CHECK-LABEL: @powf_intrinsic_negthird_approx( +; CHECK-NEXT: [[POW:%.*]] = call afn float @llvm.pow.f32(float %x, float 0xBFD5555560000000) +; CHECK-NEXT: ret float [[POW]] +; + %pow = call afn float @llvm.pow.f32(float %x, float 0xbfd5555560000000) + ret float %pow +} + +define double @pow_libcall_negthird_fast(double %x) { +; CHECK-LABEL: @pow_libcall_negthird_fast( +; CHECK-NEXT: [[CBRT:%.*]] = call fast double @cbrt(double %x) +; CHECK-NEXT: [[TMP1:%.*]] = fdiv fast double 1.000000e+00, [[CBRT]] +; CHECK-NEXT: ret double [[TMP1]] +; + %pow = call fast double @pow(double %x, double 0xbfd5555555555555) + ret double %pow +} + +define float @powf_libcall_negthird_fast(float %x) { +; CHECK-LABEL: @powf_libcall_negthird_fast( +; CHECK-NEXT: [[CBRTF:%.*]] = call fast float @cbrtf(float %x) +; CHECK-NEXT: [[TMP1:%.*]] = fdiv fast float 1.000000e+00, [[CBRTF]] +; CHECK-NEXT: ret float [[TMP1]] +; + %pow = call fast float @powf(float %x, float 0xbfd5555560000000) + ret float %pow +} + +declare double @llvm.pow.f64(double, double) #0 +declare float @llvm.pow.f32(float, float) #0 +declare double @pow(double, double) +declare float @powf(float, float) + +attributes #0 = { nounwind readnone speculatable } Index: llvm/test/Transforms/InstCombine/pow-sqrt.ll =================================================================== --- llvm/test/Transforms/InstCombine/pow-sqrt.ll +++ llvm/test/Transforms/InstCombine/pow-sqrt.ll @@ -9,24 +9,25 @@ ret double %pow } -define <2 x double> @pow_intrinsic_half_approx(<2 x double> %x) { -; CHECK-LABEL: @pow_intrinsic_half_approx( -; CHECK-NEXT: [[POW:%.*]] = call afn <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> ) -; CHECK-NEXT: ret <2 x double> [[POW]] +define <2 x double> @pow_intrinsic_half(<2 x double> %x) { +; CHECK-LABEL: @pow_intrinsic_half( +; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @llvm.sqrt.v2f64(<2 x double> %x) +; CHECK-NEXT: [[TMP2:%.*]] = call <2 x double> @llvm.fabs.v2f64(<2 x double> [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = fcmp oeq <2 x double> %x, +; CHECK-NEXT: [[TMP4:%.*]] = select <2 x i1> [[TMP3]], <2 x double> , <2 x double> [[TMP2]] +; CHECK-NEXT: ret <2 x double> [[TMP4]] ; - %pow = call afn <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> ) + %pow = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> ) ret <2 x double> %pow } -define double @pow_libcall_half_approx(double %x) { -; CHECK-LABEL: @pow_libcall_half_approx( -; CHECK-NEXT: [[SQRT:%.*]] = call double @sqrt(double %x) -; CHECK-NEXT: [[TMP1:%.*]] = call double @llvm.fabs.f64(double [[SQRT]]) -; CHECK-NEXT: [[TMP2:%.*]] = fcmp oeq double %x, 0xFFF0000000000000 -; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], double 0x7FF0000000000000, double [[TMP1]] -; CHECK-NEXT: ret double [[TMP3]] +define double @pow_libcall_half_ninf(double %x) { +; CHECK-LABEL: @pow_libcall_half_ninf( +; CHECK-NEXT: [[SQRT:%.*]] = call ninf double @sqrt(double %x) +; CHECK-NEXT: [[TMP1:%.*]] = call ninf double @llvm.fabs.f64(double [[SQRT]]) +; CHECK-NEXT: ret double [[TMP1]] ; - %pow = call afn double @pow(double %x, double 5.0e-01) + %pow = call ninf double @pow(double %x, double 5.0e-01) ret double %pow } @@ -40,12 +41,16 @@ ret <2 x double> %pow } -define double @pow_intrinsic_neghalf_approx(double %x) { -; CHECK-LABEL: @pow_intrinsic_neghalf_approx( -; CHECK-NEXT: [[POW:%.*]] = call afn double @llvm.pow.f64(double %x, double -5.000000e-01) -; CHECK-NEXT: ret double [[POW]] +define double @pow_intrinsic_neghalf(double %x) { +; CHECK-LABEL: @pow_intrinsic_neghalf( +; CHECK-NEXT: [[TMP1:%.*]] = call double @llvm.sqrt.f64(double %x) +; CHECK-NEXT: [[TMP2:%.*]] = call double @llvm.fabs.f64(double [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = fcmp oeq double %x, 0xFFF0000000000000 +; CHECK-NEXT: [[DOTOP:%.*]] = fdiv double 1.000000e+00, [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = select i1 [[TMP3]], double 0.000000e+00, double [[DOTOP]] +; CHECK-NEXT: ret double [[TMP4]] ; - %pow = call afn double @llvm.pow.f64(double %x, double -5.0e-01) + %pow = call double @llvm.pow.f64(double %x, double -5.0e-01) ret double %pow } @@ -65,5 +70,3 @@ declare float @powf(float, float) attributes #0 = { nounwind readnone speculatable } -attributes #1 = { nounwind readnone } -