diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -17,6 +17,7 @@ DirectXRegisterInfo.cpp DirectXSubtarget.cpp DirectXTargetMachine.cpp + DXILCBufferLowering.cpp DXILMetadata.cpp DXILOpBuilder.cpp DXILOpLowering.cpp diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -26,11 +26,13 @@ def ThreadIdInGroupClass : dxil_class<"ThreadIdInGroup">; def ThreadIdClass : dxil_class<"ThreadId">; def GroupIdClass : dxil_class<"GroupId">; +def CBufferLoadClass : dxil_class<"CBufferLoad">; +def CreateHandleClass : dxil_class<"CreateHandle">; def binary_uint : dxil_category<"Binary uint">; def unary_float : dxil_category<"Unary float">; def ComputeID : dxil_category<"Compute/Mesh/Amplification shader">; - +def Resources : dxil_category<"Resources">; // The parameter description for a DXIL instruction class dxil_param ]>, dxil_map_intrinsic; + + +def CBufferLoad : dxil_op< "CBufferLoad", 58, CBufferLoadClass,Resources, "loads a value from a constant buffer resource", "half;float;double;i8;i16;i32;i64;", "ro", + [ + dxil_param<0, "$o", "", "the value for the constant buffer variable">, + dxil_param<1, "i32", "opcode", "DXIL opcode">, + dxil_param<2, "dx.types.Handle", "srv", "cbuffer handle">, + dxil_param<3, "i32", "byteOffset", "linear byte offset of value">, + dxil_param<4, "i32", "alignment", "load access alignment", 1> + ]>; + +def CreateHandle : dxil_op< "CreateHandle", 57, CreateHandleClass, Resources, "creates the handle to a resource", + "void;", "ro", + [ + dxil_param<0, "dx.types.Handle", "", "the handle to the resource">, + dxil_param<1, "i32", "opcode", "DXIL opcode">, + dxil_param<2, "i8", "resourceClass", "the class of resource to create (SRV, UAV, CBuffer, Sampler)", 1>, // maps to DxilResourceBase::Class + dxil_param<3, "i32", "rangeId", "range identifier for resource", 1>, + dxil_param<4, "i32", "index", "zero-based index into range">, + dxil_param<5, "i1", "nonUniformIndex", "non-uniform resource index", 1> + ]>; diff --git a/llvm/lib/Target/DirectX/DXILCBufferLowering.cpp b/llvm/lib/Target/DirectX/DXILCBufferLowering.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILCBufferLowering.cpp @@ -0,0 +1,186 @@ +//===- DXILCBufferLower.cpp - Lowering CBuffer to DIXL --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains passes to lower cbuffer load to DXIL. +//===----------------------------------------------------------------------===// + +#include "DXILConstants.h" +#include "DXILOpBuilder.h" +#include "DirectX.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/Utils/Local.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsDirectX.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "dxil-cbuf-lower" + +using namespace llvm; +using namespace llvm::dxil; + +namespace { +struct BufAccess { + Value *Index; // Address when access cbuf. + Instruction *User; // The instruction which access cbuf. +}; + +} // namespace + +static void collectBufUserAccess(User *U, Value *Addr, + std::vector &AccessList, + const DataLayout &DL) { + if (auto *GEP = dyn_cast(U)) { + // Calculate new Addr. + Value *NewAddr = Addr; + IRBuilder<> B(GEP->getContext()); + if (auto *Inst = dyn_cast(GEP)) + B.SetInsertPoint(Inst); + if (GEP->hasAllConstantIndices()) { + SmallVector IdxList(GEP->idx_begin(), GEP->idx_end()); + NewAddr = B.CreateAdd(Addr, B.getInt32(DL.getIndexedOffsetInType( + GEP->getSourceElementType(), IdxList))); + } else { + Value *Offset = EmitGEPOffset(&B, DL, GEP, /*NoAssumptions=*/true); + NewAddr = B.CreateAdd(Addr, Offset); + } + + for (User *GEPU : GEP->users()) + collectBufUserAccess(GEPU, NewAddr, AccessList, DL); + } else if (isa(U) || isa(U)) { + for (User *AU : U->users()) + collectBufUserAccess(AU, Addr, AccessList, DL); + } else if (auto *LI = dyn_cast(U)) { + BufAccess Access = {Addr, LI}; + AccessList.emplace_back(Access); + } else + llvm_unreachable("unsupported user"); +} + +static std::vector collectBufAccess(GlobalVariable *GV, + const DataLayout &DL) { + auto &Ctx = GV->getContext(); + Value *OffsetZero = ConstantInt::get(Type::getInt32Ty(Ctx), 0); + std::vector AccessList; + for (User *U : GV->users()) + collectBufUserAccess(U, OffsetZero, AccessList, DL); + return AccessList; +} + +static bool lowerCBufferAccess(Module &M) { + // FIXME: Allocate resource binding first. + const StringRef Name = "hlsl.cbufs"; + auto *ResTable = M.getNamedMetadata(Name); + if (!ResTable) + return false; + const DataLayout &DL = M.getDataLayout(); + + for (auto *Res : ResTable->operands()) { + assert(Res->getNumOperands() == 5 && "invalid resource metadata"); + auto *GVMD = cast(Res->getOperand(0).get()); + auto *GV = cast(GVMD->getValue()); + std::vector AccessList = collectBufAccess(GV, DL); + + SmallDenseMap HandleMap; + uint64_t RangeID = + mdconst::extract(Res->getOperand(2))->getLimitedValue(); + ConstantInt *CBIndex = mdconst::extract(Res->getOperand(3)); + assert(CBIndex->getLimitedValue() != -1ULL && "unallocated binding"); + for (auto &Access : AccessList) { + Value *Index = Access.Index; + Instruction *User = Access.User; + // make handle in the Function. + Function *F = User->getParent()->getParent(); + auto it = HandleMap.find(F); + CallInst *Hdl = nullptr; + if (it == HandleMap.end()) { + IRBuilder<> B(&*F->getEntryBlock().getFirstInsertionPt()); + DXILOpBuilder DXILB(M, B); + Hdl = DXILB.createCreateHandle( + static_cast(dxil::ResourceClass::CBuffer), RangeID, + CBIndex, false); + HandleMap[F] = Hdl; + } else + Hdl = it->second; + + LoadInst *LI = cast(User); + Type *Ty = LI->getType(); + Value *CBLd = nullptr; + + IRBuilder<> B(LI); + DXILOpBuilder DXILB(M, B); + if (Ty->isIntegerTy() || Ty->isFloatingPointTy()) { + CBLd = DXILB.createCBufferLoad(Ty, Hdl, Index, + DL.getPrefTypeAlign(Ty).value()); + } else if (isa(Ty)) { + // Only support fixed vectory type. + auto *VT = cast(Ty); + Value *Result = PoisonValue::get(VT); + Type *EltTy = VT->getElementType(); + uint64_t Align = DL.getPrefTypeAlign(EltTy).value(); + for (unsigned i = 0; i < VT->getNumElements(); ++i) { + Value *Offset = + B.CreateAdd(Index, B.getInt32(i * DL.getTypeAllocSize(EltTy))); + Value *Elt = DXILB.createCBufferLoad(EltTy, Hdl, Offset, Align); + Result = B.CreateInsertElement(Result, Elt, i); + } + CBLd = Result; + } else { + assert(0 && "invalid type for cbuffer load."); + } + LI->replaceAllUsesWith(CBLd); + LI->eraseFromParent(); + } + + return false; + } + return true; +} + +namespace { + +/// A pass that lowering cbuffer access into DXIL. +class DXILCBufLowering : public PassInfoMixin { +public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { + if (lowerCBufferAccess(M)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); + } +}; +} // namespace + +namespace { +class DXILCBufferLoweringLegacy : public ModulePass { +public: + bool runOnModule(Module &M) override { return lowerCBufferAccess(M); } + StringRef getPassName() const override { return "DXIL CBuffer lowering"; } + DXILCBufferLoweringLegacy() : ModulePass(ID) {} + + static char ID; // Pass identification. +}; +char DXILCBufferLoweringLegacy::ID = 0; + +} // end anonymous namespace + +INITIALIZE_PASS_BEGIN(DXILCBufferLoweringLegacy, DEBUG_TYPE, + "DXIL CBuffer lowering", false, false) +INITIALIZE_PASS_END(DXILCBufferLoweringLegacy, DEBUG_TYPE, + "DXIL CBuffer lowering", false, false) + +ModulePass *llvm::createDXILCBufferLoweringLegacyPass() { + return new DXILCBufferLoweringLegacy(); +} diff --git a/llvm/lib/Target/DirectX/DXILConstants.h b/llvm/lib/Target/DirectX/DXILConstants.h --- a/llvm/lib/Target/DirectX/DXILConstants.h +++ b/llvm/lib/Target/DirectX/DXILConstants.h @@ -15,6 +15,8 @@ namespace llvm { namespace dxil { +enum class ResourceClass { SRV = 0, UAV, CBuffer, Sampler, Invalid }; + #define DXIL_OP_ENUM #include "DXILOperation.inc" #undef DXIL_OP_ENUM diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h --- a/llvm/lib/Target/DirectX/DXILOpBuilder.h +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h @@ -35,6 +35,11 @@ bool NoOpCodeParam); static const char *getOpCodeName(dxil::OpCode DXILOp); + CallInst *createCreateHandle(int8_t ResClass, int RangeID, Value *Index, + bool NonUniformIndex); + CallInst *createCBufferLoad(Type *OverloadTy, Value *Hdl, Value *ByteOffset, + uint32_t Alignment); + private: Module &M; IRBuilderBase &B; diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -320,5 +320,24 @@ const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) { return ::getOpCodeName(DXILOp); } + +CallInst *DXILOpBuilder::createCreateHandle(int8_t ResClass, int RangeID, + Value *Index, + bool NonUniformIndex) { + auto Fn = + getOrCreateDXILOpFunction(dxil::OpCode::CreateHandle, B.getVoidTy(), M); + return B.CreateCall(Fn, {B.getInt32((int32_t)dxil::OpCode::CreateHandle), + B.getInt8(ResClass), B.getInt32(RangeID), Index, + B.getInt1(NonUniformIndex)}); +} + +CallInst *DXILOpBuilder::createCBufferLoad(Type *OverloadTy, Value *Hdl, + Value *ByteOffset, + uint32_t Alignment) { + auto Fn = getOrCreateDXILOpFunction(dxil::OpCode::CBufferLoad, OverloadTy, M); + return B.CreateCall(Fn, {B.getInt32((int32_t)dxil::OpCode::CBufferLoad), Hdl, + ByteOffset, B.getInt32(Alignment)}); +} + } // namespace dxil } // namespace llvm diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h --- a/llvm/lib/Target/DirectX/DirectX.h +++ b/llvm/lib/Target/DirectX/DirectX.h @@ -33,6 +33,12 @@ /// Pass to lowering LLVM intrinsic call to DXIL op function call. ModulePass *createDXILOpLoweringLegacyPass(); +/// Initializer for DXILCBufferLowering +void initializeDXILCBufferLoweringLegacyPass(PassRegistry &); + +/// Pass to lowering CBuffer to DXIL op function call. +ModulePass *createDXILCBufferLoweringLegacyPass(); + /// Initializer for DXILTranslateMetadata. void initializeDXILTranslateMetadataPass(PassRegistry &); diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -39,6 +39,7 @@ initializeEmbedDXILPassPass(*PR); initializeDXILOpLoweringLegacyPass(*PR); initializeDXILTranslateMetadataPass(*PR); + initializeDXILCBufferLoweringLegacyPass(*PR); } class DXILTargetObjectFile : public TargetLoweringObjectFile { @@ -68,6 +69,7 @@ FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; } void addCodeGenPrepare() override { + addPass(createDXILCBufferLoweringLegacyPass()); addPass(createDXILOpLoweringLegacyPass()); addPass(createDXILPrepareModulePass()); addPass(createDXILTranslateMetadataPass()); diff --git a/llvm/test/CodeGen/DirectX/cbuf.ll b/llvm/test/CodeGen/DirectX/cbuf.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/cbuf.ll @@ -0,0 +1,36 @@ +; RUN: opt -S -dxil-cbuf-lower < %s | FileCheck %s +target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" +target triple = "dxil-unknown-shadermodel6.7-library" + +; Make sure generate create handle. +; CHECK:%[[HDL:.+]] = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 2, i32 0, i32 2, i1 false) + +; Make sure load at offset 0/8/20 for @A.cb.float/double/<2 x i32>.y +; CHECK:call float @dx.op.cbufferLoad.f32(i32 58, %dx.types.Handle %[[HDL]], i32 0, i32 4) +; CHECK:call double @dx.op.cbufferLoad.f64(i32 58, %dx.types.Handle %[[HDL]], i32 8, i32 8) +; CHECK:call i32 @dx.op.cbufferLoad.i32(i32 58, %dx.types.Handle %[[HDL]], i32 20, i32 4) +@A.cb. = external constant { float, i32, double, <2 x i32> } + +; Function Attrs: noinline nounwind optnone +define noundef float @"?foo@@YAMXZ"() #0 { +entry: + %0 = load float, ptr @A.cb., align 4 + %conv = fpext float %0 to double + %1 = load double, ptr getelementptr inbounds ({ float, i32, double, <2 x i32> }, ptr @A.cb., i32 0, i32 2), align 8 + %2 = load <2 x i32>, ptr getelementptr inbounds ({ float, i32, double, <2 x i32> }, ptr @A.cb., i32 0, i32 3), align 8 + %3 = extractelement <2 x i32> %2, i32 1 + %conv1 = sitofp i32 %3 to double + %4 = call double @llvm.fmuladd.f64(double %1, double %conv1, double %conv) + %conv2 = fptrunc double %4 to float + ret float %conv2 +} + +; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn +declare double @llvm.fmuladd.f64(double, double, double) #1 + +attributes #0 = { noinline nounwind } +attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn } + +!hlsl.cbufs = !{!1} + +!1 = !{ptr @A.cb., !"A.cb.ty", i32 0, i32 2, i32 1}