Skip to content

Commit c699a61

Browse files
committedOct 16, 2014
fold: sqrt(x * x * y) -> fabs(x) * sqrt(y)
If a square root call has an FP multiplication argument that can be reassociated, then we can hoist a repeated factor out of the square root call and into a fabs(). In the simplest case, this: y = sqrt(x * x); becomes this: y = fabs(x); This patch relies on an earlier optimization in instcombine or reassociate to put the multiplication tree into a canonical form, so we don't have to search over every permutation of the multiplication tree. Because there are no IR-level FastMathFlags for intrinsics (PR21290), we have to use function-level attributes to do this optimization. This needs to be fixed for both the intrinsics and in the backend. Differential Revision: http://reviews.llvm.org/D5787 llvm-svn: 219944
1 parent d70f3c2 commit c699a61

File tree

3 files changed

+258
-1
lines changed

3 files changed

+258
-1
lines changed
 

Diff for: ‎llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class LibCallSimplifier {
9393
Value *optimizePow(CallInst *CI, IRBuilder<> &B);
9494
Value *optimizeExp2(CallInst *CI, IRBuilder<> &B);
9595
Value *optimizeFabs(CallInst *CI, IRBuilder<> &B);
96+
Value *optimizeSqrt(CallInst *CI, IRBuilder<> &B);
9697
Value *optimizeSinCosPi(CallInst *CI, IRBuilder<> &B);
9798

9899
// Integer Library Call Optimizations

Diff for: ‎llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp

+87-1
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@
2727
#include "llvm/IR/Intrinsics.h"
2828
#include "llvm/IR/LLVMContext.h"
2929
#include "llvm/IR/Module.h"
30+
#include "llvm/IR/PatternMatch.h"
3031
#include "llvm/Support/Allocator.h"
3132
#include "llvm/Support/CommandLine.h"
3233
#include "llvm/Target/TargetLibraryInfo.h"
3334
#include "llvm/Transforms/Utils/BuildLibCalls.h"
3435

3536
using namespace llvm;
37+
using namespace PatternMatch;
3638

3739
static cl::opt<bool>
3840
ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden,
@@ -1254,6 +1256,85 @@ Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) {
12541256
return Ret;
12551257
}
12561258

1259+
Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
1260+
Function *Callee = CI->getCalledFunction();
1261+
1262+
Value *Ret = nullptr;
1263+
if (UnsafeFPShrink && Callee->getName() == "sqrt" &&
1264+
TLI->has(LibFunc::sqrtf)) {
1265+
Ret = optimizeUnaryDoubleFP(CI, B, true);
1266+
}
1267+
1268+
// FIXME: For finer-grain optimization, we need intrinsics to have the same
1269+
// fast-math flag decorations that are applied to FP instructions. For now,
1270+
// we have to rely on the function-level unsafe-fp-math attribute to do this
1271+
// optimization because there's no other way to express that the sqrt can be
1272+
// reassociated.
1273+
Function *F = CI->getParent()->getParent();
1274+
if (F->hasFnAttribute("unsafe-fp-math")) {
1275+
// Check for unsafe-fp-math = true.
1276+
Attribute Attr = F->getFnAttribute("unsafe-fp-math");
1277+
if (Attr.getValueAsString() != "true")
1278+
return Ret;
1279+
}
1280+
Value *Op = CI->getArgOperand(0);
1281+
if (Instruction *I = dyn_cast<Instruction>(Op)) {
1282+
if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) {
1283+
// We're looking for a repeated factor in a multiplication tree,
1284+
// so we can do this fold: sqrt(x * x) -> fabs(x);
1285+
// or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y).
1286+
Value *Op0 = I->getOperand(0);
1287+
Value *Op1 = I->getOperand(1);
1288+
Value *RepeatOp = nullptr;
1289+
Value *OtherOp = nullptr;
1290+
if (Op0 == Op1) {
1291+
// Simple match: the operands of the multiply are identical.
1292+
RepeatOp = Op0;
1293+
} else {
1294+
// Look for a more complicated pattern: one of the operands is itself
1295+
// a multiply, so search for a common factor in that multiply.
1296+
// Note: We don't bother looking any deeper than this first level or for
1297+
// variations of this pattern because instcombine's visitFMUL and/or the
1298+
// reassociation pass should give us this form.
1299+
Value *OtherMul0, *OtherMul1;
1300+
if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) {
1301+
// Pattern: sqrt((x * y) * z)
1302+
if (OtherMul0 == OtherMul1) {
1303+
// Matched: sqrt((x * x) * z)
1304+
RepeatOp = OtherMul0;
1305+
OtherOp = Op1;
1306+
}
1307+
}
1308+
}
1309+
if (RepeatOp) {
1310+
// Fast math flags for any created instructions should match the sqrt
1311+
// and multiply.
1312+
// FIXME: We're not checking the sqrt because it doesn't have
1313+
// fast-math-flags (see earlier comment).
1314+
IRBuilder<true, ConstantFolder,
1315+
IRBuilderDefaultInserter<true> >::FastMathFlagGuard Guard(B);
1316+
B.SetFastMathFlags(I->getFastMathFlags());
1317+
// If we found a repeated factor, hoist it out of the square root and
1318+
// replace it with the fabs of that factor.
1319+
Module *M = Callee->getParent();
1320+
Type *ArgType = Op->getType();
1321+
Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType);
1322+
Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs");
1323+
if (OtherOp) {
1324+
// If we found a non-repeated factor, we still need to get its square
1325+
// root. We then multiply that by the value that was simplified out
1326+
// of the square root calculation.
1327+
Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType);
1328+
Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt");
1329+
return B.CreateFMul(FabsCall, SqrtCall);
1330+
}
1331+
return FabsCall;
1332+
}
1333+
}
1334+
}
1335+
return Ret;
1336+
}
1337+
12571338
static bool isTrigLibCall(CallInst *CI);
12581339
static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg,
12591340
bool UseFloat, Value *&Sin, Value *&Cos,
@@ -1919,6 +2000,8 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
19192000
return optimizeExp2(CI, Builder);
19202001
case Intrinsic::fabs:
19212002
return optimizeFabs(CI, Builder);
2003+
case Intrinsic::sqrt:
2004+
return optimizeSqrt(CI, Builder);
19222005
default:
19232006
return nullptr;
19242007
}
@@ -1995,6 +2078,10 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
19952078
case LibFunc::fabs:
19962079
case LibFunc::fabsl:
19972080
return optimizeFabs(CI, Builder);
2081+
case LibFunc::sqrtf:
2082+
case LibFunc::sqrt:
2083+
case LibFunc::sqrtl:
2084+
return optimizeSqrt(CI, Builder);
19982085
case LibFunc::ffs:
19992086
case LibFunc::ffsl:
20002087
case LibFunc::ffsll:
@@ -2055,7 +2142,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
20552142
case LibFunc::logb:
20562143
case LibFunc::sin:
20572144
case LibFunc::sinh:
2058-
case LibFunc::sqrt:
20592145
case LibFunc::tan:
20602146
case LibFunc::tanh:
20612147
if (UnsafeFPShrink && hasFloatVersion(FuncName))

