diff --git a/clang/lib/Basic/Targets/NVPTX.h b/clang/lib/Basic/Targets/NVPTX.h --- a/clang/lib/Basic/Targets/NVPTX.h +++ b/clang/lib/Basic/Targets/NVPTX.h @@ -176,6 +176,8 @@ } bool hasBitIntType() const override { return true; } + bool hasBFloat16Type() const override { return true; } + const char *getBFloat16Mangling() const override { return "u6__bf16"; }; }; } // namespace targets } // namespace clang diff --git a/clang/lib/Basic/Targets/NVPTX.cpp b/clang/lib/Basic/Targets/NVPTX.cpp --- a/clang/lib/Basic/Targets/NVPTX.cpp +++ b/clang/lib/Basic/Targets/NVPTX.cpp @@ -52,6 +52,9 @@ VLASupported = false; AddrSpaceMap = &NVPTXAddrSpaceMap; UseAddrSpaceMapMangling = true; + // __bf16 is always available as a load/store only type. + BFloat16Width = BFloat16Align = 16; + BFloat16Format = &llvm::APFloat::BFloat(); // Define available target features // These must be defined in sorted order! diff --git a/clang/test/CodeGenCUDA/bf16.cu b/clang/test/CodeGenCUDA/bf16.cu new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenCUDA/bf16.cu @@ -0,0 +1,46 @@ +// REQUIRES: nvptx-registered-target +// REQUIRES: x86-registered-target + +// RUN: %clang_cc1 "-aux-triple" "x86_64-unknown-linux-gnu" "-triple" "nvptx64-nvidia-cuda" \ +// RUN: -fcuda-is-device "-aux-target-cpu" "x86-64" -S -o - %s | FileCheck %s + +#include "Inputs/cuda.h" + +// CHECK-LABEL: .visible .func _Z8test_argPu6__bf16u6__bf16( +// CHECK: .param .b64 _Z8test_argPu6__bf16u6__bf16_param_0, +// CHECK: .param .b16 _Z8test_argPu6__bf16u6__bf16_param_1 +// +__device__ void test_arg(__bf16 *out, __bf16 in) { +// CHECK: ld.param.b16 %{{h.*}}, [_Z8test_argPu6__bf16u6__bf16_param_1]; + __bf16 bf16 = in; + *out = bf16; +// CHECK: st.b16 +// CHECK: ret; +} + + +// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z8test_retu6__bf16( +// CHECK: .param .b16 _Z8test_retu6__bf16_param_0 +__device__ __bf16 test_ret( __bf16 in) { +// CHECK: ld.param.b16 %h{{.*}}, [_Z8test_retu6__bf16_param_0]; + return in; +// CHECK: st.param.b16 [func_retval0+0], %h +// CHECK: ret; +} + +// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z9test_callu6__bf16( +// CHECK: .param .b16 _Z9test_callu6__bf16_param_0 +__device__ __bf16 test_call( __bf16 in) { +// CHECK: ld.param.b16 %h{{.*}}, [_Z9test_callu6__bf16_param_0]; +// CHECK: st.param.b16 [param0+0], %h2; +// CHECK: .param .b32 retval0; +// CHECK: call.uni (retval0), +// CHECK-NEXT: _Z8test_retu6__bf16, +// CHECK-NEXT: ( +// CHECK-NEXT: param0 +// CHECK-NEXT ); +// CHECK: ld.param.b16 %h{{.*}}, [retval0+0]; + return test_ret(in); +// CHECK: st.param.b16 [func_retval0+0], %h +// CHECK: ret; +} diff --git a/clang/test/SemaCUDA/bf16.cu b/clang/test/SemaCUDA/bf16.cu new file mode 100644 --- /dev/null +++ b/clang/test/SemaCUDA/bf16.cu @@ -0,0 +1,33 @@ +// REQUIRES: nvptx-registered-target +// REQUIRES: x86-registered-target + +// RUN: %clang_cc1 "-triple" "x86_64-unknown-linux-gnu" "-aux-triple" "nvptx64-nvidia-cuda" \ +// RUN: "-target-cpu" "x86-64" -fsyntax-only -verify=scalar %s +// RUN: %clang_cc1 "-aux-triple" "x86_64-unknown-linux-gnu" "-triple" "nvptx64-nvidia-cuda" \ +// RUN: -fcuda-is-device "-aux-target-cpu" "x86-64" -fsyntax-only -verify=scalar %s + +#include "Inputs/cuda.h" + +__device__ void test(bool b, __bf16 *out, __bf16 in) { + __bf16 bf16 = in; // No error on using the type itself. + + bf16 + bf16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__bf16')}} + bf16 - bf16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__bf16')}} + bf16 * bf16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__bf16')}} + bf16 / bf16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__bf16')}} + + __fp16 fp16; + + bf16 + fp16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__fp16')}} + fp16 + bf16; // scalar-error {{invalid operands to binary expression ('__fp16' and '__bf16')}} + bf16 - fp16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__fp16')}} + fp16 - bf16; // scalar-error {{invalid operands to binary expression ('__fp16' and '__bf16')}} + bf16 * fp16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__fp16')}} + fp16 * bf16; // scalar-error {{invalid operands to binary expression ('__fp16' and '__bf16')}} + bf16 / fp16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__fp16')}} + fp16 / bf16; // scalar-error {{invalid operands to binary expression ('__fp16' and '__bf16')}} + bf16 = fp16; // scalar-error {{assigning to '__bf16' from incompatible type '__fp16'}} + fp16 = bf16; // scalar-error {{assigning to '__fp16' from incompatible type '__bf16'}} + bf16 + (b ? fp16 : bf16); // scalar-error {{incompatible operand types ('__fp16' and '__bf16')}} + *out = bf16; +} diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -1831,6 +1831,7 @@ break; case Type::HalfTyID: + case Type::BFloatTyID: case Type::FloatTyID: case Type::DoubleTyID: AddIntToBuffer(cast(CPV)->getValueAPF().bitcastToAPInt()); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -823,8 +823,10 @@ case MVT::i64: return Opcode_i64; case MVT::f16: + case MVT::bf16: return Opcode_f16; case MVT::v2f16: + case MVT::v2bf16: return Opcode_f16x2; case MVT::f32: return Opcode_f32; @@ -834,6 +836,20 @@ return None; } } +static int getLdStRegType(EVT VT) { + if (VT.isFloatingPoint()) + switch (VT.getSimpleVT().SimpleTy) { + case MVT::f16: + case MVT::bf16: + case MVT::v2f16: + case MVT::v2bf16: + return NVPTX::PTXLdStInstCode::Untyped; + default: + return NVPTX::PTXLdStInstCode::Float; + } + else + return NVPTX::PTXLdStInstCode::Unsigned; +} bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { SDLoc dl(N); @@ -891,19 +907,16 @@ // Vector Setting unsigned vecType = NVPTX::PTXLdStInstCode::Scalar; if (SimpleVT.isVector()) { - assert(LoadedVT == MVT::v2f16 && "Unexpected vector type"); - // v2f16 is loaded using ld.b32 + assert((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) && + "Unexpected vector type"); + // v2f16/v2bf16 is loaded using ld.b32 fromTypeWidth = 32; } if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD)) fromType = NVPTX::PTXLdStInstCode::Signed; - else if (ScalarVT.isFloatingPoint()) - // f16 uses .b16 as its storage type. - fromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped - : NVPTX::PTXLdStInstCode::Float; else - fromType = NVPTX::PTXLdStInstCode::Unsigned; + fromType = getLdStRegType(ScalarVT); // Create the machine instruction DAG SDValue Chain = N->getOperand(0); @@ -1033,11 +1046,8 @@ N->getOperand(N->getNumOperands() - 1))->getZExtValue(); if (ExtensionType == ISD::SEXTLOAD) FromType = NVPTX::PTXLdStInstCode::Signed; - else if (ScalarVT.isFloatingPoint()) - FromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped - : NVPTX::PTXLdStInstCode::Float; else - FromType = NVPTX::PTXLdStInstCode::Unsigned; + FromType = getLdStRegType(ScalarVT); unsigned VecType; @@ -1057,7 +1067,7 @@ // v8f16 is a special case. PTX doesn't have ld.v8.f16 // instruction. Instead, we split the vector into v2f16 chunks and // load them with ld.v4.b32. - if (EltVT == MVT::v2f16) { + if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) { assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode."); EltVT = MVT::i32; FromType = NVPTX::PTXLdStInstCode::Untyped; @@ -1745,18 +1755,13 @@ MVT ScalarVT = SimpleVT.getScalarType(); unsigned toTypeWidth = ScalarVT.getSizeInBits(); if (SimpleVT.isVector()) { - assert(StoreVT == MVT::v2f16 && "Unexpected vector type"); + assert((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) && + "Unexpected vector type"); // v2f16 is stored using st.b32 toTypeWidth = 32; } - unsigned int toType; - if (ScalarVT.isFloatingPoint()) - // f16 uses .b16 as its storage type. - toType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped - : NVPTX::PTXLdStInstCode::Float; - else - toType = NVPTX::PTXLdStInstCode::Unsigned; + unsigned int toType = getLdStRegType(ScalarVT); // Create the machine instruction DAG SDValue Chain = ST->getChain(); @@ -1896,12 +1901,7 @@ assert(StoreVT.isSimple() && "Store value is not simple"); MVT ScalarVT = StoreVT.getSimpleVT().getScalarType(); unsigned ToTypeWidth = ScalarVT.getSizeInBits(); - unsigned ToType; - if (ScalarVT.isFloatingPoint()) - ToType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped - : NVPTX::PTXLdStInstCode::Float; - else - ToType = NVPTX::PTXLdStInstCode::Unsigned; + unsigned ToType = getLdStRegType(ScalarVT); SmallVector StOps; SDValue N2; @@ -1929,7 +1929,7 @@ // v8f16 is a special case. PTX doesn't have st.v8.f16 // instruction. Instead, we split the vector into v2f16 chunks and // store them with st.v4.b32. - if (EltVT == MVT::v2f16) { + if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) { assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode."); EltVT = MVT::i32; ToType = NVPTX::PTXLdStInstCode::Untyped; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -133,6 +133,9 @@ case MVT::v2f16: case MVT::v4f16: case MVT::v8f16: // <4 x f16x2> + case MVT::v2bf16: + case MVT::v4bf16: + case MVT::v8bf16: // <4 x bf16x2> case MVT::v2f32: case MVT::v4f32: case MVT::v2f64: @@ -190,8 +193,8 @@ // Vectors with an even number of f16 elements will be passed to // us as an array of v2f16 elements. We must match this so we // stay in sync with Ins/Outs. - if (EltVT == MVT::f16 && NumElts % 2 == 0) { - EltVT = MVT::v2f16; + if ((EltVT == MVT::f16 || EltVT == MVT::f16) && NumElts % 2 == 0) { + EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16; NumElts /= 2; } for (unsigned j = 0; j != NumElts; ++j) { @@ -400,6 +403,8 @@ addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass); addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass); addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass); + addRegisterClass(MVT::bf16, &NVPTX::Float16RegsRegClass); + addRegisterClass(MVT::v2bf16, &NVPTX::Float16x2RegsRegClass); // Conversion to/from FP16/FP16x2 is always legal. setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal); @@ -495,6 +500,7 @@ setOperationAction(ISD::ConstantFP, MVT::f64, Legal); setOperationAction(ISD::ConstantFP, MVT::f32, Legal); setOperationAction(ISD::ConstantFP, MVT::f16, Legal); + setOperationAction(ISD::ConstantFP, MVT::bf16, Legal); // TRAP can be lowered to PTX trap setOperationAction(ISD::TRAP, MVT::Other, Legal); @@ -2334,14 +2340,17 @@ case MVT::v2i32: case MVT::v2i64: case MVT::v2f16: + case MVT::v2bf16: case MVT::v2f32: case MVT::v2f64: case MVT::v4i8: case MVT::v4i16: case MVT::v4i32: case MVT::v4f16: + case MVT::v4bf16: case MVT::v4f32: case MVT::v8f16: // <4 x f16x2> + case MVT::v8bf16: // <4 x bf16x2> // This is a "native" vector type break; } @@ -2386,7 +2395,8 @@ // v8f16 is a special case. PTX doesn't have st.v8.f16 // instruction. Instead, we split the vector into v2f16 chunks and // store them with st.v4.b32. - assert(EltVT == MVT::f16 && "Wrong type for the vector."); + assert((EltVT == MVT::f16 || EltVT == MVT::bf16) && + "Wrong type for the vector."); Opcode = NVPTXISD::StoreV4; StoreF16x2 = true; break; @@ -4987,11 +4997,12 @@ // v8f16 is a special case. PTX doesn't have ld.v8.f16 // instruction. Instead, we split the vector into v2f16 chunks and // load them with ld.v4.b32. - assert(EltVT == MVT::f16 && "Unsupported v8 vector type."); + assert((EltVT == MVT::f16 || EltVT == MVT::bf16) && + "Unsupported v8 vector type."); LoadF16x2 = true; Opcode = NVPTXISD::LoadV4; - EVT ListVTs[] = {MVT::v2f16, MVT::v2f16, MVT::v2f16, MVT::v2f16, - MVT::Other}; + EVT VVT = (EltVT == MVT::f16) ? MVT::v2f16 : MVT::v2bf16; + EVT ListVTs[] = {VVT, VVT, VVT, VVT, MVT::Other}; LdResVTs = DAG.getVTList(ListVTs); break; } diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -172,6 +172,30 @@ def useShortPtr : Predicate<"useShortPointers()">; def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; +// Helper class to aid conversion between ValueType and a matching RegisterClass. + +class ValueToRegClass { + string name = !cast(T); + NVPTXRegClass ret = !cond( + !eq(name, "i1"): Int1Regs, + !eq(name, "i16"): Int16Regs, + !eq(name, "i32"): Int32Regs, + !eq(name, "i64"): Int64Regs, + !eq(name, "f16"): Float16Regs, + !eq(name, "v2f16"): Float16x2Regs, + !eq(name, "bf16"): Float16Regs, + !eq(name, "v2bf16"): Float16x2Regs, + !eq(name, "f32"): Float32Regs, + !eq(name, "f64"): Float64Regs, + !eq(name, "ai32"): Int32ArgRegs, + !eq(name, "ai64"): Int64ArgRegs, + !eq(name, "af32"): Float32ArgRegs, + !eq(name, "if64"): Float64ArgRegs, + ); +} + + + //===----------------------------------------------------------------------===// // Some Common Instruction Class Templates //===----------------------------------------------------------------------===// @@ -277,26 +301,26 @@ NVPTXInst<(outs Float16Regs:$dst), (ins Float16Regs:$a, Float16Regs:$b), !strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"), - [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + [(set Float16Regs:$dst, (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>, Requires<[useFP16Math, doF32FTZ]>; def f16rr : NVPTXInst<(outs Float16Regs:$dst), (ins Float16Regs:$a, Float16Regs:$b), !strconcat(OpcStr, ".f16 \t$dst, $a, $b;"), - [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + [(set Float16Regs:$dst, (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>, Requires<[useFP16Math]>; def f16x2rr_ftz : NVPTXInst<(outs Float16x2Regs:$dst), (ins Float16x2Regs:$a, Float16x2Regs:$b), !strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"), - [(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>, + [(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>, Requires<[useFP16Math, doF32FTZ]>; def f16x2rr : NVPTXInst<(outs Float16x2Regs:$dst), (ins Float16x2Regs:$a, Float16x2Regs:$b), !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"), - [(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>, + [(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>, Requires<[useFP16Math]>; } @@ -351,26 +375,26 @@ NVPTXInst<(outs Float16Regs:$dst), (ins Float16Regs:$a, Float16Regs:$b), !strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"), - [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + [(set Float16Regs:$dst, (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>, Requires<[useFP16Math, allowFMA, doF32FTZ]>; def f16rr : NVPTXInst<(outs Float16Regs:$dst), (ins Float16Regs:$a, Float16Regs:$b), !strconcat(OpcStr, ".f16 \t$dst, $a, $b;"), - [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + [(set Float16Regs:$dst, (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>, Requires<[useFP16Math, allowFMA]>; def f16x2rr_ftz : NVPTXInst<(outs Float16x2Regs:$dst), (ins Float16x2Regs:$a, Float16x2Regs:$b), !strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"), - [(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>, + [(set (v2f16 Float16x2Regs:$dst), (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>, Requires<[useFP16Math, allowFMA, doF32FTZ]>; def f16x2rr : NVPTXInst<(outs Float16x2Regs:$dst), (ins Float16x2Regs:$a, Float16x2Regs:$b), !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"), - [(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>, + [(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>, Requires<[useFP16Math, allowFMA]>; // These have strange names so we don't perturb existing mir tests. @@ -414,25 +438,25 @@ NVPTXInst<(outs Float16Regs:$dst), (ins Float16Regs:$a, Float16Regs:$b), !strconcat(OpcStr, ".rn.ftz.f16 \t$dst, $a, $b;"), - [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + [(set Float16Regs:$dst, (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>, Requires<[useFP16Math, noFMA, doF32FTZ]>; def _rnf16rr : NVPTXInst<(outs Float16Regs:$dst), (ins Float16Regs:$a, Float16Regs:$b), !strconcat(OpcStr, ".rn.f16 \t$dst, $a, $b;"), - [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + [(set Float16Regs:$dst, (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>, Requires<[useFP16Math, noFMA]>; def _rnf16x2rr_ftz : NVPTXInst<(outs Float16x2Regs:$dst), (ins Float16x2Regs:$a, Float16x2Regs:$b), !strconcat(OpcStr, ".rn.ftz.f16x2 \t$dst, $a, $b;"), - [(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>, + [(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>, Requires<[useFP16Math, noFMA, doF32FTZ]>; def _rnf16x2rr : NVPTXInst<(outs Float16x2Regs:$dst), (ins Float16x2Regs:$a, Float16x2Regs:$b), !strconcat(OpcStr, ".rn.f16x2 \t$dst, $a, $b;"), - [(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>, + [(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>, Requires<[useFP16Math, noFMA]>; } @@ -924,15 +948,15 @@ // // F16 NEG // -class FNEG_F16_F16X2 : +class FNEG_F16_F16X2 : NVPTXInst<(outs RC:$dst), (ins RC:$src), !strconcat(OpcStr, " \t$dst, $src;"), - [(set RC:$dst, (fneg RC:$src))]>, + [(set RC:$dst, (fneg (T RC:$src)))]>, Requires<[useFP16Math, hasPTX60, hasSM53, Pred]>; -def FNEG16_ftz : FNEG_F16_F16X2<"neg.ftz.f16", Float16Regs, doF32FTZ>; -def FNEG16 : FNEG_F16_F16X2<"neg.f16", Float16Regs, True>; -def FNEG16x2_ftz : FNEG_F16_F16X2<"neg.ftz.f16x2", Float16x2Regs, doF32FTZ>; -def FNEG16x2 : FNEG_F16_F16X2<"neg.f16x2", Float16x2Regs, True>; +def FNEG16_ftz : FNEG_F16_F16X2<"neg.ftz.f16", f16, Float16Regs, doF32FTZ>; +def FNEG16 : FNEG_F16_F16X2<"neg.f16", f16, Float16Regs, True>; +def FNEG16x2_ftz : FNEG_F16_F16X2<"neg.ftz.f16x2", v2f16, Float16x2Regs, doF32FTZ>; +def FNEG16x2 : FNEG_F16_F16X2<"neg.f16x2", v2f16, Float16x2Regs, True>; // // F64 division @@ -1105,17 +1129,17 @@ Requires<[Pred]>; } -multiclass FMA_F16 { +multiclass FMA_F16 { def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), - [(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>, + [(set RC:$dst, (fma (T RC:$a), (T RC:$b), (T RC:$c)))]>, Requires<[useFP16Math, Pred]>; } -defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", Float16Regs, doF32FTZ>; -defm FMA16 : FMA_F16<"fma.rn.f16", Float16Regs, True>; -defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", Float16x2Regs, doF32FTZ>; -defm FMA16x2 : FMA_F16<"fma.rn.f16x2", Float16x2Regs, True>; +defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Float16Regs, doF32FTZ>; +defm FMA16 : FMA_F16<"fma.rn.f16", f16, Float16Regs, True>; +defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Float16x2Regs, doF32FTZ>; +defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Float16x2Regs, True>; defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>; defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>; defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>; @@ -1569,52 +1593,57 @@ !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>; } - multiclass SELP_PATTERN { + multiclass SELP_PATTERN { def rr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, Int1Regs:$p), !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), - [(set RC:$dst, (select Int1Regs:$p, RC:$a, RC:$b))]>; + [(set (T RC:$dst), (select Int1Regs:$p, (T RC:$a), (T RC:$b)))]>; def ri : NVPTXInst<(outs RC:$dst), (ins RC:$a, ImmCls:$b, Int1Regs:$p), !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), - [(set RC:$dst, (select Int1Regs:$p, RC:$a, ImmNode:$b))]>; + [(set (T RC:$dst), (select Int1Regs:$p, (T RC:$a), (T ImmNode:$b)))]>; def ir : NVPTXInst<(outs RC:$dst), (ins ImmCls:$a, RC:$b, Int1Regs:$p), !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), - [(set RC:$dst, (select Int1Regs:$p, ImmNode:$a, RC:$b))]>; + [(set (T RC:$dst), (select Int1Regs:$p, ImmNode:$a, (T RC:$b)))]>; def ii : NVPTXInst<(outs RC:$dst), (ins ImmCls:$a, ImmCls:$b, Int1Regs:$p), !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), - [(set RC:$dst, (select Int1Regs:$p, ImmNode:$a, ImmNode:$b))]>; + [(set (T RC:$dst), (select Int1Regs:$p, ImmNode:$a, ImmNode:$b))]>; } } // Don't pattern match on selp.{s,u}{16,32,64} -- selp.b{16,32,64} is just as // good. -defm SELP_b16 : SELP_PATTERN<"b16", Int16Regs, i16imm, imm>; +defm SELP_b16 : SELP_PATTERN<"b16", i16, Int16Regs, i16imm, imm>; defm SELP_s16 : SELP<"s16", Int16Regs, i16imm>; defm SELP_u16 : SELP<"u16", Int16Regs, i16imm>; -defm SELP_b32 : SELP_PATTERN<"b32", Int32Regs, i32imm, imm>; +defm SELP_b32 : SELP_PATTERN<"b32", i32, Int32Regs, i32imm, imm>; defm SELP_s32 : SELP<"s32", Int32Regs, i32imm>; defm SELP_u32 : SELP<"u32", Int32Regs, i32imm>; -defm SELP_b64 : SELP_PATTERN<"b64", Int64Regs, i64imm, imm>; +defm SELP_b64 : SELP_PATTERN<"b64", i64, Int64Regs, i64imm, imm>; defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>; defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>; -defm SELP_f16 : SELP_PATTERN<"b16", Float16Regs, f16imm, fpimm>; -defm SELP_f32 : SELP_PATTERN<"f32", Float32Regs, f32imm, fpimm>; -defm SELP_f64 : SELP_PATTERN<"f64", Float64Regs, f64imm, fpimm>; +defm SELP_f16 : SELP_PATTERN<"b16", f16, Float16Regs, f16imm, fpimm>; + +defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>; +defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>; + +// This does not work as tablegen fails to infer the type of 'imm'. +//def v2f16imm : Operand; +//defm SELP_f16x2 : SELP_PATTERN<"b32", v2f16, Float16x2Regs, v2f16imm, imm>; def SELP_f16x2rr : NVPTXInst<(outs Float16x2Regs:$dst), (ins Float16x2Regs:$a, Float16x2Regs:$b, Int1Regs:$p), "selp.b32 \t$dst, $a, $b, $p;", [(set Float16x2Regs:$dst, - (select Int1Regs:$p, Float16x2Regs:$a, Float16x2Regs:$b))]>; + (select Int1Regs:$p, (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>; //----------------------------------- // Data Movement (Load / Store, Move) @@ -1847,22 +1876,22 @@ multiclass FSET_FORMAT { // f16 -> pred - def : Pat<(i1 (OpNode Float16Regs:$a, Float16Regs:$b)), + def : Pat<(i1 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))), (SETP_f16rr Float16Regs:$a, Float16Regs:$b, ModeFTZ)>, Requires<[useFP16Math,doF32FTZ]>; - def : Pat<(i1 (OpNode Float16Regs:$a, Float16Regs:$b)), + def : Pat<(i1 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))), (SETP_f16rr Float16Regs:$a, Float16Regs:$b, Mode)>, Requires<[useFP16Math]>; - def : Pat<(i1 (OpNode Float16Regs:$a, fpimm:$b)), + def : Pat<(i1 (OpNode (f16 Float16Regs:$a), fpimm:$b)), (SETP_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), ModeFTZ)>, Requires<[useFP16Math,doF32FTZ]>; - def : Pat<(i1 (OpNode Float16Regs:$a, fpimm:$b)), + def : Pat<(i1 (OpNode (f16 Float16Regs:$a), fpimm:$b)), (SETP_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), Mode)>, Requires<[useFP16Math]>; - def : Pat<(i1 (OpNode fpimm:$a, Float16Regs:$b)), + def : Pat<(i1 (OpNode fpimm:$a, (f16 Float16Regs:$b))), (SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, ModeFTZ)>, Requires<[useFP16Math,doF32FTZ]>; - def : Pat<(i1 (OpNode fpimm:$a, Float16Regs:$b)), + def : Pat<(i1 (OpNode fpimm:$a, (f16 Float16Regs:$b))), (SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>, Requires<[useFP16Math]>; @@ -1892,22 +1921,22 @@ (SETP_f64ir fpimm:$a, Float64Regs:$b, Mode)>; // f16 -> i32 - def : Pat<(i32 (OpNode Float16Regs:$a, Float16Regs:$b)), + def : Pat<(i32 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))), (SET_f16rr Float16Regs:$a, Float16Regs:$b, ModeFTZ)>, Requires<[useFP16Math, doF32FTZ]>; - def : Pat<(i32 (OpNode Float16Regs:$a, Float16Regs:$b)), + def : Pat<(i32 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))), (SET_f16rr Float16Regs:$a, Float16Regs:$b, Mode)>, Requires<[useFP16Math]>; - def : Pat<(i32 (OpNode Float16Regs:$a, fpimm:$b)), + def : Pat<(i32 (OpNode (f16 Float16Regs:$a), fpimm:$b)), (SET_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), ModeFTZ)>, Requires<[useFP16Math, doF32FTZ]>; - def : Pat<(i32 (OpNode Float16Regs:$a, fpimm:$b)), + def : Pat<(i32 (OpNode (f16 Float16Regs:$a), fpimm:$b)), (SET_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), Mode)>, Requires<[useFP16Math]>; - def : Pat<(i32 (OpNode fpimm:$a, Float16Regs:$b)), + def : Pat<(i32 (OpNode fpimm:$a, (f16 Float16Regs:$b))), (SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, ModeFTZ)>, Requires<[useFP16Math, doF32FTZ]>; - def : Pat<(i32 (OpNode fpimm:$a, Float16Regs:$b)), + def : Pat<(i32 (OpNode fpimm:$a, (f16 Float16Regs:$b))), (SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>, Requires<[useFP16Math]>; @@ -2329,10 +2358,10 @@ ".reg .b$size param$a;", [(DeclareScalarParam (i32 imm:$a), (i32 imm:$size), (i32 1))]>; -class MoveParamInst : +class MoveParamInst : NVPTXInst<(outs regclass:$dst), (ins regclass:$src), !strconcat("mov", asmstr, " \t$dst, $src;"), - [(set regclass:$dst, (MoveParam regclass:$src))]>; + [(set (T regclass:$dst), (MoveParam (T regclass:$src)))]>; class MoveParamSymbolInst : @@ -2340,8 +2369,8 @@ !strconcat("mov", asmstr, " \t$dst, $src;"), [(set regclass:$dst, (MoveParam texternalsym:$src))]>; -def MoveParamI64 : MoveParamInst; -def MoveParamI32 : MoveParamInst; +def MoveParamI64 : MoveParamInst; +def MoveParamI32 : MoveParamInst; def MoveParamSymbolI64 : MoveParamSymbolInst; def MoveParamSymbolI32 : MoveParamSymbolInst; @@ -2350,9 +2379,9 @@ NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src), "cvt.u16.u32 \t$dst, $src;", [(set Int16Regs:$dst, (MoveParam Int16Regs:$src))]>; -def MoveParamF64 : MoveParamInst; -def MoveParamF32 : MoveParamInst; -def MoveParamF16 : MoveParamInst; +def MoveParamF64 : MoveParamInst; +def MoveParamF32 : MoveParamInst; +def MoveParamF16 : MoveParamInst; class PseudoUseParamInst : NVPTXInst<(outs), (ins regclass:$src), @@ -2365,20 +2394,22 @@ def PseudoUseParamF64 : PseudoUseParamInst; def PseudoUseParamF32 : PseudoUseParamInst; -class ProxyRegInst : +class ProxyRegInst : NVPTXInst<(outs regclass:$dst), (ins regclass:$src), !strconcat("mov.", SzStr, " \t$dst, $src;"), - [(set regclass:$dst, (ProxyReg regclass:$src))]>; + [(set (T regclass:$dst), (ProxyReg (T regclass:$src)))]>; let isCodeGenOnly=1, isPseudo=1 in { - def ProxyRegI1 : ProxyRegInst<"pred", Int1Regs>; - def ProxyRegI16 : ProxyRegInst<"b16", Int16Regs>; - def ProxyRegI32 : ProxyRegInst<"b32", Int32Regs>; - def ProxyRegI64 : ProxyRegInst<"b64", Int64Regs>; - def ProxyRegF16 : ProxyRegInst<"b16", Float16Regs>; - def ProxyRegF32 : ProxyRegInst<"f32", Float32Regs>; - def ProxyRegF64 : ProxyRegInst<"f64", Float64Regs>; - def ProxyRegF16x2 : ProxyRegInst<"b32", Float16x2Regs>; + def ProxyRegI1 : ProxyRegInst<"pred", i1, Int1Regs>; + def ProxyRegI16 : ProxyRegInst<"b16", i16, Int16Regs>; + def ProxyRegI32 : ProxyRegInst<"b32", i32, Int32Regs>; + def ProxyRegI64 : ProxyRegInst<"b64", i64, Int64Regs>; + def ProxyRegF16 : ProxyRegInst<"b16", f16, Float16Regs>; + def ProxyRegBF16 : ProxyRegInst<"b16", bf16, Float16Regs>; + def ProxyRegF32 : ProxyRegInst<"f32", f32, Float32Regs>; + def ProxyRegF64 : ProxyRegInst<"f64", f64, Float64Regs>; + def ProxyRegF16x2 : ProxyRegInst<"b32", v2f16, Float16x2Regs>; + def ProxyRegBF16x2 : ProxyRegInst<"b32", v2bf16, Float16x2Regs>; } // @@ -2669,22 +2700,29 @@ //---- Conversion ---- -class F_BITCONVERT : +class F_BITCONVERT.ret, + NVPTXRegClass regclassOut = ValueToRegClass.ret> : NVPTXInst<(outs regclassOut:$d), (ins regclassIn:$a), !strconcat("mov.b", SzStr, " \t$d, $a;"), - [(set regclassOut:$d, (bitconvert regclassIn:$a))]>; + [(set (TOut regclassOut:$d), (bitconvert (TIn regclassIn:$a)))]>; -def BITCONVERT_16_I2F : F_BITCONVERT<"16", Int16Regs, Float16Regs>; -def BITCONVERT_16_F2I : F_BITCONVERT<"16", Float16Regs, Int16Regs>; -def BITCONVERT_32_I2F : F_BITCONVERT<"32", Int32Regs, Float32Regs>; -def BITCONVERT_32_F2I : F_BITCONVERT<"32", Float32Regs, Int32Regs>; -def BITCONVERT_64_I2F : F_BITCONVERT<"64", Int64Regs, Float64Regs>; -def BITCONVERT_64_F2I : F_BITCONVERT<"64", Float64Regs, Int64Regs>; -def BITCONVERT_32_I2F16x2 : F_BITCONVERT<"32", Int32Regs, Float16x2Regs>; -def BITCONVERT_32_F16x22I : F_BITCONVERT<"32", Float16x2Regs, Int32Regs>; -def BITCONVERT_32_F2F16x2 : F_BITCONVERT<"32", Float32Regs, Float16x2Regs>; -def BITCONVERT_32_F16x22F : F_BITCONVERT<"32", Float16x2Regs, Float32Regs>; +def BITCONVERT_16_I2F : F_BITCONVERT<"16", i16, f16>; +def BITCONVERT_16_F2I : F_BITCONVERT<"16", f16, i16>; +def BITCONVERT_16_I2BF : F_BITCONVERT<"16", i16, bf16>; +def BITCONVERT_16_BF2I : F_BITCONVERT<"16", bf16, i16>; +def BITCONVERT_32_I2F : F_BITCONVERT<"32", i32, f32>; +def BITCONVERT_32_F2I : F_BITCONVERT<"32", f32, i32>; +def BITCONVERT_64_I2F : F_BITCONVERT<"64", i64, f64>; +def BITCONVERT_64_F2I : F_BITCONVERT<"64", f64, i64>; +def BITCONVERT_32_I2F16x2 : F_BITCONVERT<"32", i32, v2f16>; +def BITCONVERT_32_F16x22I : F_BITCONVERT<"32", v2f16, i32>; +def BITCONVERT_32_F2F16x2 : F_BITCONVERT<"32", f32, v2f16>; +def BITCONVERT_32_F16x22F : F_BITCONVERT<"32", v2f16, f32>; +def BITCONVERT_32_I2BF16x2 : F_BITCONVERT<"32", i32, v2bf16>; +def BITCONVERT_32_BF16x22I : F_BITCONVERT<"32", v2bf16, i32>; +def BITCONVERT_32_F2BF16x2 : F_BITCONVERT<"32", f32, v2bf16>; +def BITCONVERT_32_BF16x22F : F_BITCONVERT<"32", v2bf16, f32>; // NOTE: pred->fp are currently sub-optimal due to an issue in TableGen where // we cannot specify floating-point literals in isel patterns. Therefore, we @@ -2752,23 +2790,23 @@ // f16 -> sint -def : Pat<(i1 (fp_to_sint Float16Regs:$a)), +def : Pat<(i1 (fp_to_sint (f16 Float16Regs:$a))), (SETP_b16ri (BITCONVERT_16_F2I Float16Regs:$a), 0, CmpEQ)>; -def : Pat<(i16 (fp_to_sint Float16Regs:$a)), - (CVT_s16_f16 Float16Regs:$a, CvtRZI)>; -def : Pat<(i32 (fp_to_sint Float16Regs:$a)), - (CVT_s32_f16 Float16Regs:$a, CvtRZI)>; -def : Pat<(i64 (fp_to_sint Float16Regs:$a)), +def : Pat<(i16 (fp_to_sint (f16 Float16Regs:$a))), + (CVT_s16_f16 (f16 Float16Regs:$a), CvtRZI)>; +def : Pat<(i32 (fp_to_sint (f16 Float16Regs:$a))), + (CVT_s32_f16 (f16 Float16Regs:$a), CvtRZI)>; +def : Pat<(i64 (fp_to_sint (f16 Float16Regs:$a))), (CVT_s64_f16 Float16Regs:$a, CvtRZI)>; // f16 -> uint -def : Pat<(i1 (fp_to_uint Float16Regs:$a)), +def : Pat<(i1 (fp_to_uint (f16 Float16Regs:$a))), (SETP_b16ri (BITCONVERT_16_F2I Float16Regs:$a), 0, CmpEQ)>; -def : Pat<(i16 (fp_to_uint Float16Regs:$a)), +def : Pat<(i16 (fp_to_uint (f16 Float16Regs:$a))), (CVT_u16_f16 Float16Regs:$a, CvtRZI)>; -def : Pat<(i32 (fp_to_uint Float16Regs:$a)), +def : Pat<(i32 (fp_to_uint (f16 Float16Regs:$a))), (CVT_u32_f16 Float16Regs:$a, CvtRZI)>; -def : Pat<(i64 (fp_to_uint Float16Regs:$a)), +def : Pat<(i64 (fp_to_uint (f16 Float16Regs:$a))), (CVT_u64_f16 Float16Regs:$a, CvtRZI)>; // f32 -> sint @@ -2915,7 +2953,7 @@ def : Pat<(select Int32Regs:$pred, Int64Regs:$a, Int64Regs:$b), (SELP_b64rr Int64Regs:$a, Int64Regs:$b, (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; -def : Pat<(select Int32Regs:$pred, Float16Regs:$a, Float16Regs:$b), +def : Pat<(select Int32Regs:$pred, (f16 Float16Regs:$a), (f16 Float16Regs:$b)), (SELP_f16rr Float16Regs:$a, Float16Regs:$b, (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; def : Pat<(select Int32Regs:$pred, Float32Regs:$a, Float32Regs:$b), @@ -2980,7 +3018,7 @@ def BuildF16x2 : NVPTXInst<(outs Float16x2Regs:$dst), (ins Float16Regs:$a, Float16Regs:$b), "mov.b32 \t$dst, {{$a, $b}};", - [(set Float16x2Regs:$dst, + [(set (v2f16 Float16x2Regs:$dst), (build_vector (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>; // Directly initializing underlying the b32 register is one less SASS @@ -3079,13 +3117,13 @@ (CVT_f32_f64 Float64Regs:$a, CvtRN)>; // fpextend f16 -> f32 -def : Pat<(f32 (fpextend Float16Regs:$a)), +def : Pat<(f32 (fpextend (f16 Float16Regs:$a))), (CVT_f32_f16 Float16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; -def : Pat<(f32 (fpextend Float16Regs:$a)), +def : Pat<(f32 (fpextend (f16 Float16Regs:$a))), (CVT_f32_f16 Float16Regs:$a, CvtNONE)>; // fpextend f16 -> f64 -def : Pat<(f64 (fpextend Float16Regs:$a)), +def : Pat<(f64 (fpextend (f16 Float16Regs:$a))), (CVT_f64_f16 Float16Regs:$a, CvtNONE)>; // fpextend f32 -> f64 @@ -3100,7 +3138,7 @@ // fceil, ffloor, froundeven, ftrunc. multiclass CVT_ROUND { - def : Pat<(OpNode Float16Regs:$a), + def : Pat<(OpNode (f16 Float16Regs:$a)), (CVT_f16_f16 Float16Regs:$a, Mode)>; def : Pat<(OpNode Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, ModeFTZ)>, Requires<[doF32FTZ]>; diff --git a/llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp b/llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp --- a/llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp @@ -75,6 +75,8 @@ case NVPTX::ProxyRegI64: case NVPTX::ProxyRegF16: case NVPTX::ProxyRegF16x2: + case NVPTX::ProxyRegBF16: + case NVPTX::ProxyRegBF16x2: case NVPTX::ProxyRegF32: case NVPTX::ProxyRegF64: replaceMachineInstructionUsage(MF, MI); diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td --- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td @@ -60,8 +60,8 @@ def Int16Regs : NVPTXRegClass<[i16], 16, (add (sequence "RS%u", 0, 4))>; def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 4), VRFrame32, VRFrameLocal32)>; def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>; -def Float16Regs : NVPTXRegClass<[f16], 16, (add (sequence "H%u", 0, 4))>; -def Float16x2Regs : NVPTXRegClass<[v2f16], 32, (add (sequence "HH%u", 0, 4))>; +def Float16Regs : NVPTXRegClass<[f16,bf16], 16, (add (sequence "H%u", 0, 4))>; +def Float16x2Regs : NVPTXRegClass<[v2f16,v2bf16], 32, (add (sequence "HH%u", 0, 4))>; def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>; def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>; def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>; diff --git a/llvm/test/CodeGen/NVPTX/bf16.ll b/llvm/test/CodeGen/NVPTX/bf16.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/bf16.ll @@ -0,0 +1,35 @@ +; RUN: llc < %s -march=nvptx | FileCheck %s +; RUN: %if ptxas %{ llc < %s -march=nvptx | %ptxas-verify %} + +; LDST: .b8 bfloat_array[8] = {1, 2, 3, 4, 5, 6, 7, 8}; +@"bfloat_array" = addrspace(1) constant [4 x bfloat] + [bfloat 0xR0201, bfloat 0xR0403, bfloat 0xR0605, bfloat 0xR0807] + +define void @test_load_store(bfloat addrspace(1)* %in, bfloat addrspace(1)* %out) { +; CHECK-LABEL: @test_load_store +; CHECK: ld.global.b16 [[TMP:%h[0-9]+]], [{{%r[0-9]+}}] +; CHECK: st.global.b16 [{{%r[0-9]+}}], [[TMP]] + %val = load bfloat, bfloat addrspace(1)* %in + store bfloat %val, bfloat addrspace(1) * %out + ret void +} + +define void @test_bitcast_from_bfloat(bfloat addrspace(1)* %in, i16 addrspace(1)* %out) { +; CHECK-LABEL: @test_bitcast_from_bfloat +; CHECK: ld.global.b16 [[TMP:%h[0-9]+]], [{{%r[0-9]+}}] +; CHECK: st.global.b16 [{{%r[0-9]+}}], [[TMP]] + %val = load bfloat, bfloat addrspace(1) * %in + %val_int = bitcast bfloat %val to i16 + store i16 %val_int, i16 addrspace(1)* %out + ret void +} + +define void @test_bitcast_to_bfloat(bfloat addrspace(1)* %out, i16 addrspace(1)* %in) { +; CHECK-LABEL: @test_bitcast_to_bfloat +; CHECK: ld.global.u16 [[TMP:%rs[0-9]+]], [{{%r[0-9]+}}] +; CHECK: st.global.u16 [{{%r[0-9]+}}], [[TMP]] + %val = load i16, i16 addrspace(1)* %in + %val_fp = bitcast i16 %val to bfloat + store bfloat %val_fp, bfloat addrspace(1)* %out + ret void +}