diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4016,16 +4016,25 @@ let Spellings = [Microsoft<"shader">]; let Subjects = SubjectList<[HLSLEntry]>; let LangOpts = [HLSL]; - let Args = [EnumArgument<"Type", "ShaderType", - ["pixel", "vertex", "geometry", "hull", "domain", - "compute", "raygeneration", "intersection", - "anyhit", "closesthit", "miss", "callable", "mesh", - "amplification"], - ["Pixel", "Vertex", "Geometry", "Hull", "Domain", - "Compute", "RayGeneration", "Intersection", - "AnyHit", "ClosestHit", "Miss", "Callable", "Mesh", - "Amplification"] - >]; + // NOTE: + // order for the enum should match order in llvm::Triple::EnvironmentType. + // ShaderType will be converted to llvm::Triple::EnvironmentType like + // (llvm::Triple::EnvironmentType)((uint32_t)ShaderType + + // (uint32_t)llvm::Triple::EnvironmentType::Pixel). + // This will avoid update code for convert when new shader type is added. + let Args = [ + EnumArgument<"Type", "ShaderType", + [ + "pixel", "vertex", "geometry", "hull", "domain", "compute", + "library", "raygeneration", "intersection", "anyHit", + "closestHit", "miss", "callable", "mesh", "amplification" + ], + [ + "Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute", + "Library", "RayGeneration", "Intersection", "AnyHit", + "ClosestHit", "Miss", "Callable", "Mesh", "Amplification" + ]> + ]; let Documentation = [HLSLSV_ShaderTypeAttrDocs]; } diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11645,6 +11645,7 @@ def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numthreads attribute cannot exceed %1">; def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed %0">; +def err_hlsl_missing_numthreads : Error<"missing numthreads attribute for %0 shader entry">; def err_hlsl_attribute_param_mismatch : Error<"%0 attribute parameters do not match the previous declaration">; def err_hlsl_pointers_unsupported : Error< diff --git a/clang/include/clang/Basic/TargetOptions.h b/clang/include/clang/Basic/TargetOptions.h --- a/clang/include/clang/Basic/TargetOptions.h +++ b/clang/include/clang/Basic/TargetOptions.h @@ -113,6 +113,9 @@ /// The validator version for dxil. std::string DxilValidatorVersion; + + /// The entry point name for HLSL shader being compiled as specified by -E. + std::string HLSLEntry; }; } // end namespace clang diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -6909,3 +6909,12 @@ def enable_16bit_types : DXCFlag<"enable-16bit-types">, Alias, HelpText<"Enable 16-bit types and disable min precision types." "Available in HLSL 2018 and shader model 6.2.">; +def hlsl_entrypoint : Option<["-"], "hlsl-entry", KIND_SEPARATE>, + Group, + Flags<[CC1Option]>, + MarshallingInfoString>, + HelpText<"Entry point name for hlsl">; +def dxc_entrypoint : Option<["--", "/", "-"], "E", KIND_JOINED_OR_SEPARATE>, + Group, + Flags<[DXCOption, NoXarchOption]>, + HelpText<"Entry point name">; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -2898,6 +2898,7 @@ QualType NewT, QualType OldT); void CheckMain(FunctionDecl *FD, const DeclSpec &D); void CheckMSVCRTEntryPoint(FunctionDecl *FD); + void CheckHLSLEntryPoint(FunctionDecl *FD); Attr *getImplicitCodeSegOrSectionAttrForFunction(const FunctionDecl *FD, bool IsDefinition); void CheckFunctionOrTemplateParamDeclarator(Scope *S, Declarator &D); diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp @@ -3514,7 +3514,8 @@ options::OPT_S, options::OPT_emit_llvm, options::OPT_disable_llvm_passes, - options::OPT_fnative_half_type}; + options::OPT_fnative_half_type, + options::OPT_hlsl_entrypoint}; for (const auto &Arg : ForwardedArguments) if (const auto *A = Args.getLastArg(Arg)) diff --git a/clang/lib/Driver/ToolChains/HLSL.cpp b/clang/lib/Driver/ToolChains/HLSL.cpp --- a/clang/lib/Driver/ToolChains/HLSL.cpp +++ b/clang/lib/Driver/ToolChains/HLSL.cpp @@ -158,6 +158,12 @@ if (!isLegalValidatorVersion(ValVerStr, getDriver())) continue; } + if (A->getOption().getID() == options::OPT_dxc_entrypoint) { + DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_hlsl_entrypoint), + A->getValue()); + A->claim(); + continue; + } if (A->getOption().getID() == options::OPT_emit_pristine_llvm) { // Translate fcgl into -S -emit-llvm and -disable-llvm-passes. DAL->AddFlagArg(nullptr, Opts.getOption(options::OPT_S)); diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp --- a/clang/lib/Frontend/CompilerInvocation.cpp +++ b/clang/lib/Frontend/CompilerInvocation.cpp @@ -510,6 +510,10 @@ Diags.Report(diag::err_drv_argument_not_allowed_with) << "-fgnu89-inline" << GetInputKindName(IK); + if (Args.hasArg(OPT_hlsl_entrypoint) && !LangOpts.HLSL) + Diags.Report(diag::err_drv_argument_not_allowed_with) + << "-hlsl-entry" << GetInputKindName(IK); + if (Args.hasArg(OPT_fgpu_allow_device_init) && !LangOpts.HIP) Diags.Report(diag::warn_ignored_hip_only_option) << Args.getLastArg(OPT_fgpu_allow_device_init)->getAsString(Args); diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -10005,6 +10005,27 @@ } } + if (getLangOpts().HLSL) { + auto &TargetInfo = getASTContext().getTargetInfo(); + // Skip operator overload which not identifier. + // Also make sure NewFD is in translation-unit scope. + if (!NewFD->isInvalidDecl() && Name.isIdentifier() && + NewFD->getName() == TargetInfo.getTargetOpts().HLSLEntry && + S->getDepth() == 0) { + CheckHLSLEntryPoint(NewFD); + if (!NewFD->isInvalidDecl()) { + auto TripleShaderType = TargetInfo.getTriple().getEnvironment(); + AttributeCommonInfo AL(NewFD->getBeginLoc()); + HLSLShaderAttr::ShaderType ShaderType = (HLSLShaderAttr::ShaderType)( + TripleShaderType - (uint32_t)llvm::Triple::Pixel); + // To share code with HLSLShaderAttr, add HLSLShaderAttr to entry + // function. + if (HLSLShaderAttr *Attr = mergeHLSLShaderAttr(NewFD, AL, ShaderType)) + NewFD->addAttr(Attr); + } + } + } + if (!getLangOpts().CPlusPlus) { // Perform semantic checking on the function declaration. if (!NewFD->isInvalidDecl() && NewFD->isMain()) @@ -11833,6 +11854,23 @@ } } +void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) { + auto &TargetInfo = getASTContext().getTargetInfo(); + auto const Triple = TargetInfo.getTriple(); + switch (Triple.getEnvironment()) { + default: + // FIXME: check all shader profiles. + break; + case llvm::Triple::EnvironmentType::Compute: + if (!FD->hasAttr()) { + Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads) + << Triple.getEnvironmentName(); + FD->setInvalidDecl(); + } + break; + } +} + bool Sema::CheckForConstantInitializer(Expr *Init, QualType DclT) { // FIXME: Need strict checking. In C89, we need to check for // any assignment, increment, decrement, function-calls, or diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -6904,7 +6904,11 @@ return; HLSLShaderAttr::ShaderType ShaderType; - if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) { + if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType) || + // Library is added to help convert HLSLShaderAttr::ShaderType to + // llvm::Triple::EnviromentType. It is not a legal + // HLSLShaderAttr::ShaderType. + ShaderType == HLSLShaderAttr::Library) { S.Diag(AL.getLoc(), diag::warn_attribute_type_not_supported) << AL << Str << ArgLoc; return; diff --git a/clang/test/Driver/dxc_E.hlsl b/clang/test/Driver/dxc_E.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/Driver/dxc_E.hlsl @@ -0,0 +1,4 @@ +// RUN: %clang_dxc -Efoo -Tlib_6_7 foo.hlsl -### %s 2>&1 | FileCheck %s + +// Make sure E option flag which translated into "-hlsl-entry". +// CHECK:"-hlsl-entry" "foo" diff --git a/clang/test/Driver/hlsl-entry.cpp b/clang/test/Driver/hlsl-entry.cpp new file mode 100644 --- /dev/null +++ b/clang/test/Driver/hlsl-entry.cpp @@ -0,0 +1,3 @@ +// RUN:not %clang -cc1 -triple dxil-pc-shadermodel6.3-compute -x c++ -hlsl-entry foo %s 2>&1 | FileCheck %s --check-prefix=NOTHLSL + +// NOTHLSL:invalid argument '-hlsl-entry' not allowed with 'C++' diff --git a/clang/test/SemaHLSL/entry.hlsl b/clang/test/SemaHLSL/entry.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/SemaHLSL/entry.hlsl @@ -0,0 +1,15 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry foo -DWITH_NUM_THREADS -ast-dump -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry foo -o - %s -verify + + +// Make sure add HLSLShaderAttr along with HLSLNumThreadsAttr. +// CHECK:HLSLNumThreadsAttr 0x{{.*}} 1 1 1 +// CHECK:HLSLShaderAttr 0x{{.*}} Compute + +#ifdef WITH_NUM_THREADS +[numthreads(1,1,1)] +#endif +// expected-error@+1 {{missing numthreads attribute for compute shader entry}} +void foo() { + +} diff --git a/clang/test/SemaHLSL/prohibit_pointer.hlsl b/clang/test/SemaHLSL/prohibit_pointer.hlsl --- a/clang/test/SemaHLSL/prohibit_pointer.hlsl +++ b/clang/test/SemaHLSL/prohibit_pointer.hlsl @@ -1,4 +1,4 @@ -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -o - -fsyntax-only %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -o - -fsyntax-only %s -verify // expected-error@+1 {{pointers are unsupported in HLSL}} typedef int (*fn_int)(int); diff --git a/clang/test/SemaHLSL/shader_type_attr.hlsl b/clang/test/SemaHLSL/shader_type_attr.hlsl --- a/clang/test/SemaHLSL/shader_type_attr.hlsl +++ b/clang/test/SemaHLSL/shader_type_attr.hlsl @@ -53,10 +53,11 @@ [shader(1)] // expected-warning@+1 {{'shader' attribute argument not supported: cs}} [shader("cs")] - +// expected-warning@+1 {{'shader' attribute argument not supported: library}} +[shader("library")] #endif // END of FAIL -// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute [shader("compute")] int entry() { return 1; @@ -64,11 +65,11 @@ // Because these two attributes match, they should both appear in the AST [shader("compute")] -// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute int secondFn(); [shader("compute")] -// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute int secondFn() { return 1; }