Index: llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h =================================================================== --- llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -133,6 +133,7 @@ Value *optimizeCAbs(CallInst *CI, IRBuilder<> &B); Value *optimizeCos(CallInst *CI, IRBuilder<> &B); Value *optimizePow(CallInst *CI, IRBuilder<> &B); + Value *replacePowWithCbrt(CallInst *Pow, IRBuilder<> &B); Value *replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B); Value *optimizeExp2(CallInst *CI, IRBuilder<> &B); Value *optimizeFMinFMax(CallInst *CI, IRBuilder<> &B); Index: llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1119,6 +1119,48 @@ return InnerChain[Exp]; } +/// Use cube root in place of pow(x, +/-0.333...). +Value *LibCallSimplifier::replacePowWithCbrt(CallInst *Pow, IRBuilder<> &B) { + // Only in finite and normal math. + if (!Pow->hasNoInfs() || !Pow->hasNoNaNs()) + return nullptr; + + Type *Ty = Pow->getType(); + Type::TypeID TyID = Ty->getTypeID(); + // TODO: Handle SP FP as well, though the argument seems to never match below. + if (TyID != Type::DoubleTyID) + return nullptr; + + ConstantFP *Arg2C = dyn_cast(Pow->getArgOperand(1)); + if (!Arg2C) + return nullptr; + + const APFloat SPPlusThird(1.0f / 3.0f), SPMinusThird(-1.0f / 3.0f), + DPPlusThird(1.0 / 3.0), DPMinusThird(-1.0 / 3.0); + bool isPlusThird = Arg2C->isExactlyValue((TyID == Type::FloatTyID) + ? SPPlusThird : DPPlusThird), + isMinusThird = Arg2C->isExactlyValue((TyID == Type::FloatTyID) + ? SPMinusThird : DPMinusThird); + if (!isPlusThird && !isMinusThird) + return nullptr; + + if (!hasUnaryFloatFn(TLI, Ty, LibFunc_cbrt, LibFunc_cbrtf, LibFunc_cbrtl)) + return nullptr; + + // Fast-math flags from the pow() are propagated to all replacement ops. + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(Pow->getFastMathFlags()); + Value *Cbrt = emitUnaryFloatFnCall(Pow->getArgOperand(0), + TLI->getName(LibFunc_cbrt), B, + Pow->getCalledFunction()->getAttributes()); + + // If this is pow(x, -0.333...), get the reciprocal. + if (isMinusThird) + Cbrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Cbrt); + + return Cbrt; +} + /// Use square root in place of pow(x, +/-0.5). Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { // TODO: There is some subset of 'fast' under which these transforms should @@ -1217,6 +1259,9 @@ if (Value *Sqrt = replacePowWithSqrt(CI, B)) return Sqrt; + if (Value *Cbrt = replacePowWithCbrt(CI, B)) + return Cbrt; + ConstantFP *Op2C = dyn_cast(Op2); if (!Op2C) return Ret;