Index: include/llvm/Target/TargetLowering.h =================================================================== --- include/llvm/Target/TargetLowering.h +++ include/llvm/Target/TargetLowering.h @@ -238,9 +238,10 @@ return false; } - /// Return true if sqrt(x) is as cheap or cheaper than 1 / rsqrt(x) - bool isFsqrtCheap() const { - return FsqrtIsCheap; + /// Return true if SQRT(X) shouldn't be replaced with X*RSQRT(X). + virtual bool isFsqrtCheap(SDValue X, SelectionDAG &DAG) const { + // Default behavior is to replace SQRT(X) with X*RSQRT(X). + return false; } /// Returns true if target has indicated at least one type should be bypassed. @@ -1356,10 +1357,6 @@ /// control. void setJumpIsExpensive(bool isExpensive = true); - /// Tells the code generator that fsqrt is cheap, and should not be replaced - /// with an alternative sequence of instructions. - void setFsqrtIsCheap(bool isCheap = true) { FsqrtIsCheap = isCheap; } - /// Tells the code generator that this target supports floating point /// exceptions and cares about preserving floating point exception behavior. void setHasFloatingPointExceptions(bool FPExceptions = true) { @@ -1880,9 +1877,6 @@ /// combined with "shift" to BitExtract instructions. bool HasExtractBitsInsn; - // Don't expand fsqrt with an approximation based on the inverse sqrt. - bool FsqrtIsCheap; - /// Tells the code generator to bypass slow divide or remainder /// instructions. For example, BypassSlowDivWidths[32,8] tells the code /// generator to bypass 32-bit integer div/rem with an 8-bit unsigned integer Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -8873,14 +8873,18 @@ } SDValue DAGCombiner::visitFSQRT(SDNode *N) { - if (!DAG.getTarget().Options.UnsafeFPMath || TLI.isFsqrtCheap()) + if (!DAG.getTarget().Options.UnsafeFPMath) + return SDValue(); + + SDValue N0 = N->getOperand(0); + if (TLI.isFsqrtCheap(N0, DAG)) return SDValue(); // TODO: FSQRT nodes should have flags that propagate to the created nodes. // For now, create a Flags object for use with all unsafe math transforms. SDNodeFlags Flags; Flags.setUnsafeAlgebra(true); - return buildSqrtEstimate(N->getOperand(0), &Flags); + return buildSqrtEstimate(N0, &Flags); } /// copysign(x, fp_extend(y)) -> copysign(x, y) Index: lib/CodeGen/TargetLoweringBase.cpp =================================================================== --- lib/CodeGen/TargetLoweringBase.cpp +++ lib/CodeGen/TargetLoweringBase.cpp @@ -807,7 +807,6 @@ SelectIsExpensive = false; HasMultipleConditionRegisters = false; HasExtractBitsInsn = false; - FsqrtIsCheap = false; JumpIsExpensive = JumpIsExpensiveOverride; PredictableSelectIsExpensive = false; MaskAndBranchFoldingIsLegal = false; Index: lib/Target/AMDGPU/AMDGPUISelLowering.h =================================================================== --- lib/Target/AMDGPU/AMDGPUISelLowering.h +++ lib/Target/AMDGPU/AMDGPUISelLowering.h @@ -175,6 +175,9 @@ const char* getTargetNodeName(unsigned Opcode) const override; + bool isFsqrtCheap(SDValue Operand, SelectionDAG &DAG) const override { + return true; + } SDValue getRsqrtEstimate(SDValue Operand, DAGCombinerInfo &DCI, unsigned &RefinementSteps, Index: lib/Target/AMDGPU/AMDGPUISelLowering.cpp =================================================================== --- lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -409,8 +409,6 @@ setSelectIsExpensive(false); PredictableSelectIsExpensive = false; - setFsqrtIsCheap(true); - // We want to find all load dependencies for long chains of stores to enable // merging into very wide vectors. The problem is with vectors with > 4 // elements. MergeConsecutiveStores will attempt to merge these because x8/x16 Index: lib/Target/X86/X86.td =================================================================== --- lib/Target/X86/X86.td +++ lib/Target/X86/X86.td @@ -247,6 +247,12 @@ def FeatureFastPartialYMMWrite : SubtargetFeature<"fast-partial-ymm-write", "HasFastPartialYMMWrite", "true", "Partial writes to YMM registers are fast">; +def FeatureFastScalarFSQRT + : SubtargetFeature<"fast-scalar-fsqrt", "HasFastScalarFSQRT", + "true", "Scalar SQRT is fast (disable Newton-Raphson)">; +def FeatureFastVectorFSQRT + : SubtargetFeature<"fast-vector-fsqrt", "HasFastVectorFSQRT", + "true", "Vector SQRT is fast (disable Newton-Raphson)">; //===----------------------------------------------------------------------===// // X86 processors supported. @@ -440,7 +446,8 @@ FeaturePCLMUL, FeatureXSAVE, FeatureXSAVEOPT, - FeatureLAHFSAHF + FeatureLAHFSAHF, + FeatureFastScalarFSQRT ]>; class SandyBridgeProc : ProcModel; // FIXME: define SKL model Index: lib/Target/X86/X86ISelLowering.h =================================================================== --- lib/Target/X86/X86ISelLowering.h +++ lib/Target/X86/X86ISelLowering.h @@ -1202,6 +1202,9 @@ /// Convert a comparison if required by the subtarget. SDValue ConvertCmpIfNecessary(SDValue Cmp, SelectionDAG &DAG) const; + /// Check if replacement of SQRT with RSQRT should be disabled. + bool isFsqrtCheap(SDValue Operand, SelectionDAG &DAG) const override; + /// Use rsqrt* to speed up sqrt calculations. SDValue getRsqrtEstimate(SDValue Operand, DAGCombinerInfo &DCI, unsigned &RefinementSteps, Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -14691,6 +14691,19 @@ return DAG.getNode(X86ISD::SAHF, dl, MVT::i32, TruncSrl); } +/// Check if replacement of SQRT with RSQRT should be disabled. +bool X86TargetLowering::isFsqrtCheap(SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + + // We never want to use both SQRT and RSQRT instructions for the same input. + if (DAG.getNodeIfExists(X86ISD::FRSQRT, DAG.getVTList(VT), Op)) + return false; + + if (VT.isVector()) + return Subtarget.hasFastVectorFSQRT(); + return Subtarget.hasFastScalarFSQRT(); +} + /// The minimum architected relative accuracy is 2^-12. We need one /// Newton-Raphson step to have a good float result (24 bits of precision). SDValue X86TargetLowering::getRsqrtEstimate(SDValue Op, Index: lib/Target/X86/X86Subtarget.h =================================================================== --- lib/Target/X86/X86Subtarget.h +++ lib/Target/X86/X86Subtarget.h @@ -195,6 +195,14 @@ /// of a YMM register without clearing the upper part. bool HasFastPartialYMMWrite; + /// True if hardware SQRTSS instruction is at least as fast (latency) as + /// RSQRTSS followed by a Newton-Raphson iteration. + bool HasFastScalarFSQRT; + + /// True if hardware SQRTPS/VSQRTPS instructions are at least as fast + /// (throughput) as RSQRTPS/VRSQRTPS followed by a Newton-Raphson iteration. + bool HasFastVectorFSQRT; + /// True if 8-bit divisions are significantly faster than /// 32-bit divisions and should be used when possible. bool HasSlowDivide32; @@ -429,6 +437,8 @@ bool hasCmpxchg16b() const { return HasCmpxchg16b; } bool useLeaForSP() const { return UseLeaForSP; } bool hasFastPartialYMMWrite() const { return HasFastPartialYMMWrite; } + bool hasFastScalarFSQRT() const { return HasFastScalarFSQRT; } + bool hasFastVectorFSQRT() const { return HasFastVectorFSQRT; } bool hasSlowDivide32() const { return HasSlowDivide32; } bool hasSlowDivide64() const { return HasSlowDivide64; } bool padShortFunctions() const { return PadShortFunctions; } Index: lib/Target/X86/X86Subtarget.cpp =================================================================== --- lib/Target/X86/X86Subtarget.cpp +++ lib/Target/X86/X86Subtarget.cpp @@ -322,6 +322,8 @@ HasCmpxchg16b = false; UseLeaForSP = false; HasFastPartialYMMWrite = false; + HasFastScalarFSQRT = false; + HasFastVectorFSQRT = false; HasSlowDivide32 = false; HasSlowDivide64 = false; PadShortFunctions = false; Index: test/CodeGen/X86/sqrt-fastmath-tune.ll =================================================================== --- /dev/null +++ test/CodeGen/X86/sqrt-fastmath-tune.ll @@ -0,0 +1,57 @@ +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -O2 -mcpu=nehalem | FileCheck %s --check-prefix=SCALAR-EST --check-prefix=VECTOR-EST +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -O2 -mcpu=sandybridge | FileCheck %s --check-prefix=SCALAR-ACC --check-prefix=VECTOR-EST +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -O2 -mcpu=broadwell | FileCheck %s --check-prefix=SCALAR-ACC --check-prefix=VECTOR-EST +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -O2 -mcpu=skylake | FileCheck %s --check-prefix=SCALAR-ACC --check-prefix=VECTOR-ACC + +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -O2 -mattr=+fast-scalar-fsqrt,-fast-vector-fsqrt | FileCheck %s --check-prefix=SCALAR-ACC --check-prefix=VECTOR-EST +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -O2 -mattr=-fast-scalar-fsqrt,+fast-vector-fsqrt | FileCheck %s --check-prefix=SCALAR-EST --check-prefix=VECTOR-ACC + +declare float @llvm.sqrt.f32(float) #0 +declare <4 x float> @llvm.sqrt.v4f32(<4 x float>) #0 +declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) #0 + +define float @foo_x1(float %f) #0 { +; SCALAR-EST-LABEL: foo_x1: +; SCALAR-EST: # BB#0: +; SCALAR-EST-NEXT: rsqrtss %xmm0 +; SCALAR-EST: retq +; +; SCALAR-ACC-LABEL: foo_x1: +; SCALAR-ACC: # BB#0: +; SCALAR-ACC-NEXT: {{^ *v?sqrtss %xmm0}} +; SCALAR-ACC-NEXT: retq + %call = tail call float @llvm.sqrt.f32(float %f) #1 + ret float %call +} + +define <4 x float> @foo_x4(<4 x float> %f) #0 { +; VECTOR-EST-LABEL: foo_x4: +; VECTOR-EST: # BB#0: +; VECTOR-EST-NEXT: rsqrtps %xmm0 +; VECTOR-EST: retq +; +; VECTOR-ACC-LABEL: foo_x4: +; VECTOR-ACC: # BB#0: +; VECTOR-ACC-NEXT: {{^ *v?sqrtps %xmm0}} +; VECTOR-ACC-NEXT: retq + %call = tail call <4 x float> @llvm.sqrt.v4f32(<4 x float> %f) #1 + ret <4 x float> %call +} + +define <8 x float> @foo_x8(<8 x float> %f) #0 { +; VECTOR-EST-LABEL: foo_x8: +; VECTOR-EST: # BB#0: +; VECTOR-EST-NEXT: rsqrtps +; VECTOR-EST: retq +; +; VECTOR-ACC-LABEL: foo_x8: +; VECTOR-ACC: # BB#0: +; VECTOR-ACC-NEXT: {{^ *v?sqrtps %[xy]mm0}} +; VECTOR-ACC-NOT: rsqrt +; VECTOR-ACC: retq + %call = tail call <8 x float> @llvm.sqrt.v8f32(<8 x float> %f) #1 + ret <8 x float> %call +} + +attributes #0 = { "unsafe-fp-math"="true" } +attributes #1 = { nounwind readnone }