diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -22,6 +22,7 @@ DXILOpLowering.cpp DXILPrepare.cpp DXILResource.cpp + DXILResourceAnalysis.cpp DXILTranslateMetadata.cpp PointerTypeAnalysis.cpp diff --git a/llvm/lib/Target/DirectX/DXILResource.h b/llvm/lib/Target/DirectX/DXILResource.h --- a/llvm/lib/Target/DirectX/DXILResource.h +++ b/llvm/lib/Target/DirectX/DXILResource.h @@ -17,6 +17,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/Metadata.h" +#include "llvm/Support/Compiler.h" #include namespace llvm { @@ -53,6 +54,8 @@ void write(LLVMContext &Ctx, MutableArrayRef Entries); + void print(raw_ostream &O, StringRef IDPrefix, StringRef BindingPrefix) const; + // The value ordering of this enumeration is part of the DXIL ABI. Elements // can only be added to the end, and not removed. enum class Kinds : uint32_t { @@ -78,6 +81,11 @@ NumEntries, }; + static StringRef getKindName(Kinds Kind); + static void printKind(Kinds Kind, unsigned alignment, raw_ostream &OS, + bool SRV = false, bool HasCounter = false, + uint32_t SampleCount = 0); + // The value ordering of this enumeration is part of the DXIL ABI. Elements // can only be added to the end, and not removed. enum class ComponentType : uint32_t { @@ -103,6 +111,10 @@ LastEntry }; + static StringRef getComponentTypeName(ComponentType CompType); + static void printComponentType(Kinds Kind, ComponentType CompType, + unsigned alignment, raw_ostream &OS); + public: struct ExtendedProperties { llvm::Optional ElementType; @@ -133,6 +145,7 @@ UAVResource(uint32_t I, FrontendResource R); MDNode *write(); + void print(raw_ostream &O) const; }; // FIXME: Fully computing the resource structures requires analyzing the IR @@ -140,15 +153,16 @@ // resource. This partial patch handles some of the leg work, but not all of it. // See issue https://github.com/llvm/llvm-project/issues/57936. class Resources { - Module &Mod; llvm::SmallVector UAVs; - void collectUAVs(); + void collectUAVs(Module &M); public: - Resources(Module &M) : Mod(M) { collectUAVs(); } + void collect(Module &M); - void write(); + void write(Module &M); + void print(raw_ostream &O) const; + LLVM_DUMP_METHOD void dump() const; }; } // namespace dxil diff --git a/llvm/lib/Target/DirectX/DXILResource.cpp b/llvm/lib/Target/DirectX/DXILResource.cpp --- a/llvm/lib/Target/DirectX/DXILResource.cpp +++ b/llvm/lib/Target/DirectX/DXILResource.cpp @@ -15,6 +15,8 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Format.h" using namespace llvm; using namespace llvm::dxil; @@ -32,8 +34,8 @@ return cast(Entry->getOperand(2))->getValue(); } -void Resources::collectUAVs() { - NamedMDNode *Entry = Mod.getNamedMetadata("hlsl.uavs"); +void Resources::collectUAVs(Module &M) { + NamedMDNode *Entry = M.getNamedMetadata("hlsl.uavs"); if (!Entry || Entry->getNumOperands() == 0) return; @@ -43,6 +45,8 @@ } } +void Resources::collect(Module &M) { collectUAVs(M); } + ResourceBase::ResourceBase(uint32_t I, FrontendResource R) : ID(I), GV(R.getGlobalVariable()), Name(""), Space(0), LowerBound(0), RangeSize(1) { @@ -50,12 +54,194 @@ RangeSize = ArrTy->getNumElements(); } +StringRef ResourceBase::getComponentTypeName(ComponentType CompType) { + switch (CompType) { + case ComponentType::LastEntry: + case ComponentType::Invalid: + return "invalid"; + case ComponentType::I1: + return "i1"; + case ComponentType::I16: + return "i16"; + case ComponentType::U16: + return "u16"; + case ComponentType::I32: + return "i32"; + case ComponentType::U32: + return "u32"; + case ComponentType::I64: + return "i64"; + case ComponentType::U64: + return "u64"; + case ComponentType::F16: + return "f16"; + case ComponentType::F32: + return "f32"; + case ComponentType::F64: + return "f64"; + case ComponentType::SNormF16: + return "snorm_f16"; + case ComponentType::UNormF16: + return "unorm_f16"; + case ComponentType::SNormF32: + return "snorm_f32"; + case ComponentType::UNormF32: + return "unorm_f32"; + case ComponentType::SNormF64: + return "snorm_f64"; + case ComponentType::UNormF64: + return "unorm_f64"; + case ComponentType::PackedS8x32: + return "p32i8"; + case ComponentType::PackedU8x32: + return "p32u8"; + } +} + +void ResourceBase::printComponentType(Kinds Kind, ComponentType CompType, + unsigned alignment, raw_ostream &OS) { + switch (Kind) { + default: + // TODO: add vector size. + OS << right_justify(getComponentTypeName(CompType), alignment); + break; + case Kinds::RawBuffer: + OS << right_justify("byte", alignment); + break; + case Kinds::StructuredBuffer: + OS << right_justify("struct", alignment); + break; + case Kinds::CBuffer: + case Kinds::Sampler: + OS << right_justify("NA", alignment); + break; + case Kinds::Invalid: + case Kinds::NumEntries: + break; + } +} + +StringRef ResourceBase::getKindName(Kinds Kind) { + switch (Kind) { + case Kinds::NumEntries: + case Kinds::Invalid: + return "invalid"; + case Kinds::Texture1D: + return "1d"; + case Kinds::Texture2D: + return "2d"; + case Kinds::Texture2DMS: + return "2dMS"; + case Kinds::Texture3D: + return "3d"; + case Kinds::TextureCube: + return "cube"; + case Kinds::Texture1DArray: + return "1darray"; + case Kinds::Texture2DArray: + return "2darray"; + case Kinds::Texture2DMSArray: + return "2darrayMS"; + case Kinds::TextureCubeArray: + return "cubearray"; + case Kinds::TypedBuffer: + return "buf"; + case Kinds::RawBuffer: + return "rawbuf"; + case Kinds::StructuredBuffer: + return "structbuf"; + case Kinds::CBuffer: + return "cbuffer"; + case Kinds::Sampler: + return "sampler"; + case Kinds::TBuffer: + return "tbuffer"; + case Kinds::RTAccelerationStructure: + return "ras"; + case Kinds::FeedbackTexture2D: + return "fbtex2d"; + case Kinds::FeedbackTexture2DArray: + return "fbtex2darray"; + } +} + +void ResourceBase::printKind(Kinds Kind, unsigned alignment, raw_ostream &OS, + bool SRV, bool HasCounter, uint32_t SampleCount) { + switch (Kind) { + default: + OS << right_justify(getKindName(Kind), alignment); + break; + + case Kinds::RawBuffer: + case Kinds::StructuredBuffer: + if (SRV) + OS << right_justify("r/o", alignment); + else { + if (!HasCounter) + OS << right_justify("r/w", alignment); + else + OS << right_justify("r/w+cnt", alignment); + } + break; + case Kinds::TypedBuffer: + OS << right_justify("buf", alignment); + break; + case Kinds::Texture2DMS: + case Kinds::Texture2DMSArray: { + std::string dimName = getKindName(Kind).str(); + if (SampleCount) + dimName += std::to_string(SampleCount); + OS << right_justify(dimName, alignment); + } break; + case Kinds::CBuffer: + case Kinds::Sampler: + OS << right_justify("NA", alignment); + break; + case Kinds::Invalid: + case Kinds::NumEntries: + break; + } +} + +void ResourceBase::print(raw_ostream &OS, StringRef IDPrefix, + StringRef BindingPrefix) const { + std::string ResID = IDPrefix.str(); + ResID += std::to_string(ID); + OS << right_justify(ResID, 8); + + std::string Bind = BindingPrefix.str(); + Bind += std::to_string(LowerBound); + if (Space) + Bind += ",space" + std::to_string(Space); + + OS << right_justify(Bind, 15); + if (RangeSize != UINT_MAX) + OS << right_justify(std::to_string(RangeSize), 6) << "\n"; + else + OS << right_justify("unbounded", 6) << "\n"; +} + UAVResource::UAVResource(uint32_t I, FrontendResource R) : ResourceBase(I, R), Shape(Kinds::Invalid), GloballyCoherent(false), HasCounter(false), IsROV(false), ExtProps() { parseSourceType(R.getSourceType()); } +void UAVResource::print(raw_ostream &OS) const { + OS << "; " << left_justify(Name, 31); + + OS << right_justify("UAV", 10); + + printComponentType( + Shape, ExtProps.ElementType.value_or(ComponentType::Invalid), 8, OS); + + // FIXME: support SampleCount. + // See https://github.com/llvm/llvm-project/issues/58175 + printKind(Shape, 12, OS, /*SRV*/ false, HasCounter); + // Print the binding part. + ResourceBase::print(OS, "U", "u"); +} + // FIXME: Capture this in HLSL source. I would go do this right now, but I want // to get this in first so that I can make sure to capture all the extra // information we need to remove the source type string from here (See issue: @@ -140,19 +326,34 @@ return MDNode::get(Ctx, Entries); } -void Resources::write() { +void Resources::write(Module &M) { Metadata *ResourceMDs[4] = {nullptr, nullptr, nullptr, nullptr}; SmallVector UAVMDs; for (auto &UAV : UAVs) UAVMDs.emplace_back(UAV.write()); if (!UAVMDs.empty()) - ResourceMDs[1] = MDNode::get(Mod.getContext(), UAVMDs); + ResourceMDs[1] = MDNode::get(M.getContext(), UAVMDs); - NamedMDNode *DXResMD = Mod.getOrInsertNamedMetadata("dx.resources"); - DXResMD->addOperand(MDNode::get(Mod.getContext(), ResourceMDs)); + NamedMDNode *DXResMD = M.getOrInsertNamedMetadata("dx.resources"); + DXResMD->addOperand(MDNode::get(M.getContext(), ResourceMDs)); - NamedMDNode *Entry = Mod.getNamedMetadata("hlsl.uavs"); + NamedMDNode *Entry = M.getNamedMetadata("hlsl.uavs"); if (Entry) Entry->eraseFromParent(); } + +void Resources::print(raw_ostream &O) const { + O << ";\n" + << "; Resource Bindings:\n" + << ";\n" + << "; Name Type Format Dim " + "ID HLSL Bind Count\n" + << "; ------------------------------ ---------- ------- ----------- " + "------- -------------- ------\n"; + + for (auto &UAV : UAVs) + UAV.print(O); +} + +void Resources::dump() const { print(dbgs()); } diff --git a/llvm/lib/Target/DirectX/DXILResourceAnalysis.h b/llvm/lib/Target/DirectX/DXILResourceAnalysis.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILResourceAnalysis.h @@ -0,0 +1,56 @@ +//===- DXILResourceAnalysis.h - DXIL Resource analysis-------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains Analysis for information about DXIL resources. +/// +//===----------------------------------------------------------------------===// + +#include "DXILResource.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include + +namespace llvm { +/// Analysis pass that exposes the \c DXILResource for a module. +class DXILResourceAnalysis : public AnalysisInfoMixin { + friend AnalysisInfoMixin; + static AnalysisKey Key; + +public: + typedef dxil::Resources Result; + dxil::Resources run(Module &M, ModuleAnalysisManager &AM); +}; + +/// Printer pass for the \c DXILResourceAnalysis results. +class DXILResourcePrinterPass : public PassInfoMixin { + raw_ostream &OS; + +public: + explicit DXILResourcePrinterPass(raw_ostream &OS) : OS(OS) {} + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); +}; + +/// The legacy pass manager's analysis pass to compute DXIL resource +/// information. +class DXILResourceWrapper : public ModulePass { + dxil::Resources Resources; + +public: + static char ID; // Pass identification, replacement for typeid + + DXILResourceWrapper(); + + dxil::Resources &getDXILResource() { return Resources; } + const dxil::Resources &getDXILResource() const { return Resources; } + + /// Calculate the DXILResource for the module. + bool runOnModule(Module &M) override; + + void print(raw_ostream &O, const Module *M = nullptr) const override; +}; +} // namespace llvm diff --git a/llvm/lib/Target/DirectX/DXILResourceAnalysis.cpp b/llvm/lib/Target/DirectX/DXILResourceAnalysis.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILResourceAnalysis.cpp @@ -0,0 +1,52 @@ +//===- DXILResourceAnalysis.cpp - DXIL Resource analysis-------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains Analysis for information about DXIL resources. +/// +//===----------------------------------------------------------------------===// + +#include "DXILResourceAnalysis.h" +#include "DirectX.h" +#include "llvm/IR/PassManager.h" + +using namespace llvm; + +#define DEBUG_TYPE "dxil-resource-analysis" + +dxil::Resources DXILResourceAnalysis::run(Module &M, + ModuleAnalysisManager &AM) { + dxil::Resources R; + R.collect(M); + return R; +} + +AnalysisKey DXILResourceAnalysis::Key; + +PreservedAnalyses DXILResourcePrinterPass::run(Module &M, + ModuleAnalysisManager &AM) { + dxil::Resources Res = AM.getResult(M); + Res.print(OS); + return PreservedAnalyses::all(); +} + +char DXILResourceWrapper::ID = 0; +INITIALIZE_PASS_BEGIN(DXILResourceWrapper, DEBUG_TYPE, + "DXIL resource Information", true, true) +INITIALIZE_PASS_END(DXILResourceWrapper, DEBUG_TYPE, + "DXIL resource Information", true, true) + +bool DXILResourceWrapper::runOnModule(Module &M) { + Resources.collect(M); + return false; +} + +DXILResourceWrapper::DXILResourceWrapper() : ModulePass(ID) {} + +void DXILResourceWrapper::print(raw_ostream &OS, const Module *) const { + Resources.print(OS); +} diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -10,6 +10,7 @@ #include "DXILMetadata.h" #include "DXILResource.h" +#include "DXILResourceAnalysis.h" #include "DirectX.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/Triple.h" @@ -28,6 +29,11 @@ StringRef getPassName() const override { return "DXIL Metadata Emit"; } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequired(); + } + bool runOnModule(Module &M) override; }; @@ -40,8 +46,8 @@ ValVerMD.update(VersionTuple(1, 0)); dxil::createShaderModelMD(M); - dxil::Resources Res(M); - Res.write(); + dxil::Resources &Res = getAnalysis().getDXILResource(); + Res.write(M); return false; } @@ -51,5 +57,8 @@ return new DXILTranslateMetadata(); } -INITIALIZE_PASS(DXILTranslateMetadata, "dxil-metadata-emit", - "DXIL Metadata Emit", false, false) +INITIALIZE_PASS_BEGIN(DXILTranslateMetadata, "dxil-metadata-emit", + "DXIL Metadata Emit", false, false) +INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapper) +INITIALIZE_PASS_END(DXILTranslateMetadata, "dxil-metadata-emit", + "DXIL Metadata Emit", false, false) diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h --- a/llvm/lib/Target/DirectX/DirectX.h +++ b/llvm/lib/Target/DirectX/DirectX.h @@ -38,6 +38,10 @@ /// Pass to emit metadata for DXIL. ModulePass *createDXILTranslateMetadataPass(); + +/// Initializer for DXILTranslateMetadata. +void initializeDXILResourceWrapperPass(PassRegistry &); + } // namespace llvm #endif // LLVM_LIB_TARGET_DIRECTX_DIRECTX_H diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.h b/llvm/lib/Target/DirectX/DirectXTargetMachine.h --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.h +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.h @@ -45,6 +45,7 @@ } TargetTransformInfo getTargetTransformInfo(const Function &F) const override; + void registerPassBuilderCallbacks(PassBuilder &PB) override; }; } // namespace llvm diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "DirectXTargetMachine.h" +#include "DXILResourceAnalysis.h" #include "DXILWriter/DXILWriterPass.h" #include "DirectX.h" #include "DirectXSubtarget.h" @@ -25,6 +26,7 @@ #include "llvm/MC/MCSectionDXContainer.h" #include "llvm/MC/SectionKind.h" #include "llvm/MC/TargetRegistry.h" +#include "llvm/Passes/PassBuilder.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/ErrorHandling.h" @@ -39,6 +41,7 @@ initializeEmbedDXILPassPass(*PR); initializeDXILOpLoweringLegacyPass(*PR); initializeDXILTranslateMetadataPass(*PR); + initializeDXILResourceWrapperPass(*PR); } class DXILTargetObjectFile : public TargetLoweringObjectFile { @@ -92,6 +95,22 @@ DirectXTargetMachine::~DirectXTargetMachine() {} +void DirectXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) { + PB.registerPipelineParsingCallback( + [](StringRef PassName, ModulePassManager &PM, + ArrayRef) { + if (PassName == "print-dxil-resource") { + PM.addPass(DXILResourcePrinterPass(dbgs())); + return true; + } + return false; + }); + + PB.registerAnalysisRegistrationCallback([](ModuleAnalysisManager &MAM) { + MAM.registerPass([&] { return DXILResourceAnalysis(); }); + }); +} + bool DirectXTargetMachine::addPassesToEmitFile( PassManagerBase &PM, raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut, CodeGenFileType FileType, bool DisableVerify, diff --git a/llvm/test/CodeGen/DirectX/UAVMetadata.ll b/llvm/test/CodeGen/DirectX/UAVMetadata.ll --- a/llvm/test/CodeGen/DirectX/UAVMetadata.ll +++ b/llvm/test/CodeGen/DirectX/UAVMetadata.ll @@ -1,10 +1,27 @@ ; RUN: opt -S -dxil-metadata-emit < %s | FileCheck %s -; ModuleID = '/home/cbieneman/dev/shuffle.hlsl' +; RUN: opt -S --passes="print-dxil-resource" < %s 2>&1 | FileCheck %s --check-prefix=PRINT + target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" target triple = "dxil-pc-shadermodel6.0-compute" %"class.hlsl::RWBuffer" = type { ptr } + +; PRINT:; Resource Bindings: +; PRINT-NEXT:; +; PRINT-NEXT:; Name Type Format Dim ID HLSL Bind Count +; PRINT-NEXT:; ------------------------------ ---------- ------- ----------- ------- -------------- ------ +; PRINT-NEXT:; UAV f16 buf U0 u0 1 +; PRINT-NEXT:; UAV f32 buf U1 u0 1 +; PRINT-NEXT:; UAV f64 buf U2 u0 1 +; PRINT-NEXT:; UAV i1 buf U3 u0 2 +; PRINT-NEXT:; UAV byte r/w U4 u0 1 +; PRINT-NEXT:; UAV struct r/w U5 u0 1 +; PRINT-NEXT:; UAV i32 buf U6 u0 1 +; PRINT-NEXT:; UAV struct r/w U7 u0 1 +; PRINT-NEXT:; UAV byte r/w U8 u0 1 +; PRINT-NEXT:; UAV u64 buf U9 u0 1 + @Zero = local_unnamed_addr global %"class.hlsl::RWBuffer" zeroinitializer, align 4 @One = local_unnamed_addr global %"class.hlsl::RWBuffer" zeroinitializer, align 4 @Two = local_unnamed_addr global %"class.hlsl::RWBuffer" zeroinitializer, align 4 @@ -30,7 +47,6 @@ !8 = !{ptr @Eight, !"RasterizerOrderedByteAddressBuffer", i32 8} !9 = !{ptr @Nine, !"RWBuffer", i32 9} - ; CHECK: !dx.resources = !{[[ResList:[!][0-9]+]]} ; CHECK: [[ResList]] = !{null, [[UAVList:[!][0-9]+]], null, null}