diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -336,6 +336,8 @@ def ObjCNonFragileRuntime : LangOpt<"", "LangOpts.ObjCRuntime.allowsClassStubs()">; +def HLSL : LangOpt<"HLSL">; + // Language option for CMSE extensions def Cmse : LangOpt<"Cmse">; @@ -3937,3 +3939,11 @@ let Subjects = SubjectList<[Function], ErrorDiag>; let Documentation = [ErrorAttrDocs]; } + +def HLSLNumThreads: InheritableAttr { + let Spellings = [Microsoft<"numthreads">]; + let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">]; + let Subjects = SubjectList<[Function]>; + let LangOpts = [HLSL]; + let Documentation = [NumThreadsDocs]; +} diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -6368,3 +6368,12 @@ .. _Return-Oriented Programming: https://en.wikipedia.org/wiki/Return-oriented_programming }]; } + +def NumThreadsDocs : Documentation { + let Category = DocCatFunction; + let Content = [{ +The ``numthreads`` attribute applies to HLSL shaders where explcit thread counts +are required. The ``X``, ``Y``, and ``Z`` values provided to the attribute +dictate the thread id. Total number of threads executed is ``X * Y * Z``. + }]; +} diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11561,4 +11561,12 @@ "'std::source_location::__impl' was not found; it must be defined before '__builtin_source_location' is called">; def err_std_source_location_impl_malformed : Error< "'std::source_location::__impl' must be standard-layout and have only two 'const char *' fields '_M_file_name' and '_M_function_name', and two integral fields '_M_line' and '_M_column'">; + +// HLSL Diagnostics +def err_hlsl_attr_unsupported_in_stage : Error<"attribute %0 is unsupported in %select{Pixel|Vertex|Geometry|Hull|Domain|Compute|Library|RayGeneration|Intersection|AnyHit|ClosestHit|Miss|Callable|Mesh|Amplification|Invalid}1 shaders, requires %2">; + +def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numthreads attribute cannot exceed %1">; +def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed %0">; + } // end of sema component. + diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -2783,7 +2783,8 @@ const IdentifierInfo *EnclosingScope = nullptr); void MaybeParseMicrosoftAttributes(ParsedAttributes &Attrs) { - if (getLangOpts().MicrosoftExt && Tok.is(tok::l_square)) { + if ((getLangOpts().MicrosoftExt || getLangOpts().HLSL) && + Tok.is(tok::l_square)) { ParsedAttributes AttrsWithRange(AttrFactory); ParseMicrosoftAttributes(AttrsWithRange); Attrs.takeAllFrom(AttrsWithRange); diff --git a/clang/lib/Parse/ParseDeclCXX.cpp b/clang/lib/Parse/ParseDeclCXX.cpp --- a/clang/lib/Parse/ParseDeclCXX.cpp +++ b/clang/lib/Parse/ParseDeclCXX.cpp @@ -4302,10 +4302,19 @@ ParsedAttr::Syntax Syntax = LO.CPlusPlus ? ParsedAttr::AS_CXX11 : ParsedAttr::AS_C2x; + // Try parsing microsoft attributes + if (getLangOpts().MicrosoftExt || getLangOpts().HLSL) { + if (hasAttribute(AttrSyntax::Microsoft, ScopeName, AttrName, + getTargetInfo(), getLangOpts())) + Syntax = ParsedAttr::AS_Microsoft; + } + // If the attribute isn't known, we will not attempt to parse any // arguments. - if (!hasAttribute(LO.CPlusPlus ? AttrSyntax::CXX : AttrSyntax::C, ScopeName, + if (Syntax != ParsedAttr::AS_Microsoft && + !hasAttribute(LO.CPlusPlus ? AttrSyntax::CXX : AttrSyntax::C, ScopeName, AttrName, getTargetInfo(), getLangOpts())) { + if (getLangOpts().MicrosoftExt || getLangOpts().HLSL) {} // Eat the left paren, then skip to the ending right paren. ConsumeParen(); SkipUntil(tok::r_paren); @@ -4688,8 +4697,17 @@ break; if (Tok.getIdentifierInfo()->getName() == "uuid") ParseMicrosoftUuidAttributeArgs(Attrs); - else + else { + IdentifierInfo *II = Tok.getIdentifierInfo(); + SourceLocation NameLoc = Tok.getLocation(); ConsumeToken(); + if (Tok.is(tok::l_paren)) { + CachedTokens OpenMPTokens; + ParseCXX11AttributeArgs(II, NameLoc, Attrs, &EndLoc, nullptr, + SourceLocation(), OpenMPTokens); + ReplayOpenMPAttributeTokens(OpenMPTokens); + } // FIXME: handle attributes that don't have arguments + } } T.consumeClose(); diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -11323,6 +11323,10 @@ return; } + // functions named main in hlsl are default entries, but have no restrictions + if (getLangOpts().HLSL) + return; + QualType T = FD->getType(); assert(T->isFunctionType() && "function decl is not of function type"); const FunctionType* FT = T->castAs(); diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -24,6 +24,7 @@ #include "clang/AST/Type.h" #include "clang/Basic/CharInfo.h" #include "clang/Basic/DarwinSDKInfo.h" +#include "clang/Basic/LangOptions.h" #include "clang/Basic/SourceLocation.h" #include "clang/Basic/SourceManager.h" #include "clang/Basic/TargetBuiltins.h" @@ -6836,6 +6837,64 @@ D->addAttr(UA); } +static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) { + using llvm::Triple; + Triple Target = S.Context.getTargetInfo().getTriple(); + if (!llvm::is_contained(Target.getEnvironment(), + {Triple::Compute, Triple::Mesh, Triple::Amplification, + Triple::Library})) { + uint32_t Pipeline = + (uint32_t)S.Context.getTargetInfo().getTriple().getEnvironment() - + (uint32_t)llvm::Triple::Pixel; + S.Diag(AL.getLoc(), diag::err_hlsl_attr_unsupported_in_stage) + << AL << Pipeline << "Compute, Amplification, Mesh or Library"; + return; + } + + llvm::VersionTuple SMVersion = Target.getOSVersion(); + uint32_t ZMax = 1024; + uint32_t ThreadMax = 1024; + if (SMVersion.getMajor() <= 4) { + ZMax = 1; + ThreadMax = 768; + } else if (SMVersion.getMajor() == 5) { + ZMax = 64; + ThreadMax = 1024; + } + + uint32_t X; + if (!checkUInt32Argument(S, AL, AL.getArgAsExpr(0), X)) + return; + if (X > 1024) { + S.Diag(AL.getArgAsExpr(0)->getExprLoc(), + diag::err_hlsl_numthreads_argument_oor) << 0 << 1024; + return; + } + uint32_t Y; + if (!checkUInt32Argument(S, AL, AL.getArgAsExpr(1), Y)) + return; + if (Y > 1024) { + S.Diag(AL.getArgAsExpr(1)->getExprLoc(), + diag::err_hlsl_numthreads_argument_oor) << 1 << 1024; + return; + } + uint32_t Z; + if (!checkUInt32Argument(S, AL, AL.getArgAsExpr(2), Z)) + return; + if (Z > ZMax) { + S.Diag(AL.getArgAsExpr(2)->getExprLoc(), + diag::err_hlsl_numthreads_argument_oor) << 2 << ZMax; + return; + } + + if (X * Y * Z > ThreadMax) { + S.Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax; + return; + } + + D->addAttr(::new (S.Context) HLSLNumThreadsAttr(S.Context, AL, X, Y, Z)); +} + static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) { if (!S.LangOpts.CPlusPlus) { S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang) @@ -8697,6 +8756,11 @@ case ParsedAttr::AT_Thread: handleDeclspecThreadAttr(S, D, AL); break; + + // HLSL attributes: + case ParsedAttr::AT_HLSLNumThreads: + handleHLSLNumThreadsAttr(S, D, AL); + break; case ParsedAttr::AT_AbiTag: handleAbiTagAttr(S, D, AL); diff --git a/clang/test/SemaHLSL/lit.local.cfg b/clang/test/SemaHLSL/lit.local.cfg new file mode 100644 --- /dev/null +++ b/clang/test/SemaHLSL/lit.local.cfg @@ -0,0 +1 @@ +config.suffixes = ['.hlsl'] diff --git a/clang/test/SemaHLSL/num_threads.hlsl b/clang/test/SemaHLSL/num_threads.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/SemaHLSL/num_threads.hlsl @@ -0,0 +1,49 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -x hlsl -ast-dump -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-amplification -x hlsl -ast-dump -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-pixel -x hlsl -ast-dump -o - %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-vertex -x hlsl -ast-dump -o - %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-hull -x hlsl -ast-dump -o - %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-domain -x hlsl -ast-dump -o - %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s -DFAIL -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel5.0-compute -x hlsl -ast-dump -o - %s -DFAIL -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel4.0-compute -x hlsl -ast-dump -o - %s -DFAIL -verify + +#if __SHADER_TARGET_STAGE == __SHADER_STAGE_COMPUTE || __SHADER_TARGET_STAGE == __SHADER_STAGE_MESH || __SHADER_TARGET_STAGE == __SHADER_STAGE_AMPLIFICATION || __SHADER_TARGET_STAGE == __SHADER_STAGE_LIBRARY +#ifdef FAIL +#if __SHADER_TARGET_MAJOR == 6 +// expected-error@+1 {{'numthreads' attribute requires an integer constant}} +[numthreads("1",2,3)] +// expected-error@+1 {{argument 'X' to numthreads attribute cannot exceed 1024}} +[numthreads(-1,2,3)] +// expected-error@+1 {{argument 'Y' to numthreads attribute cannot exceed 1024}} +[numthreads(1,-2,3)] +// expected-error@+1 {{argument 'Z' to numthreads attribute cannot exceed 1024}} +[numthreads(1,2,-3)] +// expected-error@+1 {{total number of threads cannot exceed 1024}} +[numthreads(1024,1024,1024)] +#elif __SHADER_TARGET_MAJOR == 5 +// expected-error@+1 {{argument 'Z' to numthreads attribute cannot exceed 64}} +[numthreads(1,2,68)] +#else +// expected-error@+1 {{argument 'Z' to numthreads attribute cannot exceed 1}} +[numthreads(1,2,2)] +// expected-error@+1 {{total number of threads cannot exceed 768}} +[numthreads(1024,1,1)] +#endif +#endif +// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} 1 2 1 +[numthreads(1,2,1)] +int entry() { + return 1; +} +#else +// expected-error-re@+1 {{attribute 'numthreads' is unsupported in {{[A-Za-z]+}} shaders, requires Compute, Amplification, Mesh or Library}} +[numthreads(1,1,1)] +int entry() { + return 1; +} +#endif + +