Index: clang/include/clang/Basic/RISCVVTypes.def =================================================================== --- clang/include/clang/Basic/RISCVVTypes.def +++ clang/include/clang/Basic/RISCVVTypes.def @@ -125,6 +125,13 @@ RVV_VECTOR_TYPE_FLOAT("__rvv_float16m4_t", RvvFloat16m4, RvvFloat16m4Ty, 16, 16, 1) RVV_VECTOR_TYPE_FLOAT("__rvv_float16m8_t", RvvFloat16m8, RvvFloat16m8Ty, 32, 16, 1) +RVV_VECTOR_TYPE_FLOAT("__rvv_bfloat16mf4_t",RvvBFloat16mf4,RvvBFloat16mf4Ty,1, 16, 1) +RVV_VECTOR_TYPE_FLOAT("__rvv_bfloat16mf2_t",RvvBFloat16mf2,RvvBFloat16mf2Ty,2, 16, 1) +RVV_VECTOR_TYPE_FLOAT("__rvv_bfloat16m1_t", RvvBFloat16m1, RvvBFloat16m1Ty, 4, 16, 1) +RVV_VECTOR_TYPE_FLOAT("__rvv_bfloat16m2_t", RvvBFloat16m2, RvvBFloat16m2Ty, 8, 16, 1) +RVV_VECTOR_TYPE_FLOAT("__rvv_bfloat16m4_t", RvvBFloat16m4, RvvBFloat16m4Ty, 16, 16, 1) +RVV_VECTOR_TYPE_FLOAT("__rvv_bfloat16m8_t", RvvBFloat16m8, RvvBFloat16m8Ty, 32, 16, 1) + RVV_VECTOR_TYPE_FLOAT("__rvv_float32mf2_t",RvvFloat32mf2,RvvFloat32mf2Ty,1, 32, 1) RVV_VECTOR_TYPE_FLOAT("__rvv_float32m1_t", RvvFloat32m1, RvvFloat32m1Ty, 2, 32, 1) RVV_VECTOR_TYPE_FLOAT("__rvv_float32m2_t", RvvFloat32m2, RvvFloat32m2Ty, 4, 32, 1) Index: clang/include/clang/Basic/riscv_vector_common.td =================================================================== --- clang/include/clang/Basic/riscv_vector_common.td +++ clang/include/clang/Basic/riscv_vector_common.td @@ -41,6 +41,7 @@ // x: float16_t (half) // f: float32_t (float) // d: float64_t (double) +// y: bfloat16_t (bfloat16) // // This way, given an LMUL, a record with a TypeRange "sil" will cause the // definition of 3 builtins. Each type "t" in the TypeRange (in this example Index: clang/include/clang/Support/RISCVVIntrinsicUtils.h =================================================================== --- clang/include/clang/Support/RISCVVIntrinsicUtils.h +++ clang/include/clang/Support/RISCVVIntrinsicUtils.h @@ -193,10 +193,11 @@ Int16 = 1 << 1, Int32 = 1 << 2, Int64 = 1 << 3, - Float16 = 1 << 4, - Float32 = 1 << 5, - Float64 = 1 << 6, - MaxOffset = 6, + BFloat16 = 1 << 4, + Float16 = 1 << 5, + Float32 = 1 << 6, + Float64 = 1 << 7, + MaxOffset = 7, LLVM_MARK_AS_BITMASK_ENUM(Float64), }; @@ -211,6 +212,7 @@ SignedInteger, UnsignedInteger, Float, + BFloat, Invalid, }; @@ -285,6 +287,7 @@ return isVector() && ElementBitwidth == Width; } bool isFloat() const { return ScalarType == ScalarTypeKind::Float; } + bool isBFloat() const { return ScalarType == ScalarTypeKind::BFloat; } bool isSignedInteger() const { return ScalarType == ScalarTypeKind::SignedInteger; } Index: clang/lib/Sema/SemaRISCVVectorLookup.cpp =================================================================== --- clang/lib/Sema/SemaRISCVVectorLookup.cpp +++ clang/lib/Sema/SemaRISCVVectorLookup.cpp @@ -117,6 +117,9 @@ case ScalarTypeKind::UnsignedInteger: QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false); break; + case ScalarTypeKind::BFloat: + QT = Context.BFloat16Ty; + break; case ScalarTypeKind::Float: switch (Type->getElementBitwidth()) { case 64: Index: clang/lib/Support/RISCVVIntrinsicUtils.cpp =================================================================== --- clang/lib/Support/RISCVVIntrinsicUtils.cpp +++ clang/lib/Support/RISCVVIntrinsicUtils.cpp @@ -102,6 +102,7 @@ // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 +// bfloat16 | N/A | nxv1bf16 | nxv2bf16| nxv4bf16| nxv8bf16 | nxv16bf16| nxv32bf16 // clang-format on bool RVVType::verifyType() const { @@ -113,6 +114,8 @@ return false; if (isFloat() && ElementBitwidth == 8) return false; + if (isBFloat() && ElementBitwidth != 16) + return false; if (IsTuple && (NF == 1 || NF > 8)) return false; unsigned V = *Scale; @@ -198,6 +201,9 @@ llvm_unreachable("Unhandled ElementBitwidth!"); } break; + case ScalarTypeKind::BFloat: + BuiltinStr += "y"; + break; default: llvm_unreachable("ScalarType is invalid!"); } @@ -233,6 +239,9 @@ case ScalarTypeKind::Float: ClangBuiltinStr += "float"; break; + case ScalarTypeKind::BFloat: + ClangBuiltinStr += "bfloat"; + break; case ScalarTypeKind::SignedInteger: ClangBuiltinStr += "int"; break; @@ -299,6 +308,15 @@ } else Str += getTypeString("float"); break; + case ScalarTypeKind::BFloat: + if (isScalar()) { + if (ElementBitwidth == 16) + Str += "__bf16"; + else + llvm_unreachable("Unhandled floating type."); + } else + Str += getTypeString("bfloat"); + break; case ScalarTypeKind::SignedInteger: Str += getTypeString("int"); break; @@ -321,6 +339,9 @@ case ScalarTypeKind::Float: ShortStr = "f" + utostr(ElementBitwidth); break; + case ScalarTypeKind::BFloat: + ShortStr = "bf" + utostr(ElementBitwidth); + break; case ScalarTypeKind::SignedInteger: ShortStr = "i" + utostr(ElementBitwidth); break; @@ -366,6 +387,10 @@ ElementBitwidth = 64; ScalarType = ScalarTypeKind::Float; break; + case BasicType::BFloat16: + ElementBitwidth = 16; + ScalarType = ScalarTypeKind::BFloat; + break; default: llvm_unreachable("Unhandled type code!"); } Index: clang/test/Sema/riscv-types.c =================================================================== --- clang/test/Sema/riscv-types.c +++ clang/test/Sema/riscv-types.c @@ -136,6 +136,25 @@ // CHECK: __rvv_int32m1x2_t x44; __rvv_int32m1x2_t x44; + + // CHECK: __rvv_bfloat16m1_t x45; + __rvv_bfloat16m1_t x45; + + // CHECK: __rvv_bfloat16m2_t x46; + __rvv_bfloat16m2_t x46; + + // CHECK: __rvv_bfloat16m4_t x47; + __rvv_bfloat16m4_t x47; + + // CHECK: __rvv_bfloat16m8_t x48; + __rvv_bfloat16m8_t x48; + + // CHECK: __rvv_bfloat16mf4_t x49; + __rvv_bfloat16mf4_t x49; + + // CHECK: __rvv_bfloat16mf2_t x50; + __rvv_bfloat16mf2_t x50; + } typedef __rvv_bool4_t vbool4_t; Index: clang/utils/TableGen/RISCVVEmitter.cpp =================================================================== --- clang/utils/TableGen/RISCVVEmitter.cpp +++ clang/utils/TableGen/RISCVVEmitter.cpp @@ -150,7 +150,9 @@ case 'd': return BasicType::Float64; break; - + case 'y': + return BasicType::BFloat16; + break; default: return BasicType::Unknown; } @@ -378,7 +380,7 @@ } for (BasicType BT : - {BasicType::Float16, BasicType::Float32, BasicType::Float64}) { + {BasicType::Float16, BasicType::Float32, BasicType::Float64, BasicType::BFloat16}) { for (int Log2LMUL : Log2LMULs) { auto T = TypeCache.computeType(BT, Log2LMUL, PrototypeDescriptor::Vector); if (T)