diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4041,6 +4041,13 @@ let Documentation = [HLSLSV_GroupIndexDocs]; } +def HLSLSV_DispatchThreadID: HLSLAnnotationAttr { + let Spellings = [HLSLSemantic<"SV_DispatchThreadID">]; + let Subjects = SubjectList<[ParmVar, Field]>; + let LangOpts = [HLSL]; + let Documentation = [HLSLSV_DispatchThreadIDDocs]; +} + def HLSLShader : InheritableAttr { let Spellings = [Microsoft<"shader">]; let Subjects = SubjectList<[HLSLEntry]>; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -6590,6 +6590,19 @@ }]; } +def HLSLSV_DispatchThreadIDDocs : Documentation { + let Category = DocCatFunction; + let Content = [{ +The ``SV_DispatchThreadID`` semantic, when applied to an input parameter, specifies a +data binding to map global thread offset within the Dispatch call(per dimension of the group) to the specified parameter. +When applied to a field of a struct, the data binding is specified to the field when the struct is used as a parameter type. +The semantic on the field is ignored when not used as a parameter. +This attribute is only supported in compute shaders. + +The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid + }]; +} + def AnnotateTypeDocs : Documentation { let Category = DocCatType; let Heading = "annotate_type"; diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11652,7 +11652,8 @@ // HLSL Diagnostics def err_hlsl_attr_unsupported_in_stage : Error<"attribute %0 is unsupported in %select{Pixel|Vertex|Geometry|Hull|Domain|Compute|Library|RayGeneration|Intersection|AnyHit|ClosestHit|Miss|Callable|Mesh|Amplification|Invalid}1 shaders, requires %2">; - +def err_hlsl_attr_invalid_type : Error< + "Attribute %0 only applies to fields/parameters that have type %1">; def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numthreads attribute cannot exceed %1">; def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed %0">; def err_hlsl_missing_numthreads : Error<"missing numthreads attribute for %0 shader entry">; diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -39,7 +39,8 @@ uint32_t ResourceCounters[static_cast( hlsl::ResourceClass::NumClasses)] = {0}; - llvm::Value *emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D); + llvm::Value *emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D, + llvm::Type *Ty); public: CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {} diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -109,14 +109,32 @@ ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType())); } +static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) { + if (const auto *VT = dyn_cast(Ty)) { + Value *Result = PoisonValue::get(Ty); + for (unsigned I = 0; I < VT->getNumElements(); ++I) { + Value *Elt = B.CreateCall(F, {B.getInt32(I)}); + Result = B.CreateInsertElement(Result, Elt, I); + } + return Result; + } + return B.CreateCall(F, {B.getInt32(0)}); +} + llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, - const ParmVarDecl &D) { + const ParmVarDecl &D, + llvm::Type *Ty) { assert(D.hasAttrs() && "Entry parameter missing annotation attribute!"); if (D.hasAttr()) { llvm::Function *DxGroupIndex = CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group); return B.CreateCall(FunctionCallee(DxGroupIndex)); } + if (D.hasAttr()) { + llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id); + // dx_thread_id + return buildVectorInput(B, DxThreadID, Ty); + } assert(false && "Unhandled parameter attribute"); return nullptr; } @@ -144,8 +162,16 @@ llvm::SmallVector Args; // FIXME: support struct parameters where semantics are on members. // See: https://github.com/llvm/llvm-project/issues/57874 - for (const auto *Param : FD->parameters()) { - Args.push_back(emitInputSemantic(B, *Param)); + unsigned SRetOffset = 0; + for (const auto &Param : Fn->args()) { + if (Param.hasStructRetAttr()) { + // FIXME: support output. + // See: https://github.com/llvm/llvm-project/issues/57874 + SRetOffset = 1; + continue; + } + const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset); + Args.push_back(emitInputSemantic(B, *PD, Param.getType())); } CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args); diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -6910,6 +6910,44 @@ D->addAttr(::new (S.Context) HLSLSV_GroupIndexAttr(S.Context, AL)); } +static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) { + if (!T->hasUnsignedIntegerRepresentation()) + return false; + if (auto *VT = T->getAs()) + return VT->getNumElements() <= 3; + return true; +} + +static void handleHLSLSV_DispatchThreadIDAttr(Sema &S, Decl *D, + const ParsedAttr &AL) { + using llvm::Triple; + Triple Target = S.Context.getTargetInfo().getTriple(); + // FIXME: it is OK for a compute shader entry and pixel shader entry live in + // same HLSL file.Issue https://github.com/llvm/llvm-project/issues/57880. + if (Target.getEnvironment() != Triple::Compute && + Target.getEnvironment() != Triple::Library) { + uint32_t Pipeline = + (uint32_t)S.Context.getTargetInfo().getTriple().getEnvironment() - + (uint32_t)llvm::Triple::Pixel; + S.Diag(AL.getLoc(), diag::err_hlsl_attr_unsupported_in_stage) + << AL << Pipeline << "Compute"; + return; + } + + // FIXME: support semantic on field. + // See https://github.com/llvm/llvm-project/issues/57889. + + auto *VD = cast(D); + if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) { + S.Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type) + << "SV_DispatchThreadID" + << "uint/uint2/uint3"; + return; + } + + D->addAttr(::new (S.Context) HLSLSV_DispatchThreadIDAttr(S.Context, AL)); +} + static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) { StringRef Str; SourceLocation ArgLoc; @@ -8923,6 +8961,9 @@ case ParsedAttr::AT_HLSLSV_GroupIndex: handleHLSLSVGroupIndexAttr(S, D, AL); break; + case ParsedAttr::AT_HLSLSV_DispatchThreadID: + handleHLSLSV_DispatchThreadIDAttr(S, D, AL); + break; case ParsedAttr::AT_HLSLShader: handleHLSLShaderAttr(S, D, AL); break; diff --git a/clang/test/CodeGenHLSL/semantics/DispatchThreadID.hlsl b/clang/test/CodeGenHLSL/semantics/DispatchThreadID.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenHLSL/semantics/DispatchThreadID.hlsl @@ -0,0 +1,26 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s + +// Make sure SV_DispatchThreadID translated into dx.thread.id. + +const RWBuffer In; +RWBuffer Out; + +// CHECK: define void @foo() +// CHECK: call i32 @llvm.dx.thread.id(i32 0) +// CHECK: call void @"?foo@@YAXH@Z"(i32 %{{.*}}) +[shader("compute")] +[numthreads(8,8,1)] +void foo(uint Idx : SV_DispatchThreadID) { + Out[Idx] = In[Idx]; +} + +// CHECK: define void @bar() +// CHECK: call i32 @llvm.dx.thread.id(i32 0) +// CHECK: call i32 @llvm.dx.thread.id(i32 1) +// CHECK: call void @"?bar@@YAXT?$__vector@H$01@__clang@@@Z"(<2 x i32> %{{.*}}) +[shader("compute")] +[numthreads(8,8,1)] +void bar(uint2 Idx : SV_DispatchThreadID) { + Out[Idx.y] = In[Idx.x]; +} + diff --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl --- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl @@ -1,10 +1,13 @@ -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -x hlsl -ast-dump -verify -o - %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -finclude-default-header -ast-dump -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -x hlsl -ast-dump -finclude-default-header -verify -o - %s [numthreads(8,8, 1)] -// expected-error@+1 {{attribute 'SV_GroupIndex' is unsupported in Mesh shaders, requires Compute}} -void CSMain(int GI : SV_GroupIndex) { -// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int)' +// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in Mesh shaders, requires Compute}} +// expected-error@+1 {{attribute 'SV_DispatchThreadID' is unsupported in Mesh shaders, requires Compute}} +void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint)' // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int' // CHECK-NEXT: HLSLSV_GroupIndexAttr +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint' +// CHECK-NEXT: HLSLSV_DispatchThreadIDAttr } diff --git a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl @@ -0,0 +1,17 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -ast-dump -verify -o - %s + +[numthreads(8,8, 1)] +// expected-error@+1 {{Attribute SV_DispatchThreadID only applies to fields/parameters that have type uint/uint2/uint3}} +void CSMain(float ID : SV_DispatchThreadID) { + +} + +struct ST { + int a; + float b; +}; +[numthreads(8,8, 1)] +// expected-error@+1 {{Attribute SV_DispatchThreadID only applies to fields/parameters that have type uint/uint2/uint3}} +void CSMain2(ST ID : SV_DispatchThreadID) { + +} diff --git a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl @@ -0,0 +1,20 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -finclude-default-header -ast-dump -o - %s | FileCheck %s + +[numthreads(8,8, 1)] +void CSMain(uint ID : SV_DispatchThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (uint)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:18 ID 'uint' +// CHECK-NEXT: HLSLSV_DispatchThreadIDAttr +} +[numthreads(8,8, 1)] +void CSMain1(uint2 ID : SV_DispatchThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1 'void (uint2)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:20 ID 'uint2' +// CHECK-NEXT: HLSLSV_DispatchThreadIDAttr +} +[numthreads(8,8, 1)] +void CSMain2(uint3 ID : SV_DispatchThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2 'void (uint3)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:20 ID 'uint3' +// CHECK-NEXT: HLSLSV_DispatchThreadIDAttr +}