diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -2171,9 +2171,15 @@ Align = Target->getLongFractAlign(); break; case BuiltinType::BFloat16: - if (Target->hasBFloat16Type()) { + if (Target->hasBFloat16Type() || !getLangOpts().OpenMP || + !getLangOpts().OpenMPIsDevice) { Width = Target->getBFloat16Width(); Align = Target->getBFloat16Align(); + } else { + assert(getLangOpts().OpenMP && getLangOpts().OpenMPIsDevice && + "Expected OpenMP device compilation."); + Width = AuxTarget->getBFloat16Width(); + Align = AuxTarget->getBFloat16Align(); } break; case BuiltinType::Float16: diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -1518,7 +1518,9 @@ break; case DeclSpec::TST_half: Result = Context.HalfTy; break; case DeclSpec::TST_BFloat16: - if (!S.Context.getTargetInfo().hasBFloat16Type()) + // Likewise, CUDA host and device may have different __bf16 support. + if (!S.Context.getTargetInfo().hasBFloat16Type() && !S.getLangOpts().CUDA && + !(S.getLangOpts().OpenMP && S.getLangOpts().OpenMPIsDevice)) S.Diag(DS.getTypeSpecTypeLoc(), diag::err_type_unsupported) << "__bf16"; Result = Context.BFloat16Ty; diff --git a/clang/test/SemaCUDA/amdgpu-bf16.cu b/clang/test/SemaCUDA/amdgpu-bf16.cu new file mode 100644 --- /dev/null +++ b/clang/test/SemaCUDA/amdgpu-bf16.cu @@ -0,0 +1,14 @@ +// RUN: %clang_cc1 -fsyntax-only -triple amdgcn-amd-amdhsa -aux-triple x86_64-unknown-linux-gnu -verify %s +// expected-no-diagnostics + +// If AMDGPU is the main target and X86 the aux target, ensure we +// don't complain about unsupported BF16 types in x86 code. + +#include "Inputs/cuda.h" + +__device__ void devicefn() { +} + +__bf16 hostfn(__bf16 a) { + return a; +} \ No newline at end of file