diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -560,10 +560,14 @@ setOperationAction(Op, MVT::f64, Legal); setOperationAction(Op, MVT::v2f16, Expand); } - setOperationAction(ISD::FMINNUM, MVT::f16, Promote); - setOperationAction(ISD::FMAXNUM, MVT::f16, Promote); - setOperationAction(ISD::FMINIMUM, MVT::f16, Promote); - setOperationAction(ISD::FMAXIMUM, MVT::f16, Promote); + // max.f16 is supported on sm_80+. + if (STI.allowFP16Math() && STI.getSmVersion() >= 80 && + STI.getPTXVersion() >= 70) { + setOperationAction(ISD::FMINNUM, MVT::f16, Legal); + setOperationAction(ISD::FMAXNUM, MVT::f16, Legal); + setOperationAction(ISD::FMINNUM, MVT::v2f16, Legal); + setOperationAction(ISD::FMAXNUM, MVT::v2f16, Legal); + } // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. // No FPOW or FREM in PTX. diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -249,6 +249,32 @@ (ins Float32Regs:$a, f32imm:$b), !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"), [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>; + + def f16rr_ftz : + NVPTXInst<(outs Float16Regs:$dst), + (ins Float16Regs:$a, Float16Regs:$b), + !strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"), + [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + Requires<[useFP16Math, doF32FTZ]>; + def f16rr : + NVPTXInst<(outs Float16Regs:$dst), + (ins Float16Regs:$a, Float16Regs:$b), + !strconcat(OpcStr, ".f16 \t$dst, $a, $b;"), + [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + Requires<[useFP16Math]>; + + def f16x2rr_ftz : + NVPTXInst<(outs Float16x2Regs:$dst), + (ins Float16x2Regs:$a, Float16x2Regs:$b), + !strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"), + [(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>, + Requires<[useFP16Math, doF32FTZ]>; + def f16x2rr : + NVPTXInst<(outs Float16x2Regs:$dst), + (ins Float16x2Regs:$a, Float16x2Regs:$b), + !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"), + [(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>, + Requires<[useFP16Math]>; } // Template for instructions which take three FP args. The diff --git a/llvm/test/CodeGen/NVPTX/math-intrins.ll b/llvm/test/CodeGen/NVPTX/math-intrins.ll --- a/llvm/test/CodeGen/NVPTX/math-intrins.ll +++ b/llvm/test/CodeGen/NVPTX/math-intrins.ll @@ -1,4 +1,6 @@ -; RUN: llc < %s | FileCheck %s +; RUN: llc < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16 +; RUN: llc < %s -mcpu=sm_80 | FileCheck %s --check-prefixes=CHECK,CHECK-F16 +; RUN: llc < %s -mcpu=sm_80 --nvptx-no-f16-math | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16 target triple = "nvptx64-nvidia-cuda" ; Checks that llvm intrinsics for math functions are correctly lowered to PTX. @@ -17,10 +19,14 @@ declare double @llvm.trunc.f64(double) #0 declare float @llvm.fabs.f32(float) #0 declare double @llvm.fabs.f64(double) #0 +declare half @llvm.minnum.f16(half, half) #0 declare float @llvm.minnum.f32(float, float) #0 declare double @llvm.minnum.f64(double, double) #0 +declare <2 x half> @llvm.minnum.v2f16(<2 x half>, <2 x half>) #0 +declare half @llvm.maxnum.f16(half, half) #0 declare float @llvm.maxnum.f32(float, float) #0 declare double @llvm.maxnum.f64(double, double) #0 +declare <2 x half> @llvm.maxnum.v2f16(<2 x half>, <2 x half>) #0 declare float @llvm.fma.f32(float, float, float) #0 declare double @llvm.fma.f64(double, double, double) #0 @@ -193,6 +199,14 @@ ; ---- min ---- +; CHECK-LABEL: min_half +define half @min_half(half %a, half %b) { + ; CHECK-NOF16: min.f32 + ; CHECK-F16: min.f16 + %x = call half @llvm.minnum.f16(half %a, half %b) + ret half %x +} + ; CHECK-LABEL: min_float define float @min_float(float %a, float %b) { ; CHECK: min.f32 @@ -228,8 +242,25 @@ ret double %x } +; CHECK-LABEL: min_v2half +define <2 x half> @min_v2half(<2 x half> %a, <2 x half> %b) { + ; CHECK-NOF16: min.f32 + ; CHECK-NOF16: min.f32 + ; CHECK-F16: min.f16x2 + %x = call <2 x half> @llvm.minnum.v2f16(<2 x half> %a, <2 x half> %b) + ret <2 x half> %x +} + ; ---- max ---- +; CHECK-LABEL: max_half +define half @max_half(half %a, half %b) { + ; CHECK-NOF16: max.f32 + ; CHECK-F16: max.f16 + %x = call half @llvm.maxnum.f16(half %a, half %b) + ret half %x +} + ; CHECK-LABEL: max_imm1 define float @max_imm1(float %a) { ; CHECK: max.f32 @@ -265,6 +296,15 @@ ret double %x } +; CHECK-LABEL: max_v2half +define <2 x half> @max_v2half(<2 x half> %a, <2 x half> %b) { + ; CHECK-NOF16: max.f32 + ; CHECK-NOF16: max.f32 + ; CHECK-F16: max.f16x2 + %x = call <2 x half> @llvm.maxnum.v2f16(<2 x half> %a, <2 x half> %b) + ret <2 x half> %x +} + ; ---- fma ---- ; CHECK-LABEL: @fma_float