diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -2140,6 +2140,11 @@ if (Target->hasBFloat16Type()) { Width = Target->getBFloat16Width(); Align = Target->getBFloat16Align(); + } else if ((getLangOpts().SYCLIsDevice || + (getLangOpts().OpenMP && getLangOpts().OpenMPIsDevice)) && + AuxTarget->hasBFloat16Type()) { + Width = AuxTarget->getBFloat16Width(); + Align = AuxTarget->getBFloat16Align(); } break; case BuiltinType::Float16: diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -3051,7 +3051,11 @@ break; } case BuiltinType::BFloat16: { - const TargetInfo *TI = &getASTContext().getTargetInfo(); + const TargetInfo *TI = ((getASTContext().getLangOpts().OpenMP && + getASTContext().getLangOpts().OpenMPIsDevice) || + getASTContext().getLangOpts().SYCLIsDevice) + ? getASTContext().getAuxTargetInfo() + : &getASTContext().getTargetInfo(); Out << TI->getBFloat16Mangling(); break; } diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp --- a/clang/lib/Sema/Sema.cpp +++ b/clang/lib/Sema/Sema.cpp @@ -1975,6 +1975,8 @@ (Ty->isIbm128Type() && !Context.getTargetInfo().hasIbm128Type()) || (Ty->isIntegerType() && Context.getTypeSize(Ty) == 128 && !Context.getTargetInfo().hasInt128Type()) || + (Ty->isBFloat16Type() && !Context.getTargetInfo().hasBFloat16Type() && + !LangOpts.CUDAIsDevice) || LongDoubleMismatched) { PartialDiagnostic PD = PDiag(diag::err_target_unsupported_type); if (D) diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -1518,9 +1518,10 @@ break; case DeclSpec::TST_half: Result = Context.HalfTy; break; case DeclSpec::TST_BFloat16: - if (!S.Context.getTargetInfo().hasBFloat16Type()) - S.Diag(DS.getTypeSpecTypeLoc(), diag::err_type_unsupported) - << "__bf16"; + if (!S.Context.getTargetInfo().hasBFloat16Type() && + !(S.getLangOpts().OpenMP && S.getLangOpts().OpenMPIsDevice) && + !S.getLangOpts().SYCLIsDevice) + S.Diag(DS.getTypeSpecTypeLoc(), diag::err_type_unsupported) << "__bf16"; Result = Context.BFloat16Ty; break; case DeclSpec::TST_float: Result = Context.FloatTy; break; diff --git a/clang/test/SemaSYCL/bf16.cpp b/clang/test/SemaSYCL/bf16.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaSYCL/bf16.cpp @@ -0,0 +1,22 @@ +// RUN: %clang_cc1 -triple spir64 -aux-triple x86_64-unknown-linux-gnu -fsycl-is-device -verify -fsyntax-only %s + +template +__attribute__((sycl_kernel)) void kernel(Func kernelFunc) { + kernelFunc(); // expected-note {{called by 'kernel}} +} + +void host_ok(void) { + __bf16 A; +} + +int main() +{ host_ok(); + __bf16 var; // expected-note {{'var' defined here}} + kernel([=]() { + (void)var; // expected-error {{'var' requires 16 bit size '__bf16' type support, but target 'spir64' does not support it}} + int B = sizeof(__bf16); + }); + + return 0; +} +