Index: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h =================================================================== --- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h +++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h @@ -526,6 +526,10 @@ // to sign-preserving zero. bool useF32FTZ(const MachineFunction &MF) const; + SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, + int &ExtraSteps, bool &UseOneConst, + bool Reciprocal) const override; + bool allowFMA(MachineFunction &MF, CodeGenOpt::Level OptLevel) const; bool allowUnsafeFPMath(MachineFunction &MF) const; Index: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1043,6 +1043,50 @@ return TargetLoweringBase::getPreferredVectorAction(VT); } +SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, + int Enabled, int &ExtraSteps, + bool &UseOneConst, + bool Reciprocal) const { + if (!(Enabled == ReciprocalEstimate::Enabled || + (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32()))) + return SDValue(); + + if (ExtraSteps == ReciprocalEstimate::Unspecified) + ExtraSteps = 0; + + SDLoc DL(Operand); + EVT VT = Operand.getValueType(); + bool Ftz = useF32FTZ(DAG.getMachineFunction()); + + auto MakeIntrinsicCall = [&](Intrinsic::ID IID) { + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(IID, DL, MVT::i32), Operand); + }; + + // The sqrt and rsqrt refinement processes assume we always start out with an + // approximation of the rsqrt. Therefore, if we're going to do any refinement + // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing + // any refinement, we must return a regular sqrt. + if (Reciprocal || ExtraSteps > 0) { + if (VT == MVT::f32) + return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f + : Intrinsic::nvvm_rsqrt_approx_f); + else if (VT == MVT::f64) + return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d); + else + return SDValue(); + } else { + if (VT == MVT::f32) + return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f + : Intrinsic::nvvm_sqrt_approx_f); + else { + // There's no sqrt.approx.f64 instruction, so we emit x * rsqrt(x). + return DAG.getNode(ISD::FMUL, DL, VT, Operand, + MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d)); + } + } +} + SDValue NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); Index: llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td =================================================================== --- llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td +++ llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -966,18 +966,9 @@ Requires<[reqPTX20]>; // -// F32 rsqrt +// FMA // -def RSQRTF32approx1r : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$b), - "rsqrt.approx.f32 \t$dst, $b;", []>; - -// Convert 1.0f/sqrt(x) to rsqrt.approx.f32. (There is an rsqrt.approx.f64, but -// it's emulated in software.) -def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$b)), - (RSQRTF32approx1r Float32Regs:$b)>, - Requires<[do_DIVF32_FULL, do_SQRTF32_APPROX, doNoF32FTZ]>; - multiclass FMA { def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), Index: llvm/trunk/test/CodeGen/NVPTX/fast-math.ll =================================================================== --- llvm/trunk/test/CodeGen/NVPTX/fast-math.ll +++ llvm/trunk/test/CodeGen/NVPTX/fast-math.ll @@ -1,25 +1,91 @@ ; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s -declare float @llvm.nvvm.sqrt.f(float) +declare float @llvm.sqrt.f32(float) +declare double @llvm.sqrt.f64(double) -; CHECK-LABEL: sqrt_div +; CHECK-LABEL: sqrt_div( ; CHECK: sqrt.rn.f32 ; CHECK: div.rn.f32 define float @sqrt_div(float %a, float %b) { - %t1 = tail call float @llvm.nvvm.sqrt.f(float %a) + %t1 = tail call float @llvm.sqrt.f32(float %a) %t2 = fdiv float %t1, %b ret float %t2 } -; CHECK-LABEL: sqrt_div_fast +; CHECK-LABEL: sqrt_div_fast( ; CHECK: sqrt.approx.f32 ; CHECK: div.approx.f32 define float @sqrt_div_fast(float %a, float %b) #0 { - %t1 = tail call float @llvm.nvvm.sqrt.f(float %a) + %t1 = tail call float @llvm.sqrt.f32(float %a) %t2 = fdiv float %t1, %b ret float %t2 } +; CHECK-LABEL: sqrt_div_ftz( +; CHECK: sqrt.rn.ftz.f32 +; CHECK: div.rn.ftz.f32 +define float @sqrt_div_ftz(float %a, float %b) #1 { + %t1 = tail call float @llvm.sqrt.f32(float %a) + %t2 = fdiv float %t1, %b + ret float %t2 +} + +; CHECK-LABEL: sqrt_div_fast_ftz( +; CHECK: sqrt.approx.ftz.f32 +; CHECK: div.approx.ftz.f32 +define float @sqrt_div_fast_ftz(float %a, float %b) #0 #1 { + %t1 = tail call float @llvm.sqrt.f32(float %a) + %t2 = fdiv float %t1, %b + ret float %t2 +} + +; There are no fast-math or ftz versions of sqrt and div for f64. We use +; x * rsqrt(x) for sqrt(x), and emit a vanilla divide. + +; CHECK-LABEL: sqrt_div_fast_ftz_f64( +; CHECK: rsqrt.approx.f64 +; CHECK: mul.f64 +; CHECK: div.rn.f64 +define double @sqrt_div_fast_ftz_f64(double %a, double %b) #0 #1 { + %t1 = tail call double @llvm.sqrt.f64(double %a) + %t2 = fdiv double %t1, %b + ret double %t2 +} + +; CHECK-LABEL: rsqrt( +; CHECK-NOT: rsqrt.approx +; CHECK: sqrt.rn.f32 +; CHECK-NOT: rsqrt.approx +define float @rsqrt(float %a) { + %b = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %b + ret float %ret +} + +; CHECK-LABEL: rsqrt_fast( +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +; CHECK: rsqrt.approx.f32 +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +define float @rsqrt_fast(float %a) #0 { + %b = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %b + ret float %ret +} + +; CHECK-LABEL: rsqrt_fast_ftz( +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +; CHECK: rsqrt.approx.ftz.f32 +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +define float @rsqrt_fast_ftz(float %a) #0 #1 { + %b = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %b + ret float %ret +} + ; CHECK-LABEL: fadd ; CHECK: add.rn.f32 define float @fadd(float %a, float %b) { Index: llvm/trunk/test/CodeGen/NVPTX/rsqrt.ll =================================================================== --- llvm/trunk/test/CodeGen/NVPTX/rsqrt.ll +++ llvm/trunk/test/CodeGen/NVPTX/rsqrt.ll @@ -1,13 +0,0 @@ -; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-prec-divf32=1 -nvptx-prec-sqrtf32=0 | FileCheck %s - -target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" - -declare float @llvm.nvvm.sqrt.f(float) - -define float @foo(float %a) { -; CHECK: rsqrt.approx.f32 - %val = tail call float @llvm.nvvm.sqrt.f(float %a) - %ret = fdiv float 1.0, %val - ret float %ret -} - Index: llvm/trunk/test/CodeGen/NVPTX/sqrt-approx.ll =================================================================== --- llvm/trunk/test/CodeGen/NVPTX/sqrt-approx.ll +++ llvm/trunk/test/CodeGen/NVPTX/sqrt-approx.ll @@ -0,0 +1,148 @@ +; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-prec-divf32=0 -nvptx-prec-sqrtf32=0 \ +; RUN: | FileCheck %s + +target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" + +declare float @llvm.sqrt.f32(float) +declare double @llvm.sqrt.f64(double) + +; -- reciprocal sqrt -- + +; CHECK-LABEL test_rsqrt32 +define float @test_rsqrt32(float %a) #0 { +; CHECK: rsqrt.approx.f32 + %val = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %val + ret float %ret +} + +; CHECK-LABEL test_rsqrt_ftz +define float @test_rsqrt_ftz(float %a) #0 #1 { +; CHECK: rsqrt.approx.ftz.f32 + %val = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %val + ret float %ret +} + +; CHECK-LABEL test_rsqrt64 +define double @test_rsqrt64(double %a) #0 { +; CHECK: rsqrt.approx.f64 + %val = tail call double @llvm.sqrt.f64(double %a) + %ret = fdiv double 1.0, %val + ret double %ret +} + +; CHECK-LABEL test_rsqrt64_ftz +define double @test_rsqrt64_ftz(double %a) #0 #1 { +; There's no rsqrt.approx.ftz.f64 instruction; we just use the non-ftz version. +; CHECK: rsqrt.approx.f64 + %val = tail call double @llvm.sqrt.f64(double %a) + %ret = fdiv double 1.0, %val + ret double %ret +} + +; -- sqrt -- + +; CHECK-LABEL test_sqrt32 +define float @test_sqrt32(float %a) #0 { +; CHECK: sqrt.approx.f32 + %ret = tail call float @llvm.sqrt.f32(float %a) + ret float %ret +} + +; CHECK-LABEL test_sqrt_ftz +define float @test_sqrt_ftz(float %a) #0 #1 { +; CHECK: sqrt.approx.ftz.f32 + %ret = tail call float @llvm.sqrt.f32(float %a) + ret float %ret +} + +; CHECK-LABEL test_sqrt64 +define double @test_sqrt64(double %a) #0 { +; There's no sqrt.approx.f64 instruction; we emit x * rsqrt.approx.f64(x). +; CHECK: rsqrt.approx.f64 +; CHECK: mul.f64 + %ret = tail call double @llvm.sqrt.f64(double %a) + ret double %ret +} + +; CHECK-LABEL test_sqrt64_ftz +define double @test_sqrt64_ftz(double %a) #0 #1 { +; There's no sqrt.approx.ftz.f64 instruction; we just use the non-ftz version. +; CHECK: rsqrt.approx.f64 +; CHECK: mul.f64 + %ret = tail call double @llvm.sqrt.f64(double %a) + ret double %ret +} + +; -- refined sqrt and rsqrt -- +; +; The sqrt and rsqrt refinement algorithms both emit an rsqrt.approx, followed +; by some math. + +; CHECK-LABEL: test_rsqrt32_refined +define float @test_rsqrt32_refined(float %a) #0 #2 { +; CHECK: rsqrt.approx.f32 + %val = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %val + ret float %ret +} + +; CHECK-LABEL: test_sqrt32_refined +define float @test_sqrt32_refined(float %a) #0 #2 { +; CHECK: rsqrt.approx.f32 + %ret = tail call float @llvm.sqrt.f32(float %a) + ret float %ret +} + +; CHECK-LABEL: test_rsqrt64_refined +define double @test_rsqrt64_refined(double %a) #0 #2 { +; CHECK: rsqrt.approx.f64 + %val = tail call double @llvm.sqrt.f64(double %a) + %ret = fdiv double 1.0, %val + ret double %ret +} + +; CHECK-LABEL: test_sqrt64_refined +define double @test_sqrt64_refined(double %a) #0 #2 { +; CHECK: rsqrt.approx.f64 + %ret = tail call double @llvm.sqrt.f64(double %a) + ret double %ret +} + +; -- refined sqrt and rsqrt with ftz enabled -- + +; CHECK-LABEL: test_rsqrt32_refined_ftz +define float @test_rsqrt32_refined_ftz(float %a) #0 #1 #2 { +; CHECK: rsqrt.approx.ftz.f32 + %val = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %val + ret float %ret +} + +; CHECK-LABEL: test_sqrt32_refined_ftz +define float @test_sqrt32_refined_ftz(float %a) #0 #1 #2 { +; CHECK: rsqrt.approx.ftz.f32 + %ret = tail call float @llvm.sqrt.f32(float %a) + ret float %ret +} + +; CHECK-LABEL: test_rsqrt64_refined_ftz +define double @test_rsqrt64_refined_ftz(double %a) #0 #1 #2 { +; There's no rsqrt.approx.ftz.f64, so we just use the non-ftz version. +; CHECK: rsqrt.approx.f64 + %val = tail call double @llvm.sqrt.f64(double %a) + %ret = fdiv double 1.0, %val + ret double %ret +} + +; CHECK-LABEL: test_sqrt64_refined_ftz +define double @test_sqrt64_refined_ftz(double %a) #0 #1 #2 { +; CHECK: rsqrt.approx.f64 + %ret = tail call double @llvm.sqrt.f64(double %a) + ret double %ret +} + +attributes #0 = { "unsafe-fp-math" = "true" } +attributes #1 = { "nvptx-f32ftz" = "true" } +attributes #2 = { "reciprocal-estimates" = "rsqrtf:1,rsqrtd:1,sqrtf:1,sqrtd:1" }