diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def --- a/clang/include/clang/Basic/BuiltinsNVPTX.def +++ b/clang/include/clang/Basic/BuiltinsNVPTX.def @@ -402,6 +402,23 @@ BUILTIN(__nvvm_f2h_rn_ftz, "Usf", "") BUILTIN(__nvvm_f2h_rn, "Usf", "") +TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "ZUiff", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "ZUiff", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "ZUiff", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "ZUiff", "", AND(SM_80,PTX70)) + +TARGET_BUILTIN(__nvvm_ff2f16x2_rn, "V2hff", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_ff2f16x2_rn_relu, "V2hff", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_ff2f16x2_rz, "V2hff", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_ff2f16x2_rz_relu, "V2hff", "", AND(SM_80,PTX70)) + +TARGET_BUILTIN(__nvvm_f2bf16_rn, "ZUsf", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "ZUsf", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_f2bf16_rz, "ZUsf", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "ZUsf", "", AND(SM_80,PTX70)) + +TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70)) + // Bitcast BUILTIN(__nvvm_bitcast_f2i, "if", "") diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -754,4 +754,40 @@ __nvvm_cp_async_wait_all(); #endif // CHECK: ret void -} \ No newline at end of file +} + +// CHECK-LABEL: nvvm_cvt_sm80 +__device__ void nvvm_cvt_sm80() { +#if __CUDA_ARCH__ >= 800 + // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff2bf16x2_rn(1, 1); + // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff2bf16x2_rn_relu(1, 1); + // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff2bf16x2_rz(1, 1); + // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff2bf16x2_rz_relu(1, 1); + + // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff2f16x2_rn(1, 1); + // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff2f16x2_rn_relu(1, 1); + // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff2f16x2_rz(1, 1); + // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff2f16x2_rz_relu(1, 1); + + // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn(float 1.000000e+00) + __nvvm_f2bf16_rn(1); + // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00) + __nvvm_f2bf16_rn_relu(1); + // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz(float 1.000000e+00) + __nvvm_f2bf16_rz(1); + // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00) + __nvvm_f2bf16_rz_relu(1); + + // CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00) + __nvvm_f2tf32_rna(1); +#endif + // CHECK: ret void +} diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1185,6 +1185,36 @@ def int_nvvm_f2h_rn : GCCBuiltin<"__nvvm_f2h_rn">, DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>; + def int_nvvm_ff2bf16x2_rn : GCCBuiltin<"__nvvm_ff2bf16x2_rn">, + Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>; + def int_nvvm_ff2bf16x2_rn_relu : GCCBuiltin<"__nvvm_ff2bf16x2_rn_relu">, + Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>; + def int_nvvm_ff2bf16x2_rz : GCCBuiltin<"__nvvm_ff2bf16x2_rz">, + Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>; + def int_nvvm_ff2bf16x2_rz_relu : GCCBuiltin<"__nvvm_ff2bf16x2_rz_relu">, + Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>; + + def int_nvvm_ff2f16x2_rn : GCCBuiltin<"__nvvm_ff2f16x2_rn">, + Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>; + def int_nvvm_ff2f16x2_rn_relu : GCCBuiltin<"__nvvm_ff2f16x2_rn_relu">, + Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>; + def int_nvvm_ff2f16x2_rz : GCCBuiltin<"__nvvm_ff2f16x2_rz">, + Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>; + def int_nvvm_ff2f16x2_rz_relu : GCCBuiltin<"__nvvm_ff2f16x2_rz_relu">, + Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>; + + def int_nvvm_f2bf16_rn : GCCBuiltin<"__nvvm_f2bf16_rn">, + Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>; + def int_nvvm_f2bf16_rn_relu : GCCBuiltin<"__nvvm_f2bf16_rn_relu">, + Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>; + def int_nvvm_f2bf16_rz : GCCBuiltin<"__nvvm_f2bf16_rz">, + Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>; + def int_nvvm_f2bf16_rz_relu : GCCBuiltin<"__nvvm_f2bf16_rz_relu">, + Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>; + + def int_nvvm_f2tf32_rna : GCCBuiltin<"__nvvm_f2tf32_rna">, + Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem]>; + // // Bitcast // diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp @@ -108,6 +108,10 @@ // SAT flag if (Imm & NVPTX::PTXCvtMode::SAT_FLAG) O << ".sat"; + } else if (strcmp(Modifier, "relu") == 0) { + // RELU flag + if (Imm & NVPTX::PTXCvtMode::RELU_FLAG) + O << ".relu"; } else if (strcmp(Modifier, "base") == 0) { // Default operand switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) { @@ -139,6 +143,9 @@ case NVPTX::PTXCvtMode::RP: O << ".rp"; break; + case NVPTX::PTXCvtMode::RNA: + O << ".rna"; + break; } } else { llvm_unreachable("Invalid conversion modifier"); diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h --- a/llvm/lib/Target/NVPTX/NVPTX.h +++ b/llvm/lib/Target/NVPTX/NVPTX.h @@ -137,10 +137,12 @@ RZ, RM, RP, + RNA, BASE_MASK = 0x0F, FTZ_FLAG = 0x10, - SAT_FLAG = 0x20 + SAT_FLAG = 0x20, + RELU_FLAG = 0x40 }; } diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -48,6 +48,7 @@ def CvtRZ : PatLeaf<(i32 0x6)>; def CvtRM : PatLeaf<(i32 0x7)>; def CvtRP : PatLeaf<(i32 0x8)>; +def CvtRNA : PatLeaf<(i32 0x9)>; def CvtNONE_FTZ : PatLeaf<(i32 0x10)>; def CvtRNI_FTZ : PatLeaf<(i32 0x11)>; @@ -62,6 +63,10 @@ def CvtSAT : PatLeaf<(i32 0x20)>; def CvtSAT_FTZ : PatLeaf<(i32 0x30)>; +def CvtNONE_RELU : PatLeaf<(i32 0x40)>; +def CvtRN_RELU : PatLeaf<(i32 0x45)>; +def CvtRZ_RELU : PatLeaf<(i32 0x46)>; + def CvtMode : Operand { let PrintMethod = "printCvtMode"; } @@ -526,6 +531,29 @@ "cvt.s64.s16 \t$dst, $src;", []>; def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src), "cvt.s64.s32 \t$dst, $src;", []>; + +multiclass CVT_FROM_FLOAT_SM80 { + def _f32 : + NVPTXInst<(outs RC:$dst), + (ins Float32Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:relu}.", + FromName, ".f32 \t$dst, $src;"), []>, + Requires<[hasPTX70, hasSM80]>; + } + + defm CVT_bf16 : CVT_FROM_FLOAT_SM80<"bf16", Int16Regs>; + + multiclass CVT_FROM_FLOAT_V2_SM80 { + def _f32 : + NVPTXInst<(outs RC:$dst), + (ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:relu}.", + FromName, ".f32 \t$dst, $src1, $src2;"), []>, + Requires<[hasPTX70, hasSM80]>; + } + + defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Float16x2Regs>; + defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>; } //----------------------------------- diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1046,6 +1046,38 @@ def : Pat<(int_nvvm_ui2f_rp Int32Regs:$a), (CVT_f32_u32 Int32Regs:$a, CvtRP)>; +def : Pat<(int_nvvm_ff2bf16x2_rn Float32Regs:$a, Float32Regs:$b), + (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>; +def : Pat<(int_nvvm_ff2bf16x2_rn_relu Float32Regs:$a, Float32Regs:$b), + (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>; +def : Pat<(int_nvvm_ff2bf16x2_rz Float32Regs:$a, Float32Regs:$b), + (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>; +def : Pat<(int_nvvm_ff2bf16x2_rz_relu Float32Regs:$a, Float32Regs:$b), + (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>; + +def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b), + (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>; +def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b), + (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>; +def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b), + (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>; +def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b), + (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>; + +def : Pat<(int_nvvm_f2bf16_rn Float32Regs:$a), + (CVT_bf16_f32 Float32Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_f2bf16_rn_relu Float32Regs:$a), + (CVT_bf16_f32 Float32Regs:$a, CvtRN_RELU)>; +def : Pat<(int_nvvm_f2bf16_rz Float32Regs:$a), + (CVT_bf16_f32 Float32Regs:$a, CvtRZ)>; +def : Pat<(int_nvvm_f2bf16_rz_relu Float32Regs:$a), + (CVT_bf16_f32 Float32Regs:$a, CvtRZ_RELU)>; + +def CVT_tf32_f32 : + NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a), + "cvt.rna.tf32.f32 \t$dest, $a;", + [(set Int32Regs:$dest, (int_nvvm_f2tf32_rna Float32Regs:$a))]>; + def INT_NVVM_LOHI_I2D : F_MATH_2<"mov.b64 \t$dst, {{$src0, $src1}};", Float64Regs, Int32Regs, Int32Regs, int_nvvm_lohi_i2d>; diff --git a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll @@ -0,0 +1,136 @@ +; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s + + +; CHECK-LABEL: cvt_rn_bf16x2_f32 +define i32 @cvt_rn_bf16x2_f32(float %f1, float %f2) { + +; CHECK: cvt.rn.bf16x2.f32 + %val = call i32 @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2); + +ret i32 %val +} + +; CHECK-LABEL: cvt_rn_relu_bf16x2_f32 +define i32 @cvt_rn_relu_bf16x2_f32(float %f1, float %f2) { + +; CHECK: cvt.rn.relu.bf16x2.f32 +%val = call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2); + +ret i32 %val +} + +; CHECK-LABEL: cvt_rz_bf16x2_f32 +define i32 @cvt_rz_bf16x2_f32(float %f1, float %f2) { + +; CHECK: cvt.rz.bf16x2.f32 + %val = call i32 @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2); + +ret i32 %val +} + +; CHECK-LABEL: cvt_rz_relu_bf16x2_f32 +define i32 @cvt_rz_relu_bf16x2_f32(float %f1, float %f2) { + +; CHECK: cvt.rz.relu.bf16x2.f32 +%val = call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2); + +ret i32 %val +} + +declare i32 @llvm.nvvm.ff2bf16x2.rn(float, float) +declare i32 @llvm.nvvm.ff2bf16x2.rn.relu(float, float) +declare i32 @llvm.nvvm.ff2bf16x2.rz(float, float) +declare i32 @llvm.nvvm.ff2bf16x2.rz.relu(float, float) + +; CHECK-LABEL: cvt_rn_f16x2_f32 +define <2 x half> @cvt_rn_f16x2_f32(float %f1, float %f2) { + +; CHECK: cvt.rn.f16x2.f32 + %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %f1, float %f2); + +ret <2 x half> %val +} + +; CHECK-LABEL: cvt_rn_relu_f16x2_f32 +define <2 x half> @cvt_rn_relu_f16x2_f32(float %f1, float %f2) { + +; CHECK: cvt.rn.relu.f16x2.f32 +%val = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %f1, float %f2); + +ret <2 x half> %val +} + +; CHECK-LABEL: cvt_rz_f16x2_f32 +define <2 x half> @cvt_rz_f16x2_f32(float %f1, float %f2) { + +; CHECK: cvt.rz.f16x2.f32 + %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %f1, float %f2); + +ret <2 x half> %val +} + +; CHECK-LABEL: cvt_rz_relu_f16x2_f32 +define <2 x half> @cvt_rz_relu_f16x2_f32(float %f1, float %f2) { + +; CHECK: cvt.rz.relu.f16x2.f32 +%val = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %f1, float %f2); + +ret <2 x half> %val +} + +declare <2 x half> @llvm.nvvm.ff2f16x2.rn(float, float) +declare <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float, float) +declare <2 x half> @llvm.nvvm.ff2f16x2.rz(float, float) +declare <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float, float) + +; CHECK-LABEL: cvt_rn_bf16_f32 +define i16 @cvt_rn_bf16_f32(float %f1) { + +; CHECK: cvt.rn.bf16.f32 + %val = call i16 @llvm.nvvm.f2bf16.rn(float %f1); + +ret i16 %val +} + +; CHECK-LABEL: cvt_rn_relu_bf16_f32 +define i16 @cvt_rn_relu_bf16_f32(float %f1) { + +; CHECK: cvt.rn.relu.bf16.f32 +%val = call i16 @llvm.nvvm.f2bf16.rn.relu(float %f1); + +ret i16 %val +} + +; CHECK-LABEL: cvt_rz_bf16_f32 +define i16 @cvt_rz_bf16_f32(float %f1) { + +; CHECK: cvt.rz.bf16.f32 + %val = call i16 @llvm.nvvm.f2bf16.rz(float %f1); + +ret i16 %val +} + +; CHECK-LABEL: cvt_rz_relu_bf16_f32 +define i16 @cvt_rz_relu_bf16_f32(float %f1) { + +; CHECK: cvt.rz.relu.bf16.f32 +%val = call i16 @llvm.nvvm.f2bf16.rz.relu(float %f1); + +ret i16 %val +} + +declare i16 @llvm.nvvm.f2bf16.rn(float) +declare i16 @llvm.nvvm.f2bf16.rn.relu(float) +declare i16 @llvm.nvvm.f2bf16.rz(float) +declare i16 @llvm.nvvm.f2bf16.rz.relu(float) + +; CHECK-LABEL: cvt_rna_tf32_f32 +define i32 @cvt_rna_tf32_f32(float %f1) { + +; CHECK: cvt.rna.tf32.f32 + %val = call i32 @llvm.nvvm.f2tf32.rna(float %f1); + +ret i32 %val +} + +declare i32 @llvm.nvvm.f2tf32.rna(float)