Index: llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1286,6 +1286,27 @@ return nullptr; } +static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno, + Module *M, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + // If errno is never set, then use the intrinsic for sqrt(). + if (NoErrno) { + Function *SqrtFn = + Intrinsic::getDeclaration(M, Intrinsic::sqrt, V->getType()); + return B.CreateCall(SqrtFn, V, "sqrt"); + } + + // Otherwise, use the libcall for sqrt(). + if (hasUnaryFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl)) + // TODO: We also should check that the target can in fact lower the sqrt() + // libcall. We currently have no way to ask this question, so we ask if + // the target has a sqrt() libcall, which is not exactly the same. + return emitUnaryFloatFnCall(V, TLI->getName(LibFunc_sqrt), B, Attrs); + + return nullptr; +} + /// Use square root in place of pow(x, +/-0.5). Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); @@ -1298,19 +1319,8 @@ (!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5))) return nullptr; - // If errno is never set, then use the intrinsic for sqrt(). - if (Pow->doesNotAccessMemory()) { - Function *SqrtFn = Intrinsic::getDeclaration(Pow->getModule(), - Intrinsic::sqrt, Ty); - Sqrt = B.CreateCall(SqrtFn, Base, "sqrt"); - } - // Otherwise, use the libcall for sqrt(). - else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) - // TODO: We also should check that the target can in fact lower the sqrt() - // libcall. We currently have no way to ask this question, so we ask if - // the target has a sqrt() libcall, which is not exactly the same. - Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), B, Attrs); - else + Sqrt = getSqrtCall(Base, Attrs, Pow->doesNotAccessMemory(), Mod, B, TLI); + if (!Sqrt) return nullptr; // Handle signed zero base by expanding to fabs(sqrt(x)). @@ -1391,9 +1401,33 @@ const APFloat *ExpoF; if (Pow->isFast() && match(Expo, m_APFloat(ExpoF))) { // We limit to a max of 7 multiplications, thus the maximum exponent is 32. + // If the exponent is an integer+0.5 we generate a call to sqrt and an + // additional fmul. + // TODO: This whole transformation should be backend specific (e.g. some + // backends might prefer libcalls or the limit for the exponent might + // be different) and it should also consider optimizing for size. APFloat LimF(ExpoF->getSemantics(), 33.0), ExpoA(abs(*ExpoF)); - if (ExpoA.isInteger() && ExpoA.compare(LimF) == APFloat::cmpLessThan) { + if (ExpoA.compare(LimF) == APFloat::cmpLessThan) { + // This transformation applies to integer or integer+0.5 exponents only. + // For integer+0.5, we create a sqrt(Base) call. + Value *Sqrt = nullptr; + if (!ExpoA.isInteger()) { + APFloat Expo2 = ExpoA; + // To check if ExpoA is an integer + 0.5, we add it to itself. If there + // is no floating point exception and the result is an integer, then + // ExpoA == integer + 0.5 + if (Expo2.add(ExpoA, APFloat::rmNearestTiesToEven) != APFloat::opOK) + return nullptr; + + if (!Expo2.isInteger()) + return nullptr; + + Sqrt = + getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(), + Pow->doesNotAccessMemory(), Pow->getModule(), B, TLI); + } + // We will memoize intermediate products of the Addition Chain. Value *InnerChain[33] = {nullptr}; InnerChain[1] = Base; @@ -1404,6 +1438,10 @@ ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored); Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B); + // Expand pow(x, y+0.5) to pow(x, y) * sqrt(x). + if (Sqrt) + FMul = B.CreateFMul(FMul, Sqrt); + // If the exponent is negative, then get the reciprocal. if (ExpoF->isNegative()) FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal"); Index: llvm/trunk/test/Transforms/InstCombine/pow-4.ll =================================================================== --- llvm/trunk/test/Transforms/InstCombine/pow-4.ll +++ llvm/trunk/test/Transforms/InstCombine/pow-4.ll @@ -5,6 +5,8 @@ declare float @llvm.pow.f32(float, float) declare <2 x double> @llvm.pow.v2f64(<2 x double>, <2 x double>) declare <2 x float> @llvm.pow.v2f32(<2 x float>, <2 x float>) +declare <4 x float> @llvm.pow.v4f32(<4 x float>, <4 x float>) +declare double @pow(double, double) ; pow(x, 3.0) define double @test_simplify_3(double %x) { @@ -117,3 +119,107 @@ ret double %1 } +; pow(x, 16.5) with double +define double @test_simplify_16_5(double %x) { +; CHECK-LABEL: @test_simplify_16_5( +; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X]]) +; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X:%.*]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]] +; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]] +; CHECK-NEXT: ret double [[TMP4]] +; + %1 = call fast double @llvm.pow.f64(double %x, double 1.650000e+01) + ret double %1 +} + +; pow(x, -16.5) with double +define double @test_simplify_neg_16_5(double %x) { +; CHECK-LABEL: @test_simplify_neg_16_5( +; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X]]) +; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X:%.*]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]] +; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]] +; CHECK-NEXT: [[RECIPROCAL:%.*]] = fdiv fast double 1.000000e+00, [[TMP4]] +; CHECK-NEXT: ret double [[RECIPROCAL]] +; + %1 = call fast double @llvm.pow.f64(double %x, double -1.650000e+01) + ret double %1 +} + +; pow(x, 16.5) with double +define double @test_simplify_16_5_libcall(double %x) { +; CHECK-LABEL: @test_simplify_16_5_libcall( +; CHECK-NEXT: [[SQRT:%.*]] = call fast double @sqrt(double [[X:%.*]]) +; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]] +; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]] +; CHECK-NEXT: ret double [[TMP4]] +; + %1 = call fast double @pow(double %x, double 1.650000e+01) + ret double %1 +} + +; pow(x, -16.5) with double +define double @test_simplify_neg_16_5_libcall(double %x) { +; CHECK-LABEL: @test_simplify_neg_16_5_libcall( +; CHECK-NEXT: [[SQRT:%.*]] = call fast double @sqrt(double [[X:%.*]]) +; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]] +; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]] +; CHECK-NEXT: [[RECIPROCAL:%.*]] = fdiv fast double 1.000000e+00, [[TMP4]] +; CHECK-NEXT: ret double [[RECIPROCAL]] +; + %1 = call fast double @pow(double %x, double -1.650000e+01) + ret double %1 +} + +; pow(x, -8.5) with float +define float @test_simplify_neg_8_5(float %x) { +; CHECK-LABEL: @test_simplify_neg_8_5( +; CHECK-NEXT: [[SQRT:%.*]] = call fast float @llvm.sqrt.f32(float [[X:%.*]]) +; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast float [[X]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast float [[SQUARE]], [[SQUARE]] +; CHECK-NEXT: [[TMP2:%.*]] = fmul fast float [[TMP1]], [[SQRT]] +; CHECK-NEXT: [[RECIPROCAL:%.*]] = fdiv fast float 1.000000e+00, [[TMP2]] +; CHECK-NEXT: ret float [[RECIPROCAL]] +; + %1 = call fast float @llvm.pow.f32(float %x, float -0.450000e+01) + ret float %1 +} + +; pow(x, 7.5) with <2 x double> +define <2 x double> @test_simplify_7_5(<2 x double> %x) { +; CHECK-LABEL: @test_simplify_7_5( +; CHECK-NEXT: [[SQRT:%.*]] = call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> [[X:%.*]]) +; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast <2 x double> [[X]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast <2 x double> [[SQUARE]], [[SQUARE]] +; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <2 x double> [[TMP1]], [[X]] +; CHECK-NEXT: [[TMP3:%.*]] = fmul fast <2 x double> [[SQUARE]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = fmul fast <2 x double> [[TMP3]], [[SQRT]] +; CHECK-NEXT: ret <2 x double> [[TMP4]] +; + %1 = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> ) + ret <2 x double> %1 +} + +; pow(x, 3.5) with <4 x float> +define <4 x float> @test_simplify_3_5(<4 x float> %x) { +; CHECK-LABEL: @test_simplify_3_5( +; CHECK-NEXT: [[SQRT:%.*]] = call fast <4 x float> @llvm.sqrt.v4f32(<4 x float> [[X:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast <4 x float> [[X]], [[X]] +; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <4 x float> [[TMP1]], [[X]] +; CHECK-NEXT: [[TMP3:%.*]] = fmul fast <4 x float> [[TMP2]], [[SQRT]] +; CHECK-NEXT: ret <4 x float> [[TMP3]] +; + %1 = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %x, <4 x float> ) + ret <4 x float> %1 +} +