diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -14936,9 +14936,9 @@ if (auto *ComplexTy = OrigType->getAs()) Type = ComplexTy->getElementType(); if (Type->isRealFloatingType()) { - llvm::APFloat InitValue = - llvm::APFloat::getAllOnesValue(Context.getTypeSize(Type), - /*isIEEE=*/true); + llvm::APFloat InitValue = llvm::APFloat::getAllOnesValue( + Context.getFloatTypeSemantics(Type), + Context.getTypeSize(Type)); Init = FloatingLiteral::Create(Context, InitValue, /*isexact=*/true, Type, ELoc); } else if (Type->isScalarType()) { diff --git a/llvm/docs/BitCodeFormat.rst b/llvm/docs/BitCodeFormat.rst --- a/llvm/docs/BitCodeFormat.rst +++ b/llvm/docs/BitCodeFormat.rst @@ -1107,6 +1107,14 @@ The ``HALF`` record (code 10) adds a ``half`` (16-bit floating point) type to the type table. +TYPE_CODE_BFLOAT Record +^^^^^^^^^^^^^^^^^^^^^ + +``[BFLOAT]`` + +The ``BFLOAT`` record (code 23) adds a ``bfloat`` (16-bit brain floating point) +type to the type table. + TYPE_CODE_FLOAT Record ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -2963,6 +2963,12 @@ * - ``half`` - 16-bit floating-point value + * - ``bfloat`` + - 16-bit "brain" floating-point value (7-bit significand). Provides the + same number of exponent bits as ``float``, so that it matches its dynamic + range, but with greatly reduced precision. Used in Intel's AVX-512 BF16 + extensions and Arm's ARMv8.6-A extensions, among others. + * - ``float`` - 32-bit floating-point value @@ -2970,7 +2976,7 @@ - 64-bit floating-point value * - ``fp128`` - - 128-bit floating-point value (112-bit mantissa) + - 128-bit floating-point value (112-bit significand) * - ``x86_fp80`` - 80-bit floating-point value (X87) @@ -3303,20 +3309,20 @@ values are represented in their IEEE hexadecimal format so that assembly and disassembly do not cause any bits to change in the constants. -When using the hexadecimal form, constants of types half, float, and -double are represented using the 16-digit form shown above (which -matches the IEEE754 representation for double); half and float values -must, however, be exactly representable as IEEE 754 half and single -precision, respectively. Hexadecimal format is always used for long -double, and there are three forms of long double. The 80-bit format used -by x86 is represented as ``0xK`` followed by 20 hexadecimal digits. The -128-bit format used by PowerPC (two adjacent doubles) is represented by -``0xM`` followed by 32 hexadecimal digits. The IEEE 128-bit format is -represented by ``0xL`` followed by 32 hexadecimal digits. Long doubles -will only work if they match the long double format on your target. -The IEEE 16-bit format (half precision) is represented by ``0xH`` -followed by 4 hexadecimal digits. All hexadecimal formats are big-endian -(sign bit at the left). +When using the hexadecimal form, constants of types bfloat, half, float, and +double are represented using the 16-digit form shown above (which matches the +IEEE754 representation for double); bfloat, half and float values must, however, +be exactly representable as bfloat, IEEE 754 half, and IEEE 754 single +precision respectively. Hexadecimal format is always used for long double, and +there are three forms of long double. The 80-bit format used by x86 is +represented as ``0xK`` followed by 20 hexadecimal digits. The 128-bit format +used by PowerPC (two adjacent doubles) is represented by ``0xM`` followed by 32 +hexadecimal digits. The IEEE 128-bit format is represented by ``0xL`` followed +by 32 hexadecimal digits. Long doubles will only work if they match the long +double format on your target. The IEEE 16-bit format (half precision) is +represented by ``0xH`` followed by 4 hexadecimal digits. The bfloat 16-bit +format is represented by ``0xR`` followed by 4 hexadecimal digits. All +hexadecimal formats are big-endian (sign bit at the left). There are no constants of type x86_mmx. diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h --- a/llvm/include/llvm-c/Core.h +++ b/llvm/include/llvm-c/Core.h @@ -146,6 +146,7 @@ typedef enum { LLVMVoidTypeKind, /**< type with no size */ LLVMHalfTypeKind, /**< 16 bit floating point type */ + LLVMBFloatTypeKind, /**< 16 bit brain floating point type */ LLVMFloatTypeKind, /**< 32 bit floating point type */ LLVMDoubleTypeKind, /**< 64 bit floating point type */ LLVMX86_FP80TypeKind, /**< 80 bit floating point type (X87) */ @@ -1164,6 +1165,11 @@ LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C); /** + * Obtain a 16-bit brain floating point type from a context. + */ +LLVMTypeRef LLVMBFloatTypeInContext(LLVMContextRef C); + +/** * Obtain a 32-bit floating point type from a context. */ LLVMTypeRef LLVMFloatTypeInContext(LLVMContextRef C); @@ -1195,6 +1201,7 @@ * These map to the functions in this group of the same name. */ LLVMTypeRef LLVMHalfType(void); +LLVMTypeRef LLVMBFloatType(void); LLVMTypeRef LLVMFloatType(void); LLVMTypeRef LLVMDoubleType(void); LLVMTypeRef LLVMX86FP80Type(void); diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h --- a/llvm/include/llvm/ADT/APFloat.h +++ b/llvm/include/llvm/ADT/APFloat.h @@ -151,6 +151,7 @@ /// @{ enum Semantics { S_IEEEhalf, + S_BFloat, S_IEEEsingle, S_IEEEdouble, S_x87DoubleExtended, @@ -162,6 +163,7 @@ static Semantics SemanticsToEnum(const llvm::fltSemantics &Sem); static const fltSemantics &IEEEhalf() LLVM_READNONE; + static const fltSemantics &BFloat() LLVM_READNONE; static const fltSemantics &IEEEsingle() LLVM_READNONE; static const fltSemantics &IEEEdouble() LLVM_READNONE; static const fltSemantics &IEEEquad() LLVM_READNONE; @@ -541,6 +543,7 @@ /// @} APInt convertHalfAPFloatToAPInt() const; + APInt convertBFloatAPFloatToAPInt() const; APInt convertFloatAPFloatToAPInt() const; APInt convertDoubleAPFloatToAPInt() const; APInt convertQuadrupleAPFloatToAPInt() const; @@ -548,6 +551,7 @@ APInt convertPPCDoubleDoubleAPFloatToAPInt() const; void initFromAPInt(const fltSemantics *Sem, const APInt &api); void initFromHalfAPInt(const APInt &api); + void initFromBFloatAPInt(const APInt &api); void initFromFloatAPInt(const APInt &api); void initFromDoubleAPInt(const APInt &api); void initFromQuadrupleAPInt(const APInt &api); @@ -954,9 +958,10 @@ /// Returns a float which is bitcasted from an all one value int. /// + /// \param Semantics - type float semantics /// \param BitWidth - Select float type - /// \param isIEEE - If 128 bit number, select between PPC and IEEE - static APFloat getAllOnesValue(unsigned BitWidth, bool isIEEE = false); + static APFloat getAllOnesValue(const fltSemantics &Semantics, + unsigned BitWidth); /// Used to insert APFloat objects, or objects that contain APFloat objects, /// into FoldingSets. diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h --- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h +++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h @@ -166,7 +166,9 @@ TYPE_CODE_FUNCTION = 21, // FUNCTION: [vararg, retty, paramty x N] - TYPE_CODE_TOKEN = 22 // TOKEN + TYPE_CODE_TOKEN = 22, // TOKEN + + TYPE_CODE_BFLOAT = 23 // BRAIN FLOATING POINT }; enum OperandBundleTagCode { diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -721,14 +721,15 @@ return getImpl(Data, Ty); } - /// getFP() constructors - Return a constant with array type with an element - /// count and element type of float with precision matching the number of - /// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits, - /// double for 64bits) Note that this can return a ConstantAggregateZero - /// object. - static Constant *getFP(LLVMContext &Context, ArrayRef Elts); - static Constant *getFP(LLVMContext &Context, ArrayRef Elts); - static Constant *getFP(LLVMContext &Context, ArrayRef Elts); + /// getFP() constructors - Return a constant of array type with a float + /// element type taken from argument `ElementType', and count taken from + /// argument `Elts'. The amount of bits of the contained type must match the + /// number of bits of the type contained in the passed in ArrayRef. + /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note + /// that this can return a ConstantAggregateZero object. + static Constant *getFP(Type *ElementType, ArrayRef Elts); + static Constant *getFP(Type *ElementType, ArrayRef Elts); + static Constant *getFP(Type *ElementType, ArrayRef Elts); /// This method constructs a CDS and initializes it with a text string. /// The default behavior (AddNull==true) causes a null terminator to @@ -780,14 +781,15 @@ static Constant *get(LLVMContext &Context, ArrayRef Elts); static Constant *get(LLVMContext &Context, ArrayRef Elts); - /// getFP() constructors - Return a constant with vector type with an element - /// count and element type of float with the precision matching the number of - /// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits, - /// double for 64bits) Note that this can return a ConstantAggregateZero - /// object. - static Constant *getFP(LLVMContext &Context, ArrayRef Elts); - static Constant *getFP(LLVMContext &Context, ArrayRef Elts); - static Constant *getFP(LLVMContext &Context, ArrayRef Elts); + /// getFP() constructors - Return a constant of vector type with a float + /// element type taken from argument `ElementType', and count taken from + /// argument `Elts'. The amount of bits of the contained type must match the + /// number of bits of the type contained in the passed in ArrayRef. + /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note + /// that this can return a ConstantAggregateZero object. + static Constant *getFP(Type *ElementType, ArrayRef Elts); + static Constant *getFP(Type *ElementType, ArrayRef Elts); + static Constant *getFP(Type *ElementType, ArrayRef Elts); /// Return a ConstantVector with the specified constant in each element. /// The specified constant has to be a of a compatible type (i8/i16/ diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h --- a/llvm/include/llvm/IR/DataLayout.h +++ b/llvm/include/llvm/IR/DataLayout.h @@ -651,6 +651,7 @@ case Type::IntegerTyID: return TypeSize::Fixed(Ty->getIntegerBitWidth()); case Type::HalfTyID: + case Type::BFloatTyID: return TypeSize::Fixed(16); case Type::FloatTyID: return TypeSize::Fixed(32); diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h --- a/llvm/include/llvm/IR/IRBuilder.h +++ b/llvm/include/llvm/IR/IRBuilder.h @@ -477,6 +477,11 @@ return Type::getHalfTy(Context); } + /// Fetch the type representing a 16-bit brain floating point value. + Type *getBFloatTy() { + return Type::getBFloatTy(Context); + } + /// Fetch the type representing a 32-bit floating point value. Type *getFloatTy() { return Type::getFloatTy(Context); diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h --- a/llvm/include/llvm/IR/Type.h +++ b/llvm/include/llvm/IR/Type.h @@ -54,27 +54,28 @@ /// enum TypeID { // PrimitiveTypes - make sure LastPrimitiveTyID stays up to date. - VoidTyID = 0, ///< 0: type with no size - HalfTyID, ///< 1: 16-bit floating point type - FloatTyID, ///< 2: 32-bit floating point type - DoubleTyID, ///< 3: 64-bit floating point type - X86_FP80TyID, ///< 4: 80-bit floating point type (X87) - FP128TyID, ///< 5: 128-bit floating point type (112-bit mantissa) - PPC_FP128TyID, ///< 6: 128-bit floating point type (two 64-bits, PowerPC) - LabelTyID, ///< 7: Labels - MetadataTyID, ///< 8: Metadata - X86_MMXTyID, ///< 9: MMX vectors (64 bits, X86 specific) - TokenTyID, ///< 10: Tokens + VoidTyID = 0, ///< 0: type with no size + HalfTyID, ///< 1: 16-bit floating point type + BFloatTyID, ///< 2: 16-bit floating point type (7-bit significand) + FloatTyID, ///< 3: 32-bit floating point type + DoubleTyID, ///< 4: 64-bit floating point type + X86_FP80TyID, ///< 5: 80-bit floating point type (X87) + FP128TyID, ///< 6: 128-bit floating point type (112-bit significand) + PPC_FP128TyID, ///< 7: 128-bit floating point type (two 64-bits, PowerPC) + LabelTyID, ///< 8: Labels + MetadataTyID, ///< 9: Metadata + X86_MMXTyID, ///< 10: MMX vectors (64 bits, X86 specific) + TokenTyID, ///< 11: Tokens // Derived types... see DerivedTypes.h file. // Make sure FirstDerivedTyID stays up to date! - IntegerTyID, ///< 11: Arbitrary bit width integers - FunctionTyID, ///< 12: Functions - StructTyID, ///< 13: Structures - ArrayTyID, ///< 14: Arrays - PointerTyID, ///< 15: Pointers - FixedVectorTyID, ///< 16: Fixed width SIMD vector type - ScalableVectorTyID ///< 17: Scalable SIMD vector type + IntegerTyID, ///< 12: Arbitrary bit width integers + FunctionTyID, ///< 13: Functions + StructTyID, ///< 14: Structures + ArrayTyID, ///< 15: Arrays + PointerTyID, ///< 16: Pointers + FixedVectorTyID, ///< 17: Fixed width SIMD vector type + ScalableVectorTyID ///< 18: Scalable SIMD vector type }; private: @@ -140,6 +141,9 @@ /// Return true if this is 'half', a 16-bit IEEE fp type. bool isHalfTy() const { return getTypeID() == HalfTyID; } + /// Return true if this is 'bfloat', a 16-bit bfloat type. + bool isBFloatTy() const { return getTypeID() == BFloatTyID; } + /// Return true if this is 'float', a 32-bit IEEE fp type. bool isFloatTy() const { return getTypeID() == FloatTyID; } @@ -157,8 +161,8 @@ /// Return true if this is one of the six floating-point types bool isFloatingPointTy() const { - return getTypeID() == HalfTyID || getTypeID() == FloatTyID || - getTypeID() == DoubleTyID || + return getTypeID() == HalfTyID || getTypeID() == BFloatTyID || + getTypeID() == FloatTyID || getTypeID() == DoubleTyID || getTypeID() == X86_FP80TyID || getTypeID() == FP128TyID || getTypeID() == PPC_FP128TyID; } @@ -166,6 +170,7 @@ const fltSemantics &getFltSemantics() const { switch (getTypeID()) { case HalfTyID: return APFloat::IEEEhalf(); + case BFloatTyID: return APFloat::BFloat(); case FloatTyID: return APFloat::IEEEsingle(); case DoubleTyID: return APFloat::IEEEdouble(); case X86_FP80TyID: return APFloat::x87DoubleExtended(); @@ -387,6 +392,7 @@ static Type *getVoidTy(LLVMContext &C); static Type *getLabelTy(LLVMContext &C); static Type *getHalfTy(LLVMContext &C); + static Type *getBFloatTy(LLVMContext &C); static Type *getFloatTy(LLVMContext &C); static Type *getDoubleTy(LLVMContext &C); static Type *getMetadataTy(LLVMContext &C); @@ -422,6 +428,7 @@ // types as pointee. // static PointerType *getHalfPtrTy(LLVMContext &C, unsigned AS = 0); + static PointerType *getBFloatPtrTy(LLVMContext &C, unsigned AS = 0); static PointerType *getFloatPtrTy(LLVMContext &C, unsigned AS = 0); static PointerType *getDoublePtrTy(LLVMContext &C, unsigned AS = 0); static PointerType *getX86_FP80PtrTy(LLVMContext &C, unsigned AS = 0); diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp --- a/llvm/lib/AsmParser/LLLexer.cpp +++ b/llvm/lib/AsmParser/LLLexer.cpp @@ -820,6 +820,7 @@ TYPEKEYWORD("void", Type::getVoidTy(Context)); TYPEKEYWORD("half", Type::getHalfTy(Context)); + TYPEKEYWORD("bfloat", Type::getBFloatTy(Context)); TYPEKEYWORD("float", Type::getFloatTy(Context)); TYPEKEYWORD("double", Type::getDoubleTy(Context)); TYPEKEYWORD("x86_fp80", Type::getX86_FP80Ty(Context)); @@ -985,11 +986,13 @@ /// HexFP128Constant 0xL[0-9A-Fa-f]+ /// HexPPC128Constant 0xM[0-9A-Fa-f]+ /// HexHalfConstant 0xH[0-9A-Fa-f]+ +/// HexBFloatConstant 0xR[0-9A-Fa-f]+ lltok::Kind LLLexer::Lex0x() { CurPtr = TokStart + 2; char Kind; - if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H') { + if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H' || + CurPtr[0] == 'R') { Kind = *CurPtr++; } else { Kind = 'J'; @@ -1007,7 +1010,7 @@ if (Kind == 'J') { // HexFPConstant - Floating point constant represented in IEEE format as a // hexadecimal number for when exponential notation is not precise enough. - // Half, Float, and double only. + // Half, BFloat, Float, and double only. APFloatVal = APFloat(APFloat::IEEEdouble(), APInt(64, HexIntToVal(TokStart + 2, CurPtr))); return lltok::APFloat; @@ -1035,6 +1038,11 @@ APFloatVal = APFloat(APFloat::IEEEhalf(), APInt(16,HexIntToVal(TokStart+3, CurPtr))); return lltok::APFloat; + case 'R': + // Brain floating point + APFloatVal = APFloat(APFloat::BFloat(), + APInt(16, HexIntToVal(TokStart + 3, CurPtr))); + return lltok::APFloat; } } diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -5247,13 +5247,16 @@ !ConstantFP::isValueValidForType(Ty, ID.APFloatVal)) return Error(ID.Loc, "floating point constant invalid for type"); - // The lexer has no type info, so builds all half, float, and double FP - // constants as double. Fix this here. Long double does not need this. + // The lexer has no type info, so builds all half, bfloat, float, and double + // FP constants as double. Fix this here. Long double does not need this. if (&ID.APFloatVal.getSemantics() == &APFloat::IEEEdouble()) { bool Ignored; if (Ty->isHalfTy()) ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored); + else if (Ty->isBFloatTy()) + ID.APFloatVal.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, + &Ignored); else if (Ty->isFloatTy()) ID.APFloatVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &Ignored); diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -1720,6 +1720,9 @@ case bitc::TYPE_CODE_HALF: // HALF ResultTy = Type::getHalfTy(Context); break; + case bitc::TYPE_CODE_BFLOAT: // BFLOAT + ResultTy = Type::getBFloatTy(Context); + break; case bitc::TYPE_CODE_FLOAT: // FLOAT ResultTy = Type::getFloatTy(Context); break; @@ -2429,6 +2432,9 @@ if (CurTy->isHalfTy()) V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(), APInt(16, (uint16_t)Record[0]))); + else if (CurTy->isBFloatTy()) + V = ConstantFP::get(Context, APFloat(APFloat::BFloat(), + APInt(16, (uint32_t)Record[0]))); else if (CurTy->isFloatTy()) V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(), APInt(32, (uint32_t)Record[0]))); @@ -2526,21 +2532,27 @@ } else if (EltTy->isHalfTy()) { SmallVector Elts(Record.begin(), Record.end()); if (isa(CurTy)) - V = ConstantDataVector::getFP(Context, Elts); + V = ConstantDataVector::getFP(EltTy, Elts); + else + V = ConstantDataArray::getFP(EltTy, Elts); + } else if (EltTy->isBFloatTy()) { + SmallVector Elts(Record.begin(), Record.end()); + if (isa(CurTy)) + V = ConstantDataVector::getFP(EltTy, Elts); else - V = ConstantDataArray::getFP(Context, Elts); + V = ConstantDataArray::getFP(EltTy, Elts); } else if (EltTy->isFloatTy()) { SmallVector Elts(Record.begin(), Record.end()); if (isa(CurTy)) - V = ConstantDataVector::getFP(Context, Elts); + V = ConstantDataVector::getFP(EltTy, Elts); else - V = ConstantDataArray::getFP(Context, Elts); + V = ConstantDataArray::getFP(EltTy, Elts); } else if (EltTy->isDoubleTy()) { SmallVector Elts(Record.begin(), Record.end()); if (isa(CurTy)) - V = ConstantDataVector::getFP(Context, Elts); + V = ConstantDataVector::getFP(EltTy, Elts); else - V = ConstantDataArray::getFP(Context, Elts); + V = ConstantDataArray::getFP(EltTy, Elts); } else { return error("Invalid type for value"); } diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -881,6 +881,7 @@ switch (T->getTypeID()) { case Type::VoidTyID: Code = bitc::TYPE_CODE_VOID; break; case Type::HalfTyID: Code = bitc::TYPE_CODE_HALF; break; + case Type::BFloatTyID: Code = bitc::TYPE_CODE_BFLOAT; break; case Type::FloatTyID: Code = bitc::TYPE_CODE_FLOAT; break; case Type::DoubleTyID: Code = bitc::TYPE_CODE_DOUBLE; break; case Type::X86_FP80TyID: Code = bitc::TYPE_CODE_X86_FP80; break; @@ -2387,7 +2388,8 @@ } else if (const ConstantFP *CFP = dyn_cast(C)) { Code = bitc::CST_CODE_FLOAT; Type *Ty = CFP->getType(); - if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) { + if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() || + Ty->isDoubleTy()) { Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue()); } else if (Ty->isX86_FP80Ty()) { // api needed to prevent premature destruction diff --git a/llvm/lib/CodeGen/MIRParser/MILexer.cpp b/llvm/lib/CodeGen/MIRParser/MILexer.cpp --- a/llvm/lib/CodeGen/MIRParser/MILexer.cpp +++ b/llvm/lib/CodeGen/MIRParser/MILexer.cpp @@ -534,7 +534,7 @@ } static bool isValidHexFloatingPointPrefix(char C) { - return C == 'H' || C == 'K' || C == 'L' || C == 'M'; + return C == 'H' || C == 'K' || C == 'L' || C == 'M' || C == 'R'; } static Cursor lexFloatingPointLiteral(Cursor Range, Cursor C, MIToken &Token) { diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -588,6 +588,7 @@ switch (Ty->getTypeID()) { case Type::VoidTyID: OS << "void"; return; case Type::HalfTyID: OS << "half"; return; + case Type::BFloatTyID: OS << "bfloat"; return; case Type::FloatTyID: OS << "float"; return; case Type::DoubleTyID: OS << "double"; return; case Type::X86_FP80TyID: OS << "x86_fp80"; return; @@ -1379,7 +1380,7 @@ return; } - // Either half, or some form of long double. + // Either half, bfloat or some form of long double. // These appear as a magic letter identifying the type, then a // fixed number of hex digits. Out << "0x"; @@ -1407,6 +1408,10 @@ Out << 'H'; Out << format_hex_no_prefix(API.getZExtValue(), 4, /*Upper=*/true); + } else if (&APF.getSemantics() == &APFloat::BFloat()) { + Out << 'R'; + Out << format_hex_no_prefix(API.getZExtValue(), 4, + /*Upper=*/true); } else llvm_unreachable("Unsupported floating point type"); return; diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -332,6 +332,9 @@ case Type::HalfTyID: return ConstantFP::get(Ty->getContext(), APFloat::getZero(APFloat::IEEEhalf())); + case Type::BFloatTyID: + return ConstantFP::get(Ty->getContext(), + APFloat::getZero(APFloat::BFloat())); case Type::FloatTyID: return ConstantFP::get(Ty->getContext(), APFloat::getZero(APFloat::IEEEsingle())); @@ -386,8 +389,8 @@ APInt::getAllOnesValue(ITy->getBitWidth())); if (Ty->isFloatingPointTy()) { - APFloat FL = APFloat::getAllOnesValue(Ty->getPrimitiveSizeInBits(), - !Ty->isPPC_FP128Ty()); + APFloat FL = APFloat::getAllOnesValue(Ty->getFltSemantics(), + Ty->getPrimitiveSizeInBits()); return ConstantFP::get(Ty->getContext(), FL); } @@ -763,6 +766,8 @@ static const fltSemantics *TypeToFloatSemantics(Type *Ty) { if (Ty->isHalfTy()) return &APFloat::IEEEhalf(); + if (Ty->isBFloatTy()) + return &APFloat::BFloat(); if (Ty->isFloatTy()) return &APFloat::IEEEsingle(); if (Ty->isDoubleTy()) @@ -880,6 +885,8 @@ Type *Ty; if (&V.getSemantics() == &APFloat::IEEEhalf()) Ty = Type::getHalfTy(Context); + else if (&V.getSemantics() == &APFloat::BFloat()) + Ty = Type::getBFloatTy(Context); else if (&V.getSemantics() == &APFloat::IEEEsingle()) Ty = Type::getFloatTy(Context); else if (&V.getSemantics() == &APFloat::IEEEdouble()) @@ -1029,7 +1036,7 @@ Elts.push_back(CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); else return nullptr; - return SequentialTy::getFP(V[0]->getContext(), Elts); + return SequentialTy::getFP(V[0]->getType(), Elts); } template @@ -1048,7 +1055,7 @@ else if (CI->getType()->isIntegerTy(64)) return getIntSequenceIfElementsMatch(V); } else if (ConstantFP *CFP = dyn_cast(C)) { - if (CFP->getType()->isHalfTy()) + if (CFP->getType()->isHalfTy() || CFP->getType()->isBFloatTy()) return getFPSequenceIfElementsMatch(V); else if (CFP->getType()->isFloatTy()) return getFPSequenceIfElementsMatch(V); @@ -1421,6 +1428,12 @@ Val2.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &losesInfo); return !losesInfo; } + case Type::BFloatTyID: { + if (&Val2.getSemantics() == &APFloat::BFloat()) + return true; + Val2.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &losesInfo); + return !losesInfo; + } case Type::FloatTyID: { if (&Val2.getSemantics() == &APFloat::IEEEsingle()) return true; @@ -1429,6 +1442,7 @@ } case Type::DoubleTyID: { if (&Val2.getSemantics() == &APFloat::IEEEhalf() || + &Val2.getSemantics() == &APFloat::BFloat() || &Val2.getSemantics() == &APFloat::IEEEsingle() || &Val2.getSemantics() == &APFloat::IEEEdouble()) return true; @@ -1437,16 +1451,19 @@ } case Type::X86_FP80TyID: return &Val2.getSemantics() == &APFloat::IEEEhalf() || + &Val2.getSemantics() == &APFloat::BFloat() || &Val2.getSemantics() == &APFloat::IEEEsingle() || &Val2.getSemantics() == &APFloat::IEEEdouble() || &Val2.getSemantics() == &APFloat::x87DoubleExtended(); case Type::FP128TyID: return &Val2.getSemantics() == &APFloat::IEEEhalf() || + &Val2.getSemantics() == &APFloat::BFloat() || &Val2.getSemantics() == &APFloat::IEEEsingle() || &Val2.getSemantics() == &APFloat::IEEEdouble() || &Val2.getSemantics() == &APFloat::IEEEquad(); case Type::PPC_FP128TyID: return &Val2.getSemantics() == &APFloat::IEEEhalf() || + &Val2.getSemantics() == &APFloat::BFloat() || &Val2.getSemantics() == &APFloat::IEEEsingle() || &Val2.getSemantics() == &APFloat::IEEEdouble() || &Val2.getSemantics() == &APFloat::PPCDoubleDouble(); @@ -2562,7 +2579,8 @@ } bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) { - if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) return true; + if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy()) + return true; if (auto *IT = dyn_cast(Ty)) { switch (IT->getBitWidth()) { case 8: @@ -2680,26 +2698,29 @@ Next = nullptr; } -/// getFP() constructors - Return a constant with array type with an element -/// count and element type of float with precision matching the number of -/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits, -/// double for 64bits) Note that this can return a ConstantAggregateZero -/// object. -Constant *ConstantDataArray::getFP(LLVMContext &Context, - ArrayRef Elts) { - Type *Ty = ArrayType::get(Type::getHalfTy(Context), Elts.size()); +/// getFP() constructors - Return a constant of array type with a float +/// element type taken from argument `ElementType', and count taken from +/// argument `Elts'. The amount of bits of the contained type must match the +/// number of bits of the type contained in the passed in ArrayRef. +/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note +/// that this can return a ConstantAggregateZero object. +Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef Elts) { + assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) && + "Element type is not a 16-bit float type"); + Type *Ty = ArrayType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 2), Ty); } -Constant *ConstantDataArray::getFP(LLVMContext &Context, - ArrayRef Elts) { - Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size()); +Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef Elts) { + assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type"); + Type *Ty = ArrayType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 4), Ty); } -Constant *ConstantDataArray::getFP(LLVMContext &Context, - ArrayRef Elts) { - Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size()); +Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef Elts) { + assert(ElementType->isDoubleTy() && + "Element type is not a 64-bit float type"); + Type *Ty = ArrayType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 8), Ty); } @@ -2751,26 +2772,32 @@ return getImpl(StringRef(Data, Elts.size() * 8), Ty); } -/// getFP() constructors - Return a constant with vector type with an element -/// count and element type of float with the precision matching the number of -/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits, -/// double for 64bits) Note that this can return a ConstantAggregateZero -/// object. -Constant *ConstantDataVector::getFP(LLVMContext &Context, +/// getFP() constructors - Return a constant of vector type with a float +/// element type taken from argument `ElementType', and count taken from +/// argument `Elts'. The amount of bits of the contained type must match the +/// number of bits of the type contained in the passed in ArrayRef. +/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note +/// that this can return a ConstantAggregateZero object. +Constant *ConstantDataVector::getFP(Type *ElementType, ArrayRef Elts) { - Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size()); + assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) && + "Element type is not a 16-bit float type"); + Type *Ty = VectorType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 2), Ty); } -Constant *ConstantDataVector::getFP(LLVMContext &Context, +Constant *ConstantDataVector::getFP(Type *ElementType, ArrayRef Elts) { - Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size()); + assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type"); + Type *Ty = VectorType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 4), Ty); } -Constant *ConstantDataVector::getFP(LLVMContext &Context, +Constant *ConstantDataVector::getFP(Type *ElementType, ArrayRef Elts) { - Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size()); + assert(ElementType->isDoubleTy() && + "Element type is not a 64-bit float type"); + Type *Ty = VectorType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 8), Ty); } @@ -2800,17 +2827,22 @@ if (CFP->getType()->isHalfTy()) { SmallVector Elts( NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); - return getFP(V->getContext(), Elts); + return getFP(V->getType(), Elts); + } + if (CFP->getType()->isBFloatTy()) { + SmallVector Elts( + NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); + return getFP(V->getType(), Elts); } if (CFP->getType()->isFloatTy()) { SmallVector Elts( NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); - return getFP(V->getContext(), Elts); + return getFP(V->getType(), Elts); } if (CFP->getType()->isDoubleTy()) { SmallVector Elts( NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); - return getFP(V->getContext(), Elts); + return getFP(V->getType(), Elts); } } return ConstantVector::getSplat({NumElts, false}, V); @@ -2875,6 +2907,10 @@ auto EltVal = *reinterpret_cast(EltPtr); return APFloat(APFloat::IEEEhalf(), APInt(16, EltVal)); } + case Type::BFloatTyID: { + auto EltVal = *reinterpret_cast(EltPtr); + return APFloat(APFloat::BFloat(), APInt(16, EltVal)); + } case Type::FloatTyID: { auto EltVal = *reinterpret_cast(EltPtr); return APFloat(APFloat::IEEEsingle(), APInt(32, EltVal)); @@ -2899,8 +2935,8 @@ } Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const { - if (getElementType()->isHalfTy() || getElementType()->isFloatTy() || - getElementType()->isDoubleTy()) + if (getElementType()->isHalfTy() || getElementType()->isBFloatTy() || + getElementType()->isFloatTy() || getElementType()->isDoubleTy()) return ConstantFP::get(getContext(), getElementAsAPFloat(Elt)); return ConstantInt::get(getElementType(), getElementAsInteger(Elt)); diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp --- a/llvm/lib/IR/Core.cpp +++ b/llvm/lib/IR/Core.cpp @@ -477,6 +477,8 @@ return LLVMVoidTypeKind; case Type::HalfTyID: return LLVMHalfTypeKind; + case Type::BFloatTyID: + return LLVMBFloatTypeKind; case Type::FloatTyID: return LLVMFloatTypeKind; case Type::DoubleTyID: @@ -595,6 +597,9 @@ LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C) { return (LLVMTypeRef) Type::getHalfTy(*unwrap(C)); } +LLVMTypeRef LLVMBFloatTypeInContext(LLVMContextRef C) { + return (LLVMTypeRef) Type::getBFloatTy(*unwrap(C)); +} LLVMTypeRef LLVMFloatTypeInContext(LLVMContextRef C) { return (LLVMTypeRef) Type::getFloatTy(*unwrap(C)); } @@ -617,6 +622,9 @@ LLVMTypeRef LLVMHalfType(void) { return LLVMHalfTypeInContext(LLVMGetGlobalContext()); } +LLVMTypeRef LLVMBFloatType(void) { + return LLVMBFloatTypeInContext(LLVMGetGlobalContext()); +} LLVMTypeRef LLVMFloatType(void) { return LLVMFloatTypeInContext(LLVMGetGlobalContext()); } diff --git a/llvm/lib/IR/DataLayout.cpp b/llvm/lib/IR/DataLayout.cpp --- a/llvm/lib/IR/DataLayout.cpp +++ b/llvm/lib/IR/DataLayout.cpp @@ -162,7 +162,7 @@ {INTEGER_ALIGN, 16, Align(2), Align(2)}, // i16 {INTEGER_ALIGN, 32, Align(4), Align(4)}, // i32 {INTEGER_ALIGN, 64, Align(4), Align(8)}, // i64 - {FLOAT_ALIGN, 16, Align(2), Align(2)}, // half + {FLOAT_ALIGN, 16, Align(2), Align(2)}, // half, bfloat {FLOAT_ALIGN, 32, Align(4), Align(4)}, // float {FLOAT_ALIGN, 64, Align(8), Align(8)}, // double {FLOAT_ALIGN, 128, Align(16), Align(16)}, // ppcf128, quad, ... @@ -732,6 +732,7 @@ AlignType = INTEGER_ALIGN; break; case Type::HalfTyID: + case Type::BFloatTyID: case Type::FloatTyID: case Type::DoubleTyID: // PPC_FP128TyID and FP128TyID have different data contents, but the diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp --- a/llvm/lib/IR/Function.cpp +++ b/llvm/lib/IR/Function.cpp @@ -655,6 +655,7 @@ case Type::VoidTyID: Result += "isVoid"; break; case Type::MetadataTyID: Result += "Metadata"; break; case Type::HalfTyID: Result += "f16"; break; + case Type::BFloatTyID: Result += "bf16"; break; case Type::FloatTyID: Result += "f32"; break; case Type::DoubleTyID: Result += "f64"; break; case Type::X86_FP80TyID: Result += "f80"; break; diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h --- a/llvm/lib/IR/LLVMContextImpl.h +++ b/llvm/lib/IR/LLVMContextImpl.h @@ -1342,7 +1342,8 @@ std::unique_ptr TheNoneToken; // Basic type instances. - Type VoidTy, LabelTy, HalfTy, FloatTy, DoubleTy, MetadataTy, TokenTy; + Type VoidTy, LabelTy, HalfTy, BFloatTy, FloatTy, DoubleTy, MetadataTy, + TokenTy; Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy; IntegerType Int1Ty, Int8Ty, Int16Ty, Int32Ty, Int64Ty, Int128Ty; diff --git a/llvm/lib/IR/LLVMContextImpl.cpp b/llvm/lib/IR/LLVMContextImpl.cpp --- a/llvm/lib/IR/LLVMContextImpl.cpp +++ b/llvm/lib/IR/LLVMContextImpl.cpp @@ -26,6 +26,7 @@ VoidTy(C, Type::VoidTyID), LabelTy(C, Type::LabelTyID), HalfTy(C, Type::HalfTyID), + BFloatTy(C, Type::BFloatTyID), FloatTy(C, Type::FloatTyID), DoubleTy(C, Type::DoubleTyID), MetadataTy(C, Type::MetadataTyID), diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -40,6 +40,7 @@ switch (IDNumber) { case VoidTyID : return getVoidTy(C); case HalfTyID : return getHalfTy(C); + case BFloatTyID : return getBFloatTy(C); case FloatTyID : return getFloatTy(C); case DoubleTyID : return getDoubleTy(C); case X86_FP80TyID : return getX86_FP80Ty(C); @@ -112,6 +113,7 @@ TypeSize Type::getPrimitiveSizeInBits() const { switch (getTypeID()) { case Type::HalfTyID: return TypeSize::Fixed(16); + case Type::BFloatTyID: return TypeSize::Fixed(16); case Type::FloatTyID: return TypeSize::Fixed(32); case Type::DoubleTyID: return TypeSize::Fixed(64); case Type::X86_FP80TyID: return TypeSize::Fixed(80); @@ -142,6 +144,7 @@ return VTy->getElementType()->getFPMantissaWidth(); assert(isFloatingPointTy() && "Not a floating point type!"); if (getTypeID() == HalfTyID) return 11; + if (getTypeID() == BFloatTyID) return 8; if (getTypeID() == FloatTyID) return 24; if (getTypeID() == DoubleTyID) return 53; if (getTypeID() == X86_FP80TyID) return 64; @@ -167,6 +170,7 @@ Type *Type::getVoidTy(LLVMContext &C) { return &C.pImpl->VoidTy; } Type *Type::getLabelTy(LLVMContext &C) { return &C.pImpl->LabelTy; } Type *Type::getHalfTy(LLVMContext &C) { return &C.pImpl->HalfTy; } +Type *Type::getBFloatTy(LLVMContext &C) { return &C.pImpl->BFloatTy; } Type *Type::getFloatTy(LLVMContext &C) { return &C.pImpl->FloatTy; } Type *Type::getDoubleTy(LLVMContext &C) { return &C.pImpl->DoubleTy; } Type *Type::getMetadataTy(LLVMContext &C) { return &C.pImpl->MetadataTy; } @@ -191,6 +195,10 @@ return getHalfTy(C)->getPointerTo(AS); } +PointerType *Type::getBFloatPtrTy(LLVMContext &C, unsigned AS) { + return getBFloatTy(C)->getPointerTo(AS); +} + PointerType *Type::getFloatPtrTy(LLVMContext &C, unsigned AS) { return getFloatTy(C)->getPointerTo(AS); } diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -69,6 +69,7 @@ }; static const fltSemantics semIEEEhalf = {15, -14, 11, 16}; + static const fltSemantics semBFloat = {127, -126, 8, 16}; static const fltSemantics semIEEEsingle = {127, -126, 24, 32}; static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64}; static const fltSemantics semIEEEquad = {16383, -16382, 113, 128}; @@ -117,6 +118,8 @@ switch (S) { case S_IEEEhalf: return IEEEhalf(); + case S_BFloat: + return BFloat(); case S_IEEEsingle: return IEEEsingle(); case S_IEEEdouble: @@ -135,6 +138,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) { if (&Sem == &llvm::APFloat::IEEEhalf()) return S_IEEEhalf; + else if (&Sem == &llvm::APFloat::BFloat()) + return S_BFloat; else if (&Sem == &llvm::APFloat::IEEEsingle()) return S_IEEEsingle; else if (&Sem == &llvm::APFloat::IEEEdouble()) @@ -152,6 +157,9 @@ const fltSemantics &APFloatBase::IEEEhalf() { return semIEEEhalf; } + const fltSemantics &APFloatBase::BFloat() { + return semBFloat; + } const fltSemantics &APFloatBase::IEEEsingle() { return semIEEEsingle; } @@ -3255,6 +3263,33 @@ (mysignificand & 0x7fffff))); } +APInt IEEEFloat::convertBFloatAPFloatToAPInt() const { + assert(semantics == (const llvm::fltSemantics *)&semBFloat); + assert(partCount() == 1); + + uint32_t myexponent, mysignificand; + + if (isFiniteNonZero()) { + myexponent = exponent + 127; // bias + mysignificand = (uint32_t)*significandParts(); + if (myexponent == 1 && !(mysignificand & 0x80)) + myexponent = 0; // denormal + } else if (category == fcZero) { + myexponent = 0; + mysignificand = 0; + } else if (category == fcInfinity) { + myexponent = 0x1f; + mysignificand = 0; + } else { + assert(category == fcNaN && "Unknown category!"); + myexponent = 0x1f; + mysignificand = (uint32_t)*significandParts(); + } + + return APInt(16, (((sign & 1) << 15) | ((myexponent & 0xff) << 7) | + (mysignificand & 0x7f))); +} + APInt IEEEFloat::convertHalfAPFloatToAPInt() const { assert(semantics == (const llvm::fltSemantics*)&semIEEEhalf); assert(partCount()==1); @@ -3290,6 +3325,9 @@ if (semantics == (const llvm::fltSemantics*)&semIEEEhalf) return convertHalfAPFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semBFloat) + return convertBFloatAPFloatToAPInt(); + if (semantics == (const llvm::fltSemantics*)&semIEEEsingle) return convertFloatAPFloatToAPInt(); @@ -3486,6 +3524,37 @@ } } +void IEEEFloat::initFromBFloatAPInt(const APInt &api) { + assert(api.getBitWidth() == 16); + uint32_t i = (uint32_t)*api.getRawData(); + uint32_t myexponent = (i >> 7) & 0xff; + uint32_t mysignificand = i & 0x7f; + + initialize(&semBFloat); + assert(partCount() == 1); + + sign = i >> 15; + if (myexponent == 0 && mysignificand == 0) { + // exponent, significand meaningless + category = fcZero; + } else if (myexponent == 0xff && mysignificand == 0) { + // exponent, significand meaningless + category = fcInfinity; + } else if (myexponent == 0xff && mysignificand != 0) { + // sign, exponent, significand meaningless + category = fcNaN; + *significandParts() = mysignificand; + } else { + category = fcNormal; + exponent = myexponent - 127; // bias + *significandParts() = mysignificand; + if (myexponent == 0) // denormal + exponent = -126; + else + *significandParts() |= 0x80; // integer bit + } +} + void IEEEFloat::initFromHalfAPInt(const APInt &api) { assert(api.getBitWidth()==16); uint32_t i = (uint32_t)*api.getRawData(); @@ -3524,6 +3593,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { if (Sem == &semIEEEhalf) return initFromHalfAPInt(api); + if (Sem == &semBFloat) + return initFromBFloatAPInt(api); if (Sem == &semIEEEsingle) return initFromFloatAPInt(api); if (Sem == &semIEEEdouble) @@ -4763,26 +4834,9 @@ llvm_unreachable("Unexpected semantics"); } -APFloat APFloat::getAllOnesValue(unsigned BitWidth, bool isIEEE) { - if (isIEEE) { - switch (BitWidth) { - case 16: - return APFloat(semIEEEhalf, APInt::getAllOnesValue(BitWidth)); - case 32: - return APFloat(semIEEEsingle, APInt::getAllOnesValue(BitWidth)); - case 64: - return APFloat(semIEEEdouble, APInt::getAllOnesValue(BitWidth)); - case 80: - return APFloat(semX87DoubleExtended, APInt::getAllOnesValue(BitWidth)); - case 128: - return APFloat(semIEEEquad, APInt::getAllOnesValue(BitWidth)); - default: - llvm_unreachable("Unknown floating bit width"); - } - } else { - assert(BitWidth == 128); - return APFloat(semPPCDoubleDouble, APInt::getAllOnesValue(BitWidth)); - } +APFloat APFloat::getAllOnesValue(const fltSemantics &Semantics, + unsigned BitWidth) { + return APFloat(Semantics, APInt::getAllOnesValue(BitWidth)); } void APFloat::print(raw_ostream &OS) const { diff --git a/llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp b/llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp --- a/llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp @@ -323,6 +323,7 @@ } case Type::FunctionTyID: case Type::VoidTyID: + case Type::BFloatTyID: case Type::X86_FP80TyID: case Type::FP128TyID: case Type::PPC_FP128TyID: diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -11526,8 +11526,9 @@ MVT LogicVT = VT; if (EltVT == MVT::f32 || EltVT == MVT::f64) { Zero = DAG.getConstantFP(0.0, DL, EltVT); - AllOnes = DAG.getConstantFP( - APFloat::getAllOnesValue(EltVT.getSizeInBits(), true), DL, EltVT); + APFloat AllOnesValue = APFloat::getAllOnesValue( + SelectionDAG::EVTToAPFloatSemantics(EltVT), EltVT.getSizeInBits()); + AllOnes = DAG.getConstantFP(AllOnesValue, DL, EltVT); LogicVT = MVT::getVectorVT(EltVT == MVT::f64 ? MVT::i64 : MVT::i32, Mask.size()); } else { diff --git a/llvm/test/Assembler/bfloat.ll b/llvm/test/Assembler/bfloat.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Assembler/bfloat.ll @@ -0,0 +1,38 @@ +; RUN: llvm-as < %s | llvm-dis | FileCheck %s --check-prefix=ASSEM-DISASS +; RUN: opt < %s -O3 -S | FileCheck %s --check-prefix=OPT +; RUN: verify-uselistorder %s +; Basic smoke tests for bfloat type. + +define bfloat @check_bfloat(bfloat %A) { +; ASSEM-DISASS: ret bfloat %A + ret bfloat %A +} + +define bfloat @check_bfloat_literal() { +; ASSEM-DISASS: ret bfloat 0xR3149 + ret bfloat 0xR3149 +} + +define <4 x bfloat> @check_fixed_vector() { +; ASSEM-DISASS: ret <4 x bfloat> %tmp + %tmp = fadd <4 x bfloat> undef, undef + ret <4 x bfloat> %tmp +} + +define @check_vector() { +; ASSEM-DISASS: ret %tmp + %tmp = fadd undef, undef + ret %tmp +} + +define bfloat @check_bfloat_constprop() { + %tmp = fadd bfloat 0xR40C0, 0xR40C0 +; OPT: 0xR4140 + ret bfloat %tmp +} + +define float @check_bfloat_convert() { + %tmp = fpext bfloat 0xR4C8D to float +; OPT: 0x4191A00000000000 + ret float %tmp +} diff --git a/llvm/tools/llvm-c-test/echo.cpp b/llvm/tools/llvm-c-test/echo.cpp --- a/llvm/tools/llvm-c-test/echo.cpp +++ b/llvm/tools/llvm-c-test/echo.cpp @@ -72,6 +72,8 @@ return LLVMVoidTypeInContext(Ctx); case LLVMHalfTypeKind: return LLVMHalfTypeInContext(Ctx); + case LLVMBFloatTypeKind: + return LLVMHalfTypeInContext(Ctx); case LLVMFloatTypeKind: return LLVMFloatTypeInContext(Ctx); case LLVMDoubleTypeKind: