Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -279,6 +279,28 @@ setTargetDAGCombine(ISD::SHL); setTargetDAGCombine(ISD::SELECT); + // Library functions. These default to Expand, but we have instructions + // for them. + setOperationAction(ISD::FCEIL, MVT::f32, Legal); + setOperationAction(ISD::FCEIL, MVT::f64, Legal); + setOperationAction(ISD::FFLOOR, MVT::f32, Legal); + setOperationAction(ISD::FFLOOR, MVT::f64, Legal); + setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal); + setOperationAction(ISD::FNEARBYINT, MVT::f64, Legal); + setOperationAction(ISD::FRINT, MVT::f32, Legal); + setOperationAction(ISD::FRINT, MVT::f64, Legal); + setOperationAction(ISD::FROUND, MVT::f32, Legal); + setOperationAction(ISD::FROUND, MVT::f64, Legal); + setOperationAction(ISD::FTRUNC, MVT::f32, Legal); + setOperationAction(ISD::FTRUNC, MVT::f64, Legal); + setOperationAction(ISD::FMINNUM, MVT::f32, Legal); + setOperationAction(ISD::FMINNUM, MVT::f64, Legal); + setOperationAction(ISD::FMAXNUM, MVT::f32, Legal); + setOperationAction(ISD::FMAXNUM, MVT::f64, Legal); + + // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. + // No FPOW or FREM in PTX. + // Now deduce the information based on the above mentioned // actions computeRegisterProperties(STI.getRegisterInfo()); Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -207,15 +207,63 @@ } // Template for instructions which take three fp64 or fp32 args. The -// instructions are named ".f" (e.g. "add.f64"). +// instructions are named ".f" (e.g. "min.f64"). // // Also defines ftz (flush subnormal inputs and results to sign-preserving // zero) variants for fp32 functions. +// +// This multiclass should be used for nodes that cannot be folded into FMAs. +// For nodes that can be folded into FMAs (i.e. adds and muls), use +// F3_fma_component. multiclass F3 { def f64rr : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a, Float64Regs:$b), !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"), + [(set Float64Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>; + def f64ri : + NVPTXInst<(outs Float64Regs:$dst), + (ins Float64Regs:$a, f64imm:$b), + !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"), + [(set Float64Regs:$dst, (OpNode Float64Regs:$a, fpimm:$b))]>; + def f32rr_ftz : + NVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, Float32Regs:$b), + !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"), + [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>, + Requires<[doF32FTZ]>; + def f32ri_ftz : + NVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, f32imm:$b), + !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"), + [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>, + Requires<[doF32FTZ]>; + def f32rr : + NVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, Float32Regs:$b), + !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"), + [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>; + def f32ri : + NVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, f32imm:$b), + !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"), + [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>; +} + +// Template for instructions which take three fp64 or fp32 args. The +// instructions are named ".f" (e.g. "add.f64"). +// +// Also defines ftz (flush subnormal inputs and results to sign-preserving +// zero) variants for fp32 functions. +// +// This multiclass should be used for nodes that can be folded to make fma ops. +// In this case, we use the ".rn" variant when FMA is disabled, as this behaves +// just like the non ".rn" op, but prevents ptxas from creating FMAs. +multiclass F3_fma_component { + def f64rr : + NVPTXInst<(outs Float64Regs:$dst), + (ins Float64Regs:$a, Float64Regs:$b), + !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"), [(set Float64Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>, Requires<[allowFMA]>; def f64ri : @@ -248,41 +296,39 @@ !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"), [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>, Requires<[allowFMA]>; -} -// Same as F3, but defines ".rn" variants (round to nearest even). -multiclass F3_rn { - def f64rr : + // These have strange names so we don't perturb existing mir tests. + def _rnf64rr : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a, Float64Regs:$b), !strconcat(OpcStr, ".rn.f64 \t$dst, $a, $b;"), [(set Float64Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>, Requires<[noFMA]>; - def f64ri : + def _rnf64ri : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a, f64imm:$b), !strconcat(OpcStr, ".rn.f64 \t$dst, $a, $b;"), [(set Float64Regs:$dst, (OpNode Float64Regs:$a, fpimm:$b))]>, Requires<[noFMA]>; - def f32rr_ftz : + def _rnf32rr_ftz : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a, Float32Regs:$b), !strconcat(OpcStr, ".rn.ftz.f32 \t$dst, $a, $b;"), [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>, Requires<[noFMA, doF32FTZ]>; - def f32ri_ftz : + def _rnf32ri_ftz : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a, f32imm:$b), !strconcat(OpcStr, ".rn.ftz.f32 \t$dst, $a, $b;"), [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>, Requires<[noFMA, doF32FTZ]>; - def f32rr : + def _rnf32rr : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a, Float32Regs:$b), !strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"), [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>, Requires<[noFMA]>; - def f32ri : + def _rnf32ri : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a, f32imm:$b), !strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"), @@ -713,13 +759,12 @@ N->getValueAPF().convertToDouble() == 1.0; }]>; -defm FADD : F3<"add", fadd>; -defm FSUB : F3<"sub", fsub>; -defm FMUL : F3<"mul", fmul>; +defm FADD : F3_fma_component<"add", fadd>; +defm FSUB : F3_fma_component<"sub", fsub>; +defm FMUL : F3_fma_component<"mul", fmul>; -defm FADD_rn : F3_rn<"add", fadd>; -defm FSUB_rn : F3_rn<"sub", fsub>; -defm FMUL_rn : F3_rn<"mul", fmul>; +defm FMIN : F3<"min", fminnum>; +defm FMAX : F3<"max", fmaxnum>; defm FABS : F2<"abs", fabs>; defm FNEG : F2<"neg", fneg>; @@ -2628,6 +2673,55 @@ def retflag : SDNode<"NVPTXISD::RET_FLAG", SDTNone, [SDNPHasChain, SDNPOptInGlue]>; +// fceil, ffloor, fround, ftrunc. + +def : Pat<(fceil Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRPI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(fceil Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRPI)>, Requires<[doNoF32FTZ]>; +def : Pat<(fceil Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRPI)>; + +def : Pat<(ffloor Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRMI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(ffloor Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRMI)>, Requires<[doNoF32FTZ]>; +def : Pat<(ffloor Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRMI)>; + +def : Pat<(fround Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f32 (fround Float32Regs:$a)), + (CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; +def : Pat<(f64 (fround Float64Regs:$a)), + (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; + +def : Pat<(ftrunc Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(ftrunc Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRZI)>, Requires<[doNoF32FTZ]>; +def : Pat<(ftrunc Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRZI)>; + +// nearbyint and rint are implemented as rounding to nearest even. This isn't +// strictly correct, because it causes us to ignore the rounding mode. But it +// matches what CUDA's "libm" does. + +def : Pat<(fnearbyint Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(fnearbyint Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; +def : Pat<(fnearbyint Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; + +def : Pat<(frint Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(frint Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; +def : Pat<(frint Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; + + //----------------------------------- // Control-flow //----------------------------------- Index: llvm/test/CodeGen/NVPTX/bug22322.ll =================================================================== --- llvm/test/CodeGen/NVPTX/bug22322.ll +++ llvm/test/CodeGen/NVPTX/bug22322.ll @@ -22,7 +22,7 @@ %8 = icmp eq i32 %7, 0 %9 = select i1 %8, float 0.000000e+00, float -1.000000e+00 store float %9, float* %ret_vec.sroa.8.i, align 4 -; CHECK: setp.lt.f32 %p{{[0-9]+}}, %f{{[0-9]+}}, 0f00000000 +; CHECK: max.f32 %f{{[0-9]+}}, %f{{[0-9]+}}, 0f00000000 %10 = fcmp olt float %9, 0.000000e+00 %ret_vec.sroa.8.i.val = load float, float* %ret_vec.sroa.8.i, align 4 %11 = select i1 %10, float 0.000000e+00, float %ret_vec.sroa.8.i.val Index: llvm/test/CodeGen/NVPTX/math-intrins.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/NVPTX/math-intrins.ll @@ -0,0 +1,233 @@ +; RUN: llc < %s | FileCheck %s +target triple = "nvptx64-nvidia-cuda" + +; Checks that llvm intrinsics for math functions are correctly lowered to PTX. + +declare float @llvm.ceil.f32(float) #0 +declare double @llvm.ceil.f64(double) #0 +declare float @llvm.floor.f32(float) #0 +declare double @llvm.floor.f64(double) #0 +declare float @llvm.round.f32(float) #0 +declare double @llvm.round.f64(double) #0 +declare float @llvm.nearbyint.f32(float) #0 +declare double @llvm.nearbyint.f64(double) #0 +declare float @llvm.rint.f32(float) #0 +declare double @llvm.rint.f64(double) #0 +declare float @llvm.trunc.f32(float) #0 +declare double @llvm.trunc.f64(double) #0 +declare float @llvm.fabs.f32(float) #0 +declare double @llvm.fabs.f64(double) #0 +declare float @llvm.minnum.f32(float, float) #0 +declare double @llvm.minnum.f64(double, double) #0 +declare float @llvm.maxnum.f32(float, float) #0 +declare double @llvm.maxnum.f64(double, double) #0 + +; ---- ceil ---- + +; CHECK-LABEL: ceil_float +define float @ceil_float(float %a) { + ; CHECK: cvt.rpi.f32.f32 + %b = call float @llvm.ceil.f32(float %a) + ret float %b +} + +; CHECK-LABEL: ceil_float_ftz +define float @ceil_float_ftz(float %a) #1 { + ; CHECK: cvt.rpi.ftz.f32.f32 + %b = call float @llvm.ceil.f32(float %a) + ret float %b +} + +; CHECK-LABEL: ceil_double +define double @ceil_double(double %a) { + ; CHECK: cvt.rpi.f64.f64 + %b = call double @llvm.ceil.f64(double %a) + ret double %b +} + +; ---- floor ---- + +; CHECK-LABEL: floor_float +define float @floor_float(float %a) { + ; CHECK: cvt.rmi.f32.f32 + %b = call float @llvm.floor.f32(float %a) + ret float %b +} + +; CHECK-LABEL: floor_float_ftz +define float @floor_float_ftz(float %a) #1 { + ; CHECK: cvt.rmi.ftz.f32.f32 + %b = call float @llvm.floor.f32(float %a) + ret float %b +} + +; CHECK-LABEL: floor_double +define double @floor_double(double %a) { + ; CHECK: cvt.rmi.f64.f64 + %b = call double @llvm.floor.f64(double %a) + ret double %b +} + +; ---- round ---- + +; CHECK-LABEL: round_float +define float @round_float(float %a) { + ; CHECK: cvt.rni.f32.f32 + %b = call float @llvm.round.f32(float %a) + ret float %b +} + +; CHECK-LABEL: round_float_ftz +define float @round_float_ftz(float %a) #1 { + ; CHECK: cvt.rni.ftz.f32.f32 + %b = call float @llvm.round.f32(float %a) + ret float %b +} + +; CHECK-LABEL: round_double +define double @round_double(double %a) { + ; CHECK: cvt.rni.f64.f64 + %b = call double @llvm.round.f64(double %a) + ret double %b +} + +; ---- nearbyint ---- + +; CHECK-LABEL: nearbyint_float +define float @nearbyint_float(float %a) { + ; CHECK: cvt.rni.f32.f32 + %b = call float @llvm.nearbyint.f32(float %a) + ret float %b +} + +; CHECK-LABEL: nearbyint_float_ftz +define float @nearbyint_float_ftz(float %a) #1 { + ; CHECK: cvt.rni.ftz.f32.f32 + %b = call float @llvm.nearbyint.f32(float %a) + ret float %b +} + +; CHECK-LABEL: nearbyint_double +define double @nearbyint_double(double %a) { + ; CHECK: cvt.rni.f64.f64 + %b = call double @llvm.nearbyint.f64(double %a) + ret double %b +} + +; ---- rint ---- + +; CHECK-LABEL: rint_float +define float @rint_float(float %a) { + ; CHECK: cvt.rni.f32.f32 + %b = call float @llvm.rint.f32(float %a) + ret float %b +} + +; CHECK-LABEL: rint_float_ftz +define float @rint_float_ftz(float %a) #1 { + ; CHECK: cvt.rni.ftz.f32.f32 + %b = call float @llvm.rint.f32(float %a) + ret float %b +} + +; CHECK-LABEL: rint_double +define double @rint_double(double %a) { + ; CHECK: cvt.rni.f64.f64 + %b = call double @llvm.rint.f64(double %a) + ret double %b +} + +; ---- trunc ---- + +; CHECK-LABEL: trunc_float +define float @trunc_float(float %a) { + ; CHECK: cvt.rzi.f32.f32 + %b = call float @llvm.trunc.f32(float %a) + ret float %b +} + +; CHECK-LABEL: trunc_float_ftz +define float @trunc_float_ftz(float %a) #1 { + ; CHECK: cvt.rzi.ftz.f32.f32 + %b = call float @llvm.trunc.f32(float %a) + ret float %b +} + +; CHECK-LABEL: trunc_double +define double @trunc_double(double %a) { + ; CHECK: cvt.rzi.f64.f64 + %b = call double @llvm.trunc.f64(double %a) + ret double %b +} + +; ---- abs ---- + +; CHECK-LABEL: abs_float +define float @abs_float(float %a) { + ; CHECK: abs.f32 + %b = call float @llvm.fabs.f32(float %a) + ret float %b +} + +; CHECK-LABEL: abs_float_ftz +define float @abs_float_ftz(float %a) #1 { + ; CHECK: abs.ftz.f32 + %b = call float @llvm.fabs.f32(float %a) + ret float %b +} + +; CHECK-LABEL: abs_double +define double @abs_double(double %a) { + ; CHECK: abs.f64 + %b = call double @llvm.fabs.f64(double %a) + ret double %b +} + +; ---- min ---- + +; CHECK-LABEL: min_float +define float @min_float(float %a, float %b) { + ; CHECK: min.f32 + %x = call float @llvm.minnum.f32(float %a, float %b) + ret float %x +} + +; CHECK-LABEL: min_float_ftz +define float @min_float_ftz(float %a, float %b) #1 { + ; CHECK: min.ftz.f32 + %x = call float @llvm.minnum.f32(float %a, float %b) + ret float %x +} + +; CHECK-LABEL: min_double +define double @min_double(double %a, double %b) { + ; CHECK: min.f64 + %x = call double @llvm.minnum.f64(double %a, double %b) + ret double %x +} + +; ---- max ---- + +; CHECK-LABEL: max_float +define float @max_float(float %a, float %b) { + ; CHECK: max.f32 + %x = call float @llvm.maxnum.f32(float %a, float %b) + ret float %x +} + +; CHECK-LABEL: max_float_ftz +define float @max_float_ftz(float %a, float %b) #1 { + ; CHECK: max.ftz.f32 + %x = call float @llvm.maxnum.f32(float %a, float %b) + ret float %x +} + +; CHECK-LABEL: max_double +define double @max_double(double %a, double %b) { + ; CHECK: max.f64 + %x = call double @llvm.maxnum.f64(double %a, double %b) + ret double %x +} + +attributes #0 = { nounwind readnone } +attributes #1 = { "nvptx-f32ftz" = "true" }