Index: clang/include/clang/AST/Type.h =================================================================== --- clang/include/clang/AST/Type.h +++ clang/include/clang/AST/Type.h @@ -7179,7 +7179,7 @@ bool Ret = false; #define RVV_TYPE(Name, Id, SingletonId) #define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \ - IsFP) \ + IsFP, IsBF) \ if (ElBits == Bitwidth && IsFloat == IsFP) \ Ret |= isSpecificBuiltinType(BuiltinType::Id); #include "clang/Basic/RISCVVTypes.def" Index: clang/include/clang/Basic/RISCVVTypes.def =================================================================== --- clang/include/clang/Basic/RISCVVTypes.def +++ clang/include/clang/Basic/RISCVVTypes.def @@ -12,7 +12,7 @@ // A builtin type that has not been covered by any other #define // Defining this macro covers all the builtins. // -// - RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, IsSigned, IsFP) +// - RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, IsSigned, IsFP, IsBF) // A RISC-V V scalable vector. // // - RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) @@ -45,7 +45,7 @@ #endif #ifndef RVV_VECTOR_TYPE -#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, IsFP)\ +#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, IsFP, IsBF)\ RVV_TYPE(Name, Id, SingletonId) #endif @@ -56,12 +56,17 @@ #ifndef RVV_VECTOR_TYPE_INT #define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned) \ - RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, false) + RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, false, false) #endif #ifndef RVV_VECTOR_TYPE_FLOAT #define RVV_VECTOR_TYPE_FLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \ - RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, true) + RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, true, false) +#endif + +#ifndef RVV_VECTOR_TYPE_BFLOAT +#define RVV_VECTOR_TYPE_BFLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \ + RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, false, true) #endif //===- Vector types -------------------------------------------------------===// @@ -125,6 +130,13 @@ RVV_VECTOR_TYPE_FLOAT("__rvv_float16m4_t", RvvFloat16m4, RvvFloat16m4Ty, 16, 16, 1) RVV_VECTOR_TYPE_FLOAT("__rvv_float16m8_t", RvvFloat16m8, RvvFloat16m8Ty, 32, 16, 1) +RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16mf4_t",RvvBFloat16mf4,RvvBFloat16mf4Ty,1, 16, 1) +RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16mf2_t",RvvBFloat16mf2,RvvBFloat16mf2Ty,2, 16, 1) +RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m1_t", RvvBFloat16m1, RvvBFloat16m1Ty, 4, 16, 1) +RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m2_t", RvvBFloat16m2, RvvBFloat16m2Ty, 8, 16, 1) +RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m4_t", RvvBFloat16m4, RvvBFloat16m4Ty, 16, 16, 1) +RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m8_t", RvvBFloat16m8, RvvBFloat16m8Ty, 32, 16, 1) + RVV_VECTOR_TYPE_FLOAT("__rvv_float32mf2_t",RvvFloat32mf2,RvvFloat32mf2Ty,1, 32, 1) RVV_VECTOR_TYPE_FLOAT("__rvv_float32m1_t", RvvFloat32m1, RvvFloat32m1Ty, 2, 32, 1) RVV_VECTOR_TYPE_FLOAT("__rvv_float32m2_t", RvvFloat32m2, RvvFloat32m2Ty, 4, 32, 1) @@ -148,6 +160,7 @@ RVV_VECTOR_TYPE_INT("__rvv_int32m1x2_t", RvvInt32m1x2, RvvInt32m1x2Ty, 2, 32, 2, true) +#undef RVV_VECTOR_TYPE_BFLOAT #undef RVV_VECTOR_TYPE_FLOAT #undef RVV_VECTOR_TYPE_INT #undef RVV_VECTOR_TYPE Index: clang/include/clang/Basic/riscv_vector_common.td =================================================================== --- clang/include/clang/Basic/riscv_vector_common.td +++ clang/include/clang/Basic/riscv_vector_common.td @@ -41,6 +41,7 @@ // x: float16_t (half) // f: float32_t (float) // d: float64_t (double) +// y: bfloat16_t (bfloat16) // // This way, given an LMUL, a record with a TypeRange "sil" will cause the // definition of 3 builtins. Each type "t" in the TypeRange (in this example Index: clang/include/clang/Support/RISCVVIntrinsicUtils.h =================================================================== --- clang/include/clang/Support/RISCVVIntrinsicUtils.h +++ clang/include/clang/Support/RISCVVIntrinsicUtils.h @@ -193,10 +193,11 @@ Int16 = 1 << 1, Int32 = 1 << 2, Int64 = 1 << 3, - Float16 = 1 << 4, - Float32 = 1 << 5, - Float64 = 1 << 6, - MaxOffset = 6, + BFloat16 = 1 << 4, + Float16 = 1 << 5, + Float32 = 1 << 6, + Float64 = 1 << 7, + MaxOffset = 7, LLVM_MARK_AS_BITMASK_ENUM(Float64), }; @@ -211,6 +212,7 @@ SignedInteger, UnsignedInteger, Float, + BFloat, Invalid, }; @@ -285,6 +287,7 @@ return isVector() && ElementBitwidth == Width; } bool isFloat() const { return ScalarType == ScalarTypeKind::Float; } + bool isBFloat() const { return ScalarType == ScalarTypeKind::BFloat; } bool isSignedInteger() const { return ScalarType == ScalarTypeKind::SignedInteger; } Index: clang/lib/AST/ASTContext.cpp =================================================================== --- clang/lib/AST/ASTContext.cpp +++ clang/lib/AST/ASTContext.cpp @@ -2230,7 +2230,7 @@ break; #include "clang/Basic/PPCTypes.def" #define RVV_VECTOR_TYPE(Name, Id, SingletonId, ElKind, ElBits, NF, IsSigned, \ - IsFP) \ + IsFP, IsBF) \ case BuiltinType::Id: \ Width = 0; \ Align = ElBits; \ @@ -3998,6 +3998,9 @@ case BuiltinType::Id: \ return {ElBits == 16 ? Float16Ty : (ElBits == 32 ? FloatTy : DoubleTy), \ llvm::ElementCount::getScalable(NumEls), NF}; +#define RVV_VECTOR_TYPE_BFLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \ + case BuiltinType::Id: \ + return {BFloat16Ty, llvm::ElementCount::getScalable(NumEls), NF}; #define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) \ case BuiltinType::Id: \ return {BoolTy, llvm::ElementCount::getScalable(NumEls), 1}; @@ -4045,11 +4048,14 @@ } else if (Target->hasRISCVVTypes()) { uint64_t EltTySize = getTypeSize(EltTy); #define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \ - IsFP) \ + IsFP, IsBF) \ if (!EltTy->isBooleanType() && \ ((EltTy->hasIntegerRepresentation() && \ EltTy->hasSignedIntegerRepresentation() == IsSigned) || \ - (EltTy->hasFloatingRepresentation() && IsFP)) && \ + (EltTy->hasFloatingRepresentation() && !EltTy->isBFloat16Type() && \ + IsFP && !IsBF) || \ + (EltTy->hasFloatingRepresentation() && EltTy->isBFloat16Type() && \ + IsBF && !IsFP)) && \ EltTySize == ElBits && NumElts == NumEls && NumFields == NF) \ return SingletonId; #define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) \ Index: clang/lib/AST/Type.cpp =================================================================== --- clang/lib/AST/Type.cpp +++ clang/lib/AST/Type.cpp @@ -2449,7 +2449,7 @@ if (const BuiltinType *BT = getAs()) { switch (BT->getKind()) { // FIXME: Support more than LMUL 1. -#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, IsFP) \ +#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, IsFP, IsBF) \ case BuiltinType::Id: \ return NF == 1 && (NumEls * ElBits) == llvm::RISCV::RVVBitsPerBlock; #include "clang/Basic/RISCVVTypes.def" Index: clang/lib/Sema/SemaRISCVVectorLookup.cpp =================================================================== --- clang/lib/Sema/SemaRISCVVectorLookup.cpp +++ clang/lib/Sema/SemaRISCVVectorLookup.cpp @@ -117,6 +117,9 @@ case ScalarTypeKind::UnsignedInteger: QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false); break; + case ScalarTypeKind::BFloat: + QT = Context.BFloat16Ty; + break; case ScalarTypeKind::Float: switch (Type->getElementBitwidth()) { case 64: Index: clang/lib/Support/RISCVVIntrinsicUtils.cpp =================================================================== --- clang/lib/Support/RISCVVIntrinsicUtils.cpp +++ clang/lib/Support/RISCVVIntrinsicUtils.cpp @@ -102,6 +102,7 @@ // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 +// bfloat16 | N/A | nxv1bf16 | nxv2bf16| nxv4bf16| nxv8bf16 | nxv16bf16| nxv32bf16 // clang-format on bool RVVType::verifyType() const { @@ -113,6 +114,8 @@ return false; if (isFloat() && ElementBitwidth == 8) return false; + if (isBFloat() && ElementBitwidth != 16) + return false; if (IsTuple && (NF == 1 || NF > 8)) return false; unsigned V = *Scale; @@ -198,6 +201,9 @@ llvm_unreachable("Unhandled ElementBitwidth!"); } break; + case ScalarTypeKind::BFloat: + BuiltinStr += "y"; + break; default: llvm_unreachable("ScalarType is invalid!"); } @@ -233,6 +239,9 @@ case ScalarTypeKind::Float: ClangBuiltinStr += "float"; break; + case ScalarTypeKind::BFloat: + ClangBuiltinStr += "bfloat"; + break; case ScalarTypeKind::SignedInteger: ClangBuiltinStr += "int"; break; @@ -299,6 +308,15 @@ } else Str += getTypeString("float"); break; + case ScalarTypeKind::BFloat: + if (isScalar()) { + if (ElementBitwidth == 16) + Str += "__bf16"; + else + llvm_unreachable("Unhandled floating type."); + } else + Str += getTypeString("bfloat"); + break; case ScalarTypeKind::SignedInteger: Str += getTypeString("int"); break; @@ -321,6 +339,9 @@ case ScalarTypeKind::Float: ShortStr = "f" + utostr(ElementBitwidth); break; + case ScalarTypeKind::BFloat: + ShortStr = "bf" + utostr(ElementBitwidth); + break; case ScalarTypeKind::SignedInteger: ShortStr = "i" + utostr(ElementBitwidth); break; @@ -366,6 +387,10 @@ ElementBitwidth = 64; ScalarType = ScalarTypeKind::Float; break; + case BasicType::BFloat16: + ElementBitwidth = 16; + ScalarType = ScalarTypeKind::BFloat; + break; default: llvm_unreachable("Unhandled type code!"); } Index: clang/test/CodeGen/RISCV/rvv-intrinsics-handcrafted/rvv-intrinsic-datatypes.cpp =================================================================== --- clang/test/CodeGen/RISCV/rvv-intrinsics-handcrafted/rvv-intrinsic-datatypes.cpp +++ clang/test/CodeGen/RISCV/rvv-intrinsics-handcrafted/rvv-intrinsic-datatypes.cpp @@ -63,6 +63,12 @@ // CHECK-NEXT: [[F16M2:%.*]] = alloca , align 2 // CHECK-NEXT: [[F16M4:%.*]] = alloca , align 2 // CHECK-NEXT: [[F16M8:%.*]] = alloca , align 2 +// CHECK-NEXT: [[BF16MF4:%.*]] = alloca , align 2 +// CHECK-NEXT: [[BF16MF2:%.*]] = alloca , align 2 +// CHECK-NEXT: [[BF16M1:%.*]] = alloca , align 2 +// CHECK-NEXT: [[BF16M2:%.*]] = alloca , align 2 +// CHECK-NEXT: [[BF16M4:%.*]] = alloca , align 2 +// CHECK-NEXT: [[BF16M8:%.*]] = alloca , align 2 // CHECK-NEXT: [[F32MF2:%.*]] = alloca , align 4 // CHECK-NEXT: [[F32M1:%.*]] = alloca , align 4 // CHECK-NEXT: [[F32M2:%.*]] = alloca , align 4 @@ -140,6 +146,13 @@ vfloat16m4_t f16m4; vfloat16m8_t f16m8; + vbfloat16mf4_t bf16mf4; + vbfloat16mf2_t bf16mf2; + vbfloat16m1_t bf16m1; + vbfloat16m2_t bf16m2; + vbfloat16m4_t bf16m4; + vbfloat16m8_t bf16m8; + vfloat32mf2_t f32mf2; vfloat32m1_t f32m1; vfloat32m2_t f32m2; Index: clang/test/Sema/riscv-types.c =================================================================== --- clang/test/Sema/riscv-types.c +++ clang/test/Sema/riscv-types.c @@ -136,6 +136,25 @@ // CHECK: __rvv_int32m1x2_t x44; __rvv_int32m1x2_t x44; + + // CHECK: __rvv_bfloat16m1_t x45; + __rvv_bfloat16m1_t x45; + + // CHECK: __rvv_bfloat16m2_t x46; + __rvv_bfloat16m2_t x46; + + // CHECK: __rvv_bfloat16m4_t x47; + __rvv_bfloat16m4_t x47; + + // CHECK: __rvv_bfloat16m8_t x48; + __rvv_bfloat16m8_t x48; + + // CHECK: __rvv_bfloat16mf4_t x49; + __rvv_bfloat16mf4_t x49; + + // CHECK: __rvv_bfloat16mf2_t x50; + __rvv_bfloat16mf2_t x50; + } typedef __rvv_bool4_t vbool4_t; Index: clang/utils/TableGen/RISCVVEmitter.cpp =================================================================== --- clang/utils/TableGen/RISCVVEmitter.cpp +++ clang/utils/TableGen/RISCVVEmitter.cpp @@ -150,7 +150,9 @@ case 'd': return BasicType::Float64; break; - + case 'y': + return BasicType::BFloat16; + break; default: return BasicType::Unknown; } @@ -378,7 +380,7 @@ } for (BasicType BT : - {BasicType::Float16, BasicType::Float32, BasicType::Float64}) { + {BasicType::Float16, BasicType::Float32, BasicType::Float64, BasicType::BFloat16}) { for (int Log2LMUL : Log2LMULs) { auto T = TypeCache.computeType(BT, Log2LMUL, PrototypeDescriptor::Vector); if (T)