diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -35,6 +35,7 @@ class VarDecl; class ParmVarDecl; class HLSLBufferDecl; +class HLSLResourceBindingAttr; class CallExpr; class Type; class DeclContext; @@ -47,13 +48,20 @@ class CGHLSLRuntime { public: + struct BufferResBinding { + // The ID like 2 in register(b2, space1). + llvm::Optional Reg; + // The Space like 1 is register(b2, space1). + // Default value is 0. + unsigned Space; + BufferResBinding(HLSLResourceBindingAttr *Attr); + }; struct Buffer { Buffer(const HLSLBufferDecl *D); llvm::StringRef Name; // IsCBuffer - Whether the buffer is a cbuffer (and not a tbuffer). bool IsCBuffer; - llvm::Optional Reg; - unsigned Space; + BufferResBinding Binding; // Global variable and offset for each constant. std::vector> Constants; llvm::StructType *LayoutStruct = nullptr; @@ -82,6 +90,10 @@ void setHLSLFunctionAttributes(llvm::Function *, const FunctionDecl *); private: + void addBufferResourceAnnotation(llvm::GlobalVariable *GV, + llvm::StringRef TyName, + hlsl::ResourceClass RC, + BufferResBinding &Binding); void addConstant(VarDecl *D, Buffer &CB); void addBufferDecls(const DeclContext *DC, Buffer &CB); llvm::SmallVector Buffers; diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -20,6 +20,7 @@ #include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/Support/FormatVariadic.h" using namespace clang; using namespace CodeGen; @@ -88,11 +89,11 @@ GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) { // Create global variable for CB. - GlobalVariable *CBGV = - new GlobalVariable(Buf.LayoutStruct, /*isConstant*/ true, - GlobalValue::LinkageTypes::ExternalLinkage, nullptr, - Buf.Name + (Buf.IsCBuffer ? ".cb." : ".tb."), - GlobalValue::NotThreadLocal); + GlobalVariable *CBGV = new GlobalVariable( + Buf.LayoutStruct, /*isConstant*/ true, + GlobalValue::LinkageTypes::ExternalLinkage, nullptr, + llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."), + GlobalValue::NotThreadLocal); IRBuilder<> B(CBGV->getContext()); Value *ZeroIdx = B.getInt32(0); @@ -179,25 +180,51 @@ layoutBuffer(Buf, DL); GlobalVariable *GV = replaceBuffer(Buf); M.getGlobalList().push_back(GV); - // FIXME: generate resource binding. - // See https://github.com/llvm/llvm-project/issues/57915. + hlsl::ResourceClass RC = + Buf.IsCBuffer ? hlsl::ResourceClass::CBuffer : hlsl::ResourceClass::SRV; + std::string TyName = + Buf.Name.str() + (Buf.IsCBuffer ? ".cb." : ".tb.") + "ty"; + addBufferResourceAnnotation(GV, TyName, RC, Buf.Binding); } } -CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D) { - Name = D->getName(); - IsCBuffer = D->isCBuffer(); - if (auto *Binding = D->getAttr()) { - llvm::APInt RegInt(64, 0); - Binding->getSlot().substr(1).getAsInteger(10, RegInt); - Reg = RegInt.getLimitedValue(); +CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D) + : Name(D->getName()), IsCBuffer(D->isCBuffer()), + Binding(D->getAttr()) {} - llvm::APInt SpaceInt(64, 0); - Binding->getSpace().substr(5).getAsInteger(10, RegInt); - Space = SpaceInt.getLimitedValue(); - } else { - Space = 0; +void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV, + llvm::StringRef TyName, + hlsl::ResourceClass RC, + BufferResBinding &Binding) { + uint32_t Counter = ResourceCounters[static_cast(RC)]++; + llvm::Module &M = CGM.getModule(); + + NamedMDNode *ResourceMD = nullptr; + switch (RC) { + case hlsl::ResourceClass::UAV: + ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs"); + break; + case hlsl::ResourceClass::SRV: + ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs"); + break; + case hlsl::ResourceClass::CBuffer: + ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs"); + break; + default: + assert(false && "Unsupported buffer type!"); + return; } + + assert(ResourceMD != nullptr && + "ResourceMD must have been set by the switch above."); + + auto &Ctx = M.getContext(); + IRBuilder<> B(Ctx); + ResourceMD->addOperand(MDNode::get( + Ctx, {ValueAsMetadata::get(GV), MDString::get(Ctx, TyName), + ConstantAsMetadata::get(B.getInt32(Counter)), + ConstantAsMetadata::get(B.getInt32(Binding.Reg.value_or(UINT_MAX))), + ConstantAsMetadata::get(B.getInt32(Binding.Space))})); } void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) { @@ -212,27 +239,24 @@ return; HLSLResourceAttr::ResourceClass RC = Attr->getResourceType(); - uint32_t Counter = ResourceCounters[static_cast(RC)]++; + QualType QT(Ty, 0); + BufferResBinding Binding(RD->getAttr()); + addBufferResourceAnnotation(GV, QT.getAsString(), + static_cast(RC), Binding); +} - NamedMDNode *ResourceMD = nullptr; - switch (RC) { - case HLSLResourceAttr::ResourceClass::UAV: - ResourceMD = CGM.getModule().getOrInsertNamedMetadata("hlsl.uavs"); - break; - default: - assert(false && "Unsupported buffer type!"); - return; +CGHLSLRuntime::BufferResBinding::BufferResBinding( + HLSLResourceBindingAttr *Binding) { + if (Binding) { + llvm::APInt RegInt(64, 0); + Binding->getSlot().substr(1).getAsInteger(10, RegInt); + Reg = RegInt.getLimitedValue(); + llvm::APInt SpaceInt(64, 0); + Binding->getSpace().substr(5).getAsInteger(10, SpaceInt); + Space = SpaceInt.getLimitedValue(); + } else { + Space = 0; } - - assert(ResourceMD != nullptr && - "ResourceMD must have been set by the switch above."); - - auto &Ctx = CGM.getModule().getContext(); - IRBuilder<> B(Ctx); - QualType QT(Ty, 0); - ResourceMD->addOperand(MDNode::get( - Ctx, {ValueAsMetadata::get(GV), MDString::get(Ctx, QT.getAsString()), - ConstantAsMetadata::get(B.getInt32(Counter))})); } void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes( diff --git a/clang/test/CodeGenHLSL/builtins/RWBuffer-annotations.hlsl b/clang/test/CodeGenHLSL/builtins/RWBuffer-annotations.hlsl --- a/clang/test/CodeGenHLSL/builtins/RWBuffer-annotations.hlsl +++ b/clang/test/CodeGenHLSL/builtins/RWBuffer-annotations.hlsl @@ -8,5 +8,5 @@ } // CHECK: !hlsl.uavs = !{![[Single:[0-9]+]], ![[Array:[0-9]+]]} -// CHECK-DAG: ![[Single]] = !{ptr @"?Buffer1@@3V?$RWBuffer@M@hlsl@@A", !"RWBuffer", i32 0} -// CHECK-DAG: ![[Array]] = !{ptr @"?BufferArray@@3PAV?$RWBuffer@T?$__vector@M$03@__clang@@@hlsl@@A", !"RWBuffer >", i32 1} +// CHECK-DAG: ![[Single]] = !{ptr @"?Buffer1@@3V?$RWBuffer@M@hlsl@@A", !"RWBuffer", i32 0, i32 -1, i32 0} +// CHECK-DAG: ![[Array]] = !{ptr @"?BufferArray@@3PAV?$RWBuffer@T?$__vector@M$03@__clang@@@hlsl@@A", !"RWBuffer >", i32 1, i32 -1, i32 0} diff --git a/clang/test/CodeGenHLSL/cbuf.hlsl b/clang/test/CodeGenHLSL/cbuf.hlsl --- a/clang/test/CodeGenHLSL/cbuf.hlsl +++ b/clang/test/CodeGenHLSL/cbuf.hlsl @@ -3,7 +3,7 @@ // RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s // CHECK: @[[CB:.+]] = external constant { float, double } -cbuffer A : register(b0, space1) { +cbuffer A : register(b0, space2) { float a; double b; } @@ -21,3 +21,8 @@ // CHECK: load double, ptr getelementptr inbounds ({ float, double }, ptr @[[TB]], i32 0, i32 1), align 8 return a + b + c*d; } + +// CHECK: !hlsl.cbufs = !{![[CBMD:[0-9]+]]} +// CHECK: !hlsl.srvs = !{![[TBMD:[0-9]+]]} +// CHECK: ![[CBMD]] = !{ptr @[[CB]], !"A.cb.ty", i32 0, i32 0, i32 2} +// CHECK: ![[TBMD]] = !{ptr @[[TB]], !"A.tb.ty", i32 0, i32 2, i32 1}