Index: llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h =================================================================== --- llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ llvm/trunk/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -93,6 +93,7 @@ Value *optimizePow(CallInst *CI, IRBuilder<> &B); Value *optimizeExp2(CallInst *CI, IRBuilder<> &B); Value *optimizeFabs(CallInst *CI, IRBuilder<> &B); + Value *optimizeSqrt(CallInst *CI, IRBuilder<> &B); Value *optimizeSinCosPi(CallInst *CI, IRBuilder<> &B); // Integer Library Call Optimizations Index: llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ llvm/trunk/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -27,12 +27,14 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Target/TargetLibraryInfo.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" using namespace llvm; +using namespace PatternMatch; static cl::opt ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden, @@ -1254,6 +1256,85 @@ return Ret; } +Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + Value *Ret = nullptr; + if (UnsafeFPShrink && Callee->getName() == "sqrt" && + TLI->has(LibFunc::sqrtf)) { + Ret = optimizeUnaryDoubleFP(CI, B, true); + } + + // FIXME: For finer-grain optimization, we need intrinsics to have the same + // fast-math flag decorations that are applied to FP instructions. For now, + // we have to rely on the function-level unsafe-fp-math attribute to do this + // optimization because there's no other way to express that the sqrt can be + // reassociated. + Function *F = CI->getParent()->getParent(); + if (F->hasFnAttribute("unsafe-fp-math")) { + // Check for unsafe-fp-math = true. + Attribute Attr = F->getFnAttribute("unsafe-fp-math"); + if (Attr.getValueAsString() != "true") + return Ret; + } + Value *Op = CI->getArgOperand(0); + if (Instruction *I = dyn_cast(Op)) { + if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) { + // We're looking for a repeated factor in a multiplication tree, + // so we can do this fold: sqrt(x * x) -> fabs(x); + // or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y). + Value *Op0 = I->getOperand(0); + Value *Op1 = I->getOperand(1); + Value *RepeatOp = nullptr; + Value *OtherOp = nullptr; + if (Op0 == Op1) { + // Simple match: the operands of the multiply are identical. + RepeatOp = Op0; + } else { + // Look for a more complicated pattern: one of the operands is itself + // a multiply, so search for a common factor in that multiply. + // Note: We don't bother looking any deeper than this first level or for + // variations of this pattern because instcombine's visitFMUL and/or the + // reassociation pass should give us this form. + Value *OtherMul0, *OtherMul1; + if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) { + // Pattern: sqrt((x * y) * z) + if (OtherMul0 == OtherMul1) { + // Matched: sqrt((x * x) * z) + RepeatOp = OtherMul0; + OtherOp = Op1; + } + } + } + if (RepeatOp) { + // Fast math flags for any created instructions should match the sqrt + // and multiply. + // FIXME: We're not checking the sqrt because it doesn't have + // fast-math-flags (see earlier comment). + IRBuilder >::FastMathFlagGuard Guard(B); + B.SetFastMathFlags(I->getFastMathFlags()); + // If we found a repeated factor, hoist it out of the square root and + // replace it with the fabs of that factor. + Module *M = Callee->getParent(); + Type *ArgType = Op->getType(); + Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType); + Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs"); + if (OtherOp) { + // If we found a non-repeated factor, we still need to get its square + // root. We then multiply that by the value that was simplified out + // of the square root calculation. + Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType); + Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt"); + return B.CreateFMul(FabsCall, SqrtCall); + } + return FabsCall; + } + } + } + return Ret; +} + static bool isTrigLibCall(CallInst *CI); static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, bool UseFloat, Value *&Sin, Value *&Cos, @@ -1919,6 +2000,8 @@ return optimizeExp2(CI, Builder); case Intrinsic::fabs: return optimizeFabs(CI, Builder); + case Intrinsic::sqrt: + return optimizeSqrt(CI, Builder); default: return nullptr; } @@ -1995,6 +2078,10 @@ case LibFunc::fabs: case LibFunc::fabsl: return optimizeFabs(CI, Builder); + case LibFunc::sqrtf: + case LibFunc::sqrt: + case LibFunc::sqrtl: + return optimizeSqrt(CI, Builder); case LibFunc::ffs: case LibFunc::ffsl: case LibFunc::ffsll: @@ -2055,7 +2142,6 @@ case LibFunc::logb: case LibFunc::sin: case LibFunc::sinh: - case LibFunc::sqrt: case LibFunc::tan: case LibFunc::tanh: if (UnsafeFPShrink && hasFloatVersion(FuncName)) Index: llvm/trunk/test/Transforms/InstCombine/fast-math.ll =================================================================== --- llvm/trunk/test/Transforms/InstCombine/fast-math.ll +++ llvm/trunk/test/Transforms/InstCombine/fast-math.ll @@ -530,3 +530,173 @@ ; CHECK: fact_div6 ; CHECK: %t3 = fsub fast float %t1, %t2 } + +; ========================================================================= +; +; Test-cases for square root +; +; ========================================================================= + +; A squared factor fed into a square root intrinsic should be hoisted out +; as a fabs() value. +; We have to rely on a function-level attribute to enable this optimization +; because intrinsics don't currently have access to IR-level fast-math +; flags. If that changes, we can relax the requirement on all of these +; tests to just specify 'fast' on the sqrt. + +attributes #0 = { "unsafe-fp-math" = "true" } + +declare double @llvm.sqrt.f64(double) + +define double @sqrt_intrinsic_arg_squared(double %x) #0 { + %mul = fmul fast double %x, %x + %sqrt = call double @llvm.sqrt.f64(double %mul) + ret double %sqrt + +; CHECK-LABEL: sqrt_intrinsic_arg_squared( +; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x) +; CHECK-NEXT: ret double %fabs +} + +; Check all 6 combinations of a 3-way multiplication tree where +; one factor is repeated. + +define double @sqrt_intrinsic_three_args1(double %x, double %y) #0 { + %mul = fmul fast double %y, %x + %mul2 = fmul fast double %mul, %x + %sqrt = call double @llvm.sqrt.f64(double %mul2) + ret double %sqrt + +; CHECK-LABEL: sqrt_intrinsic_three_args1( +; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x) +; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y) +; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1 +; CHECK-NEXT: ret double %1 +} + +define double @sqrt_intrinsic_three_args2(double %x, double %y) #0 { + %mul = fmul fast double %x, %y + %mul2 = fmul fast double %mul, %x + %sqrt = call double @llvm.sqrt.f64(double %mul2) + ret double %sqrt + +; CHECK-LABEL: sqrt_intrinsic_three_args2( +; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x) +; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y) +; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1 +; CHECK-NEXT: ret double %1 +} + +define double @sqrt_intrinsic_three_args3(double %x, double %y) #0 { + %mul = fmul fast double %x, %x + %mul2 = fmul fast double %mul, %y + %sqrt = call double @llvm.sqrt.f64(double %mul2) + ret double %sqrt + +; CHECK-LABEL: sqrt_intrinsic_three_args3( +; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x) +; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y) +; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1 +; CHECK-NEXT: ret double %1 +} + +define double @sqrt_intrinsic_three_args4(double %x, double %y) #0 { + %mul = fmul fast double %y, %x + %mul2 = fmul fast double %x, %mul + %sqrt = call double @llvm.sqrt.f64(double %mul2) + ret double %sqrt + +; CHECK-LABEL: sqrt_intrinsic_three_args4( +; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x) +; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y) +; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1 +; CHECK-NEXT: ret double %1 +} + +define double @sqrt_intrinsic_three_args5(double %x, double %y) #0 { + %mul = fmul fast double %x, %y + %mul2 = fmul fast double %x, %mul + %sqrt = call double @llvm.sqrt.f64(double %mul2) + ret double %sqrt + +; CHECK-LABEL: sqrt_intrinsic_three_args5( +; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x) +; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y) +; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1 +; CHECK-NEXT: ret double %1 +} + +define double @sqrt_intrinsic_three_args6(double %x, double %y) #0 { + %mul = fmul fast double %x, %x + %mul2 = fmul fast double %y, %mul + %sqrt = call double @llvm.sqrt.f64(double %mul2) + ret double %sqrt + +; CHECK-LABEL: sqrt_intrinsic_three_args6( +; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x) +; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y) +; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1 +; CHECK-NEXT: ret double %1 +} + +define double @sqrt_intrinsic_arg_4th(double %x) #0 { + %mul = fmul fast double %x, %x + %mul2 = fmul fast double %mul, %mul + %sqrt = call double @llvm.sqrt.f64(double %mul2) + ret double %sqrt + +; CHECK-LABEL: sqrt_intrinsic_arg_4th( +; CHECK-NEXT: %mul = fmul fast double %x, %x +; CHECK-NEXT: ret double %mul +} + +define double @sqrt_intrinsic_arg_5th(double %x) #0 { + %mul = fmul fast double %x, %x + %mul2 = fmul fast double %mul, %x + %mul3 = fmul fast double %mul2, %mul + %sqrt = call double @llvm.sqrt.f64(double %mul3) + ret double %sqrt + +; CHECK-LABEL: sqrt_intrinsic_arg_5th( +; CHECK-NEXT: %mul = fmul fast double %x, %x +; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %x) +; CHECK-NEXT: %1 = fmul fast double %mul, %sqrt1 +; CHECK-NEXT: ret double %1 +} + +; Check that square root calls have the same behavior. + +declare float @sqrtf(float) +declare double @sqrt(double) +declare fp128 @sqrtl(fp128) + +define float @sqrt_call_squared_f32(float %x) #0 { + %mul = fmul fast float %x, %x + %sqrt = call float @sqrtf(float %mul) + ret float %sqrt + +; CHECK-LABEL: sqrt_call_squared_f32( +; CHECK-NEXT: %fabs = call float @llvm.fabs.f32(float %x) +; CHECK-NEXT: ret float %fabs +} + +define double @sqrt_call_squared_f64(double %x) #0 { + %mul = fmul fast double %x, %x + %sqrt = call double @sqrt(double %mul) + ret double %sqrt + +; CHECK-LABEL: sqrt_call_squared_f64( +; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x) +; CHECK-NEXT: ret double %fabs +} + +define fp128 @sqrt_call_squared_f128(fp128 %x) #0 { + %mul = fmul fast fp128 %x, %x + %sqrt = call fp128 @sqrtl(fp128 %mul) + ret fp128 %sqrt + +; CHECK-LABEL: sqrt_call_squared_f128( +; CHECK-NEXT: %fabs = call fp128 @llvm.fabs.f128(fp128 %x) +; CHECK-NEXT: ret fp128 %fabs +} +