Index: lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- lib/Transforms/Utils/SimplifyLibCalls.cpp +++ lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1264,6 +1264,26 @@ return nullptr; } +static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno, + Module *M, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + // If errno is never set, then use the intrinsic for sqrt(). + if (NoErrno) { + Function *SqrtFn = Intrinsic::getDeclaration(M, Intrinsic::sqrt, + V->getType()); + return B.CreateCall(SqrtFn, V, "sqrt"); + } + + // Otherwise, use the libcall for sqrt(). + if (hasUnaryFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) + // TODO: We also should check that the target can in fact lower the sqrt() + // libcall. We currently have no way to ask this question, so we ask if + // the target has a sqrt() libcall, which is not exactly the same. + return emitUnaryFloatFnCall(V, TLI->getName(LibFunc_sqrt), B, Attrs); + + return nullptr; +} + /// Use square root in place of pow(x, +/-0.5). Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); @@ -1276,19 +1296,8 @@ (!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5))) return nullptr; - // If errno is never set, then use the intrinsic for sqrt(). - if (Pow->doesNotAccessMemory()) { - Function *SqrtFn = Intrinsic::getDeclaration(Pow->getModule(), - Intrinsic::sqrt, Ty); - Sqrt = B.CreateCall(SqrtFn, Base, "sqrt"); - } - // Otherwise, use the libcall for sqrt(). - else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) - // TODO: We also should check that the target can in fact lower the sqrt() - // libcall. We currently have no way to ask this question, so we ask if - // the target has a sqrt() libcall, which is not exactly the same. - Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), B, Attrs); - else + Sqrt = getSqrtCall(Base, Attrs, Pow->doesNotAccessMemory(), Mod, B, TLI); + if (!Sqrt) return nullptr; // Handle signed zero base by expanding to fabs(sqrt(x)). @@ -1371,7 +1380,26 @@ // We limit to a max of 7 multiplications, thus the maximum exponent is 32. APFloat LimF(ExpoF->getSemantics(), 33.0), ExpoA(abs(*ExpoF)); - if (ExpoA.isInteger() && ExpoA.compare(LimF) == APFloat::cmpLessThan) { + if (ExpoA.compare(LimF) == APFloat::cmpLessThan) { + // This transformation applies to integer or integer+0.5 exponents only. + // For integer+0.5, we create a sqrt(Base) call. + Value *Sqrt = nullptr; + if (!ExpoA.isInteger()) { + APFloat Expo2 = ExpoA; + // To check if ExpoA is an integer + 0.5, we add it to itself. If there + // is no floating point exception and the result is an integer, then + // ExpoA == integer + 0.5 + if (Expo2.add(ExpoA, APFloat::rmNearestTiesToEven) != APFloat::opOK) + return nullptr; + + if (!Expo2.isInteger()) + return nullptr; + + Sqrt = getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(), + Pow->doesNotAccessMemory(), Pow->getModule(), B, + TLI); + } + // We will memoize intermediate products of the Addition Chain. Value *InnerChain[33] = {nullptr}; InnerChain[1] = Base; @@ -1382,6 +1410,10 @@ ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored); Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B); + // Expand pow(x, y+0.5) to pow(x, y) * sqrt(x). + if (Sqrt) + FMul = B.CreateFMul(FMul, Sqrt); + // If the exponent is negative, then get the reciprocal. if (ExpoF->isNegative()) FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal"); Index: test/Transforms/InstCombine/pow-4.ll =================================================================== --- test/Transforms/InstCombine/pow-4.ll +++ test/Transforms/InstCombine/pow-4.ll @@ -5,6 +5,7 @@ declare float @llvm.pow.f32(float, float) declare <2 x double> @llvm.pow.v2f64(<2 x double>, <2 x double>) declare <2 x float> @llvm.pow.v2f32(<2 x float>, <2 x float>) +declare double @pow(double, double) ; pow(x, 3.0) define double @test_simplify_3(double %x) { @@ -117,3 +118,65 @@ ret double %1 } +; pow(x, 16.5) +define double @test_simplify_16_5(double %x) { +; CHECK-LABEL: @test_simplify_16_5( +; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X]]) +; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X:%.*]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]] +; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]] +; CHECK-NEXT: ret double [[TMP4]] +; + %1 = call fast double @llvm.pow.f64(double %x, double 1.650000e+01) + ret double %1 +} + +; pow(x, -16.5) +define double @test_simplify_neg_16_5(double %x) { +; CHECK-LABEL: @test_simplify_neg_16_5( +; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X]]) +; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X:%.*]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]] +; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]] +; CHECK-NEXT: [[RECIPROCAL:%.*]] = fdiv fast double 1.000000e+00, [[TMP4]] +; CHECK-NEXT: ret double [[RECIPROCAL]] +; + %1 = call fast double @llvm.pow.f64(double %x, double -1.650000e+01) + ret double %1 +} + + +; pow(x, 16.5) +define double @test_simplify_16_5_libcall(double %x) { +; CHECK-LABEL: @test_simplify_16_5_libcall( +; CHECK-NEXT: [[SQRT:%.*]] = call fast double @sqrt(double [[X]]) +; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X:%.*]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]] +; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]] +; CHECK-NEXT: ret double [[TMP4]] +; + %1 = call fast double @pow(double %x, double 1.650000e+01) + ret double %1 +} + +; pow(x, -16.5) +define double @test_simplify_neg_16_5_libcall(double %x) { +; CHECK-LABEL: @test_simplify_neg_16_5_libcall( +; CHECK-NEXT: [[SQRT:%.*]] = call fast double @sqrt(double [[X]]) +; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X:%.*]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]] +; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]] +; CHECK-NEXT: [[RECIPROCAL:%.*]] = fdiv fast double 1.000000e+00, [[TMP4]] +; CHECK-NEXT: ret double [[RECIPROCAL]] +; + %1 = call fast double @pow(double %x, double -1.650000e+01) + ret double %1 +}