Index: clang/include/clang/Basic/DiagnosticSemaKinds.td =================================================================== --- clang/include/clang/Basic/DiagnosticSemaKinds.td +++ clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -9858,7 +9858,7 @@ def err_record_with_pointers_kernel_param : Error< "%select{struct|union}0 kernel parameters may not contain pointers">; def note_within_field_of_type : Note< - "within field of type %0 declared here">; + "within field %select{|or base class }1of type %0 declared here">; def note_illegal_field_declared_here : Note< "field of illegal %select{type|pointer type}0 %1 declared here">; def err_opencl_type_struct_or_union_field : Error< Index: clang/include/clang/Basic/LangOptions.def =================================================================== --- clang/include/clang/Basic/LangOptions.def +++ clang/include/clang/Basic/LangOptions.def @@ -253,6 +253,9 @@ ENUM_LANGOPT(SYCLVersion , SYCLMajorVersion, 1, SYCL_None, "Version of the SYCL standard used") LANGOPT(HIPUseNewLaunchAPI, 1, 0, "Use new kernel launching API for HIP") +LANGOPT(HIPAllowHalfArg, 1, 1, "Allow half precision types or aggregate types " + "containing half precision types as host " + "function parameter and return types for HIP") LANGOPT(SizedDeallocation , 1, 0, "sized deallocation") LANGOPT(AlignedAllocation , 1, 0, "aligned allocation") Index: clang/include/clang/Driver/Options.td =================================================================== --- clang/include/clang/Driver/Options.td +++ clang/include/clang/Driver/Options.td @@ -921,6 +921,12 @@ LangOpts<"HIPUseNewLaunchAPI">, DefaultFalse, PosFlag, NegFlag, BothFlags<[], " new kernel launching API for HIP">>; +defm hip_allow_half_arg : BoolFOption<"hip-allow-half-arg", + LangOpts<"HIPAllowHalfArg">, DefaultTrue, + PosFlag, NegFlag, + BothFlags<[], " half precision types or aggregate types containing half " + "precision types as host function parameter type or return type">>, + ShouldParseIf; defm gpu_allow_device_init : BoolFOption<"gpu-allow-device-init", LangOpts<"GPUAllowDeviceInit">, DefaultFalse, PosFlag, NegFlag, Index: clang/lib/Headers/__clang_hip_cmath.h =================================================================== --- clang/lib/Headers/__clang_hip_cmath.h +++ clang/lib/Headers/__clang_hip_cmath.h @@ -225,7 +225,9 @@ template struct __numeric_type { static void __test(...); - static _Float16 __test(_Float16); + // _Float16 is not allowed as host function arguments until ABI compatibility + // issue with gcc is resolved. + static __device__ _Float16 __test(_Float16); static float __test(float); static double __test(char); static double __test(int); Index: clang/lib/Sema/SemaDecl.cpp =================================================================== --- clang/lib/Sema/SemaDecl.cpp +++ clang/lib/Sema/SemaDecl.cpp @@ -8591,6 +8591,130 @@ } } +// Result type returned by the functor checking struct field. +enum class CheckFieldResult { + Valid, // The filed is valid + Recurse, // The field is a struct which needs to be checked recursively + Invalid, // The filed is invalid +}; + +// Check whether struct or array type contains invalid fields or elements by +// recursively visiting fields of the structs with the functor CheckField. +// Returns true if the type is valid. CheckField returns Valid if the field is +// valid, emits a diagnostic message and returns Invalid if the field is +// invalid, returns Recurse if the field is a struct which needs further check. +// ValidTypes contain known valid types. +static bool +checkStructOrArrayType(Sema &S, QualType PT, + llvm::SmallPtrSetImpl &ValidTypes, + std::function CheckFieldType, + std::function DiagInvalidParam) { + // Track nested structs we will inspect + SmallVector VisitStack; + + // Track where we are in the nested structs. Items will migrate from + // VisitStack to HistoryStack as we do the DFS for bad field. + SmallVector HistoryStack; + HistoryStack.push_back(nullptr); + + // At this point we already handled everything except of a RecordType or + // an ArrayType of a RecordType. + assert((PT->isArrayType() || PT->isRecordType()) && "Unexpected type."); + const RecordType *RecTy = + PT->getPointeeOrArrayElementType()->getAs(); + const RecordDecl *OrigRecDecl = RecTy->getDecl(); + + VisitStack.push_back(RecTy->getDecl()); + assert(VisitStack.back() && "First decl null?"); + + do { + const Decl *Next = VisitStack.pop_back_val(); + if (!Next) { + // HistoryStack is empty if a struct has no fields or base. + if (HistoryStack.empty()) + continue; + // Found a marker, we have gone up a level + if (const FieldDecl *Hist = HistoryStack.pop_back_val()) + ValidTypes.insert(Hist->getType().getTypePtr()); + + continue; + } + + // Adds everything except the original parameter declaration (which is not a + // field itself) to the history stack. + const RecordDecl *RD; + if (const FieldDecl *Field = dyn_cast(Next)) { + HistoryStack.push_back(Field); + + QualType FieldTy = Field->getType(); + // Other field types (known to be valid or invalid) are handled while we + // walk around RecordDecl::fields(). + assert((FieldTy->isArrayType() || FieldTy->isRecordType()) && + "Unexpected type."); + const Type *FieldRecTy = FieldTy->getPointeeOrArrayElementType(); + + RD = FieldRecTy->castAs()->getDecl(); + } else { + RD = cast(Next); + } + + RD = RD->getDefinition(); + // A struct return type can be undefined. + if (!RD) + continue; + + // Add a null marker so we know when we've gone back up a level + VisitStack.push_back(nullptr); + + if (const auto *CXXRD = dyn_cast(RD)) + for (auto Base : CXXRD->bases()) { + // Skip non-record type, e.g. TemplateSpecializationType + if (const auto *RT = + Base.getType().getCanonicalType()->getAs()) { + VisitStack.push_back(RT->getDecl()); + } + } + + for (const auto *FD : RD->fields()) { + QualType QT = FD->getType(); + + if (ValidTypes.count(QT.getTypePtr())) + continue; + + auto Result = CheckFieldType(QT); + + if (Result == CheckFieldResult::Valid) + continue; + + if (Result == CheckFieldResult::Recurse) { + VisitStack.push_back(FD); + continue; + } + + assert(Result == CheckFieldResult::Invalid); + DiagInvalidParam(QT); + S.Diag(OrigRecDecl->getLocation(), diag::note_within_field_of_type) + << OrigRecDecl->getDeclName() << S.getLangOpts().CPlusPlus; + + // We have an error, now let's go back up through history and show where + // the offending field came from + for (ArrayRef::const_iterator + I = HistoryStack.begin() + 1, + E = HistoryStack.end(); + I != E; ++I) { + const FieldDecl *OuterField = *I; + S.Diag(OuterField->getLocation(), diag::note_within_field_of_type) + << OuterField->getType() << S.getLangOpts().CPlusPlus; + } + + S.Diag(FD->getLocation(), diag::note_illegal_field_declared_here) + << QT->isPointerType() << QT; + return false; + } + } while (!VisitStack.empty()); + return true; +} + enum OpenCLParamType { ValidKernelParam, PtrPtrKernelParam, @@ -8752,106 +8876,109 @@ break; } - // Track nested structs we will inspect - SmallVector VisitStack; - - // Track where we are in the nested structs. Items will migrate from - // VisitStack to HistoryStack as we do the DFS for bad field. - SmallVector HistoryStack; - HistoryStack.push_back(nullptr); - - // At this point we already handled everything except of a RecordType or - // an ArrayType of a RecordType. - assert((PT->isArrayType() || PT->isRecordType()) && "Unexpected type."); - const RecordType *RecTy = - PT->getPointeeOrArrayElementType()->getAs(); - const RecordDecl *OrigRecDecl = RecTy->getDecl(); - - VisitStack.push_back(RecTy->getDecl()); - assert(VisitStack.back() && "First decl null?"); - - do { - const Decl *Next = VisitStack.pop_back_val(); - if (!Next) { - assert(!HistoryStack.empty()); - // Found a marker, we have gone up a level - if (const FieldDecl *Hist = HistoryStack.pop_back_val()) - ValidTypes.insert(Hist->getType().getTypePtr()); - - continue; - } - - // Adds everything except the original parameter declaration (which is not a - // field itself) to the history stack. - const RecordDecl *RD; - if (const FieldDecl *Field = dyn_cast(Next)) { - HistoryStack.push_back(Field); - - QualType FieldTy = Field->getType(); - // Other field types (known to be valid or invalid) are handled while we - // walk around RecordDecl::fields(). - assert((FieldTy->isArrayType() || FieldTy->isRecordType()) && - "Unexpected type."); - const Type *FieldRecTy = FieldTy->getPointeeOrArrayElementType(); - - RD = FieldRecTy->castAs()->getDecl(); + auto DiagInvalidParam = [&](QualType ParamTy) { + OpenCLParamType ParamType = getOpenCLKernelParameterType(S, ParamTy); + // OpenCL v1.2 s6.9.p: + // Arguments to kernel functions that are declared to be a struct or union + // do not allow OpenCL objects to be passed as elements of the struct or + // union. + if (ParamType == PtrKernelParam || ParamType == PtrPtrKernelParam || + ParamType == InvalidAddrSpacePtrKernelParam) { + S.Diag(Param->getLocation(), diag::err_record_with_pointers_kernel_param) + << PT->isUnionType() << PT; } else { - RD = cast(Next); + S.Diag(Param->getLocation(), diag::err_bad_kernel_param_type) << PT; } + }; + auto CheckFieldType = [&](QualType QT) { + OpenCLParamType ParamType = getOpenCLKernelParameterType(S, QT); + if (ParamType == ValidKernelParam) + return CheckFieldResult::Valid; - // Add a null marker so we know when we've gone back up a level - VisitStack.push_back(nullptr); + if (ParamType == RecordKernelParam) { + return CheckFieldResult::Recurse; + } - for (const auto *FD : RD->fields()) { - QualType QT = FD->getType(); + return CheckFieldResult::Invalid; + }; + if (!checkStructOrArrayType(S, PT, ValidTypes, CheckFieldType, + DiagInvalidParam)) + D.setInvalidType(); +} - if (ValidTypes.count(QT.getTypePtr())) - continue; +// Check whether HIP host function has parameters of half precision type or +// struct type containing half precision type and diagnose them. This is +// because gcc and clang does not have consistent ABI for half precision +// type for now. +// ToDo: disable the diagnostics once gcc and clang have a consistent ABI +// about half precision types. +static void checkHIPFunctionParameters(Sema &S, FunctionDecl *FD) { + if (S.getLangOpts().HIPAllowHalfArg || FD->hasAttr() || + FD->hasAttr()) + return; - OpenCLParamType ParamType = getOpenCLKernelParameterType(S, QT); - if (ParamType == ValidKernelParam) - continue; + auto IsInvalidType = [](QualType T) { + if (T->isArrayType()) + T = QualType(T->getPointeeOrArrayElementType(), 0); + if (T->isVectorType()) + T = T->getAs()->getElementType(); + return T->isFloat16Type() || T->isHalfType(); + }; - if (ParamType == RecordKernelParam) { - VisitStack.push_back(FD); - continue; - } + // Check field type. + auto CheckFieldType = [&](QualType FT) { + if (IsInvalidType(FT)) { + return CheckFieldResult::Invalid; + } + if (FT->isRecordType()) + return CheckFieldResult::Recurse; + return CheckFieldResult::Valid; + }; - // OpenCL v1.2 s6.9.p: - // Arguments to kernel functions that are declared to be a struct or union - // do not allow OpenCL objects to be passed as elements of the struct or - // union. - if (ParamType == PtrKernelParam || ParamType == PtrPtrKernelParam || - ParamType == InvalidAddrSpacePtrKernelParam) { - S.Diag(Param->getLocation(), - diag::err_record_with_pointers_kernel_param) - << PT->isUnionType() - << PT; - } else { - S.Diag(Param->getLocation(), diag::err_bad_kernel_param_type) << PT; - } + // Cache for known valid types to avoid repeated check. + llvm::SmallPtrSet ValidTypes; - S.Diag(OrigRecDecl->getLocation(), diag::note_within_field_of_type) - << OrigRecDecl->getDeclName(); + // Information about parameter or return types to be checked. + struct TypeCheckInfo { + QualType Ty; + SourceLocation Loc; + bool IsRet; // Whether it is return type + TypeCheckInfo(QualType T, SourceLocation L, bool _IsRet) + : Ty(T), Loc(L), IsRet(_IsRet) {} + }; + llvm::SmallVector TCInfo; + for (auto *ParmVar : FD->parameters()) + TCInfo.emplace_back( + TypeCheckInfo{ParmVar->getType(), ParmVar->getLocation(), false}); + TCInfo.emplace_back( + TypeCheckInfo{FD->getReturnType(), FD->getLocation(), true}); + + for (auto Info : TCInfo) { + QualType T = Info.Ty; + + // Diagnose invalid parameter type for the current parameter. + auto DiagInvalidType = [&](QualType Ty) { + unsigned DiagID = S.getDiagnostics().getCustomDiagID( + DiagnosticsEngine::Error, + "Invalid function %select{parameter|return}0 type: %1"); + S.Diag(Info.Loc, DiagID) << Info.IsRet << T; + }; + + if (IsInvalidType(T)) { + DiagInvalidType(T); + FD->setInvalidDecl(); + continue; + } - // We have an error, now let's go back up through history and show where - // the offending field came from - for (ArrayRef::const_iterator - I = HistoryStack.begin() + 1, - E = HistoryStack.end(); - I != E; ++I) { - const FieldDecl *OuterField = *I; - S.Diag(OuterField->getLocation(), diag::note_within_field_of_type) - << OuterField->getType(); - } + if (!T->isRecordType() && !T->isArrayType()) + continue; - S.Diag(FD->getLocation(), diag::note_illegal_field_declared_here) - << QT->isPointerType() - << QT; - D.setInvalidType(); + if (!checkStructOrArrayType(S, T, ValidTypes, CheckFieldType, + DiagInvalidType)) { + FD->setInvalidDecl(); return; } - } while (!VisitStack.empty()); + } } /// Find the DeclContext in which a tag is implicitly declared if we see an @@ -10866,6 +10993,10 @@ if (LangOpts.OpenMP) ActOnFinishedFunctionDefinitionInOpenMPAssumeScope(NewFD); + // Check HIP host function parameter types. + if (getLangOpts().HIP) + checkHIPFunctionParameters(*this, NewFD); + // Semantic checking for this function declaration (in isolation). if (getLangOpts().CPlusPlus) { Index: clang/test/SemaCUDA/half-arg.cu =================================================================== --- /dev/null +++ clang/test/SemaCUDA/half-arg.cu @@ -0,0 +1,136 @@ +// RUN: %clang_cc1 -std=c++11 -fsyntax-only -verify -fno-hip-allow-half-arg -x hip %s +// RUN: %clang_cc1 -std=c++11 -fcuda-is-device -fsyntax-only -verify -fno-hip-allow-half-arg -x hip %s +// RUN: %clang_cc1 -std=c++11 -fsyntax-only -verify=allow -x hip %s + +// allow-no-diagnostics + +#include "Inputs/cuda.h" + +// Check _Float16/__fp16 or structs containing them are not allowed as function +// parameter in HIP host functions. + +typedef _Float16 half; + +typedef _Float16 half2 __attribute__((ext_vector_type(2))); + +struct A { // expected-note 4{{within field or base class of type 'A' declared here}} + _Float16 x; // expected-note 7{{field of illegal type '_Float16' declared here}} +}; + +struct B { // expected-note {{within field or base class of type 'B' declared here}} + _Float16 x[2]; // expected-note {{field of illegal type '_Float16 [2]' declared here}} +}; + +struct C { // expected-note {{within field or base class of type 'C' declared here}} + _Float16 x[2][2]; // expected-note {{field of illegal type '_Float16 [2][2]' declared here}} +}; + +struct D { // expected-note {{within field or base class of type 'D' declared here}} + A x; // expected-note {{within field or base class of type 'A' declared here}} +}; + +struct E : public A { // expected-note {{within field or base class of type 'E' declared here}} +}; + +struct F : virtual public A { // expected-note {{within field or base class of type 'F' declared here}} +}; + +struct G { // expected-note {{within field or base class of type 'G' declared here}} + __fp16 x; // expected-note {{field of illegal type '__fp16' declared here}} +}; + +struct H { + void f(A x); + // expected-error@-1 {{Invalid function parameter type: 'A'}} +}; + +template +struct I { + T x; + void f(T x); + // expected-error@-1 {{Invalid function parameter type: 'A'}} +}; + +struct J { // expected-note {{within field or base class of type 'J' declared here}} + half2 v; // expected-note {{field of illegal type 'half2' (vector of 2 '_Float16' values) declared here}} +}; + +struct empty {}; + +struct K : public empty { + int x; +}; + +struct undefined; + +void fa1(_Float16 x); +// expected-error@-1 {{Invalid function parameter type: '_Float16'}} + +void fa2(A x); +// expected-error@-1 {{Invalid function parameter type: 'A'}} + +void fa3(B x); +// expected-error@-1 {{Invalid function parameter type: 'B'}} + +void fa4(C x); +// expected-error@-1 {{Invalid function parameter type: 'C'}} + +void fa5(D x); +// expected-error@-1 {{Invalid function parameter type: 'D'}} + +void fa6(E x); +// expected-error@-1 {{Invalid function parameter type: 'E'}} + +void fa7(F x); +// expected-error@-1 {{Invalid function parameter type: 'F'}} + +void fa8(G x); +// expected-error@-1 {{Invalid function parameter type: 'G'}} + +template void fa9(T x); +// expected-error@-1 {{Invalid function parameter type: 'A'}} +// expected-note@-2 {{candidate template ignored: substitution failure [with T = A]}} +void fa9_caller() { + A x; + fa9(x); + // expected-error@-1 {{no matching function for call to 'fa9'}} + // expected-note@-2 {{in instantiation of function template specialization 'fa9' requested here}} +} + +void fa10() { + I x; + // expected-note@-1 {{in instantiation of template class 'I' requested here}} +} + +void fa11(half x); +// expected-error@-1 {{Invalid function parameter type: 'half' (aka '_Float16')}} + +void fa12(half2 x); +// expected-error@-1 {{Invalid function parameter type: 'half2' (vector of 2 '_Float16' values)}} + +void fa13(J x); +// expected-error@-1 {{Invalid function parameter type: 'J'}} + +void fa14(int x, _Float16 y); +// expected-error@-1 {{Invalid function parameter type: '_Float16'}} + +_Float16 fa15(); +// expected-error@-1 {{Invalid function return type: '_Float16'}} + +void fa16(K x); + +undefined fa17(); + +// Check reference or pointers to _Float16/__fp16 or structs containing +// them are allowed as function parameters in HIP host functions. + +void fb1(_Float16 &x); +void fb2(_Float16 *x); +void fb3(A &x); +void fb4(A *x); + +// Check device function can use _Float16/__fp16 or struct containing +// them as parameter type. +__device__ void fc1(A x); +__global__ void fc2(A x); +__host__ __device__ void fc3(A x);