Skip to content

Commit cc9dc59

Browse files
committedSep 3, 2018
[SLC] Support expanding pow(x, n+0.5) to x * x * ... * sqrt(x)
Reviewers: evandro, efriedma, spatel Reviewed By: spatel Differential Revision: https://reviews.llvm.org/D51435 llvm-svn: 341330
1 parent f534485 commit cc9dc59

File tree

2 files changed

+158
-14
lines changed

2 files changed

+158
-14
lines changed
 

‎llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp

+52-14
Original file line numberDiff line numberDiff line change
@@ -1286,6 +1286,27 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) {
12861286
return nullptr;
12871287
}
12881288

1289+
static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno,
1290+
Module *M, IRBuilder<> &B,
1291+
const TargetLibraryInfo *TLI) {
1292+
// If errno is never set, then use the intrinsic for sqrt().
1293+
if (NoErrno) {
1294+
Function *SqrtFn =
1295+
Intrinsic::getDeclaration(M, Intrinsic::sqrt, V->getType());
1296+
return B.CreateCall(SqrtFn, V, "sqrt");
1297+
}
1298+
1299+
// Otherwise, use the libcall for sqrt().
1300+
if (hasUnaryFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf,
1301+
LibFunc_sqrtl))
1302+
// TODO: We also should check that the target can in fact lower the sqrt()
1303+
// libcall. We currently have no way to ask this question, so we ask if
1304+
// the target has a sqrt() libcall, which is not exactly the same.
1305+
return emitUnaryFloatFnCall(V, TLI->getName(LibFunc_sqrt), B, Attrs);
1306+
1307+
return nullptr;
1308+
}
1309+
12891310
/// Use square root in place of pow(x, +/-0.5).
12901311
Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) {
12911312
Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1);
@@ -1298,19 +1319,8 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) {
12981319
(!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5)))
12991320
return nullptr;
13001321

1301-
// If errno is never set, then use the intrinsic for sqrt().
1302-
if (Pow->doesNotAccessMemory()) {
1303-
Function *SqrtFn = Intrinsic::getDeclaration(Pow->getModule(),
1304-
Intrinsic::sqrt, Ty);
1305-
Sqrt = B.CreateCall(SqrtFn, Base, "sqrt");
1306-
}
1307-
// Otherwise, use the libcall for sqrt().
1308-
else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl))
1309-
// TODO: We also should check that the target can in fact lower the sqrt()
1310-
// libcall. We currently have no way to ask this question, so we ask if
1311-
// the target has a sqrt() libcall, which is not exactly the same.
1312-
Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), B, Attrs);
1313-
else
1322+
Sqrt = getSqrtCall(Base, Attrs, Pow->doesNotAccessMemory(), Mod, B, TLI);
1323+
if (!Sqrt)
13141324
return nullptr;
13151325

