diff --git a/llvm/lib/Target/DirectX/DXILMetadata.cpp b/llvm/lib/Target/DirectX/DXILMetadata.cpp --- a/llvm/lib/Target/DirectX/DXILMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILMetadata.cpp @@ -95,11 +95,11 @@ unsigned NumThreads[3]; } CS; - EntryProps(Function &F, Triple::EnvironmentType ModuleShaderKind) + EntryProps(Function &F, Attribute &EntryAttr, + Triple::EnvironmentType ModuleShaderKind) : ShaderKind(ModuleShaderKind) { if (ShaderKind == Triple::EnvironmentType::Library) { - Attribute EntryAttr = F.getFnAttribute("hlsl.shader"); StringRef EntryProfile = EntryAttr.getValueAsString(); Triple T("", "", "", EntryProfile); ShaderKind = T.getEnvironment(); @@ -108,6 +108,7 @@ if (ShaderKind == Triple::EnvironmentType::Compute) { auto NumThreadsStr = F.getFnAttribute("hlsl.numthreads").getValueAsString(); + F.removeFnAttr("hlsl.numthreads"); SmallVector NumThreads; NumThreadsStr.split(NumThreads, ','); assert(NumThreads.size() == 3 && "invalid numthreads"); @@ -206,8 +207,9 @@ EntryProps Props; public: - EntryMD(Function &F, Triple::EnvironmentType ModuleShaderKind) - : F(F), Ctx(F.getContext()), Props(F, ModuleShaderKind) {} + EntryMD(Function &F, Attribute &EntryAttr, + Triple::EnvironmentType ModuleShaderKind) + : F(F), Ctx(F.getContext()), Props(F, EntryAttr, ModuleShaderKind) {} MDTuple *emitEntryTuple(MDTuple *Resources, uint64_t RawShaderFlag) { // FIXME: add signature for profile other than CS. @@ -256,11 +258,12 @@ } // namespace void dxil::createEntryMD(Module &M, const uint64_t ShaderFlags) { - SmallVector EntryList; + SmallVector> EntryList; for (auto &F : M.functions()) { if (!F.hasFnAttribute("hlsl.shader")) continue; - EntryList.emplace_back(&F); + EntryList.emplace_back(std::make_pair(&F, F.getFnAttribute("hlsl.shader"))); + F.removeFnAttr("hlsl.shader"); } auto &Ctx = M.getContext(); @@ -279,8 +282,8 @@ EntryMD::emitEmptyEntryForLib(MDResources, ShaderFlags, Ctx); Entries.emplace_back(EmptyEntry); - for (Function *Entry : EntryList) { - EntryMD MD(*Entry, T.getEnvironment()); + for (auto &It : EntryList) { + EntryMD MD(*It.first, It.second, T.getEnvironment()); Entries.emplace_back(MD.emitEntryTupleForLib(0)); } } break; @@ -294,7 +297,8 @@ case Triple::EnvironmentType::Pixel: { assert(EntryList.size() == 1 && "non-lib profiles should only have one entry"); - EntryMD MD(*EntryList.front(), T.getEnvironment()); + EntryMD MD(*EntryList.front().first, EntryList.front().second, + T.getEnvironment()); Entries.emplace_back(MD.emitEntryTuple(MDResources, ShaderFlags)); } break; default: diff --git a/llvm/test/CodeGen/DirectX/hlsl_attr.ll b/llvm/test/CodeGen/DirectX/hlsl_attr.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/hlsl_attr.ll @@ -0,0 +1,66 @@ +; RUN: opt -S -dxil-metadata-emit < %s | FileCheck %s + + +; CHECK-LABEL:define void @CSMain() + +; Make sure hlsl.* is removed from function attribute. +; CHECK-NOT:"hlsl.numthreads" +; CHECK-NOT:"hlsl.shader" + +; CHECK:!dx.shaderModel = !{![[ShaderModel:[0-9]+]]} +; CHECK:!dx.entryPoints = !{![[ENTRY:[0-9]+]]} +; CHECK:![[ShaderModel]] = !{!"cs", i32 6, i32 0} +; CHECK:![[RES:[0-9]+]] = !{null, ![[UAV:[0-9]+]], null, null} +; CHECK:![[UAV]] = !{![[SRCBUF:[0-9]+]], ![[DSTBUF:[0-9]+]]} +; CHECK:![[SRCBUF]] = !{i32 0, ptr @"?srcBuffer@@3V?$RWBuffer@T?$__vector@H$01@__clang@@@hlsl@@A", !"", i32 0, i32 1, i32 1, i32 10, i1 false, i1 false, i1 false, null} +; CHECK:![[DSTBUF]] = !{i32 1, ptr @"?dstBuffer@@3V?$RWBuffer@T?$__vector@H$01@__clang@@@hlsl@@A", !"", i32 0, i32 0, i32 1, i32 10, i1 false, i1 false, i1 false, null} +; CHECK:![[ENTRY]] = !{ptr @CSMain, !"CSMain", null, ![[RES]], ![[EXTRA:[0-9]+]]} +; CHECK:![[EXTRA]] = !{i32 4, ![[NUMTHREADS:[0-9]+]]} +; CHECK:![[NUMTHREADS]] = !{i32 1024, i32 1, i32 1} + +target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" +target triple = "dxil-unknown-shadermodel6.0-compute" + +%"class.hlsl::RWBuffer" = type { ptr } +%dx.types.Handle = type { ptr } +%dx.types.ResRet.i32 = type { i32, i32, i32, i32, i32 } + +@"?srcBuffer@@3V?$RWBuffer@T?$__vector@H$01@__clang@@@hlsl@@A" = local_unnamed_addr global %"class.hlsl::RWBuffer" zeroinitializer, align 4 +@"?dstBuffer@@3V?$RWBuffer@T?$__vector@H$01@__clang@@@hlsl@@A" = local_unnamed_addr global %"class.hlsl::RWBuffer" zeroinitializer, align 4 + +; Function Attrs: mustprogress nofree nounwind willreturn memory(readwrite, inaccessiblemem: read) +define void @CSMain() local_unnamed_addr #0 { +entry: + %0 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 1, i32 0, i1 false) + %1 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 1, i1 false) + %2 = call i32 @dx.op.threadId.i32(i32 93, i32 0) + %3 = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle %1, i32 %2, i32 poison) + %4 = extractvalue %dx.types.ResRet.i32 %3, 0 + %5 = extractvalue %dx.types.ResRet.i32 %3, 1 + %add.i.i0 = add i32 %4, 10 + %add.i.i1 = add i32 %5, 10 + call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %0, i32 %2, i32 poison, i32 %add.i.i0, i32 %add.i.i1, i32 %add.i.i0, i32 %add.i.i0, i8 15) + ret void +} + +declare %dx.types.Handle @dx.op.createHandle(i32 %0, i8 %1, i32 %2, i32 %3, i1 %4) + +declare %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 %0, %dx.types.Handle %1, i32 %2, i32 %3) + +declare void @dx.op.bufferStore.i32(i32 %0, %dx.types.Handle %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i8 %8) + +declare i32 @dx.op.threadId.i32(i32 %0, i32 %1) + +attributes #0 = { mustprogress nofree nounwind willreturn memory(readwrite, inaccessiblemem: read) "frame-pointer"="all" "hlsl.numthreads"="1024,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } + +!hlsl.uavs = !{!0, !1} +!llvm.module.flags = !{!2, !3} +!dx.valver = !{!4} +!llvm.ident = !{!5} + +!0 = !{ptr @"?srcBuffer@@3V?$RWBuffer@T?$__vector@H$01@__clang@@@hlsl@@A", !"RWBuffer", i32 10, i32 1, i32 0} +!1 = !{ptr @"?dstBuffer@@3V?$RWBuffer@T?$__vector@H$01@__clang@@@hlsl@@A", !"RWBuffer", i32 10, i32 0, i32 0} +!2 = !{i32 1, !"wchar_size", i32 4} +!3 = !{i32 7, !"frame-pointer", i32 2} +!4 = !{i32 1, i32 1} +!5 = !{!"clang version 16.0.0 (https://github.com/llvm/llvm-project.git 9641897a6b7b727fd5139d7bf92cb6b027037193)"}