Index: llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h =================================================================== --- llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -133,6 +133,7 @@ Value *optimizeCAbs(CallInst *CI, IRBuilder<> &B); Value *optimizeCos(CallInst *CI, IRBuilder<> &B); Value *optimizePow(CallInst *CI, IRBuilder<> &B); + Value *replacePowWithCbrt(CallInst *Pow, IRBuilder<> &B); Value *replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B); Value *optimizeExp2(CallInst *CI, IRBuilder<> &B); Value *optimizeFMinFMax(CallInst *CI, IRBuilder<> &B); Index: llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1119,6 +1119,48 @@ return InnerChain[Exp]; } +/// Use cube root in place of pow(x, +/-0.333...). +Value *LibCallSimplifier::replacePowWithCbrt(CallInst *Pow, IRBuilder<> &B) { + // Only in finite and normal math. + if (!Pow->isFast()) + return nullptr; + + Type *Ty = Pow->getType(); + Type::TypeID TyID = Ty->getTypeID(); + // TODO: Handle SP FP as well, though the argument seems to never match below. + if (TyID != Type::FloatTyID && TyID != Type::DoubleTyID) + return nullptr; + + ConstantFP *Arg2C = dyn_cast(Pow->getArgOperand(1)); + if (!Arg2C) + return nullptr; + + const APFloat SPPlusThird(1.0f / 3.0f), SPMinusThird(-1.0f / 3.0f), + DPPlusThird(1.0 / 3.0), DPMinusThird(-1.0 / 3.0); + bool isPlusThird = Arg2C->isExactlyValue((TyID == Type::FloatTyID) + ? SPPlusThird : DPPlusThird), + isMinusThird = Arg2C->isExactlyValue((TyID == Type::FloatTyID) + ? SPMinusThird : DPMinusThird); + if (!isPlusThird && !isMinusThird) + return nullptr; + + if (!hasUnaryFloatFn(TLI, Ty, LibFunc_cbrt, LibFunc_cbrtf, LibFunc_cbrtl)) + return nullptr; + + // Fast-math flags from the pow() are propagated to all replacement ops. + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(Pow->getFastMathFlags()); + Value *Cbrt = emitUnaryFloatFnCall(Pow->getArgOperand(0), + TLI->getName(LibFunc_cbrt), B, + Pow->getCalledFunction()->getAttributes()); + + // If this is pow(x, -0.333...), get the reciprocal. + if (isMinusThird) + Cbrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Cbrt); + + return Cbrt; +} + /// Use square root in place of pow(x, +/-0.5). Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { // TODO: There is some subset of 'fast' under which these transforms should @@ -1217,6 +1259,9 @@ if (Value *Sqrt = replacePowWithSqrt(CI, B)) return Sqrt; + if (Value *Cbrt = replacePowWithCbrt(CI, B)) + return Cbrt; + ConstantFP *Op2C = dyn_cast(Op2); if (!Op2C) return Ret; Index: llvm/test/Transforms/InstCombine/pow-cbrt.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/InstCombine/pow-cbrt.ll @@ -0,0 +1,120 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +define double @pow_intrinsic_third_fast(double %x) { +; CHECK-LABEL: @pow_intrinsic_third_fast +; CHECK-NEXT: [[POW:%.*]] = call fast double @cbrt(double %x) +; CHECK-NEXT: ret double [[POW]] +; + %pow = call fast double @llvm.pow.f64(double %x, double 0x3fd5555555555555) + ret double %pow +} + +define float @powf_intrinsic_third_fast(float %x) { +; CHECK-LABEL: @powf_intrinsic_third_fast +; CHECK-NEXT: [[POW:%.*]] = call fast float @cbrtf(float %x) +; CHECK-NEXT: ret float [[POW]] +; + %pow = call fast float @llvm.pow.f32(float %x, float 0x3fd5555560000000) + ret float %pow +} + +define double @pow_intrinsic_third_approx(double %x) { +; CHECK-LABEL: @pow_intrinsic_third_approx +; CHECK-NEXT: [[POW:%.*]] = call afn double @llvm.pow.f64(double %x, double 0x3FD5555555555555) +; CHECK-NEXT: ret double [[POW]] +; + %pow = call afn double @llvm.pow.f64(double %x, double 0x3fd5555555555555) + ret double %pow +} + +define float @powf_intrinsic_third_approx(float %x) { +; CHECK-LABEL: @powf_intrinsic_third_approx( +; CHECK-NEXT: [[POW:%.*]] = call afn float @llvm.pow.f32(float %x, float 0x3FD5555560000000) +; CHECK-NEXT: ret float [[POW]] +; + %pow = call afn float @llvm.pow.f32(float %x, float 0x3fd5555560000000) + ret float %pow +} + +define double @pow_libcall_third_fast(double %x) { +; CHECK-LABEL: @pow_libcall_third_fast +; CHECK-NEXT: [[POW:%.*]] = call fast double @cbrt(double %x) +; CHECK-NEXT: ret double [[POW]] +; + %pow = call fast double @pow(double %x, double 0x3fd5555555555555) + ret double %pow +} + +define float @powf_libcall_third_fast(float %x) { +; CHECK-LABEL: @powf_libcall_third_fast +; CHECK-NEXT: [[POW:%.*]] = call fast float @cbrtf(float %x) +; CHECK-NEXT: ret float [[POW]] +; + %pow = call fast float @powf(float %x, float 0x3fd5555560000000) + ret float %pow +} + +define double @pow_intrinsic_negthird_fast(double %x) { +; CHECK-LABEL: @pow_intrinsic_negthird_fast +; CHECK-NEXT: [[CBR:%.*]] = call fast double @cbrt(double %x) +; CHECK-NEXT: [[POW:%.*]] = fdiv fast double 1.000000e+00, [[CBR]] +; CHECK-NEXT: ret double [[POW]] +; + %pow = call fast double @llvm.pow.f64(double %x, double 0xbfd5555555555555) + ret double %pow +} + +define float @powf_intrinsic_negthird_fast(float %x) { +; CHECK-LABEL: @powf_intrinsic_negthird_fast +; CHECK-NEXT: [[CBR:%.*]] = call fast float @cbrtf(float %x) +; CHECK-NEXT: [[POW:%.*]] = fdiv fast float 1.000000e+00, [[CBR]] +; CHECK-NEXT: ret float [[POW]] +; + %pow = call fast float @llvm.pow.f32(float %x, float 0xbfd5555560000000) + ret float %pow +} + +define double @pow_intrinsic_negthird_approx(double %x) { +; CHECK-LABEL: @pow_intrinsic_negthird_approx +; CHECK-NEXT: [[POW:%.*]] = call afn double @llvm.pow.f64(double %x, double 0xBFD5555555555555) +; CHECK-NEXT: ret double [[POW]] +; + %pow = call afn double @llvm.pow.f64(double %x, double 0xbfd5555555555555) + ret double %pow +} + +define float @powf_intrinsic_negthird_approx(float %x) { +; CHECK-LABEL: @powf_intrinsic_negthird_approx( +; CHECK-NEXT: [[POW:%.*]] = call afn float @llvm.pow.f32(float %x, float 0xBFD5555560000000) +; CHECK-NEXT: ret float [[POW]] +; + %pow = call afn float @llvm.pow.f32(float %x, float 0xbfd5555560000000) + ret float %pow +} + +define double @pow_libcall_negthird_fast(double %x) { +; CHECK-LABEL: @pow_libcall_negthird_fast +; CHECK-NEXT: [[CBR:%.*]] = call fast double @cbrt(double %x) +; CHECK-NEXT: [[POW:%.*]] = fdiv fast double 1.000000e+00, [[CBR]] +; CHECK-NEXT: ret double [[POW]] +; + %pow = call fast double @pow(double %x, double 0xbfd5555555555555) + ret double %pow +} + +define float @powf_libcall_negthird_fast(float %x) { +; CHECK-LABEL: @powf_libcall_negthird_fast +; CHECK-NEXT: [[CBR:%.*]] = call fast float @cbrtf(float %x) +; CHECK-NEXT: [[POW:%.*]] = fdiv fast float 1.000000e+00, [[CBR]] +; CHECK-NEXT: ret float [[POW]] +; + %pow = call fast float @powf(float %x, float 0xbfd5555560000000) + ret float %pow +} + +declare double @llvm.pow.f64(double, double) #0 +declare float @llvm.pow.f32(float, float) #0 +declare double @pow(double, double) +declare float @powf(float, float) + +attributes #0 = { nounwind readnone speculatable }