13161326
// Handle signed zero base by expanding to fabs(sqrt(x)).
@@ -1391,9 +1401,33 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) {
13911401
const APFloat *ExpoF;
13921402
if (Pow->isFast() && match(Expo, m_APFloat(ExpoF))) {
13931403
// We limit to a max of 7 multiplications, thus the maximum exponent is 32.
1404+
// If the exponent is an integer+0.5 we generate a call to sqrt and an
1405+
// additional fmul.
1406+
// TODO: This whole transformation should be backend specific (e.g. some
1407+
// backends might prefer libcalls or the limit for the exponent might
1408+
// be different) and it should also consider optimizing for size.
13941409
APFloat LimF(ExpoF->getSemantics(), 33.0),
13951410
ExpoA(abs(*ExpoF));
1396-
if (ExpoA.isInteger() && ExpoA.compare(LimF) == APFloat::cmpLessThan) {
1411+
if (ExpoA.compare(LimF) == APFloat::cmpLessThan) {
1412+
// This transformation applies to integer or integer+0.5 exponents only.
1413+
// For integer+0.5, we create a sqrt(Base) call.
1414+
Value *Sqrt = nullptr;
1415+
if (!ExpoA.isInteger()) {
1416+
APFloat Expo2 = ExpoA;
1417+
// To check if ExpoA is an integer + 0.5, we add it to itself. If there
1418+
// is no floating point exception and the result is an integer, then
1419+
// ExpoA == integer + 0.5
1420+
if (Expo2.add(ExpoA, APFloat::rmNearestTiesToEven) != APFloat::opOK)
1421+
return nullptr;
1422+
1423+
if (!Expo2.isInteger())
1424+
return nullptr;
1425+
1426+
Sqrt =
1427+
getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(),
1428+
Pow->doesNotAccessMemory(), Pow->getModule(), B, TLI);
1429+
}
1430+
13971431
// We will memoize intermediate products of the Addition Chain.
13981432
Value *InnerChain[33] = {nullptr};
13991433
InnerChain[1] = Base;
@@ -1404,6 +1438,10 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) {
14041438
ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored);
14051439
Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B);
14061440

1441+
// Expand pow(x, y+0.5) to pow(x, y) * sqrt(x).
1442+
if (Sqrt)
1443+
FMul = B.CreateFMul(FMul, Sqrt);
1444+
14071445
// If the exponent is negative, then get the reciprocal.
14081446
if (ExpoF->isNegative())
14091447
FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal");

‎llvm/test/Transforms/InstCombine/pow-4.ll

+106
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ declare double @llvm.pow.f64(double, double)
55
declare float @llvm.pow.f32(float, float)
66
declare <2 x double> @llvm.pow.v2f64(<2 x double>, <2 x double>)
77
declare <2 x float> @llvm.pow.v2f32(<2 x float>, <2 x float>)
8+
declare <4 x float> @llvm.pow.v4f32(<4 x float>, <4 x float>)
9+
declare double @pow(double, double)
810

911
; pow(x, 3.0)
1012
define double @test_simplify_3(double %x) {
@@ -117,3 +119,107 @@ define double @test_simplify_33(double %x) {
117119
ret double %1
118120
}
119121

122+
; pow(x, 16.5) with double
123+
define double @test_simplify_16_5(double %x) {
124+
; CHECK-LABEL: @test_simplify_16_5(
125+
; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X]])
126+
; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X:%.*]], [[X]]
127+
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]]
128+
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
129+
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
130+
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]]
131+
; CHECK-NEXT: ret double [[TMP4]]
132+
;
133+
%1 = call fast double @llvm.pow.f64(double %x, double 1.650000e+01)
134+
ret double %1
135+
}
136+
137+
; pow(x, -16.5) with double
138+
define double @test_simplify_neg_16_5(double %x) {
139+
; CHECK-LABEL: @test_simplify_neg_16_5(
140+
; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X]])
141+
; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X:%.*]], [[X]]
142+
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]]
143+
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
144+
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
145+
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]]
146+
; CHECK-NEXT: [[RECIPROCAL:%.*]] = fdiv fast double 1.000000e+00, [[TMP4]]
147+
; CHECK-NEXT: ret double [[RECIPROCAL]]
148+
;
149+
%1 = call fast double @llvm.pow.f64(double %x, double -1.650000e+01)
150+
ret double %1
151+
}
152+
153+
; pow(x, 16.5) with double
154+
define double @test_simplify_16_5_libcall(double %x) {
155+
; CHECK-LABEL: @test_simplify_16_5_libcall(
156+
; CHECK-NEXT: [[SQRT:%.*]] = call fast double @sqrt(double [[X:%.*]])
157+
; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X]], [[X]]
158+
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]]
159+
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
160+
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
161+
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]]
162+
; CHECK-NEXT: ret double [[TMP4]]
163+
;
164+
%1 = call fast double @pow(double %x, double 1.650000e+01)
165+
ret double %1
166+
}
167+
168+
; pow(x, -16.5) with double
169+
define double @test_simplify_neg_16_5_libcall(double %x) {
170+
; CHECK-LABEL: @test_simplify_neg_16_5_libcall(
171+
; CHECK-NEXT: [[SQRT:%.*]] = call fast double @sqrt(double [[X:%.*]])
172+
; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast double [[X]], [[X]]
173+
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[SQUARE]], [[SQUARE]]
174+
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
175+
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
176+
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[SQRT]]
177+
; CHECK-NEXT: [[RECIPROCAL:%.*]] = fdiv fast double 1.000000e+00, [[TMP4]]
178+
; CHECK-NEXT: ret double [[RECIPROCAL]]
179+
;
180+
%1 = call fast double @pow(double %x, double -1.650000e+01)
181+
ret double %1
182+
}
183+
184+
; pow(x, -8.5) with float
185+
define float @test_simplify_neg_8_5(float %x) {
186+
; CHECK-LABEL: @test_simplify_neg_8_5(
187+
; CHECK-NEXT: [[SQRT:%.*]] = call fast float @llvm.sqrt.f32(float [[X:%.*]])
188+
; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast float [[X]], [[X]]
189+
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast float [[SQUARE]], [[SQUARE]]
190+
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast float [[TMP1]], [[SQRT]]
191+
; CHECK-NEXT: [[RECIPROCAL:%.*]] = fdiv fast float 1.000000e+00, [[TMP2]]
192+
; CHECK-NEXT: ret float [[RECIPROCAL]]
193+
;
194+
%1 = call fast float @llvm.pow.f32(float %x, float -0.450000e+01)
195+
ret float %1
196+
}
197+
198+
; pow(x, 7.5) with <2 x double>
199+
define <2 x double> @test_simplify_7_5(<2 x double> %x) {
200+
; CHECK-LABEL: @test_simplify_7_5(
201+
; CHECK-NEXT: [[SQRT:%.*]] = call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> [[X:%.*]])
202+
; CHECK-NEXT: [[SQUARE:%.*]] = fmul fast <2 x double> [[X]], [[X]]
203+
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast <2 x double> [[SQUARE]], [[SQUARE]]
204+
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <2 x double> [[TMP1]], [[X]]
205+
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast <2 x double> [[SQUARE]], [[TMP2]]
206+
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast <2 x double> [[TMP3]], [[SQRT]]
207+
; CHECK-NEXT: ret <2 x double> [[TMP4]]
208+
;
209+
%1 = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 7.500000e+00, double 7.500000e+00>)
210+
ret <2 x double> %1
211+
}
212+
213+
; pow(x, 3.5) with <4 x float>
214+
define <4 x float> @test_simplify_3_5(<4 x float> %x) {
215+
; CHECK-LABEL: @test_simplify_3_5(
216+
; CHECK-NEXT: [[SQRT:%.*]] = call fast <4 x float> @llvm.sqrt.v4f32(<4 x float> [[X:%.*]])
217+
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast <4 x float> [[X]], [[X]]
218+
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <4 x float> [[TMP1]], [[X]]
219+
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast <4 x float> [[TMP2]], [[SQRT]]
220+
; CHECK-NEXT: ret <4 x float> [[TMP3]]
221+
;
222+
%1 = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %x, <4 x float> <float 3.500000e+00, float 3.500000e+00, float 3.500000e+00, float 3.500000e+00>)
223+
ret <4 x float> %1
224+
}
225+

0 commit comments

Comments
 (0)
Please sign in to comment.