diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -553,20 +553,29 @@ // These map to corresponding instructions for f32/f64. f16 must be // promoted to f32. v2f16 is expanded to f16, which is then promoted // to f32. - for (const auto &Op : {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, - ISD::FABS, ISD::FMINNUM, ISD::FMAXNUM}) { + for (const auto &Op : + {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FABS}) { setOperationAction(Op, MVT::f16, Promote); setOperationAction(Op, MVT::f32, Legal); setOperationAction(Op, MVT::f64, Legal); setOperationAction(Op, MVT::v2f16, Expand); } - // max.f16 is supported on sm_80+. - if (STI.allowFP16Math() && STI.getSmVersion() >= 80 && - STI.getPTXVersion() >= 70) { - setOperationAction(ISD::FMINNUM, MVT::f16, Legal); - setOperationAction(ISD::FMAXNUM, MVT::f16, Legal); - setOperationAction(ISD::FMINNUM, MVT::v2f16, Legal); - setOperationAction(ISD::FMAXNUM, MVT::v2f16, Legal); + // max.f16, max.f16x2 and max.NaN are supported on sm_80+. + auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) { + bool IsAtLeastSm80 = STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70; + return IsAtLeastSm80 ? Legal : NotSm80Action; + }; + for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) { + setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote); + setOperationAction(Op, MVT::f32, Legal); + setOperationAction(Op, MVT::f64, Legal); + setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); + } + for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) { + setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand); + setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand)); + setOperationAction(Op, MVT::f64, GetMinMaxAction(Expand)); + setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); } // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -868,6 +868,8 @@ defm FMIN : F3<"min", fminnum>; defm FMAX : F3<"max", fmaxnum>; +defm FMINNAN : F3<"min.NaN", fminimum>; +defm FMAXNAN : F3<"max.NaN", fmaximum>; defm FABS : F2<"abs", fabs>; defm FNEG : F2<"neg", fneg>; diff --git a/llvm/test/CodeGen/NVPTX/fminimum-fmaximum.ll b/llvm/test/CodeGen/NVPTX/fminimum-fmaximum.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/fminimum-fmaximum.ll @@ -0,0 +1,92 @@ +; RUN: llc < %s -march=nvptx | FileCheck %s --check-prefixes=CHECK,CHECK-NONAN +; RUN: llc < %s -march=nvptx -mcpu=sm_80 | FileCheck %s --check-prefixes=CHECK,CHECK-NAN + +; ---- minimum ---- + +; CHECK-LABEL: minimum_half +define half @minimum_half(half %a) #0 { + ; CHECK-NONAN: setp + ; CHECK-NONAN: selp.b16 + ; CHECK-NAN: min.NaN.f16 + %p = fcmp ult half %a, 0.0 + %x = select i1 %p, half %a, half 0.0 + ret half %x +} + +; CHECK-LABEL: minimum_float +define float @minimum_float(float %a) #0 { + ; CHECK-NONAN: setp + ; CHECK-NONAN: selp.f32 + ; CHECK-NAN: min.NaN.f32 + %p = fcmp ult float %a, 0.0 + %x = select i1 %p, float %a, float 0.0 + ret float %x +} + +; CHECK-LABEL: minimum_double +define double @minimum_double(double %a) #0 { + ; CHECK-NONAN: setp + ; CHECK-NONAN: selp.f64 + ; CHECK-NAN: min.NaN.f64 + %p = fcmp ult double %a, 0.0 + %x = select i1 %p, double %a, double 0.0 + ret double %x +} + +; CHECK-LABEL: minimum_v2half +define <2 x half> @minimum_v2half(<2 x half> %a, <2 x half> %b) #0 { + ; CHECK-NONAN-DAG: setp + ; CHECK-NONAN-DAG: setp + ; CHECK-NONAN-DAG: selp.b16 + ; CHECK-NONAN-DAG: selp.b16 + ; CHECK-NAN: min.NaN.f16x2 + %p = fcmp ult <2 x half> %a, zeroinitializer + %x = select <2 x i1> %p, <2 x half> %a, <2 x half> zeroinitializer + ret <2 x half> %x +} + +; ---- maximum ---- + +; CHECK-LABEL: maximum_half +define half @maximum_half(half %a) #0 { + ; CHECK-NONAN: setp + ; CHECK-NONAN: selp.b16 + ; CHECK-NAN: max.NaN.f16 + %p = fcmp ugt half %a, 0.0 + %x = select i1 %p, half %a, half 0.0 + ret half %x +} + +; CHECK-LABEL: maximum_float +define float @maximum_float(float %a) #0 { + ; CHECK-NONAN: setp + ; CHECK-NONAN: selp.f32 + ; CHECK-NAN: max.NaN.f32 + %p = fcmp ugt float %a, 0.0 + %x = select i1 %p, float %a, float 0.0 + ret float %x +} + +; CHECK-LABEL: maximum_double +define double @maximum_double(double %a) #0 { + ; CHECK-NONAN: setp + ; CHECK-NONAN: selp.f64 + ; CHECK-NAN: max.NaN.f64 + %p = fcmp ugt double %a, 0.0 + %x = select i1 %p, double %a, double 0.0 + ret double %x +} + +; CHECK-LABEL: maximum_v2half +define <2 x half> @maximum_v2half(<2 x half> %a, <2 x half> %b) #0 { + ; CHECK-NONAN-DAG: setp + ; CHECK-NONAN-DAG: setp + ; CHECK-NONAN-DAG: selp.b16 + ; CHECK-NONAN-DAG: selp.b16 + ; CHECK-NAN: max.NaN.f16x2 + %p = fcmp ugt <2 x half> %a, zeroinitializer + %x = select <2 x i1> %p, <2 x half> %a, <2 x half> zeroinitializer + ret <2 x half> %x +} + +attributes #0 = { "no-signed-zeros-fp-math"="true" } diff --git a/llvm/test/CodeGen/NVPTX/math-intrins.ll b/llvm/test/CodeGen/NVPTX/math-intrins.ll --- a/llvm/test/CodeGen/NVPTX/math-intrins.ll +++ b/llvm/test/CodeGen/NVPTX/math-intrins.ll @@ -1,6 +1,4 @@ -; RUN: llc < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16 -; RUN: llc < %s -mcpu=sm_80 | FileCheck %s --check-prefixes=CHECK,CHECK-F16 -; RUN: llc < %s -mcpu=sm_80 --nvptx-no-f16-math | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16 +; RUN: llc < %s | FileCheck %s target triple = "nvptx64-nvidia-cuda" ; Checks that llvm intrinsics for math functions are correctly lowered to PTX. @@ -19,14 +17,10 @@ declare double @llvm.trunc.f64(double) #0 declare float @llvm.fabs.f32(float) #0 declare double @llvm.fabs.f64(double) #0 -declare half @llvm.minnum.f16(half, half) #0 declare float @llvm.minnum.f32(float, float) #0 declare double @llvm.minnum.f64(double, double) #0 -declare <2 x half> @llvm.minnum.v2f16(<2 x half>, <2 x half>) #0 -declare half @llvm.maxnum.f16(half, half) #0 declare float @llvm.maxnum.f32(float, float) #0 declare double @llvm.maxnum.f64(double, double) #0 -declare <2 x half> @llvm.maxnum.v2f16(<2 x half>, <2 x half>) #0 declare float @llvm.fma.f32(float, float, float) #0 declare double @llvm.fma.f64(double, double, double) #0 @@ -199,14 +193,6 @@ ; ---- min ---- -; CHECK-LABEL: min_half -define half @min_half(half %a, half %b) { - ; CHECK-NOF16: min.f32 - ; CHECK-F16: min.f16 - %x = call half @llvm.minnum.f16(half %a, half %b) - ret half %x -} - ; CHECK-LABEL: min_float define float @min_float(float %a, float %b) { ; CHECK: min.f32 @@ -242,25 +228,8 @@ ret double %x } -; CHECK-LABEL: min_v2half -define <2 x half> @min_v2half(<2 x half> %a, <2 x half> %b) { - ; CHECK-NOF16: min.f32 - ; CHECK-NOF16: min.f32 - ; CHECK-F16: min.f16x2 - %x = call <2 x half> @llvm.minnum.v2f16(<2 x half> %a, <2 x half> %b) - ret <2 x half> %x -} - ; ---- max ---- -; CHECK-LABEL: max_half -define half @max_half(half %a, half %b) { - ; CHECK-NOF16: max.f32 - ; CHECK-F16: max.f16 - %x = call half @llvm.maxnum.f16(half %a, half %b) - ret half %x -} - ; CHECK-LABEL: max_imm1 define float @max_imm1(float %a) { ; CHECK: max.f32 @@ -296,15 +265,6 @@ ret double %x } -; CHECK-LABEL: max_v2half -define <2 x half> @max_v2half(<2 x half> %a, <2 x half> %b) { - ; CHECK-NOF16: max.f32 - ; CHECK-NOF16: max.f32 - ; CHECK-F16: max.f16x2 - %x = call <2 x half> @llvm.maxnum.v2f16(<2 x half> %a, <2 x half> %b) - ret <2 x half> %x -} - ; ---- fma ---- ; CHECK-LABEL: @fma_float