diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst --- a/clang/docs/LanguageExtensions.rst +++ b/clang/docs/LanguageExtensions.rst @@ -495,8 +495,8 @@ Half-Precision Floating Point ============================= -Clang supports two half-precision (16-bit) floating point types: ``__fp16`` and -``_Float16``. These types are supported in all language modes. +Clang supports three half-precision (16-bit) floating point types: ``__fp16``, +``_Float16`` and ``__bf16``. These types are supported in all language modes. ``__fp16`` is supported on every target, as it is purely a storage format; see below. ``_Float16`` is currently only supported on the following targets, with further @@ -508,6 +508,10 @@ ``_Float16`` will be supported on more targets as they define ABIs for it. +``__bf16`` is purely a storage format; it is currently only supported on the following targets: +* 32-bit ARM +* 64-bit ARM (AArch64) + ``__fp16`` is a storage and interchange format only. This means that values of ``__fp16`` are immediately promoted to (at least) ``float`` when used in arithmetic operations, so that e.g. the result of adding two ``__fp16`` values has type ``float``. diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h --- a/clang/include/clang-c/Index.h +++ b/clang/include/clang-c/Index.h @@ -3249,6 +3249,7 @@ CXType_UShortAccum = 36, CXType_UAccum = 37, CXType_ULongAccum = 38, + CXType_Bfloat16 = 39, CXType_FirstBuiltin = CXType_Void, CXType_LastBuiltin = CXType_ULongAccum, diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h --- a/clang/include/clang/AST/ASTContext.h +++ b/clang/include/clang/AST/ASTContext.h @@ -956,6 +956,7 @@ CanQualType SatUnsignedShortFractTy, SatUnsignedFractTy, SatUnsignedLongFractTy; CanQualType HalfTy; // [OpenCL 6.1.1.1], ARM NEON + CanQualType Bfloat16Ty; // ARM NEON CanQualType Float16Ty; // C11 extension ISO/IEC TS 18661-3 CanQualType FloatComplexTy, DoubleComplexTy, LongDoubleComplexTy; CanQualType Float128ComplexTy; diff --git a/clang/include/clang/AST/BuiltinTypes.def b/clang/include/clang/AST/BuiltinTypes.def --- a/clang/include/clang/AST/BuiltinTypes.def +++ b/clang/include/clang/AST/BuiltinTypes.def @@ -212,6 +212,9 @@ // '_Float16' FLOATING_TYPE(Float16, HalfTy) +// '__bf16' +FLOATING_TYPE(Bfloat16, Bfloat16Ty) + // '__float128' FLOATING_TYPE(Float128, Float128Ty) diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -1979,6 +1979,7 @@ bool isFloatingType() const; // C99 6.2.5p11 (real floating + complex) bool isHalfType() const; // OpenCL 6.1.1.1, NEON (IEEE 754-2008 half) bool isFloat16Type() const; // C11 extension ISO/IEC TS 18661 + bool isBFloat16Type() const; // ARM BFloat bool isFloat128Type() const; bool isRealType() const; // C99 6.2.5p17 (real floating + integer) bool isArithmeticType() const; // C99 6.2.5p18 (integer + floating) @@ -6720,6 +6721,10 @@ return isSpecificBuiltinType(BuiltinType::Float16); } +inline bool Type::isBFloat16Type() const { + return isSpecificBuiltinType(BuiltinType::Bfloat16); +} + inline bool Type::isFloat128Type() const { return isSpecificBuiltinType(BuiltinType::Float128); } diff --git a/clang/include/clang/Basic/Specifiers.h b/clang/include/clang/Basic/Specifiers.h --- a/clang/include/clang/Basic/Specifiers.h +++ b/clang/include/clang/Basic/Specifiers.h @@ -71,6 +71,7 @@ TST_Float16, // C11 extension ISO/IEC TS 18661-3 TST_Accum, // ISO/IEC JTC1 SC22 WG14 N1169 Extension TST_Fract, + TST_Bfloat16, TST_float, TST_double, TST_float128, diff --git a/clang/include/clang/Basic/TargetBuiltins.h b/clang/include/clang/Basic/TargetBuiltins.h --- a/clang/include/clang/Basic/TargetBuiltins.h +++ b/clang/include/clang/Basic/TargetBuiltins.h @@ -141,7 +141,8 @@ Poly128, Float16, Float32, - Float64 + Float64, + Bfloat16 }; NeonTypeFlags(unsigned F) : Flags(F) {} diff --git a/clang/include/clang/Basic/TargetInfo.h b/clang/include/clang/Basic/TargetInfo.h --- a/clang/include/clang/Basic/TargetInfo.h +++ b/clang/include/clang/Basic/TargetInfo.h @@ -59,6 +59,7 @@ unsigned char BoolWidth, BoolAlign; unsigned char IntWidth, IntAlign; unsigned char HalfWidth, HalfAlign; + unsigned char Bfloat16Width, Bfloat16Align; unsigned char FloatWidth, FloatAlign; unsigned char DoubleWidth, DoubleAlign; unsigned char LongDoubleWidth, LongDoubleAlign, Float128Align; @@ -100,8 +101,8 @@ unsigned short MaxVectorAlign; unsigned short MaxTLSAlign; - const llvm::fltSemantics *HalfFormat, *FloatFormat, *DoubleFormat, - *LongDoubleFormat, *Float128Format; + const llvm::fltSemantics *HalfFormat, *Bfloat16Format, *FloatFormat, + *DoubleFormat, *LongDoubleFormat, *Float128Format; ///===---- Target Data Type Query Methods -------------------------------===// enum IntType { @@ -188,6 +189,7 @@ // LLVM IR type. bool HasFloat128; bool HasFloat16; + bool HasBfloat16; unsigned char MaxAtomicPromoteWidth, MaxAtomicInlineWidth; unsigned short SimdDefaultAlign; @@ -556,6 +558,9 @@ /// Determine whether the _Float16 type is supported on this target. virtual bool hasFloat16Type() const { return HasFloat16; } + /// Determine whether the _Bfloat16 type is supported on this target. + virtual bool hasBfloat16Type() const { return HasBfloat16; } + /// Return the alignment that is suitable for storing any /// object with a fundamental alignment requirement. unsigned getSuitableAlign() const { return SuitableAlign; } @@ -604,6 +609,11 @@ unsigned getFloatAlign() const { return FloatAlign; } const llvm::fltSemantics &getFloatFormat() const { return *FloatFormat; } + /// getBfloat16Width/Align/Format - Return the size/align/format of '__bf16'. + unsigned getBfloat16Width() const { return Bfloat16Width; } + unsigned getBfloat16Align() const { return Bfloat16Align; } + const llvm::fltSemantics &getBfloat16Format() const { return *Bfloat16Format; } + /// getDoubleWidth/Align/Format - Return the size/align/format of 'double'. unsigned getDoubleWidth() const { return DoubleWidth; } unsigned getDoubleAlign() const { return DoubleAlign; } @@ -631,6 +641,11 @@ /// Return the mangled code of __float128. virtual const char *getFloat128Mangling() const { return "g"; } + /// Return the mangled code of bfloat. + virtual const char *getBfloat16Mangling() const { + llvm_unreachable("bfloat not implemented on this target"); + } + /// Return the value for the C99 FLT_EVAL_METHOD macro. virtual unsigned getFloatEvalMethod() const { return 0; } diff --git a/clang/include/clang/Basic/TokenKinds.def b/clang/include/clang/Basic/TokenKinds.def --- a/clang/include/clang/Basic/TokenKinds.def +++ b/clang/include/clang/Basic/TokenKinds.def @@ -588,6 +588,7 @@ // ARM NEON extensions. ALIAS("__fp16", half , KEYALL) +KEYWORD(__bf16 , KEYALL) // OpenCL Extension. KEYWORD(half , HALFSUPPORT) diff --git a/clang/include/clang/Sema/DeclSpec.h b/clang/include/clang/Sema/DeclSpec.h --- a/clang/include/clang/Sema/DeclSpec.h +++ b/clang/include/clang/Sema/DeclSpec.h @@ -279,6 +279,7 @@ static const TST TST_int = clang::TST_int; static const TST TST_int128 = clang::TST_int128; static const TST TST_half = clang::TST_half; + static const TST TST_Bfloat16 = clang::TST_Bfloat16; static const TST TST_float = clang::TST_float; static const TST TST_double = clang::TST_double; static const TST TST_float16 = clang::TST_Float16; diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h --- a/clang/include/clang/Serialization/ASTBitCodes.h +++ b/clang/include/clang/Serialization/ASTBitCodes.h @@ -1022,6 +1022,9 @@ /// The placeholder type for OpenMP iterator expression. PREDEF_TYPE_OMP_ITERATOR = 71, + /// \brief The '__bf16' type + PREDEF_TYPE_BFLOAT16_ID = 72, + /// OpenCL image types with auto numeration #define IMAGE_TYPE(ImgType, Id, SingletonId, Access, Suffix) \ PREDEF_TYPE_##Id##_ID, diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -100,7 +100,7 @@ using namespace clang; enum FloatingRank { - Float16Rank, HalfRank, FloatRank, DoubleRank, LongDoubleRank, Float128Rank + Bfloat16Rank, Float16Rank, HalfRank, FloatRank, DoubleRank, LongDoubleRank, Float128Rank }; /// \returns location that is relevant when searching for Doc comments related @@ -1449,6 +1449,8 @@ // half type (OpenCL 6.1.1.1) / ARM NEON __fp16 InitBuiltinType(HalfTy, BuiltinType::Half); + InitBuiltinType(Bfloat16Ty, BuiltinType::Bfloat16); + // Builtin type used to help define __builtin_va_list. VaListTagDecl = nullptr; } @@ -1646,6 +1648,8 @@ switch (T->castAs()->getKind()) { default: llvm_unreachable("Not a floating point type!"); + case BuiltinType::Bfloat16: + return Target->getBfloat16Format(); case BuiltinType::Float16: case BuiltinType::Half: return Target->getHalfFormat(); @@ -2029,6 +2033,9 @@ Width = Target->getLongFractWidth(); Align = Target->getLongFractAlign(); break; + case BuiltinType::Bfloat16: + Width = Target->getBfloat16Width(); + Align = Target->getBfloat16Align(); case BuiltinType::Float16: case BuiltinType::Half: if (Target->hasFloat16Type() || !getLangOpts().OpenMP || @@ -5839,6 +5846,7 @@ case BuiltinType::Double: return DoubleRank; case BuiltinType::LongDouble: return LongDoubleRank; case BuiltinType::Float128: return Float128Rank; + case BuiltinType::Bfloat16: return Bfloat16Rank; } } @@ -5851,6 +5859,7 @@ FloatingRank EltRank = getFloatingRank(Size); if (Domain->isComplexType()) { switch (EltRank) { + case Bfloat16Rank: case Float16Rank: case HalfRank: llvm_unreachable("Complex half is not supported"); case FloatRank: return FloatComplexTy; @@ -5863,6 +5872,7 @@ assert(Domain->isRealFloatingType() && "Unknown domain!"); switch (EltRank) { case Float16Rank: return HalfTy; + case Bfloat16Rank: return Bfloat16Ty; case HalfRank: return HalfTy; case FloatRank: return FloatTy; case DoubleRank: return DoubleTy; @@ -6835,6 +6845,7 @@ case BuiltinType::LongDouble: return 'D'; case BuiltinType::NullPtr: return '*'; // like char* + case BuiltinType::Bfloat16: case BuiltinType::Float16: case BuiltinType::Float128: case BuiltinType::Half: @@ -9696,6 +9707,11 @@ // Read the base type. switch (*Str++) { default: llvm_unreachable("Unknown builtin type letter!"); + case 'y': + assert(HowLong == 0 && !Signed && !Unsigned && + "Bad modifiers used with 'y'!"); + Type = Context.Bfloat16Ty; + break; case 'v': assert(HowLong == 0 && !Signed && !Unsigned && "Bad modifiers used with 'v'!"); diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -2746,6 +2746,11 @@ Out << TI->getFloat128Mangling(); break; } + case BuiltinType::Bfloat16: { + const TargetInfo *TI = &getASTContext().getTargetInfo(); + Out << TI->getBfloat16Mangling(); + break; + } case BuiltinType::NullPtr: Out << "Dn"; break; @@ -3162,6 +3167,7 @@ case BuiltinType::Double: EltName = "float64_t"; break; case BuiltinType::Float: EltName = "float32_t"; break; case BuiltinType::Half: EltName = "float16_t";break; + case BuiltinType::Bfloat16: EltName = "bfloat16x1_t";break; default: llvm_unreachable("unexpected Neon vector element type"); } @@ -3213,6 +3219,8 @@ return "Float32"; case BuiltinType::Double: return "Float64"; + case BuiltinType::Bfloat16: + return "Bfloat16"; default: llvm_unreachable("Unexpected vector element base type"); } diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -2139,6 +2139,7 @@ case BuiltinType::SatUShortFract: case BuiltinType::SatUFract: case BuiltinType::SatULongFract: + case BuiltinType::Bfloat16: case BuiltinType::Float128: { DiagnosticsEngine &Diags = Context.getDiags(); unsigned DiagID = Diags.getCustomDiagID( diff --git a/clang/lib/AST/NSAPI.cpp b/clang/lib/AST/NSAPI.cpp --- a/clang/lib/AST/NSAPI.cpp +++ b/clang/lib/AST/NSAPI.cpp @@ -485,6 +485,7 @@ case BuiltinType::OMPArraySection: case BuiltinType::OMPArrayShaping: case BuiltinType::OMPIterator: + case BuiltinType::Bfloat16: break; } diff --git a/clang/lib/AST/PrintfFormatString.cpp b/clang/lib/AST/PrintfFormatString.cpp --- a/clang/lib/AST/PrintfFormatString.cpp +++ b/clang/lib/AST/PrintfFormatString.cpp @@ -752,6 +752,7 @@ case BuiltinType::UInt128: case BuiltinType::Int128: case BuiltinType::Half: + case BuiltinType::Bfloat16: case BuiltinType::Float16: case BuiltinType::Float128: case BuiltinType::ShortAccum: diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -2021,7 +2021,8 @@ bool Type::isArithmeticType() const { if (const auto *BT = dyn_cast(CanonicalType)) return BT->getKind() >= BuiltinType::Bool && - BT->getKind() <= BuiltinType::Float128; + BT->getKind() <= BuiltinType::Float128 && + BT->getKind() != BuiltinType::Bfloat16; if (const auto *ET = dyn_cast(CanonicalType)) // GCC allows forward declaration of enum types (forbid by C99 6.7.2.3p2). // If a body isn't seen by the time we get here, return false. @@ -2803,6 +2804,8 @@ return "unsigned __int128"; case Half: return Policy.Half ? "half" : "__fp16"; + case Bfloat16: + return "__bf16"; case Float: return "float"; case Double: diff --git a/clang/lib/AST/TypeLoc.cpp b/clang/lib/AST/TypeLoc.cpp --- a/clang/lib/AST/TypeLoc.cpp +++ b/clang/lib/AST/TypeLoc.cpp @@ -375,6 +375,7 @@ case BuiltinType::SatUShortFract: case BuiltinType::SatUFract: case BuiltinType::SatULongFract: + case BuiltinType::Bfloat16: llvm_unreachable("Builtin type needs extra local data!"); // Fall through, if the impossible happens. diff --git a/clang/lib/Basic/TargetInfo.cpp b/clang/lib/Basic/TargetInfo.cpp --- a/clang/lib/Basic/TargetInfo.cpp +++ b/clang/lib/Basic/TargetInfo.cpp @@ -36,6 +36,7 @@ HasLegalHalfType = false; HasFloat128 = false; HasFloat16 = false; + HasBfloat16 = false; PointerWidth = PointerAlign = 32; BoolWidth = BoolAlign = 8; IntWidth = IntAlign = 32; diff --git a/clang/lib/Basic/Targets/AArch64.h b/clang/lib/Basic/Targets/AArch64.h --- a/clang/lib/Basic/Targets/AArch64.h +++ b/clang/lib/Basic/Targets/AArch64.h @@ -118,6 +118,7 @@ int getEHDataRegisterNumber(unsigned RegNo) const override; + const char *getBfloat16Mangling() const override { return "u6__bf16"; }; bool hasInt128Type() const override; }; diff --git a/clang/lib/Basic/Targets/AArch64.cpp b/clang/lib/Basic/Targets/AArch64.cpp --- a/clang/lib/Basic/Targets/AArch64.cpp +++ b/clang/lib/Basic/Targets/AArch64.cpp @@ -70,6 +70,9 @@ LongDoubleWidth = LongDoubleAlign = SuitableAlign = 128; LongDoubleFormat = &llvm::APFloat::IEEEquad(); + Bfloat16Width = Bfloat16Align = 16; + Bfloat16Format = &llvm::APFloat::Bfloat(); + // Make __builtin_ms_va_list available. HasBuiltinMSVaList = true; @@ -356,6 +359,7 @@ HasFP16FML = false; HasMTE = false; HasTME = false; + HasBfloat16 = false; ArchKind = llvm::AArch64::ArchKind::ARMV8A; for (const auto &Feature : Features) { @@ -391,6 +395,8 @@ HasMTE = true; if (Feature == "+tme") HasTME = true; + if (Feature == "+bf16") + HasBfloat16 = true; } setDataLayout(); diff --git a/clang/lib/Basic/Targets/ARM.h b/clang/lib/Basic/Targets/ARM.h --- a/clang/lib/Basic/Targets/ARM.h +++ b/clang/lib/Basic/Targets/ARM.h @@ -181,6 +181,8 @@ int getEHDataRegisterNumber(unsigned RegNo) const override; bool hasSjLjLowering() const override; + + const char *getBfloat16Mangling() const override { return "u6__bf16"; }; }; class LLVM_LIBRARY_VISIBILITY ARMleTargetInfo : public ARMTargetInfo { diff --git a/clang/lib/Basic/Targets/ARM.cpp b/clang/lib/Basic/Targets/ARM.cpp --- a/clang/lib/Basic/Targets/ARM.cpp +++ b/clang/lib/Basic/Targets/ARM.cpp @@ -25,6 +25,9 @@ IsAAPCS = true; DoubleAlign = LongLongAlign = LongDoubleAlign = SuitableAlign = 64; + Bfloat16Width = Bfloat16Align = 16; + Bfloat16Format = &llvm::APFloat::Bfloat(); + const llvm::Triple &T = getTriple(); bool IsNetBSD = T.isOSNetBSD(); @@ -74,6 +77,8 @@ DoubleAlign = LongLongAlign = LongDoubleAlign = SuitableAlign = 64; else DoubleAlign = LongLongAlign = LongDoubleAlign = SuitableAlign = 32; + Bfloat16Width = Bfloat16Align = 16; + Bfloat16Format = &llvm::APFloat::Bfloat(); WCharType = SignedInt; @@ -427,6 +432,7 @@ DotProd = 0; HasFloat16 = true; ARMCDECoprocMask = 0; + HasBfloat16 = false; // This does not diagnose illegal cases like having both // "+vfpv2" and "+vfpv3" or having "+neon" and "-fp64". @@ -495,6 +501,8 @@ Feature <= "+cdecp7") { unsigned Coproc = Feature.back() - '0'; ARMCDECoprocMask |= (1U << Coproc); + } else if (Feature == "+bf16") { + HasBfloat16 = true; } } diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -4476,6 +4476,11 @@ case NeonTypeFlags::Int16: case NeonTypeFlags::Poly16: return llvm::VectorType::get(CGF->Int16Ty, V1Ty ? 1 : (4 << IsQuad)); + case NeonTypeFlags::Bfloat16: + if (HasLegalHalfType) + return llvm::VectorType::get(CGF->BfloatTy, V1Ty ? 1 : (4 << IsQuad)); + else + return llvm::VectorType::get(CGF->Int16Ty, V1Ty ? 1 : (4 << IsQuad)); case NeonTypeFlags::Float16: if (HasLegalHalfType) return llvm::VectorType::get(CGF->HalfTy, V1Ty ? 1 : (4 << IsQuad)); diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp --- a/clang/lib/CodeGen/CGDebugInfo.cpp +++ b/clang/lib/CodeGen/CGDebugInfo.cpp @@ -761,6 +761,7 @@ case BuiltinType::Float: case BuiltinType::LongDouble: case BuiltinType::Float16: + case BuiltinType::Bfloat16: case BuiltinType::Float128: case BuiltinType::Double: // FIXME: For targets where long double and __float128 have the same size, diff --git a/clang/lib/CodeGen/CodeGenTypeCache.h b/clang/lib/CodeGen/CodeGenTypeCache.h --- a/clang/lib/CodeGen/CodeGenTypeCache.h +++ b/clang/lib/CodeGen/CodeGenTypeCache.h @@ -35,8 +35,8 @@ /// i8, i16, i32, and i64 llvm::IntegerType *Int8Ty, *Int16Ty, *Int32Ty, *Int64Ty; - /// float, double - llvm::Type *HalfTy, *FloatTy, *DoubleTy; + /// half, bfloat, float, double + llvm::Type *HalfTy, *BfloatTy, *FloatTy, *DoubleTy; /// int llvm::IntegerType *IntTy; diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -295,6 +295,12 @@ else return llvm::Type::getInt16Ty(VMContext); } + if (&format == &llvm::APFloat::Bfloat()) { + if (UseNativeHalf) + return llvm::Type::getBfloatTy(VMContext); + else + return llvm::Type::getInt16Ty(VMContext); + } if (&format == &llvm::APFloat::IEEEsingle()) return llvm::Type::getFloatTy(VMContext); if (&format == &llvm::APFloat::IEEEdouble()) @@ -486,6 +492,7 @@ /* UseNativeHalf = */ true); break; + case BuiltinType::Bfloat16: case BuiltinType::Half: // Half FP can either be storage-only (lowered to i16) or native. ResultType = getTypeForFormat( diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -3027,6 +3027,7 @@ case BuiltinType::SatUShortFract: case BuiltinType::SatUFract: case BuiltinType::SatULongFract: + case BuiltinType::Bfloat16: return false; case BuiltinType::Dependent: diff --git a/clang/lib/CodeGen/TargetInfo.cpp b/clang/lib/CodeGen/TargetInfo.cpp --- a/clang/lib/CodeGen/TargetInfo.cpp +++ b/clang/lib/CodeGen/TargetInfo.cpp @@ -5956,10 +5956,10 @@ if (isIllegalVectorType(Ty)) return coerceIllegalVector(Ty); - // _Float16 and __fp16 get passed as if it were an int or float, but with - // the top 16 bits unspecified. This is not done for OpenCL as it handles the - // half type natively, and does not need to interwork with AAPCS code. - if ((Ty->isFloat16Type() || Ty->isHalfType()) && + // _Float16, __fp16 and __bf16 get passed as if it were an int or float, but + // with the top 16 bits unspecified. This is not done for OpenCL as it handles + // the half type natively, and does not need to interwork with AAPCS code. + if ((Ty->isFloat16Type() || Ty->isHalfType() || Ty->isBFloat16Type()) && !getContext().getLangOpts().NativeHalfArgsAndReturns) { llvm::Type *ResType = IsAAPCS_VFP ? llvm::Type::getFloatTy(getVMContext()) : @@ -6159,6 +6159,7 @@ // FP16 vectors should be converted to integer vectors if (!getTarget().hasLegalHalfType() && (VT->getElementType()->isFloat16Type() || + VT->getElementType()->isBFloat16Type() || VT->getElementType()->isHalfType())) return coerceIllegalVector(RetTy); } @@ -6166,7 +6167,7 @@ // _Float16 and __fp16 get returned as if it were an int or float, but with // the top 16 bits unspecified. This is not done for OpenCL as it handles the // half type natively, and does not need to interwork with AAPCS code. - if ((RetTy->isFloat16Type() || RetTy->isHalfType()) && + if ((RetTy->isFloat16Type() || RetTy->isBFloat16Type() || RetTy->isHalfType()) && !getContext().getLangOpts().NativeHalfArgsAndReturns) { llvm::Type *ResType = IsAAPCS_VFP ? llvm::Type::getFloatTy(getVMContext()) : @@ -6256,11 +6257,13 @@ /// isIllegalVector - check whether Ty is an illegal vector type. bool ARMABIInfo::isIllegalVectorType(QualType Ty) const { if (const VectorType *VT = Ty->getAs ()) { - // On targets that don't support FP16, FP16 is expanded into float, and we - // don't want the ABI to depend on whether or not FP16 is supported in - // hardware. Thus return false to coerce FP16 vectors into integer vectors. + // On targets that don't support half, fp16 or bfloat, they are expanded + // into float, and we don't want the ABI to depend on whether or not they + // are supported in hardware. Thus return false to coerce vectors of these + // types into integer vectors. if (!getTarget().hasLegalHalfType() && (VT->getElementType()->isFloat16Type() || + VT->getElementType()->isBFloat16Type() || VT->getElementType()->isHalfType())) return true; if (isAndroid()) { @@ -6313,6 +6316,7 @@ } else { if (const VectorType *VT = Ty->getAs()) return (VT->getElementType()->isFloat16Type() || + VT->getElementType()->isBFloat16Type() || VT->getElementType()->isHalfType()); return false; } diff --git a/clang/lib/Format/FormatToken.cpp b/clang/lib/Format/FormatToken.cpp --- a/clang/lib/Format/FormatToken.cpp +++ b/clang/lib/Format/FormatToken.cpp @@ -50,6 +50,7 @@ case tok::kw_half: case tok::kw_float: case tok::kw_double: + case tok::kw___bf16: case tok::kw__Float16: case tok::kw___float128: case tok::kw_wchar_t: diff --git a/clang/lib/Index/USRGeneration.cpp b/clang/lib/Index/USRGeneration.cpp --- a/clang/lib/Index/USRGeneration.cpp +++ b/clang/lib/Index/USRGeneration.cpp @@ -753,6 +753,7 @@ case BuiltinType::SatUShortFract: case BuiltinType::SatUFract: case BuiltinType::SatULongFract: + case BuiltinType::Bfloat16: IgnoreResults = true; return; case BuiltinType::ObjCId: diff --git a/clang/lib/Parse/ParseDecl.cpp b/clang/lib/Parse/ParseDecl.cpp --- a/clang/lib/Parse/ParseDecl.cpp +++ b/clang/lib/Parse/ParseDecl.cpp @@ -3815,6 +3815,10 @@ isInvalid = DS.SetTypeSpecType(DeclSpec::TST_half, Loc, PrevSpec, DiagID, Policy); break; + case tok::kw___bf16: + isInvalid = DS.SetTypeSpecType(DeclSpec::TST_Bfloat16, Loc, PrevSpec, + DiagID, Policy); + break; case tok::kw_float: isInvalid = DS.SetTypeSpecType(DeclSpec::TST_float, Loc, PrevSpec, DiagID, Policy); @@ -4890,6 +4894,7 @@ case tok::kw_char16_t: case tok::kw_char32_t: case tok::kw_int: + case tok::kw___bf16: case tok::kw_half: case tok::kw_float: case tok::kw_double: @@ -4970,6 +4975,7 @@ case tok::kw_char32_t: case tok::kw_int: case tok::kw_half: + case tok::kw___bf16: case tok::kw_float: case tok::kw_double: case tok::kw__Accum: @@ -5136,6 +5142,7 @@ case tok::kw_int: case tok::kw_half: + case tok::kw___bf16: case tok::kw_float: case tok::kw_double: case tok::kw__Accum: diff --git a/clang/lib/Parse/ParseExpr.cpp b/clang/lib/Parse/ParseExpr.cpp --- a/clang/lib/Parse/ParseExpr.cpp +++ b/clang/lib/Parse/ParseExpr.cpp @@ -1497,6 +1497,7 @@ case tok::kw_half: case tok::kw_float: case tok::kw_double: + case tok::kw___bf16: case tok::kw__Float16: case tok::kw___float128: case tok::kw_void: diff --git a/clang/lib/Parse/ParseExprCXX.cpp b/clang/lib/Parse/ParseExprCXX.cpp --- a/clang/lib/Parse/ParseExprCXX.cpp +++ b/clang/lib/Parse/ParseExprCXX.cpp @@ -2184,6 +2184,9 @@ case tok::kw___int128: DS.SetTypeSpecType(DeclSpec::TST_int128, Loc, PrevSpec, DiagID, Policy); break; + case tok::kw___bf16: + DS.SetTypeSpecType(DeclSpec::TST_Bfloat16, Loc, PrevSpec, DiagID, Policy); + break; case tok::kw_half: DS.SetTypeSpecType(DeclSpec::TST_half, Loc, PrevSpec, DiagID, Policy); break; diff --git a/clang/lib/Parse/ParseTentative.cpp b/clang/lib/Parse/ParseTentative.cpp --- a/clang/lib/Parse/ParseTentative.cpp +++ b/clang/lib/Parse/ParseTentative.cpp @@ -1135,6 +1135,7 @@ case tok::kw_char: case tok::kw_const: case tok::kw_double: + case tok::kw___bf16: case tok::kw__Float16: case tok::kw___float128: case tok::kw_enum: @@ -1724,6 +1725,7 @@ case tok::kw_half: case tok::kw_float: case tok::kw_double: + case tok::kw___bf16: case tok::kw__Float16: case tok::kw___float128: case tok::kw_void: @@ -1818,6 +1820,7 @@ case tok::kw_half: case tok::kw_float: case tok::kw_double: + case tok::kw___bf16: case tok::kw__Float16: case tok::kw___float128: case tok::kw_void: diff --git a/clang/lib/Sema/DeclSpec.cpp b/clang/lib/Sema/DeclSpec.cpp --- a/clang/lib/Sema/DeclSpec.cpp +++ b/clang/lib/Sema/DeclSpec.cpp @@ -367,6 +367,7 @@ case TST_unspecified: case TST_void: case TST_wchar: + case TST_Bfloat16: #define GENERIC_IMAGE_TYPE(ImgType, Id) case TST_##ImgType##_t: #include "clang/Basic/OpenCLImageTypes.def" return false; @@ -564,6 +565,7 @@ case DeclSpec::TST_underlyingType: return "__underlying_type"; case DeclSpec::TST_unknown_anytype: return "__unknown_anytype"; case DeclSpec::TST_atomic: return "_Atomic"; + case DeclSpec::TST_Bfloat16: return "__bf16"; #define GENERIC_IMAGE_TYPE(ImgType, Id) \ case DeclSpec::TST_##ImgType##_t: \ return #ImgType "_t"; diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -1955,6 +1955,9 @@ case NeonTypeFlags::Float64: assert(!shift && "cannot shift float types!"); return (1 << IsQuad) - 1; + case NeonTypeFlags::Bfloat16: + assert(!shift && "cannot shift float types!"); + return (4 << IsQuad) - 1; } llvm_unreachable("Invalid NeonTypeFlag!"); } @@ -1994,6 +1997,8 @@ return Context.FloatTy; case NeonTypeFlags::Float64: return Context.DoubleTy; + case NeonTypeFlags::Bfloat16: + return Context.Bfloat16Ty; } llvm_unreachable("Invalid NeonTypeFlag!"); } diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -138,6 +138,7 @@ case tok::kw_half: case tok::kw_float: case tok::kw_double: + case tok::kw___bf16: case tok::kw__Float16: case tok::kw___float128: case tok::kw_wchar_t: diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp --- a/clang/lib/Sema/SemaOverload.cpp +++ b/clang/lib/Sema/SemaOverload.cpp @@ -1870,6 +1870,8 @@ // FIXME: disable conversions between long double and __float128 if // their representation is different until there is back end support // We of course allow this conversion if long double is really double. + if (FromType == S.Context.Bfloat16Ty || ToType == S.Context.Bfloat16Ty) + return false; if (&S.Context.getFloatTypeSemantics(FromType) != &S.Context.getFloatTypeSemantics(ToType)) { bool Float128AndLongDouble = ((FromType == S.Context.Float128Ty && diff --git a/clang/lib/Sema/SemaTemplateVariadic.cpp b/clang/lib/Sema/SemaTemplateVariadic.cpp --- a/clang/lib/Sema/SemaTemplateVariadic.cpp +++ b/clang/lib/Sema/SemaTemplateVariadic.cpp @@ -880,6 +880,7 @@ case TST_auto: case TST_auto_type: case TST_decltype_auto: + case TST_Bfloat16: #define GENERIC_IMAGE_TYPE(ImgType, Id) case TST_##ImgType##_t: #include "clang/Basic/OpenCLImageTypes.def" case TST_unknown_anytype: diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -1508,6 +1508,12 @@ Result = Context.Float16Ty; break; case DeclSpec::TST_half: Result = Context.HalfTy; break; + case DeclSpec::TST_Bfloat16: + if (!S.Context.getTargetInfo().hasBfloat16Type()) + S.Diag(DS.getTypeSpecTypeLoc(), diag::err_type_unsupported) + << "__bf16"; + Result = Context.Bfloat16Ty; + break; case DeclSpec::TST_float: Result = Context.FloatTy; break; case DeclSpec::TST_double: if (DS.getTypeSpecWidth() == DeclSpec::TSW_long) @@ -7512,7 +7518,8 @@ BTy->getKind() == BuiltinType::LongLong || BTy->getKind() == BuiltinType::ULongLong || BTy->getKind() == BuiltinType::Float || - BTy->getKind() == BuiltinType::Half; + BTy->getKind() == BuiltinType::Half || + BTy->getKind() == BuiltinType::Bfloat16; } /// HandleNeonVectorTypeAttr - The "neon_vector_type" and diff --git a/clang/lib/Serialization/ASTCommon.cpp b/clang/lib/Serialization/ASTCommon.cpp --- a/clang/lib/Serialization/ASTCommon.cpp +++ b/clang/lib/Serialization/ASTCommon.cpp @@ -249,6 +249,9 @@ case BuiltinType::OMPIterator: ID = PREDEF_TYPE_OMP_ITERATOR; break; + case BuiltinType::Bfloat16: + ID = PREDEF_TYPE_BFLOAT16_ID; + break; } return TypeIdx(ID); diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -6793,6 +6793,9 @@ case PREDEF_TYPE_INT128_ID: T = Context.Int128Ty; break; + case PREDEF_TYPE_BFLOAT16_ID: + T = Context.Bfloat16Ty; + break; case PREDEF_TYPE_HALF_ID: T = Context.HalfTy; break; diff --git a/clang/test/CodeGen/arm-mangle-16bit-float.cpp b/clang/test/CodeGen/arm-mangle-16bit-float.cpp new file mode 100644 --- /dev/null +++ b/clang/test/CodeGen/arm-mangle-16bit-float.cpp @@ -0,0 +1,10 @@ +// RUN: %clang_cc1 -triple aarch64-arm-none-eabi -fallow-half-arguments-and-returns -target-feature +bf16 -target-feature +fullfp16 -emit-llvm -o - %s | FileCheck %s --check-prefix=CHECK64 +// RUN: %clang_cc1 -triple arm-arm-none-eabi -fallow-half-arguments-and-returns -target-feature +bf16 -target-feature +fullfp16 -emit-llvm -o - %s | FileCheck %s --check-prefix=CHECK32 + +// CHECK64: define {{.*}}void @_Z3foou6__bf16(bfloat %b) +// CHECK32: define {{.*}}void @_Z3foou6__bf16(i32 %b.coerce) +void foo(__bf16 b) {} + +// CHECK64: define {{.*}}void @_Z3barDh(half %b) +// CHECK32: define {{.*}}void @_Z3barDh(i32 %b.coerce) +void bar(__fp16 b) {} diff --git a/clang/test/Sema/arm-bfloat.cpp b/clang/test/Sema/arm-bfloat.cpp new file mode 100644 --- /dev/null +++ b/clang/test/Sema/arm-bfloat.cpp @@ -0,0 +1,29 @@ +// RUN: %clang_cc1 -fsyntax-only -verify -std=c++11 \ +// RUN: -triple aarch64-arm-none-eabi -target-cpu cortex-a75 \ +// RUN: -target-feature +bf16 -target-feature +neon %s +// RUN: %clang_cc1 -fsyntax-only -verify -std=c++11 \ +// RUN: -triple arm-arm-none-eabi -target-cpu cortex-a53 \ +// RUN: -target-feature +bf16 -target-feature +neon %s + +void test(bool b) { + __bf16 bf16; + + bf16 + bf16; // expected-error {{invalid operands to binary expression ('__bf16' and '__bf16')}} + bf16 - bf16; // expected-error {{invalid operands to binary expression ('__bf16' and '__bf16')}} + bf16 * bf16; // expected-error {{invalid operands to binary expression ('__bf16' and '__bf16')}} + bf16 / bf16; // expected-error {{invalid operands to binary expression ('__bf16' and '__bf16')}} + + __fp16 fp16; + + bf16 + fp16; // expected-error {{invalid operands to binary expression ('__bf16' and '__fp16')}} + fp16 + bf16; // expected-error {{invalid operands to binary expression ('__fp16' and '__bf16')}} + bf16 - fp16; // expected-error {{invalid operands to binary expression ('__bf16' and '__fp16')}} + fp16 - bf16; // expected-error {{invalid operands to binary expression ('__fp16' and '__bf16')}} + bf16 * fp16; // expected-error {{invalid operands to binary expression ('__bf16' and '__fp16')}} + fp16 * bf16; // expected-error {{invalid operands to binary expression ('__fp16' and '__bf16')}} + bf16 / fp16; // expected-error {{invalid operands to binary expression ('__bf16' and '__fp16')}} + fp16 / bf16; // expected-error {{invalid operands to binary expression ('__fp16' and '__bf16')}} + bf16 = fp16; // expected-error {{assigning to '__bf16' from incompatible type '__fp16'}} + fp16 = bf16; // expected-error {{assigning to '__fp16' from incompatible type '__bf16'}} + bf16 + (b ? fp16 : bf16); // expected-error {{incompatible operand types ('__fp16' and '__bf16')}} +} diff --git a/clang/tools/libclang/CXType.cpp b/clang/tools/libclang/CXType.cpp --- a/clang/tools/libclang/CXType.cpp +++ b/clang/tools/libclang/CXType.cpp @@ -607,6 +607,7 @@ TKIND(Elaborated); TKIND(Pipe); TKIND(Attributed); + TKIND(Bfloat16); #define IMAGE_TYPE(ImgType, Id, SingletonId, Access, Suffix) TKIND(Id); #include "clang/Basic/OpenCLImageTypes.def" #undef IMAGE_TYPE