diff --git a/clang/lib/Basic/Targets/NVPTX.h b/clang/lib/Basic/Targets/NVPTX.h --- a/clang/lib/Basic/Targets/NVPTX.h +++ b/clang/lib/Basic/Targets/NVPTX.h @@ -176,6 +176,7 @@ } bool hasBitIntType() const override { return true; } + bool hasBFloat16Type() const override { return true; } }; } // namespace targets } // namespace clang diff --git a/clang/lib/Basic/Targets/NVPTX.cpp b/clang/lib/Basic/Targets/NVPTX.cpp --- a/clang/lib/Basic/Targets/NVPTX.cpp +++ b/clang/lib/Basic/Targets/NVPTX.cpp @@ -102,6 +102,8 @@ IntAlign = HostTarget->getIntAlign(); HalfWidth = HostTarget->getHalfWidth(); HalfAlign = HostTarget->getHalfAlign(); + BFloat16Width = HostTarget->getBFloat16Width(); + BFloat16Align = HostTarget->getBFloat16Align(); FloatWidth = HostTarget->getFloatWidth(); FloatAlign = HostTarget->getFloatAlign(); DoubleWidth = HostTarget->getDoubleWidth(); diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -1519,7 +1519,9 @@ break; case DeclSpec::TST_half: Result = Context.HalfTy; break; case DeclSpec::TST_BFloat16: - if (!S.Context.getTargetInfo().hasBFloat16Type()) + if (!(S.Context.getTargetInfo().hasBFloat16Type() || + (S.getLangOpts().CUDAIsDevice && S.Context.getAuxTargetInfo() && + S.Context.getAuxTargetInfo()->hasBFloat16Type()))) S.Diag(DS.getTypeSpecTypeLoc(), diag::err_type_unsupported) << "__bf16"; Result = Context.BFloat16Ty;