diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -3972,16 +3972,25 @@ let Spellings = [Microsoft<"shader">]; let Subjects = SubjectList<[HLSLEntry]>; let LangOpts = [HLSL]; - let Args = [EnumArgument<"Type", "ShaderType", - ["pixel", "vertex", "geometry", "hull", "domain", - "compute", "raygeneration", "intersection", - "anyhit", "closesthit", "miss", "callable", "mesh", - "amplification"], - ["Pixel", "Vertex", "Geometry", "Hull", "Domain", - "Compute", "RayGeneration", "Intersection", - "AnyHit", "ClosestHit", "Miss", "Callable", "Mesh", - "Amplification"] - >]; + // NOTE: + // order for the enum should match order in llvm::Triple::EnvironmentType. + // ShaderType will be converted to llvm::Triple::EnvironmentType like + // (llvm::Triple::EnvironmentType)((uint32_t)ShaderType + + // (uint32_t)llvm::Triple::EnvironmentType::Pixel). + // This will avoid update code for convert when new shader type is added. + let Args = [ + EnumArgument<"Type", "ShaderType", + [ + "pixel", "vertex", "geometry", "hull", "domain", "compute", + "library", "raygeneration", "intersection", "anyHit", + "closestHit", "miss", "callable", "mesh", "amplification" + ], + [ + "Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute", + "Library", "RayGeneration", "Intersection", "AnyHit", + "ClosestHit", "Miss", "Callable", "Mesh", "Amplification" + ]> + ]; let Documentation = [HLSLSV_ShaderTypeAttrDocs]; } diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11611,6 +11611,7 @@ def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numthreads attribute cannot exceed %1">; def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed %0">; +def err_hlsl_missing_numthreads : Error<"missing numthreads attribute for %0 shader entry">; def err_hlsl_attribute_param_mismatch : Error<"%0 attribute parameters do not match the previous declaration">; def err_hlsl_pointers_unsupported : Error< diff --git a/clang/include/clang/Basic/TargetOptions.h b/clang/include/clang/Basic/TargetOptions.h --- a/clang/include/clang/Basic/TargetOptions.h +++ b/clang/include/clang/Basic/TargetOptions.h @@ -113,6 +113,9 @@ /// The validator version for dxil. std::string DxilValidatorVersion; + + /// The entry point name for HLSL shader being compiled as specified by -E. + std::string HLSLEntry; }; } // end namespace clang diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -6762,3 +6762,8 @@ HelpText<"Emit pristine LLVM IR from the frontend by not running any LLVM passes at all." "Same as -S + -emit-llvm + -disable-llvm-passes.">; def fcgl : DXCFlag<"fcgl">, Alias; +def hlsl_entrypoint : Option<["--", "/", "-"], "E", KIND_JOINED_OR_SEPARATE>, + Group, + Flags<[DXCOption, CC1Option, NoXarchOption]>, + MarshallingInfoString>, + HelpText<"Entry point name">; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -2814,6 +2814,7 @@ QualType NewT, QualType OldT); void CheckMain(FunctionDecl *FD, const DeclSpec &D); void CheckMSVCRTEntryPoint(FunctionDecl *FD); + void CheckHLSLEntryPoint(FunctionDecl *FD); Attr *getImplicitCodeSegOrSectionAttrForFunction(const FunctionDecl *FD, bool IsDefinition); void CheckFunctionOrTemplateParamDeclarator(Scope *S, Declarator &D); diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp @@ -3473,7 +3473,8 @@ types::ID InputType) { const unsigned ForwardedArguments[] = {options::OPT_dxil_validator_version, options::OPT_S, options::OPT_emit_llvm, - options::OPT_disable_llvm_passes}; + options::OPT_disable_llvm_passes, + options::OPT_hlsl_entrypoint}; for (const auto &Arg : ForwardedArguments) if (const auto *A = Args.getLastArg(Arg)) diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -9867,6 +9867,28 @@ } } + if (getLangOpts().HLSL) { + auto &TargetInfo = getASTContext().getTargetInfo(); + if (!NewFD->isInvalidDecl() && + // Skip operator overload which not identifier. + Name.isIdentifier() && + NewFD->getName() == TargetInfo.getTargetOpts().HLSLEntry && + // Make sure it is in translation-unit scope. + S->getDepth() == 0) { + CheckHLSLEntryPoint(NewFD); + if (!NewFD->isInvalidDecl()) { + auto TripleShaderType = TargetInfo.getTriple().getEnvironment(); + AttributeCommonInfo AL(NewFD->getBeginLoc()); + HLSLShaderAttr::ShaderType ShaderType = (HLSLShaderAttr::ShaderType)( + TripleShaderType - (uint32_t)llvm::Triple::Pixel); + // To share code with HLSLShaderAttr, add HLSLShaderAttr to entry + // function. + if (HLSLShaderAttr *Attr = mergeHLSLShaderAttr(NewFD, AL, ShaderType)) + NewFD->addAttr(Attr); + } + } + } + if (!getLangOpts().CPlusPlus) { // Perform semantic checking on the function declaration. if (!NewFD->isInvalidDecl() && NewFD->isMain()) @@ -11687,6 +11709,21 @@ } } +void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) { + auto &TargetInfo = getASTContext().getTargetInfo(); + switch (TargetInfo.getTriple().getEnvironment()) { + default: + // FIXME: check all shader profiles. + break; + case llvm::Triple::EnvironmentType::Compute: + if (!FD->hasAttr()) { + Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads) << "Compute"; + FD->setInvalidDecl(); + } + break; + } +} + bool Sema::CheckForConstantInitializer(Expr *Init, QualType DclT) { // FIXME: Need strict checking. In C89, we need to check for // any assignment, increment, decrement, function-calls, or diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -6947,7 +6947,11 @@ return; HLSLShaderAttr::ShaderType ShaderType; - if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) { + if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType) || + // Library is added to help convert HLSLShaderAttr::ShaderType to + // llvm::Triple::EnviromentType. It is not a legal + // HLSLShaderAttr::ShaderType. + ShaderType == HLSLShaderAttr::Library) { S.Diag(AL.getLoc(), diag::warn_attribute_type_not_supported) << AL << Str << ArgLoc; return; diff --git a/clang/test/SemaHLSL/entry.hlsl b/clang/test/SemaHLSL/entry.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/SemaHLSL/entry.hlsl @@ -0,0 +1,15 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -Efoo -DWITH_NUM_THREADS -ast-dump -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -Efoo -o - %s -verify + + +// Make sure add HLSLShaderAttr along with HLSLNumThreadsAttr. +// CHECK:HLSLNumThreadsAttr 0x{{.*}} 1 1 1 +// CHECK:HLSLShaderAttr 0x{{.*}} Compute + +#ifdef WITH_NUM_THREADS +[numthreads(1,1,1)] +#endif +// expected-error@+1 {{missing numthreads attribute for Compute shader entry}} +void foo() { + +} diff --git a/clang/test/SemaHLSL/prohibit_pointer.hlsl b/clang/test/SemaHLSL/prohibit_pointer.hlsl --- a/clang/test/SemaHLSL/prohibit_pointer.hlsl +++ b/clang/test/SemaHLSL/prohibit_pointer.hlsl @@ -1,4 +1,4 @@ -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -o - -fsyntax-only %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -o - -fsyntax-only %s -verify // expected-error@+1 {{pointers are unsupported in HLSL}} typedef int (*fn_int)(int); diff --git a/clang/test/SemaHLSL/shader_type_attr.hlsl b/clang/test/SemaHLSL/shader_type_attr.hlsl --- a/clang/test/SemaHLSL/shader_type_attr.hlsl +++ b/clang/test/SemaHLSL/shader_type_attr.hlsl @@ -53,10 +53,11 @@ [shader(1)] // expected-warning@+1 {{'shader' attribute argument not supported: cs}} [shader("cs")] - +// expected-warning@+1 {{'shader' attribute argument not supported: library}} +[shader("library")] #endif // END of FAIL -// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute [shader("compute")] int entry() { return 1; @@ -64,11 +65,11 @@ // Because these two attributes match, they should both appear in the AST [shader("compute")] -// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute int secondFn(); [shader("compute")] -// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute int secondFn() { return 1; }