diff --git a/compiler-rt/lib/builtins/CMakeLists.txt b/compiler-rt/lib/builtins/CMakeLists.txt --- a/compiler-rt/lib/builtins/CMakeLists.txt +++ b/compiler-rt/lib/builtins/CMakeLists.txt @@ -167,6 +167,7 @@ trampoline_setup.c truncdfhf2.c truncdfsf2.c + truncsfbf2.c truncsfhf2.c ucmpdi2.c ucmpti2.c diff --git a/compiler-rt/lib/builtins/fp_trunc.h b/compiler-rt/lib/builtins/fp_trunc.h --- a/compiler-rt/lib/builtins/fp_trunc.h +++ b/compiler-rt/lib/builtins/fp_trunc.h @@ -59,6 +59,12 @@ #define DST_REP_C UINT16_C static const int dstSigBits = 10; +#elif defined DST_BFLOAT +typedef uint16_t dst_t; +typedef uint16_t dst_rep_t; +#define DST_REP_C UINT16_C +static const int dstSigBits = 7; + #else #error Destination should be single precision or double precision! #endif // end destination precision diff --git a/compiler-rt/lib/builtins/truncsfbf2.c b/compiler-rt/lib/builtins/truncsfbf2.c new file mode 100644 --- /dev/null +++ b/compiler-rt/lib/builtins/truncsfbf2.c @@ -0,0 +1,13 @@ +//===-- lib/truncsfbf2.c - single -> bfloat conversion ------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#define SRC_SINGLE +#define DST_BFLOAT +#include "fp_trunc_impl.inc" + +COMPILER_RT_ABI dst_t __truncsfbf2(float a) { return __truncXfYf2__(a); } diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -898,6 +898,13 @@ STRICT_FP16_TO_FP, STRICT_FP_TO_FP16, + /// BF16_TO_FP, FP_TO_BF16 - These operators are used to perform promotions + /// and truncation for bfloat16. These nodes form a semi-softened interface + /// for dealing with bf16 (as an i16), which is often a storage-only type but + /// has native conversions. + BF16_TO_FP, + FP_TO_BF16, + /// Perform various unary floating-point operations inspired by libm. For /// FPOWI, the result is undefined if if the integer operand doesn't fit into /// sizeof(int). diff --git a/llvm/include/llvm/IR/RuntimeLibcalls.def b/llvm/include/llvm/IR/RuntimeLibcalls.def --- a/llvm/include/llvm/IR/RuntimeLibcalls.def +++ b/llvm/include/llvm/IR/RuntimeLibcalls.def @@ -310,6 +310,7 @@ HANDLE_LIBCALL(FPROUND_F80_F16, "__truncxfhf2") HANDLE_LIBCALL(FPROUND_F128_F16, "__trunctfhf2") HANDLE_LIBCALL(FPROUND_PPCF128_F16, "__trunctfhf2") +HANDLE_LIBCALL(FPROUND_F32_BF16, "__truncsfbf2") HANDLE_LIBCALL(FPROUND_F64_F32, "__truncdfsf2") HANDLE_LIBCALL(FPROUND_F80_F32, "__truncxfsf2") HANDLE_LIBCALL(FPROUND_F128_F32, "__trunctfsf2") diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -998,6 +998,7 @@ Action = TLI.getOperationAction(Node->getOpcode(), MVT::Other); break; case ISD::FP_TO_FP16: + case ISD::FP_TO_BF16: case ISD::SINT_TO_FP: case ISD::UINT_TO_FP: case ISD::EXTRACT_VECTOR_ELT: @@ -2907,6 +2908,18 @@ Node->getValueType(0), dl))) Results.push_back(Tmp1); break; + case ISD::BF16_TO_FP: { + // Always expand bf16 to f32 casts, they lower to zext + shift. + SDValue Op = DAG.getNode(ISD::BITCAST, dl, MVT::i16, Node->getOperand(0)); + Op = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, Op); + Op = DAG.getNode( + ISD::SHL, dl, MVT::i32, Op, + DAG.getConstant(16, dl, + TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout()))); + Op = DAG.getNode(ISD::BITCAST, dl, MVT::f32, Op); + Results.push_back(Op); + break; + } case ISD::SIGN_EXTEND_INREG: { EVT ExtraVT = cast(Node->getOperand(1))->getVT(); EVT VT = Node->getValueType(0); @@ -4219,6 +4232,13 @@ Results.push_back(ExpandLibCall(LC, Node, false)); break; } + case ISD::FP_TO_BF16: { + RTLIB::Libcall LC = + RTLIB::getFPROUND(Node->getOperand(0).getValueType(), MVT::bf16); + assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unable to expand fp_to_bf16"); + Results.push_back(ExpandLibCall(LC, Node, false)); + break; + } case ISD::STRICT_SINT_TO_FP: case ISD::STRICT_UINT_TO_FP: case ISD::SINT_TO_FP: diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp @@ -834,6 +834,7 @@ case ISD::BR_CC: Res = SoftenFloatOp_BR_CC(N); break; case ISD::STRICT_FP_TO_FP16: case ISD::FP_TO_FP16: // Same as FP_ROUND for softening purposes + case ISD::FP_TO_BF16: case ISD::STRICT_FP_ROUND: case ISD::FP_ROUND: Res = SoftenFloatOp_FP_ROUND(N); break; case ISD::STRICT_FP_TO_SINT: @@ -885,16 +886,19 @@ // returns an i16 so doesn't meet the constraints necessary for FP_ROUND. assert(N->getOpcode() == ISD::FP_ROUND || N->getOpcode() == ISD::FP_TO_FP16 || N->getOpcode() == ISD::STRICT_FP_TO_FP16 || + N->getOpcode() == ISD::FP_TO_BF16 || N->getOpcode() == ISD::STRICT_FP_ROUND); bool IsStrict = N->isStrictFPOpcode(); SDValue Op = N->getOperand(IsStrict ? 1 : 0); EVT SVT = Op.getValueType(); EVT RVT = N->getValueType(0); - EVT FloatRVT = (N->getOpcode() == ISD::FP_TO_FP16 || - N->getOpcode() == ISD::STRICT_FP_TO_FP16) - ? MVT::f16 - : RVT; + EVT FloatRVT = RVT; + if (N->getOpcode() == ISD::FP_TO_FP16 || + N->getOpcode() == ISD::STRICT_FP_TO_FP16) + FloatRVT = MVT::f16; + else if (N->getOpcode() == ISD::FP_TO_BF16) + FloatRVT = MVT::bf16; RTLIB::Libcall LC = RTLIB::getFPROUND(SVT, FloatRVT); assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported FP_ROUND libcall"); @@ -2068,9 +2072,13 @@ static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) { if (OpVT == MVT::f16) { - return ISD::FP16_TO_FP; + return ISD::FP16_TO_FP; } else if (RetVT == MVT::f16) { - return ISD::FP_TO_FP16; + return ISD::FP_TO_FP16; + } else if (OpVT == MVT::bf16) { + return ISD::BF16_TO_FP; + } else if (RetVT == MVT::bf16) { + return ISD::FP_TO_BF16; } report_fatal_error("Attempt at an invalid promotion-related conversion"); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -365,6 +365,8 @@ case ISD::STRICT_FP16_TO_FP: return "strict_fp16_to_fp"; case ISD::FP_TO_FP16: return "fp_to_fp16"; case ISD::STRICT_FP_TO_FP16: return "strict_fp_to_fp16"; + case ISD::BF16_TO_FP: return "bf16_to_fp"; + case ISD::FP_TO_BF16: return "fp_to_bf16"; case ISD::LROUND: return "lround"; case ISD::STRICT_LROUND: return "strict_lround"; case ISD::LLROUND: return "llround"; diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -274,6 +274,9 @@ return FPROUND_F128_F16; if (OpVT == MVT::ppcf128) return FPROUND_PPCF128_F16; + } else if (RetVT == MVT::bf16) { + if (OpVT == MVT::f32) + return FPROUND_F32_BF16; } else if (RetVT == MVT::f32) { if (OpVT == MVT::f64) return FPROUND_F64_F32; @@ -1373,6 +1376,16 @@ } } + // Decide how to handle bf16. If the target does not have native bf16 support, + // promote it to f32, because there are no bf16 library calls (except for + // converting from f32 to bf16). + if (!isTypeLegal(MVT::bf16)) { + NumRegistersForVT[MVT::bf16] = NumRegistersForVT[MVT::f32]; + RegisterTypeForVT[MVT::bf16] = RegisterTypeForVT[MVT::f32]; + TransformToType[MVT::bf16] = MVT::f32; + ValueTypeActions.setTypeAction(MVT::bf16, TypePromoteFloat); + } + // Loop over all of the vector value types to see which need transformations. for (unsigned i = MVT::FIRST_VECTOR_VALUETYPE; i <= (unsigned)MVT::LAST_VECTOR_VALUETYPE; ++i) { diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -412,14 +412,15 @@ setOperationAction(Op, MVT::f128, Expand); } - setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::f80, MVT::f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::f128, MVT::f16, Expand); - setTruncStoreAction(MVT::f32, MVT::f16, Expand); - setTruncStoreAction(MVT::f64, MVT::f16, Expand); - setTruncStoreAction(MVT::f80, MVT::f16, Expand); - setTruncStoreAction(MVT::f128, MVT::f16, Expand); + for (MVT VT : {MVT::f32, MVT::f64, MVT::f80, MVT::f128}) { + setLoadExtAction(ISD::EXTLOAD, VT, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, VT, MVT::bf16, Expand); + setTruncStoreAction(VT, MVT::f16, Expand); + setTruncStoreAction(VT, MVT::bf16, Expand); + + setOperationAction(ISD::BF16_TO_FP, VT, Expand); + setOperationAction(ISD::FP_TO_BF16, VT, Expand); + } setOperationAction(ISD::PARITY, MVT::i8, Custom); setOperationAction(ISD::PARITY, MVT::i16, Custom); @@ -873,7 +874,8 @@ // EXTLOAD for MVT::f16 vectors is not legal because f16 vectors are // split/scalarized right now. - if (VT.getVectorElementType() == MVT::f16) + if (VT.getVectorElementType() == MVT::f16 || + VT.getVectorElementType() == MVT::bf16) setLoadExtAction(ISD::EXTLOAD, InnerVT, VT, Expand); } } diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/bfloat.ll @@ -0,0 +1,28 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-linux-gnu | FileCheck %s + +define void @add(ptr %pa, ptr %pb, ptr %pc) { +; CHECK-LABEL: add: +; CHECK: # %bb.0: +; CHECK-NEXT: pushq %rbx +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: .cfi_offset %rbx, -16 +; CHECK-NEXT: movq %rdx, %rbx +; CHECK-NEXT: movzwl (%rdi), %eax +; CHECK-NEXT: shll $16, %eax +; CHECK-NEXT: movd %eax, %xmm1 +; CHECK-NEXT: movzwl (%rsi), %eax +; CHECK-NEXT: shll $16, %eax +; CHECK-NEXT: movd %eax, %xmm0 +; CHECK-NEXT: addss %xmm1, %xmm0 +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movw %ax, (%rbx) +; CHECK-NEXT: popq %rbx +; CHECK-NEXT: .cfi_def_cfa_offset 8 +; CHECK-NEXT: retq + %a = load bfloat, ptr %pa + %b = load bfloat, ptr %pb + %add = fadd bfloat %a, %b + store bfloat %add, ptr %pc + ret void +}