diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -553,20 +553,29 @@ // These map to corresponding instructions for f32/f64. f16 must be // promoted to f32. v2f16 is expanded to f16, which is then promoted // to f32. - for (const auto &Op : {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, - ISD::FABS, ISD::FMINNUM, ISD::FMAXNUM}) { + for (const auto &Op : + {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FABS}) { setOperationAction(Op, MVT::f16, Promote); setOperationAction(Op, MVT::f32, Legal); setOperationAction(Op, MVT::f64, Legal); setOperationAction(Op, MVT::v2f16, Expand); } - // max.f16 is supported on sm_80+. - if (STI.allowFP16Math() && STI.getSmVersion() >= 80 && - STI.getPTXVersion() >= 70) { - setOperationAction(ISD::FMINNUM, MVT::f16, Legal); - setOperationAction(ISD::FMAXNUM, MVT::f16, Legal); - setOperationAction(ISD::FMINNUM, MVT::v2f16, Legal); - setOperationAction(ISD::FMAXNUM, MVT::v2f16, Legal); + // max.f16, max.f16x2 and max.NaN are supported on sm_80+. + auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) { + bool IsAtLeastSm80 = STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70; + return IsAtLeastSm80 ? Legal : NotSm80Action; + }; + for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) { + setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote); + setOperationAction(Op, MVT::f32, Legal); + setOperationAction(Op, MVT::f64, Legal); + setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); + } + for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) { + setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand); + setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand)); + setOperationAction(Op, MVT::f64, GetMinMaxAction(Expand)); + setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); } // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -868,6 +868,8 @@ defm FMIN : F3<"min", fminnum>; defm FMAX : F3<"max", fmaxnum>; +defm FMINNAN : F3<"min.NaN", fminimum>; +defm FMAXNAN : F3<"max.NaN", fmaximum>; defm FABS : F2<"abs", fabs>; defm FNEG : F2<"neg", fneg>; diff --git a/llvm/test/CodeGen/NVPTX/fminimum-fmaximum.ll b/llvm/test/CodeGen/NVPTX/fminimum-fmaximum.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/fminimum-fmaximum.ll @@ -0,0 +1,92 @@ +; RUN: llc < %s -march=nvptx | FileCheck %s --check-prefixes=CHECK,CHECK-NONAN +; RUN: llc < %s -march=nvptx -mcpu=sm_80 | FileCheck %s --check-prefixes=CHECK,CHECK-NAN + +; ---- minimum ---- + +; CHECK-LABEL: minimum_half +define half @minimum_half(half %a) #0 { + ; CHECK-NONAN: setp + ; CHECK-NONAN: selp.b16 + ; CHECK-NAN: min.NaN.f16 + %p = fcmp ult half %a, 0.0 + %x = select i1 %p, half %a, half 0.0 + ret half %x +} + +; CHECK-LABEL: minimum_float +define float @minimum_float(float %a) #0 { + ; CHECK-NONAN: setp + ; CHECK-NONAN: selp.f32 + ; CHECK-NAN: min.NaN.f32 + %p = fcmp ult float %a, 0.0 + %x = select i1 %p, float %a, float 0.0 + ret float %x +} + +; CHECK-LABEL: minimum_double +define double @minimum_double(double %a) #0 { + ; CHECK-NONAN: setp + ; CHECK-NONAN: selp.f64 + ; CHECK-NAN: min.NaN.f64 + %p = fcmp ult double %a, 0.0 + %x = select i1 %p, double %a, double 0.0 + ret double %x +} + +; CHECK-LABEL: minimum_v2half +define <2 x half> @minimum_v2half(<2 x half> %a, <2 x half> %b) #0 { + ; CHECK-NONAN-DAG: setp + ; CHECK-NONAN-DAG: setp + ; CHECK-NONAN-DAG: selp.b16 + ; CHECK-NONAN-DAG: selp.b16 + ; CHECK-NAN: min.NaN.f16x2 + %p = fcmp ult <2 x half> %a, zeroinitializer + %x = select <2 x i1> %p, <2 x half> %a, <2 x half> zeroinitializer + ret <2 x half> %x +} + +; ---- maximum ---- + +; CHECK-LABEL: maximum_half +define half @maximum_half(half %a) #0 { + ; CHECK-NONAN: setp + ; CHECK-NONAN: selp.b16 + ; CHECK-NAN: max.NaN.f16 + %p = fcmp ugt half %a, 0.0 + %x = select i1 %p, half %a, half 0.0 + ret half %x +} + +; CHECK-LABEL: maximum_float +define float @maximum_float(float %a) #0 { + ; CHECK-NONAN: setp + ; CHECK-NONAN: selp.f32 + ; CHECK-NAN: max.NaN.f32 + %p = fcmp ugt float %a, 0.0 + %x = select i1 %p, float %a, float 0.0 + ret float %x +} + +; CHECK-LABEL: maximum_double +define double @maximum_double(double %a) #0 { + ; CHECK-NONAN: setp + ; CHECK-NONAN: selp.f64 + ; CHECK-NAN: max.NaN.f64 + %p = fcmp ugt double %a, 0.0 + %x = select i1 %p, double %a, double 0.0 + ret double %x +} + +; CHECK-LABEL: maximum_v2half +define <2 x half> @maximum_v2half(<2 x half> %a, <2 x half> %b) #0 { + ; CHECK-NONAN-DAG: setp + ; CHECK-NONAN-DAG: setp + ; CHECK-NONAN-DAG: selp.b16 + ; CHECK-NONAN-DAG: selp.b16 + ; CHECK-NAN: max.NaN.f16x2 + %p = fcmp ugt <2 x half> %a, zeroinitializer + %x = select <2 x i1> %p, <2 x half> %a, <2 x half> zeroinitializer + ret <2 x half> %x +} + +attributes #0 = { "no-signed-zeros-fp-math"="true" }