diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -556,6 +556,10 @@ SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const; SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const; SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -546,13 +546,19 @@ // These map to conversion instructions for scalar FP types. for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT, - ISD::FROUND, ISD::FTRUNC}) { + ISD::FTRUNC}) { setOperationAction(Op, MVT::f16, Legal); setOperationAction(Op, MVT::f32, Legal); setOperationAction(Op, MVT::f64, Legal); setOperationAction(Op, MVT::v2f16, Expand); } + setOperationAction(ISD::FROUND, MVT::f16, Promote); + setOperationAction(ISD::FROUND, MVT::v2f16, Expand); + setOperationAction(ISD::FROUND, MVT::f32, Custom); + setOperationAction(ISD::FROUND, MVT::f64, Custom); + + // 'Expand' implements FCOPYSIGN without calling an external library. setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand); setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand); @@ -2068,6 +2074,96 @@ } } +SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + + if (VT == MVT::f32) + return LowerFROUND32(Op, DAG); + + if (VT == MVT::f64) + return LowerFROUND64(Op, DAG); + + llvm_unreachable("unhandled type"); +} + +// This is the the rounding method used in CUDA libdevice in C like code: +// float roundf(float A) +// { +// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f)); +// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA; +// return abs(A) < 0.5 ? (float)(int)A : RoundedA; +// } +SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op, + SelectionDAG &DAG) const { + SDLoc SL(Op); + SDValue A = Op.getOperand(0); + EVT VT = Op.getValueType(); + + SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A); + + // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f)) + SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A); + SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast, + DAG.getConstant(-2147483648, SL, MVT::i32)); + SDValue PointFiveRaw = DAG.getNode(ISD::OR, SL, MVT::i32, Sign, + DAG.getConstant(1056964608, SL, MVT::i32)); + SDValue PointFive = DAG.getNode(ISD::BITCAST, SL, VT, PointFiveRaw); + SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFive); + SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA); + + // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA; + EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue IsLarge = DAG.getSetCC(SL, SetCCVT, AbsA, + DAG.getConstantFP(0x4160000000000000, SL, VT), + ISD::SETOGT); + RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA); + + // return abs(A) < 0.5 ? (float)(int)A : RoundedA; + SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA, + DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT); + SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A); + return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA); +} + +// The implementation of round(double) is similar to that of round(float) in +// that they both separate the value range into three regions and use a method +// specific to the region to round the values. However, round(double) first +// calculates the round of the absolute value and then adds the sign back while +// round(float) directly rounds the value with sign. +SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op, + SelectionDAG &DAG) const { + SDLoc SL(Op); + SDValue A = Op.getOperand(0); + EVT VT = Op.getValueType(); + + SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A); + + // double RoundedA = (double) (int) (abs(A) + 0.5f); + SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA, + DAG.getConstantFP(0.5, SL, VT)); + SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA); + + // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA; + EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA, + DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT); + RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall, + DAG.getConstantFP(0, SL, VT), + RoundedA); + + // Add sign to rounded_A + RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A); + DAG.getNode(ISD::FTRUNC, SL, VT, A); + + // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA; + SDValue IsLarge = DAG.getSetCC(SL, SetCCVT, AbsA, + DAG.getConstantFP(0x4330000000000000, SL, VT), + ISD::SETOGT); + return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA); +} + + + SDValue NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -2098,6 +2194,8 @@ return LowerShiftRightParts(Op, DAG); case ISD::SELECT: return LowerSelect(Op, DAG); + case ISD::FROUND: + return LowerFROUND(Op, DAG); default: llvm_unreachable("Custom lowering not defined for operation"); } diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -3002,15 +3002,6 @@ def : Pat<(ffloor Float64Regs:$a), (CVT_f64_f64 Float64Regs:$a, CvtRMI)>; -def : Pat<(f16 (fround Float16Regs:$a)), - (CVT_f16_f16 Float16Regs:$a, CvtRNI)>; -def : Pat<(fround Float32Regs:$a), - (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; -def : Pat<(f32 (fround Float32Regs:$a)), - (CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; -def : Pat<(f64 (fround Float64Regs:$a)), - (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; - def : Pat<(ftrunc Float16Regs:$a), (CVT_f16_f16 Float16Regs:$a, CvtRZI)>; def : Pat<(ftrunc Float32Regs:$a), diff --git a/llvm/test/CodeGen/NVPTX/f16-instructions.ll b/llvm/test/CodeGen/NVPTX/f16-instructions.ll --- a/llvm/test/CodeGen/NVPTX/f16-instructions.ll +++ b/llvm/test/CodeGen/NVPTX/f16-instructions.ll @@ -1108,7 +1108,9 @@ ; CHECK-LABEL: test_round( ; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_round_param_0]; -; CHECK: cvt.rni.f16.f16 [[R:%h[0-9]+]], [[A]]; +; llvm.round can't be replaced with cvt.rni. We will have cuda runnable tests +; to check for codegen correctness. +; CHECK-NOT: cvt.rni.f16.f16 [[R:%h[0-9]+]], [[A]]; ; CHECK: st.param.b16 [func_retval0+0], [[R]]; ; CHECK: ret; define half @test_round(half %a) #0 { diff --git a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll --- a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll +++ b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll @@ -1380,8 +1380,9 @@ ; CHECK-LABEL: test_round( ; CHECK: ld.param.b32 [[A:%hh[0-9]+]], [test_round_param_0]; ; CHECK-DAG: mov.b32 {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]]; -; CHECK-DAG: cvt.rni.f16.f16 [[R1:%h[0-9]+]], [[A1]]; -; CHECK-DAG: cvt.rni.f16.f16 [[R0:%h[0-9]+]], [[A0]]; +; llvm.round can't be replaced with cvt.rni. We will have cuda runnable tests +; to check for codegen correctness. +; CHECK-NOT: cvt.rni.f16.f16 [[R1:%h[0-9]+]], [[A1]]; ; CHECK: mov.b32 [[R:%hh[0-9]+]], {[[R0]], [[R1]]} ; CHECK: st.param.b32 [func_retval0+0], [[R]]; ; CHECK: ret; diff --git a/llvm/test/CodeGen/NVPTX/math-intrins.ll b/llvm/test/CodeGen/NVPTX/math-intrins.ll --- a/llvm/test/CodeGen/NVPTX/math-intrins.ll +++ b/llvm/test/CodeGen/NVPTX/math-intrins.ll @@ -74,21 +74,21 @@ ; CHECK-LABEL: round_float define float @round_float(float %a) { - ; CHECK: cvt.rni.f32.f32 + ; CHECK: cvt.rzi.f32.f32 %b = call float @llvm.round.f32(float %a) ret float %b } ; CHECK-LABEL: round_float_ftz define float @round_float_ftz(float %a) #1 { - ; CHECK: cvt.rni.ftz.f32.f32 + ; CHECK: cvt.rzi.ftz.f32.f32 %b = call float @llvm.round.f32(float %a) ret float %b } ; CHECK-LABEL: round_double define double @round_double(double %a) { - ; CHECK: cvt.rni.f64.f64 + ; CHECK: cvt.rzi.f64.f64 %b = call double @llvm.round.f64(double %a) ret double %b }