diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h --- a/llvm/include/llvm/IR/Intrinsics.h +++ b/llvm/include/llvm/IR/Intrinsics.h @@ -106,6 +106,7 @@ Token, Metadata, Half, + BFloat, Float, Double, Quad, diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -280,6 +280,9 @@ def llvm_v2f16_ty : LLVMType; // 2 x half (__fp16) def llvm_v4f16_ty : LLVMType; // 4 x half (__fp16) def llvm_v8f16_ty : LLVMType; // 8 x half (__fp16) +def llvm_v2bf16_ty : LLVMType; // 2 x bfloat (__bf16) +def llvm_v4bf16_ty : LLVMType; // 4 x bfloat (__bf16) +def llvm_v8bf16_ty : LLVMType; // 8 x bfloat (__bf16) def llvm_v1f32_ty : LLVMType; // 1 x float def llvm_v2f32_ty : LLVMType; // 2 x float def llvm_v4f32_ty : LLVMType; // 4 x float diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -783,6 +783,7 @@ def llvm_nxv4i32_ty : LLVMType; def llvm_nxv2i64_ty : LLVMType; def llvm_nxv8f16_ty : LLVMType; +def llvm_nxv8bf16_ty : LLVMType; def llvm_nxv4f32_ty : LLVMType; def llvm_nxv2f64_ty : LLVMType; diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp --- a/llvm/lib/IR/Function.cpp +++ b/llvm/lib/IR/Function.cpp @@ -747,7 +747,8 @@ IIT_SUBDIVIDE2_ARG = 44, IIT_SUBDIVIDE4_ARG = 45, IIT_VEC_OF_BITCASTS_TO_INT = 46, - IIT_V128 = 47 + IIT_V128 = 47, + IIT_BF16 = 48 }; static void DecodeIITType(unsigned &NextElt, ArrayRef Infos, @@ -782,6 +783,9 @@ case IIT_F16: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Half, 0)); return; + case IIT_BF16: + OutputTable.push_back(IITDescriptor::get(IITDescriptor::BFloat, 0)); + return; case IIT_F32: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Float, 0)); return; @@ -1005,6 +1009,7 @@ case IITDescriptor::Token: return Type::getTokenTy(Context); case IITDescriptor::Metadata: return Type::getMetadataTy(Context); case IITDescriptor::Half: return Type::getHalfTy(Context); + case IITDescriptor::BFloat: return Type::getBFloatTy(Context); case IITDescriptor::Float: return Type::getFloatTy(Context); case IITDescriptor::Double: return Type::getDoubleTy(Context); case IITDescriptor::Quad: return Type::getFP128Ty(Context); @@ -1183,6 +1188,7 @@ case IITDescriptor::Token: return !Ty->isTokenTy(); case IITDescriptor::Metadata: return !Ty->isMetadataTy(); case IITDescriptor::Half: return !Ty->isHalfTy(); + case IITDescriptor::BFloat: return !Ty->isBFloatTy(); case IITDescriptor::Float: return !Ty->isFloatTy(); case IITDescriptor::Double: return !Ty->isDoubleTy(); case IITDescriptor::Quad: return !Ty->isFP128Ty(); diff --git a/llvm/utils/TableGen/IntrinsicEmitter.cpp b/llvm/utils/TableGen/IntrinsicEmitter.cpp --- a/llvm/utils/TableGen/IntrinsicEmitter.cpp +++ b/llvm/utils/TableGen/IntrinsicEmitter.cpp @@ -245,7 +245,8 @@ IIT_SUBDIVIDE2_ARG = 44, IIT_SUBDIVIDE4_ARG = 45, IIT_VEC_OF_BITCASTS_TO_INT = 46, - IIT_V128 = 47 + IIT_V128 = 47, + IIT_BF16 = 48 }; static void EncodeFixedValueType(MVT::SimpleValueType VT, @@ -266,6 +267,7 @@ switch (VT) { default: PrintFatalError("unhandled MVT in intrinsic!"); case MVT::f16: return Sig.push_back(IIT_F16); + case MVT::bf16: return Sig.push_back(IIT_BF16); case MVT::f32: return Sig.push_back(IIT_F32); case MVT::f64: return Sig.push_back(IIT_F64); case MVT::f128: return Sig.push_back(IIT_F128);