Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.h =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -515,6 +515,10 @@ bool usePrecSqrtF32() const; bool useF32FTZ(const MachineFunction &MF) const; + SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, + int &ExtraSteps, bool &UseOneConst, + bool Reciprocal) const override; + bool allowFMA(MachineFunction &MF, CodeGenOpt::Level OptLevel) const; bool allowUnsafeFPMath(MachineFunction &MF) const; Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1042,6 +1042,43 @@ return TargetLoweringBase::getPreferredVectorAction(VT); } +SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, + int Enabled, int &ExtraSteps, + bool &UseOneConst, + bool Reciprocal) const { + if (!(Enabled == ReciprocalEstimate::Enabled || + (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32()))) + return SDValue(); + + EVT VT = Operand.getValueType(); + bool Ftz = useF32FTZ(DAG.getMachineFunction()); + unsigned IID; + if (Reciprocal) { + if (VT == MVT::f32) + IID = Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f + : Intrinsic::nvvm_rsqrt_approx_f; + else if (VT == MVT::f64) + IID = Intrinsic::nvvm_rsqrt_approx_d; + else + return SDValue(); + } else { + if (VT == MVT::f32) + IID = Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f + : Intrinsic::nvvm_sqrt_approx_f; + else + return SDValue(); + } + + // TODO: We should probably lower approx sqrt(f64) as 1/rsqrt. + + if (ExtraSteps == ReciprocalEstimate::Unspecified) + ExtraSteps = 0; + + SDLoc DL(Operand); + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(IID, DL, MVT::i32), Operand); +} + SDValue NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -964,18 +964,9 @@ Requires<[reqPTX20]>; // -// F32 rsqrt +// FMA // -def RSQRTF32approx1r : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$b), - "rsqrt.approx.f32 \t$dst, $b;", []>; - -// Convert 1.0f/sqrt(x) to rsqrt.approx.f32. (There is an rsqrt.approx.f64, but -// it's emulated in software.) -def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$b)), - (RSQRTF32approx1r Float32Regs:$b)>, - Requires<[do_DIVF32_FULL, do_SQRTF32_APPROX, doNoF32FTZ]>; - multiclass FMA { def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), Index: llvm/test/CodeGen/NVPTX/fast-math.ll =================================================================== --- llvm/test/CodeGen/NVPTX/fast-math.ll +++ llvm/test/CodeGen/NVPTX/fast-math.ll @@ -1,25 +1,90 @@ ; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s -declare float @llvm.nvvm.sqrt.f(float) +declare float @llvm.sqrt.f32(float) +declare double @llvm.sqrt.f64(double) -; CHECK-LABEL: sqrt_div +; CHECK-LABEL: sqrt_div( ; CHECK: sqrt.rn.f32 ; CHECK: div.rn.f32 define float @sqrt_div(float %a, float %b) { - %t1 = tail call float @llvm.nvvm.sqrt.f(float %a) + %t1 = tail call float @llvm.sqrt.f32(float %a) %t2 = fdiv float %t1, %b ret float %t2 } -; CHECK-LABEL: sqrt_div_fast +; CHECK-LABEL: sqrt_div_fast( ; CHECK: sqrt.approx.f32 ; CHECK: div.approx.f32 define float @sqrt_div_fast(float %a, float %b) #0 { - %t1 = tail call float @llvm.nvvm.sqrt.f(float %a) + %t1 = tail call float @llvm.sqrt.f32(float %a) %t2 = fdiv float %t1, %b ret float %t2 } +; CHECK-LABEL: sqrt_div_ftz( +; CHECK: sqrt.rn.ftz.f32 +; CHECK: div.rn.ftz.f32 +define float @sqrt_div_ftz(float %a, float %b) #1 { + %t1 = tail call float @llvm.sqrt.f32(float %a) + %t2 = fdiv float %t1, %b + ret float %t2 +} + +; CHECK-LABEL: sqrt_div_fast_ftz( +; CHECK: sqrt.approx.ftz.f32 +; CHECK: div.approx.ftz.f32 +define float @sqrt_div_fast_ftz(float %a, float %b) #0 #1 { + %t1 = tail call float @llvm.sqrt.f32(float %a) + %t2 = fdiv float %t1, %b + ret float %t2 +} + +; There are no fast-math or ftz versions of sqrt and div for f64; we just emit +; the vanilla instructions. +; +; CHECK-LABEL: sqrt_div_fast_ftz_f64( +; CHECK: sqrt.rn.f64 +; CHECK: div.rn.f64 +define double @sqrt_div_fast_ftz_f64(double %a, double %b) #0 #1 { + %t1 = tail call double @llvm.sqrt.f64(double %a) + %t2 = fdiv double %t1, %b + ret double %t2 +} + +; CHECK-LABEL: rsqrt( +; CHECK-NOT: rsqrt.approx +; CHECK: sqrt.rn.f32 +; CHECK-NOT: rsqrt.approx +define float @rsqrt(float %a) { + %b = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %b + ret float %ret +} + +; CHECK-LABEL: rsqrt_fast( +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +; CHECK: rsqrt.approx.f32 +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +define float @rsqrt_fast(float %a) #0 { + %b = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %b + ret float %ret +} + +; CHECK-LABEL: rsqrt_fast_ftz( +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +; CHECK: rsqrt.approx.ftz.f32 +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +define float @rsqrt_fast_ftz(float %a) #0 #1 { + %b = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %b + ret float %ret +} + ; CHECK-LABEL: fadd ; CHECK: add.rn.f32 define float @fadd(float %a, float %b) { Index: llvm/test/CodeGen/NVPTX/rsqrt.ll =================================================================== --- llvm/test/CodeGen/NVPTX/rsqrt.ll +++ llvm/test/CodeGen/NVPTX/rsqrt.ll @@ -1,13 +1,34 @@ -; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-prec-divf32=1 -nvptx-prec-sqrtf32=0 | FileCheck %s +; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-prec-divf32=0 -nvptx-prec-sqrtf32=0 \ +; RUN: | FileCheck %s target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" -declare float @llvm.nvvm.sqrt.f(float) +declare float @llvm.sqrt.f32(float) +declare double @llvm.sqrt.f64(double) -define float @foo(float %a) { +; CHECK-LABEL test_rsqrt32 +define float @test_rsqrt32(float %a) #0 { ; CHECK: rsqrt.approx.f32 - %val = tail call float @llvm.nvvm.sqrt.f(float %a) + %val = tail call float @llvm.sqrt.f32(float %a) %ret = fdiv float 1.0, %val ret float %ret } - + +; CHECK-LABEL test_rsqrt_ftz32 +define float @test_rsqrt_ftz32(float %a) #0 #1 { +; CHECK: rsqrt.approx.ftz.f32 + %val = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %val + ret float %ret +} + +; CHECK-LABEL test_rsqrt64 +define double @test_rsqrt64(double %a) #0 { +; CHECK: rsqrt.approx.f64 + %val = tail call double @llvm.sqrt.f64(double %a) + %ret = fdiv double 1.0, %val + ret double %ret +} + +attributes #0 = { "unsafe-fp-math" = "true" } +attributes #1 = { "nvptx-f32ftz" = "true" }