diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -3984,16 +3984,25 @@ let Spellings = [Microsoft<"shader">]; let Subjects = SubjectList<[HLSLEntry]>; let LangOpts = [HLSL]; - let Args = [EnumArgument<"Type", "ShaderType", - ["pixel", "vertex", "geometry", "hull", "domain", - "compute", "raygeneration", "intersection", - "anyhit", "closesthit", "miss", "callable", "mesh", - "amplification"], - ["Pixel", "Vertex", "Geometry", "Hull", "Domain", - "Compute", "RayGeneration", "Intersection", - "AnyHit", "ClosestHit", "Miss", "Callable", "Mesh", - "Amplification"] - >]; + // NOTE: + // order for the enum should match order in llvm::Triple::EnvironmentType. + // ShaderType will be converted to llvm::Triple::EnvironmentType like + // (llvm::Triple::EnvironmentType)((uint32_t)ShaderType + + // (uint32_t)llvm::Triple::EnvironmentType::Pixel). + // This will avoid update code for convert when new shader type is added. + let Args = [ + EnumArgument<"Type", "ShaderType", + [ + "pixel", "vertex", "geometry", "hull", "domain", "compute", + "library", "raygeneration", "intersection", "anyHit", + "closestHit", "miss", "callable", "mesh", "amplification" + ], + [ + "Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute", + "Library", "RayGeneration", "Intersection", "AnyHit", + "ClosestHit", "Miss", "Callable", "Mesh", "Amplification" + ]> + ]; let Documentation = [HLSLSV_ShaderTypeAttrDocs]; } diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11634,6 +11634,7 @@ def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numthreads attribute cannot exceed %1">; def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed %0">; +def err_hlsl_missing_numthreads : Error<"missing numthreads attribute for %0 shader entry">; def err_hlsl_attribute_param_mismatch : Error<"%0 attribute parameters do not match the previous declaration">; def err_hlsl_pointers_unsupported : Error< diff --git a/clang/include/clang/Basic/TargetOptions.h b/clang/include/clang/Basic/TargetOptions.h --- a/clang/include/clang/Basic/TargetOptions.h +++ b/clang/include/clang/Basic/TargetOptions.h @@ -113,6 +113,9 @@ /// The validator version for dxil. std::string DxilValidatorVersion; + + /// The entry point name for HLSL shader being compiled as specified by -E. + std::string HLSLEntry; }; } // end namespace clang diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -6813,3 +6813,12 @@ HelpText<"Emit pristine LLVM IR from the frontend by not running any LLVM passes at all." "Same as -S + -emit-llvm + -disable-llvm-passes.">; def fcgl : DXCFlag<"fcgl">, Alias; +def hlsl_entrypoint : Option<["-"], "hlsl-entry", KIND_SEPARATE>, + Group, + Flags<[CC1Option]>, + MarshallingInfoString>, + HelpText<"Entry point name for hlsl">; +def dxc_entrypoint : Option<["--", "/", "-"], "E", KIND_JOINED_OR_SEPARATE>, + Group, + Flags<[DXCOption, NoXarchOption]>, + HelpText<"Entry point name">; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -2820,6 +2820,7 @@ QualType NewT, QualType OldT); void CheckMain(FunctionDecl *FD, const DeclSpec &D); void CheckMSVCRTEntryPoint(FunctionDecl *FD); + void CheckHLSLEntryPoint(FunctionDecl *FD); Attr *getImplicitCodeSegOrSectionAttrForFunction(const FunctionDecl *FD, bool IsDefinition); void CheckFunctionOrTemplateParamDeclarator(Scope *S, Declarator &D); diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp @@ -3477,9 +3477,12 @@ static void RenderHLSLOptions(const ArgList &Args, ArgStringList &CmdArgs, types::ID InputType) { - const unsigned ForwardedArguments[] = { - options::OPT_dxil_validator_version, options::OPT_D, options::OPT_S, - options::OPT_emit_llvm, options::OPT_disable_llvm_passes}; + const unsigned ForwardedArguments[] = {options::OPT_dxil_validator_version, + options::OPT_D, + options::OPT_S, + options::OPT_emit_llvm, + options::OPT_disable_llvm_passes, + options::OPT_hlsl_entrypoint}; for (const auto &Arg : ForwardedArguments) if (const auto *A = Args.getLastArg(Arg)) diff --git a/clang/lib/Driver/ToolChains/HLSL.cpp b/clang/lib/Driver/ToolChains/HLSL.cpp --- a/clang/lib/Driver/ToolChains/HLSL.cpp +++ b/clang/lib/Driver/ToolChains/HLSL.cpp @@ -158,6 +158,12 @@ if (!isLegalValidatorVersion(ValVerStr, getDriver())) continue; } + if (A->getOption().getID() == options::OPT_dxc_entrypoint) { + DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_hlsl_entrypoint), + A->getValue()); + A->claim(); + continue; + } if (A->getOption().getID() == options::OPT_emit_pristine_llvm) { // Translate fcgl into -S -emit-llvm and -disable-llvm-passes. DAL->AddFlagArg(nullptr, Opts.getOption(options::OPT_S)); diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp --- a/clang/lib/Frontend/CompilerInvocation.cpp +++ b/clang/lib/Frontend/CompilerInvocation.cpp @@ -510,6 +510,10 @@ Diags.Report(diag::err_drv_argument_not_allowed_with) << "-fgnu89-inline" << GetInputKindName(IK); + if (Args.hasArg(OPT_hlsl_entrypoint) && !LangOpts.HLSL) + Diags.Report(diag::err_drv_argument_not_allowed_with) + << "-hlsl-entry" << GetInputKindName(IK); + if (Args.hasArg(OPT_fgpu_allow_device_init) && !LangOpts.HIP) Diags.Report(diag::warn_ignored_hip_only_option) << Args.getLastArg(OPT_fgpu_allow_device_init)->getAsString(Args); diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -9835,6 +9835,28 @@ } } + if (getLangOpts().HLSL) { + auto &TargetInfo = getASTContext().getTargetInfo(); + if (!NewFD->isInvalidDecl() && + // Skip operator overload which not identifier. + Name.isIdentifier() && + NewFD->getName() == TargetInfo.getTargetOpts().HLSLEntry && + // Make sure it is in translation-unit scope. + S->getDepth() == 0) { + CheckHLSLEntryPoint(NewFD); + if (!NewFD->isInvalidDecl()) { + auto TripleShaderType = TargetInfo.getTriple().getEnvironment(); + AttributeCommonInfo AL(NewFD->getBeginLoc()); + HLSLShaderAttr::ShaderType ShaderType = (HLSLShaderAttr::ShaderType)( + TripleShaderType - (uint32_t)llvm::Triple::Pixel); + // To share code with HLSLShaderAttr, add HLSLShaderAttr to entry + // function. + if (HLSLShaderAttr *Attr = mergeHLSLShaderAttr(NewFD, AL, ShaderType)) + NewFD->addAttr(Attr); + } + } + } + if (!getLangOpts().CPlusPlus) { // Perform semantic checking on the function declaration. if (!NewFD->isInvalidDecl() && NewFD->isMain()) @@ -11667,6 +11689,23 @@ } } +void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) { + auto &TargetInfo = getASTContext().getTargetInfo(); + auto const Triple = TargetInfo.getTriple(); + switch (Triple.getEnvironment()) { + default: + // FIXME: check all shader profiles. + break; + case llvm::Triple::EnvironmentType::Compute: + if (!FD->hasAttr()) { + Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads) + << Triple.getEnvironmentName(); + FD->setInvalidDecl(); + } + break; + } +} + bool Sema::CheckForConstantInitializer(Expr *Init, QualType DclT) { // FIXME: Need strict checking. In C89, we need to check for // any assignment, increment, decrement, function-calls, or diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -6951,7 +6951,11 @@ return; HLSLShaderAttr::ShaderType ShaderType; - if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) { + if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType) || + // Library is added to help convert HLSLShaderAttr::ShaderType to + // llvm::Triple::EnviromentType. It is not a legal + // HLSLShaderAttr::ShaderType. + ShaderType == HLSLShaderAttr::Library) { S.Diag(AL.getLoc(), diag::warn_attribute_type_not_supported) << AL << Str << ArgLoc; return; diff --git a/clang/test/Driver/dxc_E.hlsl b/clang/test/Driver/dxc_E.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/Driver/dxc_E.hlsl @@ -0,0 +1,4 @@ +// RUN: %clang_dxc -Efoo -Tcs_6_7 -### %s 2>&1 | FileCheck %s + +// Make sure E option flag which translated into "-hlsl-entry". +// CHECK:"-hlsl-entry" "foo" diff --git a/clang/test/Driver/hlsl-entry.cpp b/clang/test/Driver/hlsl-entry.cpp new file mode 100644 --- /dev/null +++ b/clang/test/Driver/hlsl-entry.cpp @@ -0,0 +1,3 @@ +// RUN:not %clang -cc1 -triple dxil-pc-shadermodel6.3-compute -x c++ -hlsl-entry foo %s 2>&1 | FileCheck %s --check-prefix=NOTHLSL + +// NOTHLSL:invalid argument '-hlsl-entry' not allowed with 'C++' diff --git a/clang/test/SemaHLSL/entry.hlsl b/clang/test/SemaHLSL/entry.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/SemaHLSL/entry.hlsl @@ -0,0 +1,15 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry foo -DWITH_NUM_THREADS -ast-dump -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry foo -o - %s -verify + + +// Make sure add HLSLShaderAttr along with HLSLNumThreadsAttr. +// CHECK:HLSLNumThreadsAttr 0x{{.*}} 1 1 1 +// CHECK:HLSLShaderAttr 0x{{.*}} Compute + +#ifdef WITH_NUM_THREADS +[numthreads(1,1,1)] +#endif +// expected-error@+1 {{missing numthreads attribute for compute shader entry}} +void foo() { + +} diff --git a/clang/test/SemaHLSL/prohibit_pointer.hlsl b/clang/test/SemaHLSL/prohibit_pointer.hlsl --- a/clang/test/SemaHLSL/prohibit_pointer.hlsl +++ b/clang/test/SemaHLSL/prohibit_pointer.hlsl @@ -1,4 +1,4 @@ -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -o - -fsyntax-only %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -o - -fsyntax-only %s -verify // expected-error@+1 {{pointers are unsupported in HLSL}} typedef int (*fn_int)(int); diff --git a/clang/test/SemaHLSL/shader_type_attr.hlsl b/clang/test/SemaHLSL/shader_type_attr.hlsl --- a/clang/test/SemaHLSL/shader_type_attr.hlsl +++ b/clang/test/SemaHLSL/shader_type_attr.hlsl @@ -53,10 +53,11 @@ [shader(1)] // expected-warning@+1 {{'shader' attribute argument not supported: cs}} [shader("cs")] - +// expected-warning@+1 {{'shader' attribute argument not supported: library}} +[shader("library")] #endif // END of FAIL -// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute [shader("compute")] int entry() { return 1; @@ -64,11 +65,11 @@ // Because these two attributes match, they should both appear in the AST [shader("compute")] -// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute int secondFn(); [shader("compute")] -// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute int secondFn() { return 1; }