diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -584,6 +584,8 @@ SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const; SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -623,9 +623,10 @@ setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); } for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) { - setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand); - setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand)); - setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); + setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Custom), Custom); + setOperationAction(Op, MVT::f32, GetMinMaxAction(Custom)); + setOperationAction(Op, MVT::f64, Custom); + setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Custom), Custom); } // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. @@ -2202,7 +2203,31 @@ return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA); } +// Lower FMINIMUM / FMAXIMUM for SM < 8.0. We use FMINNUM / FMAXNUM followed by +// a NaN check to handle NaNs correctly. +// +// Techincally, FMINNUM/FMAXNUM do not handle the -0.0 / +0.0 case correctly, +// since they define them according to the IEEE 754-2008 semantics (it's +// undefined which one is returned). However, the PTX min/max instructions to +// which FMINNUM and FMAXNUM are lowered to conform to the IEEE 754-2019 +// semantics (-0.0 < +0.0), thus the lowering ends up working out correctly. +// +// TODO: Replace FMINNUM/FMAXNUM with ops that conform to IEEE 754-2019 standard +// once those are available in LLVM. +SDValue NVPTXTargetLowering::LowerFMINIMUM_FMAXIMUM(SDValue Op, + SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + ISD::NodeType NT = + Op.getOpcode() == ISD::FMINIMUM ? ISD::FMINNUM : ISD::FMAXNUM; + SDValue LHS = Op.getOperand(0); + SDValue RHS = Op.getOperand(1); + SDLoc SL(Op); + SDValue NonPropagatingResult = DAG.getNode(NT, SL, VT, {LHS, RHS}); + SDValue NaN = + DAG.getConstantFP(std::numeric_limits::quiet_NaN(), SL, VT); + return DAG.getSelectCC(SL, LHS, RHS, NaN, NonPropagatingResult, ISD::SETUO); +} SDValue NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { @@ -2236,6 +2261,9 @@ return LowerSelect(Op, DAG); case ISD::FROUND: return LowerFROUND(Op, DAG); + case ISD::FMINIMUM: + case ISD::FMAXIMUM: + return LowerFMINIMUM_FMAXIMUM(Op, DAG); default: llvm_unreachable("Custom lowering not defined for operation"); } diff --git a/llvm/test/CodeGen/NVPTX/fminimum-fmaximum.ll b/llvm/test/CodeGen/NVPTX/fminimum-fmaximum.ll --- a/llvm/test/CodeGen/NVPTX/fminimum-fmaximum.ll +++ b/llvm/test/CodeGen/NVPTX/fminimum-fmaximum.ll @@ -5,6 +5,10 @@ ; ---- minimum ---- +declare half @llvm.minimum.f16(half %a, half %b) +declare float @llvm.minimum.f32(float %a, float %b) +declare double @llvm.minimum.f64(double %a, double %b) + ; CHECK-LABEL: minimum_half define half @minimum_half(half %a) #0 { ; CHECK-NONAN: setp @@ -15,6 +19,16 @@ ret half %x } +; CHECK-LABEL: minimum_intr_half +define half @minimum_intr_half(half %a, half %b) #0 { + ; CHECK-NONAN-DAG: min.f32 + ; CHECK-NONAN-DAG: setp.nan.f32 + ; CHECK-NONAN-DAG: selp.b16 + ; CHECK-NAN: min.NaN.f16 + %x = call half @llvm.minimum.f16(half %a, half %b) + ret half %x +} + ; CHECK-LABEL: minimum_float define float @minimum_float(float %a) #0 { ; CHECK-NONAN: setp @@ -25,6 +39,16 @@ ret float %x } +; CHECK-LABEL: minimum_intr_float +define float @minimum_intr_float(float %a, float %b) #0 { + ; CHECK-NONAN-DAG: min.f32 + ; CHECK-NONAN-DAG: setp.nan.f32 + ; CHECK-NONAN-DAG: selp.f32 + ; CHECK-NAN: min.NaN.f32 + %x = call float @llvm.minimum.f32(float %a, float %b) + ret float %x +} + ; CHECK-LABEL: minimum_double define double @minimum_double(double %a) #0 { ; CHECK: setp @@ -34,6 +58,15 @@ ret double %x } +; CHECK-LABEL: minimum_intr_double +define double @minimum_intr_double(double %a, double %b) #0 { + ; CHECK-DAG: min.f64 + ; CHECK-DAG: setp.nan.f64 + ; CHECK-DAG: selp.f64 + %x = call double @llvm.minimum.f64(double %a, double %b) + ret double %x +} + ; CHECK-LABEL: minimum_v2half define <2 x half> @minimum_v2half(<2 x half> %a) #0 { ; CHECK-NONAN-DAG: setp @@ -48,6 +81,10 @@ ; ---- maximum ---- +declare half @llvm.maximum.f16(half %a, half %b) +declare float @llvm.maximum.f32(float %a, float %b) +declare double @llvm.maximum.f64(double %a, double %b) + ; CHECK-LABEL: maximum_half define half @maximum_half(half %a) #0 { ; CHECK-NONAN: setp @@ -58,6 +95,16 @@ ret half %x } +; CHECK-LABEL: maximum_intr_half +define half @maximum_intr_half(half %a, half %b) #0 { + ; CHECK-NONAN-DAG: max.f32 + ; CHECK-NONAN-DAG: setp.nan.f32 + ; CHECK-NONAN-DAG: selp.b16 + ; CHECK-NAN: max.NaN.f16 + %x = call half @llvm.maximum.f16(half %a, half %b) + ret half %x +} + ; CHECK-LABEL: maximum_float define float @maximum_float(float %a) #0 { ; CHECK-NONAN: setp @@ -68,6 +115,16 @@ ret float %x } +; CHECK-LABEL: maximum_intr_float +define float @maximum_intr_float(float %a, float %b) #0 { + ; CHECK-NONAN-DAG: max.f32 + ; CHECK-NONAN-DAG: setp.nan.f32 + ; CHECK-NONAN-DAG: selp.f32 + ; CHECK-NAN: max.NaN.f32 + %x = call float @llvm.maximum.f32(float %a, float %b) + ret float %x +} + ; CHECK-LABEL: maximum_double define double @maximum_double(double %a) #0 { ; CHECK: setp @@ -77,6 +134,15 @@ ret double %x } +; CHECK-LABEL: maximum_intr_double +define double @maximum_intr_double(double %a, double %b) #0 { + ; CHECK-DAG: max.f64 + ; CHECK-DAG: setp.nan.f64 + ; CHECK-DAG: selp.f64 + %x = call double @llvm.maximum.f64(double %a, double %b) + ret double %x +} + ; CHECK-LABEL: maximum_v2half define <2 x half> @maximum_v2half(<2 x half> %a) #0 { ; CHECK-NONAN-DAG: setp