diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -14736,9 +14736,9 @@ if (auto *ComplexTy = OrigType->getAs()) Type = ComplexTy->getElementType(); if (Type->isRealFloatingType()) { - llvm::APFloat InitValue = - llvm::APFloat::getAllOnesValue(Context.getTypeSize(Type), - /*isIEEE=*/true); + llvm::APFloat InitValue = llvm::APFloat::getAllOnesValue( + Context.getFloatTypeSemantics(Type), + Context.getTypeSize(Type)); Init = FloatingLiteral::Create(Context, InitValue, /*isexact=*/true, Type, ELoc); } else if (Type->isScalarType()) { diff --git a/llvm/docs/BitCodeFormat.rst b/llvm/docs/BitCodeFormat.rst --- a/llvm/docs/BitCodeFormat.rst +++ b/llvm/docs/BitCodeFormat.rst @@ -1107,6 +1107,14 @@ The ``HALF`` record (code 10) adds a ``half`` (16-bit floating point) type to the type table. +TYPE_CODE_BFLOAT Record +^^^^^^^^^^^^^^^^^^^^^ + +``[BFLOAT]`` + +The ``BFLOAT`` record (code 23) adds a ``bfloat`` (16-bit brain floating point) +type to the type table. + TYPE_CODE_FLOAT Record ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -2892,6 +2892,9 @@ * - ``half`` - 16-bit floating-point value + * - ``bfloat`` + - 16-bit brain floating-point value (8-bit mantissa) + * - ``float`` - 32-bit floating-point value @@ -3232,20 +3235,20 @@ values are represented in their IEEE hexadecimal format so that assembly and disassembly do not cause any bits to change in the constants. -When using the hexadecimal form, constants of types half, float, and -double are represented using the 16-digit form shown above (which -matches the IEEE754 representation for double); half and float values -must, however, be exactly representable as IEEE 754 half and single -precision, respectively. Hexadecimal format is always used for long -double, and there are three forms of long double. The 80-bit format used -by x86 is represented as ``0xK`` followed by 20 hexadecimal digits. The -128-bit format used by PowerPC (two adjacent doubles) is represented by -``0xM`` followed by 32 hexadecimal digits. The IEEE 128-bit format is -represented by ``0xL`` followed by 32 hexadecimal digits. Long doubles -will only work if they match the long double format on your target. -The IEEE 16-bit format (half precision) is represented by ``0xH`` -followed by 4 hexadecimal digits. All hexadecimal formats are big-endian -(sign bit at the left). +When using the hexadecimal form, constants of types half, bfloat, float, and +double are represented using the 16-digit form shown above (which matches the +IEEE754 representation for double); half and float values must, however, be +exactly representable as IEEE 754 half and single precision, +respectively. Hexadecimal format is always used for long double, and there are +three forms of long double. The 80-bit format used by x86 is represented as +``0xK`` followed by 20 hexadecimal digits. The 128-bit format used by PowerPC +(two adjacent doubles) is represented by ``0xM`` followed by 32 hexadecimal +digits. The IEEE 128-bit format is represented by ``0xL`` followed by 32 +hexadecimal digits. Long doubles will only work if they match the long double +format on your target. The IEEE 16-bit format (half precision) is represented +by ``0xH`` followed by 4 hexadecimal digits. The bfloat 16-bit format is +represented by ``0xR`` followed by 4 hexadecimal digits. All hexadecimal formats +are big-endian (sign bit at the left). There are no constants of type x86_mmx. diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h --- a/llvm/include/llvm-c/Core.h +++ b/llvm/include/llvm-c/Core.h @@ -146,6 +146,7 @@ typedef enum { LLVMVoidTypeKind, /**< type with no size */ LLVMHalfTypeKind, /**< 16 bit floating point type */ + LLVMBfloatTypeKind, /**< 16 bit brain floating point type */ LLVMFloatTypeKind, /**< 32 bit floating point type */ LLVMDoubleTypeKind, /**< 64 bit floating point type */ LLVMX86_FP80TypeKind, /**< 80 bit floating point type (X87) */ @@ -1163,6 +1164,11 @@ LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C); /** + * Obtain a 16-bit brain floating point type from a context. + */ +LLVMTypeRef LLVMBfloatTypeInContext(LLVMContextRef C); + +/** * Obtain a 32-bit floating point type from a context. */ LLVMTypeRef LLVMFloatTypeInContext(LLVMContextRef C); @@ -1194,6 +1200,7 @@ * These map to the functions in this group of the same name. */ LLVMTypeRef LLVMHalfType(void); +LLVMTypeRef LLVMBfloatType(void); LLVMTypeRef LLVMFloatType(void); LLVMTypeRef LLVMDoubleType(void); LLVMTypeRef LLVMX86FP80Type(void); diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h --- a/llvm/include/llvm/ADT/APFloat.h +++ b/llvm/include/llvm/ADT/APFloat.h @@ -151,6 +151,7 @@ /// @{ enum Semantics { S_IEEEhalf, + S_Bfloat, S_IEEEsingle, S_IEEEdouble, S_x87DoubleExtended, @@ -162,6 +163,7 @@ static Semantics SemanticsToEnum(const llvm::fltSemantics &Sem); static const fltSemantics &IEEEhalf() LLVM_READNONE; + static const fltSemantics &Bfloat() LLVM_READNONE; static const fltSemantics &IEEEsingle() LLVM_READNONE; static const fltSemantics &IEEEdouble() LLVM_READNONE; static const fltSemantics &IEEEquad() LLVM_READNONE; @@ -541,6 +543,7 @@ /// @} APInt convertHalfAPFloatToAPInt() const; + APInt convertBfloatAPFloatToAPInt() const; APInt convertFloatAPFloatToAPInt() const; APInt convertDoubleAPFloatToAPInt() const; APInt convertQuadrupleAPFloatToAPInt() const; @@ -548,6 +551,7 @@ APInt convertPPCDoubleDoubleAPFloatToAPInt() const; void initFromAPInt(const fltSemantics *Sem, const APInt &api); void initFromHalfAPInt(const APInt &api); + void initFromBfloatAPInt(const APInt &api); void initFromFloatAPInt(const APInt &api); void initFromDoubleAPInt(const APInt &api); void initFromQuadrupleAPInt(const APInt &api); @@ -954,9 +958,10 @@ /// Returns a float which is bitcasted from an all one value int. /// + /// \param Semantics - type float semantics /// \param BitWidth - Select float type - /// \param isIEEE - If 128 bit number, select between PPC and IEEE - static APFloat getAllOnesValue(unsigned BitWidth, bool isIEEE = false); + static APFloat getAllOnesValue(const fltSemantics &Semantics, + unsigned BitWidth); /// Used to insert APFloat objects, or objects that contain APFloat objects, /// into FoldingSets. diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h --- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h +++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h @@ -166,7 +166,9 @@ TYPE_CODE_FUNCTION = 21, // FUNCTION: [vararg, retty, paramty x N] - TYPE_CODE_TOKEN = 22 // TOKEN + TYPE_CODE_TOKEN = 22, // TOKEN + + TYPE_CODE_BFLOAT = 23 // BRAIN FLOATING POINT }; enum OperandBundleTagCode { diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -725,9 +725,9 @@ /// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits, /// double for 64bits) Note that this can return a ConstantAggregateZero /// object. - static Constant *getFP(LLVMContext &Context, ArrayRef Elts); - static Constant *getFP(LLVMContext &Context, ArrayRef Elts); - static Constant *getFP(LLVMContext &Context, ArrayRef Elts); + static Constant *getFP(Type *ElementType, ArrayRef Elts); + static Constant *getFP(Type *ElementType, ArrayRef Elts); + static Constant *getFP(Type *ElementType, ArrayRef Elts); /// This method constructs a CDS and initializes it with a text string. /// The default behavior (AddNull==true) causes a null terminator to @@ -784,9 +784,9 @@ /// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits, /// double for 64bits) Note that this can return a ConstantAggregateZero /// object. - static Constant *getFP(LLVMContext &Context, ArrayRef Elts); - static Constant *getFP(LLVMContext &Context, ArrayRef Elts); - static Constant *getFP(LLVMContext &Context, ArrayRef Elts); + static Constant *getFP(Type *ElementType, ArrayRef Elts); + static Constant *getFP(Type *ElementType, ArrayRef Elts); + static Constant *getFP(Type *ElementType, ArrayRef Elts); /// Return a ConstantVector with the specified constant in each element. /// The specified constant has to be a of a compatible type (i8/i16/ diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h --- a/llvm/include/llvm/IR/DataLayout.h +++ b/llvm/include/llvm/IR/DataLayout.h @@ -652,6 +652,8 @@ return TypeSize::Fixed(Ty->getIntegerBitWidth()); case Type::HalfTyID: return TypeSize::Fixed(16); + case Type::BfloatTyID: + return TypeSize::Fixed(16); case Type::FloatTyID: return TypeSize::Fixed(32); case Type::DoubleTyID: diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h --- a/llvm/include/llvm/IR/Type.h +++ b/llvm/include/llvm/IR/Type.h @@ -56,24 +56,25 @@ // PrimitiveTypes - make sure LastPrimitiveTyID stays up to date. VoidTyID = 0, ///< 0: type with no size HalfTyID, ///< 1: 16-bit floating point type - FloatTyID, ///< 2: 32-bit floating point type - DoubleTyID, ///< 3: 64-bit floating point type - X86_FP80TyID, ///< 4: 80-bit floating point type (X87) - FP128TyID, ///< 5: 128-bit floating point type (112-bit mantissa) - PPC_FP128TyID, ///< 6: 128-bit floating point type (two 64-bits, PowerPC) - LabelTyID, ///< 7: Labels - MetadataTyID, ///< 8: Metadata - X86_MMXTyID, ///< 9: MMX vectors (64 bits, X86 specific) - TokenTyID, ///< 10: Tokens + BfloatTyID, ///< 2: 16-bit floating point type + FloatTyID, ///< 3: 32-bit floating point type + DoubleTyID, ///< 4: 64-bit floating point type + X86_FP80TyID, ///< 5: 80-bit floating point type (X87) + FP128TyID, ///< 6: 128-bit floating point type (112-bit mantissa) + PPC_FP128TyID, ///< 7: 128-bit floating point type (two 64-bits, PowerPC) + LabelTyID, ///< 8: Labels + MetadataTyID, ///< 9: Metadata + X86_MMXTyID, ///< 10: MMX vectors (64 bits, X86 specific) + TokenTyID, ///< 11: Tokens // Derived types... see DerivedTypes.h file. // Make sure FirstDerivedTyID stays up to date! - IntegerTyID, ///< 11: Arbitrary bit width integers - FunctionTyID, ///< 12: Functions - StructTyID, ///< 13: Structures - ArrayTyID, ///< 14: Arrays - PointerTyID, ///< 15: Pointers - VectorTyID ///< 16: SIMD 'packed' format, or other vector type + IntegerTyID, ///< 12: Arbitrary bit width integers + FunctionTyID, ///< 13: Functions + StructTyID, ///< 14: Structures + ArrayTyID, ///< 15: Arrays + PointerTyID, ///< 16: Pointers + VectorTyID ///< 17: SIMD 'packed' format, or other vector type }; private: @@ -139,6 +140,9 @@ /// Return true if this is 'half', a 16-bit IEEE fp type. bool isHalfTy() const { return getTypeID() == HalfTyID; } + /// Return true if this is 'bfloat', a 16-bit bfloat type. + bool isBfloatTy() const { return getTypeID() == BfloatTyID; } + /// Return true if this is 'float', a 32-bit IEEE fp type. bool isFloatTy() const { return getTypeID() == FloatTyID; } @@ -156,8 +160,8 @@ /// Return true if this is one of the six floating-point types bool isFloatingPointTy() const { - return getTypeID() == HalfTyID || getTypeID() == FloatTyID || - getTypeID() == DoubleTyID || + return getTypeID() == HalfTyID || getTypeID() == BfloatTyID || + getTypeID() == FloatTyID || getTypeID() == DoubleTyID || getTypeID() == X86_FP80TyID || getTypeID() == FP128TyID || getTypeID() == PPC_FP128TyID; } @@ -165,6 +169,7 @@ const fltSemantics &getFltSemantics() const { switch (getTypeID()) { case HalfTyID: return APFloat::IEEEhalf(); + case BfloatTyID: return APFloat::Bfloat(); case FloatTyID: return APFloat::IEEEsingle(); case DoubleTyID: return APFloat::IEEEdouble(); case X86_FP80TyID: return APFloat::x87DoubleExtended(); @@ -399,6 +404,7 @@ static Type *getVoidTy(LLVMContext &C); static Type *getLabelTy(LLVMContext &C); static Type *getHalfTy(LLVMContext &C); + static Type *getBfloatTy(LLVMContext &C); static Type *getFloatTy(LLVMContext &C); static Type *getDoubleTy(LLVMContext &C); static Type *getMetadataTy(LLVMContext &C); @@ -434,6 +440,7 @@ // types as pointee. // static PointerType *getHalfPtrTy(LLVMContext &C, unsigned AS = 0); + static PointerType *getBfloatPtrTy(LLVMContext &C, unsigned AS = 0); static PointerType *getFloatPtrTy(LLVMContext &C, unsigned AS = 0); static PointerType *getDoublePtrTy(LLVMContext &C, unsigned AS = 0); static PointerType *getX86_FP80PtrTy(LLVMContext &C, unsigned AS = 0); diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp --- a/llvm/lib/AsmParser/LLLexer.cpp +++ b/llvm/lib/AsmParser/LLLexer.cpp @@ -818,6 +818,7 @@ TYPEKEYWORD("void", Type::getVoidTy(Context)); TYPEKEYWORD("half", Type::getHalfTy(Context)); + TYPEKEYWORD("bfloat", Type::getBfloatTy(Context)); TYPEKEYWORD("float", Type::getFloatTy(Context)); TYPEKEYWORD("double", Type::getDoubleTy(Context)); TYPEKEYWORD("x86_fp80", Type::getX86_FP80Ty(Context)); @@ -983,11 +984,13 @@ /// HexFP128Constant 0xL[0-9A-Fa-f]+ /// HexPPC128Constant 0xM[0-9A-Fa-f]+ /// HexHalfConstant 0xH[0-9A-Fa-f]+ +/// HexBfloatConstant 0xR[0-9A-Fa-f]+ lltok::Kind LLLexer::Lex0x() { CurPtr = TokStart + 2; char Kind; - if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H') { + if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H' || + CurPtr[0] == 'R') { Kind = *CurPtr++; } else { Kind = 'J'; @@ -1005,7 +1008,7 @@ if (Kind == 'J') { // HexFPConstant - Floating point constant represented in IEEE format as a // hexadecimal number for when exponential notation is not precise enough. - // Half, Float, and double only. + // Half, Bfloat, Float, and double only. APFloatVal = APFloat(APFloat::IEEEdouble(), APInt(64, HexIntToVal(TokStart + 2, CurPtr))); return lltok::APFloat; @@ -1033,6 +1036,11 @@ APFloatVal = APFloat(APFloat::IEEEhalf(), APInt(16,HexIntToVal(TokStart+3, CurPtr))); return lltok::APFloat; + case 'R': + // Brain floating point + APFloatVal = APFloat(APFloat::Bfloat(), + APInt(16, HexIntToVal(TokStart + 3, CurPtr))); + return lltok::APFloat; } } diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -5188,13 +5188,16 @@ !ConstantFP::isValueValidForType(Ty, ID.APFloatVal)) return Error(ID.Loc, "floating point constant invalid for type"); - // The lexer has no type info, so builds all half, float, and double FP - // constants as double. Fix this here. Long double does not need this. + // The lexer has no type info, so builds all half, bfloat, float, and double + // FP constants as double. Fix this here. Long double does not need this. if (&ID.APFloatVal.getSemantics() == &APFloat::IEEEdouble()) { bool Ignored; if (Ty->isHalfTy()) ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored); + else if (Ty->isBfloatTy()) + ID.APFloatVal.convert(APFloat::Bfloat(), APFloat::rmNearestTiesToEven, + &Ignored); else if (Ty->isFloatTy()) ID.APFloatVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &Ignored); diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -1730,6 +1730,9 @@ case bitc::TYPE_CODE_HALF: // HALF ResultTy = Type::getHalfTy(Context); break; + case bitc::TYPE_CODE_BFLOAT: // BFLOAT + ResultTy = Type::getBfloatTy(Context); + break; case bitc::TYPE_CODE_FLOAT: // FLOAT ResultTy = Type::getFloatTy(Context); break; @@ -2439,6 +2442,9 @@ if (CurTy->isHalfTy()) V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(), APInt(16, (uint16_t)Record[0]))); + else if (CurTy->isBfloatTy()) + V = ConstantFP::get(Context, APFloat(APFloat::Bfloat(), + APInt(16, (uint32_t)Record[0]))); else if (CurTy->isFloatTy()) V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(), APInt(32, (uint32_t)Record[0]))); @@ -2536,21 +2542,27 @@ } else if (EltTy->isHalfTy()) { SmallVector Elts(Record.begin(), Record.end()); if (isa(CurTy)) - V = ConstantDataVector::getFP(Context, Elts); + V = ConstantDataVector::getFP(EltTy, Elts); + else + V = ConstantDataArray::getFP(EltTy, Elts); + } else if (EltTy->isBfloatTy()) { + SmallVector Elts(Record.begin(), Record.end()); + if (isa(CurTy)) + V = ConstantDataVector::getFP(EltTy, Elts); else - V = ConstantDataArray::getFP(Context, Elts); + V = ConstantDataArray::getFP(EltTy, Elts); } else if (EltTy->isFloatTy()) { SmallVector Elts(Record.begin(), Record.end()); if (isa(CurTy)) - V = ConstantDataVector::getFP(Context, Elts); + V = ConstantDataVector::getFP(EltTy, Elts); else - V = ConstantDataArray::getFP(Context, Elts); + V = ConstantDataArray::getFP(EltTy, Elts); } else if (EltTy->isDoubleTy()) { SmallVector Elts(Record.begin(), Record.end()); if (isa(CurTy)) - V = ConstantDataVector::getFP(Context, Elts); + V = ConstantDataVector::getFP(EltTy, Elts); else - V = ConstantDataArray::getFP(Context, Elts); + V = ConstantDataArray::getFP(EltTy, Elts); } else { return error("Invalid type for value"); } diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -878,6 +878,7 @@ switch (T->getTypeID()) { case Type::VoidTyID: Code = bitc::TYPE_CODE_VOID; break; case Type::HalfTyID: Code = bitc::TYPE_CODE_HALF; break; + case Type::BfloatTyID: Code = bitc::TYPE_CODE_BFLOAT; break; case Type::FloatTyID: Code = bitc::TYPE_CODE_FLOAT; break; case Type::DoubleTyID: Code = bitc::TYPE_CODE_DOUBLE; break; case Type::X86_FP80TyID: Code = bitc::TYPE_CODE_X86_FP80; break; @@ -2376,7 +2377,8 @@ } else if (const ConstantFP *CFP = dyn_cast(C)) { Code = bitc::CST_CODE_FLOAT; Type *Ty = CFP->getType(); - if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) { + if (Ty->isHalfTy() || Ty->isBfloatTy() || Ty->isFloatTy() || + Ty->isDoubleTy()) { Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue()); } else if (Ty->isX86_FP80Ty()) { // api needed to prevent premature destruction diff --git a/llvm/lib/CodeGen/MIRParser/MILexer.cpp b/llvm/lib/CodeGen/MIRParser/MILexer.cpp --- a/llvm/lib/CodeGen/MIRParser/MILexer.cpp +++ b/llvm/lib/CodeGen/MIRParser/MILexer.cpp @@ -534,7 +534,7 @@ } static bool isValidHexFloatingPointPrefix(char C) { - return C == 'H' || C == 'K' || C == 'L' || C == 'M'; + return C == 'H' || C == 'K' || C == 'L' || C == 'M' || C == 'R'; } static Cursor lexFloatingPointLiteral(Cursor Range, Cursor C, MIToken &Token) { diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -588,6 +588,7 @@ switch (Ty->getTypeID()) { case Type::VoidTyID: OS << "void"; return; case Type::HalfTyID: OS << "half"; return; + case Type::BfloatTyID: OS << "bfloat"; return; case Type::FloatTyID: OS << "float"; return; case Type::DoubleTyID: OS << "double"; return; case Type::X86_FP80TyID: OS << "x86_fp80"; return; @@ -1377,7 +1378,7 @@ return; } - // Either half, or some form of long double. + // Either half, bfloat or some form of long double. // These appear as a magic letter identifying the type, then a // fixed number of hex digits. Out << "0x"; @@ -1405,6 +1406,10 @@ Out << 'H'; Out << format_hex_no_prefix(API.getZExtValue(), 4, /*Upper=*/true); + } else if (&APF.getSemantics() == &APFloat::Bfloat()) { + Out << 'R'; + Out << format_hex_no_prefix(API.getZExtValue(), 4, + /*Upper=*/true); } else llvm_unreachable("Unsupported floating point type"); return; diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -332,6 +332,9 @@ case Type::HalfTyID: return ConstantFP::get(Ty->getContext(), APFloat::getZero(APFloat::IEEEhalf())); + case Type::BfloatTyID: + return ConstantFP::get(Ty->getContext(), + APFloat::getZero(APFloat::Bfloat())); case Type::FloatTyID: return ConstantFP::get(Ty->getContext(), APFloat::getZero(APFloat::IEEEsingle())); @@ -385,8 +388,8 @@ APInt::getAllOnesValue(ITy->getBitWidth())); if (Ty->isFloatingPointTy()) { - APFloat FL = APFloat::getAllOnesValue(Ty->getPrimitiveSizeInBits(), - !Ty->isPPC_FP128Ty()); + APFloat FL = APFloat::getAllOnesValue(Ty->getFltSemantics(), + Ty->getPrimitiveSizeInBits()); return ConstantFP::get(Ty->getContext(), FL); } @@ -762,6 +765,8 @@ static const fltSemantics *TypeToFloatSemantics(Type *Ty) { if (Ty->isHalfTy()) return &APFloat::IEEEhalf(); + if (Ty->isBfloatTy()) + return &APFloat::Bfloat(); if (Ty->isFloatTy()) return &APFloat::IEEEsingle(); if (Ty->isDoubleTy()) @@ -879,6 +884,8 @@ Type *Ty; if (&V.getSemantics() == &APFloat::IEEEhalf()) Ty = Type::getHalfTy(Context); + else if (&V.getSemantics() == &APFloat::Bfloat()) + Ty = Type::getBfloatTy(Context); else if (&V.getSemantics() == &APFloat::IEEEsingle()) Ty = Type::getFloatTy(Context); else if (&V.getSemantics() == &APFloat::IEEEdouble()) @@ -1028,7 +1035,7 @@ Elts.push_back(CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); else return nullptr; - return SequentialTy::getFP(V[0]->getContext(), Elts); + return SequentialTy::getFP(V[0]->getType(), Elts); } template @@ -1047,7 +1054,7 @@ else if (CI->getType()->isIntegerTy(64)) return getIntSequenceIfElementsMatch(V); } else if (ConstantFP *CFP = dyn_cast(C)) { - if (CFP->getType()->isHalfTy()) + if (CFP->getType()->isHalfTy() || CFP->getType()->isBfloatTy()) return getFPSequenceIfElementsMatch(V); else if (CFP->getType()->isFloatTy()) return getFPSequenceIfElementsMatch(V); @@ -1420,6 +1427,12 @@ Val2.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &losesInfo); return !losesInfo; } + case Type::BfloatTyID: { + if (&Val2.getSemantics() == &APFloat::Bfloat()) + return true; + Val2.convert(APFloat::Bfloat(), APFloat::rmNearestTiesToEven, &losesInfo); + return !losesInfo; + } case Type::FloatTyID: { if (&Val2.getSemantics() == &APFloat::IEEEsingle()) return true; @@ -1428,6 +1441,7 @@ } case Type::DoubleTyID: { if (&Val2.getSemantics() == &APFloat::IEEEhalf() || + &Val2.getSemantics() == &APFloat::Bfloat() || &Val2.getSemantics() == &APFloat::IEEEsingle() || &Val2.getSemantics() == &APFloat::IEEEdouble()) return true; @@ -1436,16 +1450,19 @@ } case Type::X86_FP80TyID: return &Val2.getSemantics() == &APFloat::IEEEhalf() || + &Val2.getSemantics() == &APFloat::Bfloat() || &Val2.getSemantics() == &APFloat::IEEEsingle() || &Val2.getSemantics() == &APFloat::IEEEdouble() || &Val2.getSemantics() == &APFloat::x87DoubleExtended(); case Type::FP128TyID: return &Val2.getSemantics() == &APFloat::IEEEhalf() || + &Val2.getSemantics() == &APFloat::Bfloat() || &Val2.getSemantics() == &APFloat::IEEEsingle() || &Val2.getSemantics() == &APFloat::IEEEdouble() || &Val2.getSemantics() == &APFloat::IEEEquad(); case Type::PPC_FP128TyID: return &Val2.getSemantics() == &APFloat::IEEEhalf() || + &Val2.getSemantics() == &APFloat::Bfloat() || &Val2.getSemantics() == &APFloat::IEEEsingle() || &Val2.getSemantics() == &APFloat::IEEEdouble() || &Val2.getSemantics() == &APFloat::PPCDoubleDouble(); @@ -2561,7 +2578,8 @@ } bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) { - if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) return true; + if (Ty->isHalfTy() || Ty->isBfloatTy() || Ty->isFloatTy() || Ty->isDoubleTy()) + return true; if (auto *IT = dyn_cast(Ty)) { switch (IT->getBitWidth()) { case 8: @@ -2684,21 +2702,18 @@ /// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits, /// double for 64bits) Note that this can return a ConstantAggregateZero /// object. -Constant *ConstantDataArray::getFP(LLVMContext &Context, - ArrayRef Elts) { - Type *Ty = ArrayType::get(Type::getHalfTy(Context), Elts.size()); +Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef Elts) { + Type *Ty = ArrayType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 2), Ty); } -Constant *ConstantDataArray::getFP(LLVMContext &Context, - ArrayRef Elts) { - Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size()); +Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef Elts) { + Type *Ty = ArrayType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 4), Ty); } -Constant *ConstantDataArray::getFP(LLVMContext &Context, - ArrayRef Elts) { - Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size()); +Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef Elts) { + Type *Ty = ArrayType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 8), Ty); } @@ -2755,21 +2770,21 @@ /// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits, /// double for 64bits) Note that this can return a ConstantAggregateZero /// object. -Constant *ConstantDataVector::getFP(LLVMContext &Context, +Constant *ConstantDataVector::getFP(Type *ElementType, ArrayRef Elts) { - Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size()); + Type *Ty = VectorType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 2), Ty); } -Constant *ConstantDataVector::getFP(LLVMContext &Context, +Constant *ConstantDataVector::getFP(Type *ElementType, ArrayRef Elts) { - Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size()); + Type *Ty = VectorType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 4), Ty); } -Constant *ConstantDataVector::getFP(LLVMContext &Context, +Constant *ConstantDataVector::getFP(Type *ElementType, ArrayRef Elts) { - Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size()); + Type *Ty = VectorType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 8), Ty); } @@ -2799,17 +2814,22 @@ if (CFP->getType()->isHalfTy()) { SmallVector Elts( NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); - return getFP(V->getContext(), Elts); + return getFP(V->getType(), Elts); + } + if (CFP->getType()->isBfloatTy()) { + SmallVector Elts( + NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); + return getFP(V->getType(), Elts); } if (CFP->getType()->isFloatTy()) { SmallVector Elts( NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); - return getFP(V->getContext(), Elts); + return getFP(V->getType(), Elts); } if (CFP->getType()->isDoubleTy()) { SmallVector Elts( NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); - return getFP(V->getContext(), Elts); + return getFP(V->getType(), Elts); } } return ConstantVector::getSplat({NumElts, false}, V); @@ -2874,6 +2894,10 @@ auto EltVal = *reinterpret_cast(EltPtr); return APFloat(APFloat::IEEEhalf(), APInt(16, EltVal)); } + case Type::BfloatTyID: { + auto EltVal = *reinterpret_cast(EltPtr); + return APFloat(APFloat::Bfloat(), APInt(16, EltVal)); + } case Type::FloatTyID: { auto EltVal = *reinterpret_cast(EltPtr); return APFloat(APFloat::IEEEsingle(), APInt(32, EltVal)); @@ -2898,8 +2922,8 @@ } Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const { - if (getElementType()->isHalfTy() || getElementType()->isFloatTy() || - getElementType()->isDoubleTy()) + if (getElementType()->isHalfTy() || getElementType()->isBfloatTy() || + getElementType()->isFloatTy() || getElementType()->isDoubleTy()) return ConstantFP::get(getContext(), getElementAsAPFloat(Elt)); return ConstantInt::get(getElementType(), getElementAsInteger(Elt)); diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp --- a/llvm/lib/IR/Core.cpp +++ b/llvm/lib/IR/Core.cpp @@ -477,6 +477,8 @@ return LLVMVoidTypeKind; case Type::HalfTyID: return LLVMHalfTypeKind; + case Type::BfloatTyID: + return LLVMBfloatTypeKind; case Type::FloatTyID: return LLVMFloatTypeKind; case Type::DoubleTyID: @@ -593,6 +595,9 @@ LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C) { return (LLVMTypeRef) Type::getHalfTy(*unwrap(C)); } +LLVMTypeRef LLVMBfloatTypeInContext(LLVMContextRef C) { + return (LLVMTypeRef) Type::getBfloatTy(*unwrap(C)); +} LLVMTypeRef LLVMFloatTypeInContext(LLVMContextRef C) { return (LLVMTypeRef) Type::getFloatTy(*unwrap(C)); } @@ -615,6 +620,9 @@ LLVMTypeRef LLVMHalfType(void) { return LLVMHalfTypeInContext(LLVMGetGlobalContext()); } +LLVMTypeRef LLVMBfloatType(void) { + return LLVMBfloatTypeInContext(LLVMGetGlobalContext()); +} LLVMTypeRef LLVMFloatType(void) { return LLVMFloatTypeInContext(LLVMGetGlobalContext()); } diff --git a/llvm/lib/IR/DataLayout.cpp b/llvm/lib/IR/DataLayout.cpp --- a/llvm/lib/IR/DataLayout.cpp +++ b/llvm/lib/IR/DataLayout.cpp @@ -162,7 +162,7 @@ {INTEGER_ALIGN, 16, Align(2), Align(2)}, // i16 {INTEGER_ALIGN, 32, Align(4), Align(4)}, // i32 {INTEGER_ALIGN, 64, Align(4), Align(8)}, // i64 - {FLOAT_ALIGN, 16, Align(2), Align(2)}, // half + {FLOAT_ALIGN, 16, Align(2), Align(2)}, // half, bfloat {FLOAT_ALIGN, 32, Align(4), Align(4)}, // float {FLOAT_ALIGN, 64, Align(8), Align(8)}, // double {FLOAT_ALIGN, 128, Align(16), Align(16)}, // ppcf128, quad, ... @@ -729,6 +729,7 @@ AlignType = INTEGER_ALIGN; break; case Type::HalfTyID: + case Type::BfloatTyID: case Type::FloatTyID: case Type::DoubleTyID: // PPC_FP128TyID and FP128TyID have different data contents, but the diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp --- a/llvm/lib/IR/Function.cpp +++ b/llvm/lib/IR/Function.cpp @@ -654,6 +654,7 @@ case Type::VoidTyID: Result += "isVoid"; break; case Type::MetadataTyID: Result += "Metadata"; break; case Type::HalfTyID: Result += "f16"; break; + case Type::BfloatTyID: Result += "bf16"; break; case Type::FloatTyID: Result += "f32"; break; case Type::DoubleTyID: Result += "f64"; break; case Type::X86_FP80TyID: Result += "f80"; break; diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h --- a/llvm/lib/IR/LLVMContextImpl.h +++ b/llvm/lib/IR/LLVMContextImpl.h @@ -1332,7 +1332,8 @@ std::unique_ptr TheNoneToken; // Basic type instances. - Type VoidTy, LabelTy, HalfTy, FloatTy, DoubleTy, MetadataTy, TokenTy; + Type VoidTy, LabelTy, HalfTy, BfloatTy, FloatTy, DoubleTy, MetadataTy, + TokenTy; Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy; IntegerType Int1Ty, Int8Ty, Int16Ty, Int32Ty, Int64Ty, Int128Ty; diff --git a/llvm/lib/IR/LLVMContextImpl.cpp b/llvm/lib/IR/LLVMContextImpl.cpp --- a/llvm/lib/IR/LLVMContextImpl.cpp +++ b/llvm/lib/IR/LLVMContextImpl.cpp @@ -26,6 +26,7 @@ VoidTy(C, Type::VoidTyID), LabelTy(C, Type::LabelTyID), HalfTy(C, Type::HalfTyID), + BfloatTy(C, Type::BfloatTyID), FloatTy(C, Type::FloatTyID), DoubleTy(C, Type::DoubleTyID), MetadataTy(C, Type::MetadataTyID), diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -40,6 +40,7 @@ switch (IDNumber) { case VoidTyID : return getVoidTy(C); case HalfTyID : return getHalfTy(C); + case BfloatTyID : return getBfloatTy(C); case FloatTyID : return getFloatTy(C); case DoubleTyID : return getDoubleTy(C); case X86_FP80TyID : return getX86_FP80Ty(C); @@ -115,6 +116,7 @@ TypeSize Type::getPrimitiveSizeInBits() const { switch (getTypeID()) { case Type::HalfTyID: return TypeSize::Fixed(16); + case Type::BfloatTyID: return TypeSize::Fixed(16); case Type::FloatTyID: return TypeSize::Fixed(32); case Type::DoubleTyID: return TypeSize::Fixed(64); case Type::X86_FP80TyID: return TypeSize::Fixed(80); @@ -141,6 +143,7 @@ return VTy->getElementType()->getFPMantissaWidth(); assert(isFloatingPointTy() && "Not a floating point type!"); if (getTypeID() == HalfTyID) return 11; + if (getTypeID() == BfloatTyID) return 8; if (getTypeID() == FloatTyID) return 24; if (getTypeID() == DoubleTyID) return 53; if (getTypeID() == X86_FP80TyID) return 64; @@ -166,6 +169,7 @@ Type *Type::getVoidTy(LLVMContext &C) { return &C.pImpl->VoidTy; } Type *Type::getLabelTy(LLVMContext &C) { return &C.pImpl->LabelTy; } Type *Type::getHalfTy(LLVMContext &C) { return &C.pImpl->HalfTy; } +Type *Type::getBfloatTy(LLVMContext &C) { return &C.pImpl->BfloatTy; } Type *Type::getFloatTy(LLVMContext &C) { return &C.pImpl->FloatTy; } Type *Type::getDoubleTy(LLVMContext &C) { return &C.pImpl->DoubleTy; } Type *Type::getMetadataTy(LLVMContext &C) { return &C.pImpl->MetadataTy; } @@ -190,6 +194,10 @@ return getHalfTy(C)->getPointerTo(AS); } +PointerType *Type::getBfloatPtrTy(LLVMContext &C, unsigned AS) { + return getBfloatTy(C)->getPointerTo(AS); +} + PointerType *Type::getFloatPtrTy(LLVMContext &C, unsigned AS) { return getFloatTy(C)->getPointerTo(AS); } diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -69,6 +69,7 @@ }; static const fltSemantics semIEEEhalf = {15, -14, 11, 16}; + static const fltSemantics semBfloat = {127, -126, 8, 16}; static const fltSemantics semIEEEsingle = {127, -126, 24, 32}; static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64}; static const fltSemantics semIEEEquad = {16383, -16382, 113, 128}; @@ -117,6 +118,8 @@ switch (S) { case S_IEEEhalf: return IEEEhalf(); + case S_Bfloat: + return Bfloat(); case S_IEEEsingle: return IEEEsingle(); case S_IEEEdouble: @@ -135,6 +138,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) { if (&Sem == &llvm::APFloat::IEEEhalf()) return S_IEEEhalf; + else if (&Sem == &llvm::APFloat::Bfloat()) + return S_Bfloat; else if (&Sem == &llvm::APFloat::IEEEsingle()) return S_IEEEsingle; else if (&Sem == &llvm::APFloat::IEEEdouble()) @@ -152,6 +157,9 @@ const fltSemantics &APFloatBase::IEEEhalf() { return semIEEEhalf; } + const fltSemantics &APFloatBase::Bfloat() { + return semBfloat; + } const fltSemantics &APFloatBase::IEEEsingle() { return semIEEEsingle; } @@ -3255,6 +3263,33 @@ (mysignificand & 0x7fffff))); } +APInt IEEEFloat::convertBfloatAPFloatToAPInt() const { + assert(semantics == (const llvm::fltSemantics *)&semBfloat); + assert(partCount() == 1); + + uint32_t myexponent, mysignificand; + + if (isFiniteNonZero()) { + myexponent = exponent + 127; // bias + mysignificand = (uint32_t)*significandParts(); + if (myexponent == 1 && !(mysignificand & 0x80)) + myexponent = 0; // denormal + } else if (category == fcZero) { + myexponent = 0; + mysignificand = 0; + } else if (category == fcInfinity) { + myexponent = 0x1f; + mysignificand = 0; + } else { + assert(category == fcNaN && "Unknown category!"); + myexponent = 0x1f; + mysignificand = (uint32_t)*significandParts(); + } + + return APInt(16, (((sign & 1) << 15) | ((myexponent & 0xff) << 7) | + (mysignificand & 0x7f))); +} + APInt IEEEFloat::convertHalfAPFloatToAPInt() const { assert(semantics == (const llvm::fltSemantics*)&semIEEEhalf); assert(partCount()==1); @@ -3290,6 +3325,9 @@ if (semantics == (const llvm::fltSemantics*)&semIEEEhalf) return convertHalfAPFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semBfloat) + return convertBfloatAPFloatToAPInt(); + if (semantics == (const llvm::fltSemantics*)&semIEEEsingle) return convertFloatAPFloatToAPInt(); @@ -3486,6 +3524,37 @@ } } +void IEEEFloat::initFromBfloatAPInt(const APInt &api) { + assert(api.getBitWidth() == 16); + uint32_t i = (uint32_t)*api.getRawData(); + uint32_t myexponent = (i >> 7) & 0xff; + uint32_t mysignificand = i & 0x7f; + + initialize(&semBfloat); + assert(partCount() == 1); + + sign = i >> 15; + if (myexponent == 0 && mysignificand == 0) { + // exponent, significand meaningless + category = fcZero; + } else if (myexponent == 0xff && mysignificand == 0) { + // exponent, significand meaningless + category = fcInfinity; + } else if (myexponent == 0xff && mysignificand != 0) { + // sign, exponent, significand meaningless + category = fcNaN; + *significandParts() = mysignificand; + } else { + category = fcNormal; + exponent = myexponent - 127; // bias + *significandParts() = mysignificand; + if (myexponent == 0) // denormal + exponent = -126; + else + *significandParts() |= 0x80; // integer bit + } +} + void IEEEFloat::initFromHalfAPInt(const APInt &api) { assert(api.getBitWidth()==16); uint32_t i = (uint32_t)*api.getRawData(); @@ -3524,6 +3593,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { if (Sem == &semIEEEhalf) return initFromHalfAPInt(api); + if (Sem == &semBfloat) + return initFromBfloatAPInt(api); if (Sem == &semIEEEsingle) return initFromFloatAPInt(api); if (Sem == &semIEEEdouble) @@ -4763,26 +4834,9 @@ llvm_unreachable("Unexpected semantics"); } -APFloat APFloat::getAllOnesValue(unsigned BitWidth, bool isIEEE) { - if (isIEEE) { - switch (BitWidth) { - case 16: - return APFloat(semIEEEhalf, APInt::getAllOnesValue(BitWidth)); - case 32: - return APFloat(semIEEEsingle, APInt::getAllOnesValue(BitWidth)); - case 64: - return APFloat(semIEEEdouble, APInt::getAllOnesValue(BitWidth)); - case 80: - return APFloat(semX87DoubleExtended, APInt::getAllOnesValue(BitWidth)); - case 128: - return APFloat(semIEEEquad, APInt::getAllOnesValue(BitWidth)); - default: - llvm_unreachable("Unknown floating bit width"); - } - } else { - assert(BitWidth == 128); - return APFloat(semPPCDoubleDouble, APInt::getAllOnesValue(BitWidth)); - } +APFloat APFloat::getAllOnesValue(const fltSemantics &Semantics, + unsigned BitWidth) { + return APFloat(Semantics, APInt::getAllOnesValue(BitWidth)); } void APFloat::print(raw_ostream &OS) const { diff --git a/llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp b/llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp --- a/llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp @@ -322,6 +322,7 @@ } case Type::FunctionTyID: case Type::VoidTyID: + case Type::BfloatTyID: case Type::X86_FP80TyID: case Type::FP128TyID: case Type::PPC_FP128TyID: diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -11438,8 +11438,9 @@ MVT LogicVT = VT; if (EltVT == MVT::f32 || EltVT == MVT::f64) { Zero = DAG.getConstantFP(0.0, DL, EltVT); - AllOnes = DAG.getConstantFP( - APFloat::getAllOnesValue(EltVT.getSizeInBits(), true), DL, EltVT); + APFloat AllOnesValue = APFloat::getAllOnesValue( + SelectionDAG::EVTToAPFloatSemantics(EltVT), EltVT.getSizeInBits()); + AllOnes = DAG.getConstantFP(AllOnesValue, DL, EltVT); LogicVT = MVT::getVectorVT(EltVT == MVT::f64 ? MVT::i64 : MVT::i32, Mask.size()); } else { diff --git a/llvm/test/Assembler/bfloat-constprop.ll b/llvm/test/Assembler/bfloat-constprop.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Assembler/bfloat-constprop.ll @@ -0,0 +1,17 @@ +; RUN: opt < %s -O3 -S | FileCheck %s +; RUN: verify-uselistorder %s +; Testing bfloat constant propagation. + +define bfloat @abc() nounwind { +entry: + %a = alloca bfloat, align 2 + %b = alloca bfloat, align 2 + %.compoundliteral = alloca float, align 4 + store bfloat 0xR40C0, bfloat* %a, align 2 + store bfloat 0xR40C0, bfloat* %b, align 2 + %tmp = load bfloat, bfloat* %a, align 2 + %tmp1 = load bfloat, bfloat* %b, align 2 + %add = fadd bfloat %tmp, %tmp1 +; CHECK: 0xR4140 + ret bfloat %add +} diff --git a/llvm/test/Assembler/bfloat-conv.ll b/llvm/test/Assembler/bfloat-conv.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Assembler/bfloat-conv.ll @@ -0,0 +1,14 @@ +; RUN: opt < %s -O3 -S | FileCheck %s +; RUN: verify-uselistorder %s +; Testing bfloat to float conversion. + +define float @abc() nounwind { +entry: + %a = alloca bfloat, align 2 + %.compoundliteral = alloca float, align 4 + store bfloat 0xR4C8D, bfloat* %a, align 2 + %tmp = load bfloat, bfloat* %a, align 2 + %conv = fpext bfloat %tmp to float +; CHECK: 0x4191A00000000000 + ret float %conv +} diff --git a/llvm/test/Assembler/bfloat.ll b/llvm/test/Assembler/bfloat.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Assembler/bfloat.ll @@ -0,0 +1,15 @@ +; RUN: llvm-as < %s | llvm-dis | FileCheck %s +; RUN: verify-uselistorder %s +; Basic smoke test for bfloat type. + +; CHECK: define bfloat @check_bfloat +define bfloat @check_bfloat(bfloat %A) { +; CHECK: ret bfloat %A + ret bfloat %A +} + +; CHECK: define bfloat @check_bfloat_literal +define bfloat @check_bfloat_literal() { +; CHECK: ret bfloat 0xR3149 + ret bfloat 0xR3149 +} diff --git a/llvm/tools/llvm-c-test/echo.cpp b/llvm/tools/llvm-c-test/echo.cpp --- a/llvm/tools/llvm-c-test/echo.cpp +++ b/llvm/tools/llvm-c-test/echo.cpp @@ -72,6 +72,8 @@ return LLVMVoidTypeInContext(Ctx); case LLVMHalfTypeKind: return LLVMHalfTypeInContext(Ctx); + case LLVMBfloatTypeKind: + return LLVMHalfTypeInContext(Ctx); case LLVMFloatTypeKind: return LLVMFloatTypeInContext(Ctx); case LLVMDoubleTypeKind: