diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -17,6 +17,7 @@ #include "clang/Basic/TargetOptions.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/Support/FormatVariadic.h" using namespace clang; using namespace CodeGen; @@ -94,4 +95,11 @@ F->addFnAttr(ShaderAttrKindStr, ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType())); } + if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr()) { + const StringRef NumThreadsKindStr = "hlsl.numthreads"; + StringRef NumThreadsStr = + formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(), + NumThreadsAttr->getZ()); + F->addFnAttr(NumThreadsKindStr, NumThreadsStr); + } } diff --git a/clang/test/CodeGenHLSL/entry.hlsl b/clang/test/CodeGenHLSL/entry.hlsl --- a/clang/test/CodeGenHLSL/entry.hlsl +++ b/clang/test/CodeGenHLSL/entry.hlsl @@ -3,8 +3,9 @@ // Make sure not mangle entry. // CHECK:define void @foo() // Make sure add function attribute. -// CHECK:"dx.shader"="compute" -[numthreads(1,1,1)] +// CHECK:"hlsl.numthreads"="16,8,1" +// CHECK-SAME:"dx.shader"="compute" +[numthreads(16,8,1)] void foo() { }