Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -757,7 +757,27 @@ defm FABS : F2<"abs", fabs>; defm FNEG : F2<"neg", fneg>; -defm FSQRT : F2<"sqrt.rn", fsqrt>; + +// +// sqrt +// + +multiclass FSQRT_f32 Preds> { + def : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a), + !strconcat("sqrt.", Options, "f32 \t$dst, $a;"), + [(set Float32Regs:$dst, (fsqrt Float32Regs:$a))]>, + Requires; +} +defm FSQRT_f32_approx_ftz : + FSQRT_f32<"approx.ftz.", [doF32FTZ, do_SQRTF32_APPROX]>; +defm FSQRT_f32_approx : FSQRT_f32<"approx.", [do_SQRTF32_APPROX]>; +defm FSQRT_f32_ftz : FSQRT_f32<"rn.ftz.", [doF32FTZ]>; +defm FSQRT_f32_noftz : FSQRT_f32<"rn.", []>; + +def FSQRT_f64 : + NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a), + "sqrt.rn.f64 \t$dst, $a;", + [(set Float64Regs:$dst, (fsqrt Float64Regs:$a))]>; // // F64 division @@ -908,17 +928,45 @@ Requires<[reqPTX20]>; // -// F32 rsqrt +// F32 rsqrt. When do_SQRTF32_APPROX and do_DIVF32_APPROX are enabled, we can +// transform 1.0f/sqrt(x) into rsqrt.approx.f32. +// +// We do this for both @llvm.sqrt.f32 and @llvm.nvvm.sqrt.f. Ideally we'd only +// do it for the generic LLVM intrinsic, on the assumption that if you use the +// nvvm-specific intrinsic, you really want that particular instruction. But +// libdevice and the CUDA headers emit llvm.nvvm.sqrt.f, and we want this +// transformation to apply there. +// +// TODO: Should we turn this on when only one of the *APPROX flags is enabled? +// Our value is already approximate... // -def RSQRTF32approx1r : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$b), - "rsqrt.approx.f32 \t$dst, $b;", []>; +def RSQRTF32approx : + NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a), + "rsqrt.approx.f32 \t$dst, $a;", []>; +def RSQRTF32approx_ftz : + NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a), + "rsqrt.approx.ftz.f32 \t$dst, $a;", []>; -// Convert 1.0f/sqrt(x) to rsqrt.approx.f32. (There is an rsqrt.approx.f64, but -// it's emulated in software.) -def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$b)), - (RSQRTF32approx1r Float32Regs:$b)>, - Requires<[do_DIVF32_FULL, do_SQRTF32_APPROX, doNoF32FTZ]>; +// 1.0f / @llvm.nvvm.sqrt.f(x) -> rsqrt.approx{.ftz}(x) +def : Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)), + (RSQRTF32approx_ftz Float32Regs:$a)>, + Requires<[do_SQRTF32_APPROX, do_DIVF32_APPROX, doF32FTZ]>; +def : Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)), + (RSQRTF32approx Float32Regs:$a)>, + Requires<[do_SQRTF32_APPROX, do_DIVF32_APPROX]>; + +// 1.0f / @llvm.sqrt.f32(x) -> rsqrt.approx{.ftz}(x) +def : Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)), + (RSQRTF32approx_ftz Float32Regs:$a)>, + Requires<[do_SQRTF32_APPROX, do_DIVF32_APPROX, doF32FTZ]>; +def : Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)), + (RSQRTF32approx Float32Regs:$a)>, + Requires<[do_SQRTF32_APPROX, do_DIVF32_APPROX]>; + +// +// FMA +// multiclass FMA { def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), Index: llvm/test/CodeGen/NVPTX/fast-math.ll =================================================================== --- llvm/test/CodeGen/NVPTX/fast-math.ll +++ llvm/test/CodeGen/NVPTX/fast-math.ll @@ -1,23 +1,97 @@ ; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s +declare float @llvm.sqrt.f32(float) +declare double @llvm.sqrt.f64(double) declare float @llvm.nvvm.sqrt.f(float) -; CHECK-LABEL: sqrt_div +; CHECK-LABEL: sqrt_div( ; CHECK: sqrt.rn.f32 ; CHECK: div.rn.f32 +; CHECK: sqrt.rn.f32 define float @sqrt_div(float %a, float %b) { - %t1 = tail call float @llvm.nvvm.sqrt.f(float %a) + %t1 = tail call float @llvm.sqrt.f32(float %a) %t2 = fdiv float %t1, %b - ret float %t2 + %t3 = tail call float @llvm.nvvm.sqrt.f(float %t2); + ret float %t3 } -; CHECK-LABEL: sqrt_div_fast +; CHECK-LABEL: sqrt_div_fast( ; CHECK: sqrt.approx.f32 ; CHECK: div.approx.f32 +; CHECK: sqrt.approx.f32 define float @sqrt_div_fast(float %a, float %b) #0 { - %t1 = tail call float @llvm.nvvm.sqrt.f(float %a) + %t1 = tail call float @llvm.sqrt.f32(float %a) %t2 = fdiv float %t1, %b - ret float %t2 + %t3 = tail call float @llvm.nvvm.sqrt.f(float %t2); + ret float %t3 +} + +; CHECK-LABEL: sqrt_div_ftz( +; CHECK: sqrt.rn.ftz.f32 +; CHECK: div.rn.ftz.f32 +; CHECK: sqrt.rn.ftz.f32 +define float @sqrt_div_ftz(float %a, float %b) #1 { + %t1 = tail call float @llvm.sqrt.f32(float %a) + %t2 = fdiv float %t1, %b + %t3 = tail call float @llvm.nvvm.sqrt.f(float %t2); + ret float %t3 +} + +; CHECK-LABEL: sqrt_div_fast_ftz( +; CHECK: sqrt.approx.ftz.f32 +; CHECK: div.approx.ftz.f32 +; CHECK: sqrt.approx.ftz.f32 +define float @sqrt_div_fast_ftz(float %a, float %b) #0 #1 { + %t1 = tail call float @llvm.sqrt.f32(float %a) + %t2 = fdiv float %t1, %b + %t3 = tail call float @llvm.nvvm.sqrt.f(float %t2); + ret float %t3 +} + +; There are no fast-math or ftz versions of sqrt and div; we just emit the +; vanilla instructions. There's also no @llvm.nvvm.sqrt.d intrinsic. +; +; CHECK-LABEL: sqrt_div_fast_ftz_f64( +; CHECK: sqrt.rn.f64 +; CHECK: div.rn.f64 +define double @sqrt_div_fast_ftz_f64(double %a, double %b) #0 #1 { + %t1 = tail call double @llvm.sqrt.f64(double %a) + %t2 = fdiv double %t1, %b + ret double %t2 +} + +; CHECK-LABEL: rsqrt( +; CHECK-NOT: rsqrt.approx +; CHECK: sqrt.rn.f32 +; CHECK-NOT: rsqrt.approx +define float @rsqrt(float %a) { + %b = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %b + ret float %ret +} + +; CHECK-LABEL: rsqrt_fast( +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +; CHECK: rsqrt.approx.f32 +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +define float @rsqrt_fast(float %a) #0 { + %b = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %b + ret float %ret +} + +; CHECK-LABEL: rsqrt_fast_ftz( +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +; CHECK: rsqrt.approx.ftz.f32 +; CHECK-NOT: div. +; CHECK-NOT: sqrt. +define float @rsqrt_fast_ftz(float %a) #0 #1 { + %b = tail call float @llvm.sqrt.f32(float %a) + %ret = fdiv float 1.0, %b + ret float %ret } ; CHECK-LABEL: fadd Index: llvm/test/CodeGen/NVPTX/rsqrt.ll =================================================================== --- llvm/test/CodeGen/NVPTX/rsqrt.ll +++ llvm/test/CodeGen/NVPTX/rsqrt.ll @@ -1,4 +1,5 @@ -; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-prec-divf32=1 -nvptx-prec-sqrtf32=0 | FileCheck %s +; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-prec-divf32=0 -nvptx-prec-sqrtf32=0 \ +; RUN: | FileCheck %s target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" @@ -10,4 +11,3 @@ %ret = fdiv float 1.0, %val ret float %ret } -