Diff for: ‎llvm/test/Transforms/InstCombine/fast-math.ll

+170
Original file line numberDiff line numberDiff line change
@@ -530,3 +530,173 @@ define float @fact_div6(float %x) {
530530
; CHECK: fact_div6
531531
; CHECK: %t3 = fsub fast float %t1, %t2
532532
}
533+
534+
; =========================================================================
535+
;
536+
; Test-cases for square root
537+
;
538+
; =========================================================================
539+
540+
; A squared factor fed into a square root intrinsic should be hoisted out
541+
; as a fabs() value.
542+
; We have to rely on a function-level attribute to enable this optimization
543+
; because intrinsics don't currently have access to IR-level fast-math
544+
; flags. If that changes, we can relax the requirement on all of these
545+
; tests to just specify 'fast' on the sqrt.
546+
547+
attributes #0 = { "unsafe-fp-math" = "true" }
548+
549+
declare double @llvm.sqrt.f64(double)
550+
551+
define double @sqrt_intrinsic_arg_squared(double %x) #0 {
552+
%mul = fmul fast double %x, %x
553+
%sqrt = call double @llvm.sqrt.f64(double %mul)
554+
ret double %sqrt
555+
556+
; CHECK-LABEL: sqrt_intrinsic_arg_squared(
557+
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
558+
; CHECK-NEXT: ret double %fabs
559+
}
560+
561+
; Check all 6 combinations of a 3-way multiplication tree where
562+
; one factor is repeated.
563+
564+
define double @sqrt_intrinsic_three_args1(double %x, double %y) #0 {
565+
%mul = fmul fast double %y, %x
566+
%mul2 = fmul fast double %mul, %x
567+
%sqrt = call double @llvm.sqrt.f64(double %mul2)
568+
ret double %sqrt
569+
570+
; CHECK-LABEL: sqrt_intrinsic_three_args1(
571+
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
572+
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
573+
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
574+
; CHECK-NEXT: ret double %1
575+
}
576+
577+
define double @sqrt_intrinsic_three_args2(double %x, double %y) #0 {
578+
%mul = fmul fast double %x, %y
579+
%mul2 = fmul fast double %mul, %x
580+
%sqrt = call double @llvm.sqrt.f64(double %mul2)
581+
ret double %sqrt
582+
583+
; CHECK-LABEL: sqrt_intrinsic_three_args2(
584+
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
585+
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
586+
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
587+
; CHECK-NEXT: ret double %1
588+
}
589+
590+
define double @sqrt_intrinsic_three_args3(double %x, double %y) #0 {
591+
%mul = fmul fast double %x, %x
592+
%mul2 = fmul fast double %mul, %y
593+
%sqrt = call double @llvm.sqrt.f64(double %mul2)
594+
ret double %sqrt
595+
596+
; CHECK-LABEL: sqrt_intrinsic_three_args3(
597+
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
598+
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
599+
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
600+
; CHECK-NEXT: ret double %1
601+
}
602+
603+
define double @sqrt_intrinsic_three_args4(double %x, double %y) #0 {
604+
%mul = fmul fast double %y, %x
605+
%mul2 = fmul fast double %x, %mul
606+
%sqrt = call double @llvm.sqrt.f64(double %mul2)
607+
ret double %sqrt
608+
609+
; CHECK-LABEL: sqrt_intrinsic_three_args4(
610+
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
611+
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
612+
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
613+
; CHECK-NEXT: ret double %1
614+
}
615+
616+
define double @sqrt_intrinsic_three_args5(double %x, double %y) #0 {
617+
%mul = fmul fast double %x, %y
618+
%mul2 = fmul fast double %x, %mul
619+
%sqrt = call double @llvm.sqrt.f64(double %mul2)
620+
ret double %sqrt
621+
622+
; CHECK-LABEL: sqrt_intrinsic_three_args5(
623+
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
624+
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
625+
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
626+
; CHECK-NEXT: ret double %1
627+
}
628+
629+
define double @sqrt_intrinsic_three_args6(double %x, double %y) #0 {
630+
%mul = fmul fast double %x, %x
631+
%mul2 = fmul fast double %y, %mul
632+
%sqrt = call double @llvm.sqrt.f64(double %mul2)
633+
ret double %sqrt
634+
635+
; CHECK-LABEL: sqrt_intrinsic_three_args6(
636+
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
637+
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
638+
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
639+
; CHECK-NEXT: ret double %1
640+
}
641+
642+
define double @sqrt_intrinsic_arg_4th(double %x) #0 {
643+
%mul = fmul fast double %x, %x
644+
%mul2 = fmul fast double %mul, %mul
645+
%sqrt = call double @llvm.sqrt.f64(double %mul2)
646+
ret double %sqrt
647+
648+
; CHECK-LABEL: sqrt_intrinsic_arg_4th(
649+
; CHECK-NEXT: %mul = fmul fast double %x, %x
650+
; CHECK-NEXT: ret double %mul
651+
}
652+
653+
define double @sqrt_intrinsic_arg_5th(double %x) #0 {
654+
%mul = fmul fast double %x, %x
655+
%mul2 = fmul fast double %mul, %x
656+
%mul3 = fmul fast double %mul2, %mul
657+
%sqrt = call double @llvm.sqrt.f64(double %mul3)
658+
ret double %sqrt
659+
660+
; CHECK-LABEL: sqrt_intrinsic_arg_5th(
661+
; CHECK-NEXT: %mul = fmul fast double %x, %x
662+
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %x)
663+
; CHECK-NEXT: %1 = fmul fast double %mul, %sqrt1
664+
; CHECK-NEXT: ret double %1
665+
}
666+
667+
; Check that square root calls have the same behavior.
668+
669+
declare float @sqrtf(float)
670+
declare double @sqrt(double)
671+
declare fp128 @sqrtl(fp128)
672+
673+
define float @sqrt_call_squared_f32(float %x) #0 {
674+
%mul = fmul fast float %x, %x
675+
%sqrt = call float @sqrtf(float %mul)
676+
ret float %sqrt
677+
678+
; CHECK-LABEL: sqrt_call_squared_f32(
679+
; CHECK-NEXT: %fabs = call float @llvm.fabs.f32(float %x)
680+
; CHECK-NEXT: ret float %fabs
681+
}
682+
683+
define double @sqrt_call_squared_f64(double %x) #0 {
684+
%mul = fmul fast double %x, %x
685+
%sqrt = call double @sqrt(double %mul)
686+
ret double %sqrt
687+
688+
; CHECK-LABEL: sqrt_call_squared_f64(
689+
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
690+
; CHECK-NEXT: ret double %fabs
691+
}
692+
693+
define fp128 @sqrt_call_squared_f128(fp128 %x) #0 {
694+
%mul = fmul fast fp128 %x, %x
695+
%sqrt = call fp128 @sqrtl(fp128 %mul)
696+
ret fp128 %sqrt
697+
698+
; CHECK-LABEL: sqrt_call_squared_f128(
699+
; CHECK-NEXT: %fabs = call fp128 @llvm.fabs.f128(fp128 %x)
700+
; CHECK-NEXT: ret fp128 %fabs
701+
}
702+

0 commit comments

Comments
 (0)
Please sign in to comment.