diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11611,6 +11611,7 @@ def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numthreads attribute cannot exceed %1">; def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed %0">; +def err_hlsl_missing_numthreads : Error<"missing numthreads attribute for %0 shader entry">; def err_hlsl_attribute_param_mismatch : Error<"%0 attribute parameters do not match the previous declaration">; def err_hlsl_pointers_unsupported : Error< diff --git a/clang/include/clang/Basic/TargetOptions.h b/clang/include/clang/Basic/TargetOptions.h --- a/clang/include/clang/Basic/TargetOptions.h +++ b/clang/include/clang/Basic/TargetOptions.h @@ -113,6 +113,9 @@ /// The validator version for dxil. std::string DxilValidatorVersion; + + /// The entry point name for HLSL shader being compiled as specified by -E. + std::string HLSLEntry; }; } // end namespace clang diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -6762,3 +6762,8 @@ HelpText<"Emit pristine LLVM IR from the frontend by not running any LLVM passes at all." "Same as -S + -emit-llvm + -disable-llvm-passes.">; def fcgl : DXCFlag<"fcgl">, Alias; +def hlsl_entrypoint : Option<["--", "/", "-"], "E", KIND_JOINED_OR_SEPARATE>, + Group, + Flags<[DXCOption, CC1Option, NoXarchOption]>, + MarshallingInfoString>, + HelpText<"Entry point name">; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -2814,6 +2814,7 @@ QualType NewT, QualType OldT); void CheckMain(FunctionDecl *FD, const DeclSpec &D); void CheckMSVCRTEntryPoint(FunctionDecl *FD); + void CheckHLSLEntryPoint(FunctionDecl *FD); Attr *getImplicitCodeSegOrSectionAttrForFunction(const FunctionDecl *FD, bool IsDefinition); void CheckFunctionOrTemplateParamDeclarator(Scope *S, Declarator &D); diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp @@ -3473,7 +3473,8 @@ types::ID InputType) { const unsigned ForwardedArguments[] = {options::OPT_dxil_validator_version, options::OPT_S, options::OPT_emit_llvm, - options::OPT_disable_llvm_passes}; + options::OPT_disable_llvm_passes, + options::OPT_hlsl_entrypoint}; for (const auto &Arg : ForwardedArguments) if (const auto *A = Args.getLastArg(Arg)) diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -9867,6 +9867,55 @@ } } + if (getLangOpts().HLSL) { + auto &TargetInfo = getASTContext().getTargetInfo(); + if (!NewFD->isInvalidDecl() && + // Skip operator overload which not identifier. + Name.isIdentifier() && + NewFD->getName() == TargetInfo.getTargetOpts().HLSLEntry && + // Make sure it is in translation-unit scope. + S->getDepth() == 0) { + CheckHLSLEntryPoint(NewFD); + if (!NewFD->isInvalidDecl()) { + auto TripeShaderType = TargetInfo.getTriple().getEnvironment(); + AttributeCommonInfo AL(NewFD->getBeginLoc()); + HLSLShaderAttr::ShaderType ShaderType = + HLSLShaderAttr::ShaderType::Callable; + switch (TripeShaderType) { + default: + break; + case llvm::Triple::EnvironmentType::Compute: + ShaderType = HLSLShaderAttr::ShaderType::Compute; + break; + case llvm::Triple::EnvironmentType::Vertex: + ShaderType = HLSLShaderAttr::ShaderType::Vertex; + break; + case llvm::Triple::EnvironmentType::Hull: + ShaderType = HLSLShaderAttr::ShaderType::Hull; + break; + case llvm::Triple::EnvironmentType::Domain: + ShaderType = HLSLShaderAttr::ShaderType::Domain; + break; + case llvm::Triple::EnvironmentType::Geometry: + ShaderType = HLSLShaderAttr::ShaderType::Geometry; + break; + case llvm::Triple::EnvironmentType::Pixel: + ShaderType = HLSLShaderAttr::ShaderType::Pixel; + break; + case llvm::Triple::EnvironmentType::Mesh: + ShaderType = HLSLShaderAttr::ShaderType::Mesh; + break; + case llvm::Triple::EnvironmentType::Amplification: + ShaderType = HLSLShaderAttr::ShaderType::Amplification; + break; + } + // To share code with HLSLShaderAttr, add HLSLShaderAttr to entry function. + if (HLSLShaderAttr *Attr = mergeHLSLShaderAttr(NewFD, AL, ShaderType)) + NewFD->addAttr(Attr); + } + } + } + if (!getLangOpts().CPlusPlus) { // Perform semantic checking on the function declaration. if (!NewFD->isInvalidDecl() && NewFD->isMain()) @@ -11687,6 +11736,21 @@ } } +void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) { + auto &TargetInfo = getASTContext().getTargetInfo(); + switch (TargetInfo.getTriple().getEnvironment()) { + default: + // FIXME: check all shader profiles. + break; + case llvm::Triple::EnvironmentType::Compute: + if (!FD->hasAttr()) { + Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads) << "Compute"; + FD->setInvalidDecl(); + } + break; + } +} + bool Sema::CheckForConstantInitializer(Expr *Init, QualType DclT) { // FIXME: Need strict checking. In C89, we need to check for // any assignment, increment, decrement, function-calls, or diff --git a/clang/test/SemaHLSL/entry.hlsl b/clang/test/SemaHLSL/entry.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/SemaHLSL/entry.hlsl @@ -0,0 +1,15 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -Efoo -DWITH_NUM_THREADS -ast-dump -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -Efoo -o - %s -verify + + +// Make sure add HLSLShaderAttr along with HLSLNumThreadsAttr. +// CHECK:HLSLNumThreadsAttr 0x{{.*}} 1 1 1 +// CHECK:HLSLShaderAttr 0x{{.*}} Compute + +#ifdef WITH_NUM_THREADS +[numthreads(1,1,1)] +#endif +// expected-error@+1 {{missing numthreads attribute for Compute shader entry}} +void foo() { + +} \ No newline at end of file diff --git a/clang/test/SemaHLSL/prohibit_pointer.hlsl b/clang/test/SemaHLSL/prohibit_pointer.hlsl --- a/clang/test/SemaHLSL/prohibit_pointer.hlsl +++ b/clang/test/SemaHLSL/prohibit_pointer.hlsl @@ -1,4 +1,4 @@ -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -o - -fsyntax-only %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -o - -fsyntax-only %s -verify // expected-error@+1 {{pointers are unsupported in HLSL}} typedef int (*fn_int)(int);