diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -3968,6 +3968,23 @@ let Documentation = [HLSLSV_GroupIndexDocs]; } +def HLSLShader : InheritableAttr { + let Spellings = [Microsoft<"shader">]; + let Subjects = SubjectList<[HLSLEntry]>; + let LangOpts = [HLSL]; + let Args = [EnumArgument<"Type", "ShaderType", + ["pixel", "vertex", "geometry", "hull", "domain", + "compute", "raygeneration", "intersection", + "anyhit", "closesthit", "miss", "callable", "mesh", + "amplification"], + ["Pixel", "Vertex", "Geometry", "Hull", "Domain", + "Compute", "RayGeneration", "Intersection", + "AnyHit", "ClosestHit", "Miss", "Callable", "Mesh", + "Amplification"] + >]; + let Documentation = [HLSLSV_ShaderTypeAttrDocs]; +} + def RandomizeLayout : InheritableAttr { let Spellings = [GCC<"randomize_layout">]; let Subjects = SubjectList<[Record]>; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -6380,6 +6380,25 @@ }]; } +def HLSLSV_ShaderTypeAttrDocs : Documentation { + let Category = DocCatFunction; + let Content = [{ +The ``shader`` type attribute applies to HLSL shader entry functions to +identify the shader type for the entry function. +The syntax is: + ``[shader(string-literal)]`` +where the string literal is one of: "pixel", "vertex", "geometry", "hull", + "domain", "compute", "raygeneration", "intersection", "anyhit", "closesthit", + "miss", "callable", "mesh", "amplification". +Normally the shader type is set by shader target with the ``-T`` option like +``-Tps_6_1``. +When compiling to a library target like ``lib_6_3``, the shader type attribute + can help the compiler to identify the shader type. +It is mostly used by Raytracing shaders where shaders must be compiled into a +library and linked at runtime. + }]; +} + def ClangRandomizeLayoutDocs : Documentation { let Category = DocCatDecl; let Heading = "randomize_layout, no_randomize_layout"; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -3494,6 +3494,8 @@ HLSLNumThreadsAttr *mergeHLSLNumThreadsAttr(Decl *D, const AttributeCommonInfo &AL, int X, int Y, int Z); + HLSLShaderAttr *mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL, + HLSLShaderAttr::ShaderType ShaderType); void mergeDeclAttributes(NamedDecl *New, Decl *Old, AvailabilityMergeKind AMK = AMK_Redeclaration); diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -2810,6 +2810,8 @@ else if (const auto *NT = dyn_cast(Attr)) NewAttr = S.mergeHLSLNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), NT->getZ()); + else if (const auto *SA = dyn_cast(Attr)) + NewAttr = S.mergeHLSLShaderAttr(D, *SA, SA->getType()); else if (Attr->shouldInheritEvenIfAlreadyPresent() || !DeclHasAttr(D, Attr)) NewAttr = cast(Attr->clone(S.Context)); diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -6940,6 +6940,39 @@ D->addAttr(::new (S.Context) HLSLSV_GroupIndexAttr(S.Context, AL)); } +static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) { + StringRef Str; + SourceLocation ArgLoc; + if (!S.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc)) + return; + + HLSLShaderAttr::ShaderType ShaderType; + if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) { + S.Diag(AL.getLoc(), diag::warn_attribute_type_not_supported) + << AL << Str << ArgLoc; + return; + } + + // FIXME: check function match the shader stage. + + HLSLShaderAttr *NewAttr = S.mergeHLSLShaderAttr(D, AL, ShaderType); + if (NewAttr) + D->addAttr(NewAttr); +} + +HLSLShaderAttr * +Sema::mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL, + HLSLShaderAttr::ShaderType ShaderType) { + if (HLSLShaderAttr *NT = D->getAttr()) { + if (NT->getType() != ShaderType) { + Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; + Diag(AL.getLoc(), diag::note_conflicting_attribute); + } + return nullptr; + } + return HLSLShaderAttr::Create(Context, ShaderType, AL); +} + static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) { if (!S.LangOpts.CPlusPlus) { S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang) @@ -8815,6 +8848,9 @@ case ParsedAttr::AT_HLSLSV_GroupIndex: handleHLSLSVGroupIndexAttr(S, D, AL); break; + case ParsedAttr::AT_HLSLShader: + handleHLSLShaderAttr(S, D, AL); + break; case ParsedAttr::AT_AbiTag: handleAbiTagAttr(S, D, AL); diff --git a/clang/test/SemaHLSL/shader_type_attr.hlsl b/clang/test/SemaHLSL/shader_type_attr.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/SemaHLSL/shader_type_attr.hlsl @@ -0,0 +1,74 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -ast-dump -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -ast-dump -o - %s -DFAIL -verify + +// FileCheck test make sure HLSLShaderAttr is generated in AST. +// verify test make sure validation on shader type attribute works as expected. + +#ifdef FAIL + +// expected-warning@+1 {{'shader' attribute only applies to global functions}} +[shader("compute")] +struct Fido { + // expected-warning@+1 {{'shader' attribute only applies to global functions}} + [shader("pixel")] + void wag() {} + // expected-warning@+1 {{'shader' attribute only applies to global functions}} + [shader("vertex")] + static void oops() {} +}; + +// expected-warning@+1 {{'shader' attribute only applies to global functions}} +[shader("vertex")] +static void oops() {} + +namespace spec { +// expected-warning@+1 {{'shader' attribute only applies to global functions}} +[shader("vertex")] +static void oops() {} +} // namespace spec + +// expected-error@+1 {{'shader' attribute parameters do not match the previous declaration}} +[shader("compute")] +// expected-note@+1 {{conflicting attribute is here}} +[shader("vertex")] +int doubledUp() { + return 1; +} + +// expected-note@+1 {{conflicting attribute is here}} +[shader("vertex")] +int forwardDecl(); + +// expected-error@+1 {{'shader' attribute parameters do not match the previous declaration}} +[shader("compute")] +int forwardDecl() { + return 1; +} + +// expected-error@+1 {{'shader' attribute takes one argument}} +[shader()] +// expected-error@+1 {{'shader' attribute takes one argument}} +[shader(1, 2)] +// expected-error@+1 {{'shader' attribute requires a string}} +[shader(1)] +// expected-warning@+1 {{'shader' attribute argument not supported: cs}} +[shader("cs")] + +#endif // END of FAIL + +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +[shader("compute")] +int entry() { + return 1; +} + +// Because these two attributes match, they should both appear in the AST +[shader("compute")] +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +int secondFn(); + +[shader("compute")] +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +int secondFn() { + return 1; +}