diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -23,6 +23,7 @@ DXILOpLowering.cpp DXILPrepare.cpp DXILTranslateMetadata.cpp + DXILTypedBufferLowering.cpp MemAccessLowerHelper.cpp PointerTypeAnalysis.cpp diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -26,6 +26,8 @@ def ThreadIdInGroupClass : dxil_class<"ThreadIdInGroup">; def ThreadIdClass : dxil_class<"ThreadId">; def GroupIdClass : dxil_class<"GroupId">; +def BufferLoadClass : dxil_class<"BufferLoad">; +def BufferStoreClass : dxil_class<"BufferStore">; def CBufferLoadClass : dxil_class<"CBufferLoad">; def CreateHandleClass : dxil_class<"CreateHandle">; @@ -165,3 +167,26 @@ dxil_param<4, "i32", "index", "zero-based index into range">, dxil_param<5, "i1", "nonUniformIndex", "non-uniform resource index", 1> ]>; + +def BufferLoad : dxil_op< "BufferLoad", 68, BufferLoadClass,Resources, "reads from a TypedBuffer", "half;float;i16;i32;", "ro", + [ + dxil_param<0, "dx.types.ResRet", "", "the loaded value">, + dxil_param<1, "i32", "opcode", "DXIL opcode">, + dxil_param<2, "dx.types.Handle", "srv", "handle of TypedBuffer SRV to sample">, + dxil_param<3, "i32", "index", "element index">, + dxil_param<4, "i32", "offset", "Used for offset into element for StructuredBuffer in sm6.0/6.1. Always undef for ByteAddressBuffer/TypedBuffer. Always undef for shader model higher than 6.1"> + ]>; + +def BufferStore : dxil_op< "BufferStore", 69, BufferStoreClass,Resources, "writes to a RWTypedBuffer", "half;float;i16;i32;", "", + [ + dxil_param<0, "void", "", "">, + dxil_param<1, "i32", "opcode", "DXIL opcode">, + dxil_param<2, "dx.types.Handle", "uav", "handle of UAV to store to">, + dxil_param<3, "i32", "coord0", "coordinate in elements">, + dxil_param<4, "i32", "coord1", "coordinate (unused?>">, + dxil_param<5, "$o", "value0", "value">, + dxil_param<6, "$o", "value1", "value">, + dxil_param<7, "$o", "value2", "value">, + dxil_param<8, "$o", "value3", "value">, + dxil_param<9, "i8", "mask", "written value mask"> + ]>; diff --git a/llvm/lib/Target/DirectX/DXILCBufferLowering.cpp b/llvm/lib/Target/DirectX/DXILCBufferLowering.cpp --- a/llvm/lib/Target/DirectX/DXILCBufferLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILCBufferLowering.cpp @@ -46,7 +46,7 @@ auto *GVMD = cast(Res->getOperand(0).get()); auto *GV = cast(GVMD->getValue()); std::vector AccessList; - MemAccessLowerHelper::collectZeroOffsetMemAccess(GV, AccessList, DL); + MemAccessLowerHelper::collectMemAccess(GV, AccessList, DL); SmallDenseMap HandleMap; uint64_t RangeID = diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h --- a/llvm/lib/Target/DirectX/DXILOpBuilder.h +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h @@ -39,6 +39,10 @@ bool NonUniformIndex); CallInst *createCBufferLoad(Type *OverloadTy, Value *Hdl, Value *ByteOffset, uint32_t Alignment); + CallInst *createBufferLoad(Type *OverloadTy, Value *Hdl, Value *Index); + CallInst *createBufferStore(Type *OverloadTy, Value *Hdl, Value *Index, + Value *V0, Value *V1, Value *V2, Value *V3, + uint8_t Mask); private: Module &M; diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -11,6 +11,7 @@ #include "DXILOpBuilder.h" #include "DXILConstants.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Support/DXILOperationCommon.h" @@ -339,5 +340,21 @@ ByteOffset, B.getInt32(Alignment)}); } +CallInst *DXILOpBuilder::createBufferLoad(Type *OverloadTy, Value *Hdl, + Value *Index) { + auto Fn = getOrCreateDXILOpFunction(dxil::OpCode::BufferLoad, OverloadTy, M); + return B.CreateCall(Fn, {B.getInt32((int32_t)dxil::OpCode::BufferLoad), Hdl, + Index, PoisonValue::get(B.getInt32Ty())}); +} + +CallInst *DXILOpBuilder::createBufferStore(Type *OverloadTy, Value *Hdl, + Value *Index, Value *V0, Value *V1, + Value *V2, Value *V3, uint8_t Mask) { + auto Fn = getOrCreateDXILOpFunction(dxil::OpCode::BufferStore, OverloadTy, M); + return B.CreateCall(Fn, {B.getInt32((int32_t)dxil::OpCode::BufferStore), Hdl, + Index, PoisonValue::get(B.getInt32Ty()), V0, V1, V2, + V3, B.getInt8(Mask)}); +} + } // namespace dxil } // namespace llvm diff --git a/llvm/lib/Target/DirectX/DXILTypedBufferLowering.cpp b/llvm/lib/Target/DirectX/DXILTypedBufferLowering.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILTypedBufferLowering.cpp @@ -0,0 +1,177 @@ +//===- DXILCBufferLower.cpp - Lowering CBuffer to DIXL --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains passes to lower cbuffer load to DXIL. +//===----------------------------------------------------------------------===// + +#include "DXILConstants.h" +#include "DXILOpBuilder.h" +#include "DirectX.h" +#include "MemAccessLowerHelper.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/Utils/Local.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsDirectX.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "dxil-typedbuf-lower" + +using namespace llvm; +using namespace llvm::dxil; + +static bool lowerUAVAccess(Module &M) { + // FIXME: Allocate resource binding first. + const StringRef Name = "hlsl.uavs"; + auto *ResTable = M.getNamedMetadata(Name); + if (!ResTable) + return false; + const DataLayout &DL = M.getDataLayout(); + + for (auto *Res : ResTable->operands()) { + assert(Res->getNumOperands() == 5 && "invalid resource metadata"); + auto *GVMD = cast(Res->getOperand(0).get()); + auto *GV = cast(GVMD->getValue()); + std::vector AccessList; + for (User *U : make_early_inc_range(GV->users())) { + if (auto *SI = dyn_cast(U)) { + Value *V = SI->getValueOperand(); + // FIXME: use createHandle generated in clangCodeGen. + // See https://github.com/llvm/llvm-project/issues/58031. + SI->eraseFromParent(); + MemAccessLowerHelper::collectMemAccess(V, AccessList, DL); + } else if (auto *IntrinsicCI = dyn_cast(U)) { + if (IntrinsicCI->getIntrinsicID() == Intrinsic::invariant_start || + IntrinsicCI->getIntrinsicID() == Intrinsic::invariant_end) + IntrinsicCI->eraseFromParent(); + } else if (LoadInst *LI = dyn_cast(U)) + MemAccessLowerHelper::collectMemAccess(LI, AccessList, DL); + } + + SmallDenseMap HandleMap; + uint64_t RangeID = + mdconst::extract(Res->getOperand(2))->getLimitedValue(); + ConstantInt *UAVIndex = mdconst::extract(Res->getOperand(3)); + assert(UAVIndex->getLimitedValue() != -1ULL && "unallocated binding"); + for (auto &Access : AccessList) { + Value *Index = Access.Index; + Instruction *User = Access.User; + // make handle in the Function. + Function *F = User->getParent()->getParent(); + auto it = HandleMap.find(F); + CallInst *Hdl = nullptr; + if (it == HandleMap.end()) { + IRBuilder<> B(&*F->getEntryBlock().getFirstInsertionPt()); + DXILOpBuilder DXILB(M, B); + Hdl = DXILB.createCreateHandle( + static_cast(dxil::ResourceClass::UAV), RangeID, UAVIndex, + false); + HandleMap[F] = Hdl; + } else + Hdl = it->second; + + if (StoreInst *SI = dyn_cast(User)) { + Value *V = SI->getValueOperand(); + // Skip the store on hdl. + if (SI->getPointerOperand() == GV) + continue; + Type *Ty = V->getType(); + IRBuilder<> B(SI); + DXILOpBuilder DXILB(M, B); + Type *OverloadTy = Ty->isVectorTy() ? Ty->getContainedType(0) : Ty; + Value *UnusedV = PoisonValue::get(OverloadTy); + Value *Elts[4] = {UnusedV, UnusedV, UnusedV, UnusedV}; + uint8_t Mask = 0; + if (Ty->isIntegerTy() || Ty->isFloatingPointTy()) { + Elts[0] = V; + Mask = 1; + } else if (isa(Ty)) { + // Only support fixed vectory type. + auto *VT = cast(Ty); + for (unsigned I = 0; I < VT->getNumElements(); ++I) + Elts[I] = B.CreateExtractValue(V, I); + Mask = (1 << VT->getNumElements()) - 1; + } else { + assert(0 && "invalid type for buffer load."); + } + DXILB.createBufferStore(OverloadTy, Hdl, Index, Elts[0], Elts[1], + Elts[2], Elts[3], Mask); + SI->eraseFromParent(); + } else { + LoadInst *LI = cast(User); + Type *Ty = LI->getType(); + IRBuilder<> B(LI); + DXILOpBuilder DXILB(M, B); + Value *BufLd = DXILB.createBufferLoad( + Ty->isVectorTy() ? Ty->getContainedType(0) : Ty, Hdl, Index); + + if (Ty->isIntegerTy() || Ty->isFloatingPointTy()) { + BufLd = B.CreateExtractValue(BufLd, 0); + } else if (isa(Ty)) { + // Only support fixed vectory type. + auto *VT = cast(Ty); + Value *Result = PoisonValue::get(VT); + for (unsigned I = 0; I < VT->getNumElements(); ++I) { + Value *Elt = B.CreateExtractValue(BufLd, I); + Result = B.CreateInsertElement(Result, Elt, I); + } + BufLd = Result; + } else { + assert(0 && "invalid type for buffer load."); + } + LI->replaceAllUsesWith(BufLd); + LI->eraseFromParent(); + } + } + } + return true; +} + +namespace { + +/// A pass that lowering typed buffer access into DXIL. +class DXILTypedBufLowering : public PassInfoMixin { +public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { + if (lowerUAVAccess(M)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); + } +}; +} // namespace + +namespace { +class DXILTypedBufferLoweringLegacy : public ModulePass { +public: + bool runOnModule(Module &M) override { return lowerUAVAccess(M); } + StringRef getPassName() const override { return "DXIL TypedBuffer lowering"; } + DXILTypedBufferLoweringLegacy() : ModulePass(ID) {} + + static char ID; // Pass identification. +}; +char DXILTypedBufferLoweringLegacy::ID = 0; + +} // end anonymous namespace + +INITIALIZE_PASS_BEGIN(DXILTypedBufferLoweringLegacy, DEBUG_TYPE, + "DXIL TypedBuffer lowering", false, false) +INITIALIZE_PASS_END(DXILTypedBufferLoweringLegacy, DEBUG_TYPE, + "DXIL TypedBuffer lowering", false, false) + +ModulePass *llvm::createDXILTypedBufferLoweringLegacyPass() { + return new DXILTypedBufferLoweringLegacy(); +} diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h --- a/llvm/lib/Target/DirectX/DirectX.h +++ b/llvm/lib/Target/DirectX/DirectX.h @@ -39,6 +39,12 @@ /// Pass to lowering CBuffer to DXIL op function call. ModulePass *createDXILCBufferLoweringLegacyPass(); +/// Initializer for DXILTypedBufferLowering +void initializeDXILTypedBufferLoweringLegacyPass(PassRegistry &); + +/// Pass to lowering TypedBuffer to DXIL op function call. +ModulePass *createDXILTypedBufferLoweringLegacyPass(); + /// Initializer for DXILTranslateMetadata. void initializeDXILTranslateMetadataPass(PassRegistry &); diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -40,6 +40,7 @@ initializeDXILOpLoweringLegacyPass(*PR); initializeDXILTranslateMetadataPass(*PR); initializeDXILCBufferLoweringLegacyPass(*PR); + initializeDXILTypedBufferLoweringLegacyPass(*PR); } class DXILTargetObjectFile : public TargetLoweringObjectFile { @@ -70,6 +71,7 @@ FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; } void addCodeGenPrepare() override { addPass(createDXILCBufferLoweringLegacyPass()); + addPass(createDXILTypedBufferLoweringLegacyPass()); addPass(createDXILOpLoweringLegacyPass()); addPass(createDXILPrepareModulePass()); addPass(createDXILTranslateMetadataPass()); diff --git a/llvm/lib/Target/DirectX/MemAccessLowerHelper.h b/llvm/lib/Target/DirectX/MemAccessLowerHelper.h --- a/llvm/lib/Target/DirectX/MemAccessLowerHelper.h +++ b/llvm/lib/Target/DirectX/MemAccessLowerHelper.h @@ -27,16 +27,15 @@ namespace MemAccessLowerHelper { struct MemAccess { - Value *Index; // Address when access memory. - uint32_t Offset; // Immediate Offset when access memory for legacy cbuffer and - // structured buffer. + Value *Index; // The first Index of GEP when access memory. + Value *Offset; // Offset calculated from rest of GEP indices. Instruction *User; // The instruction which access memory like Load/Store. }; /// Collect all memory access for non-legacy cbuffer/ typed buffer global /// variable \c GV. -void collectZeroOffsetMemAccess(Value *Ptr, std::vector &AccessList, - const DataLayout &DL); +void collectMemAccess(Value *Ptr, std::vector &AccessList, + const DataLayout &DL); } // namespace MemAccessLowerHelper } // namespace dxil diff --git a/llvm/lib/Target/DirectX/MemAccessLowerHelper.cpp b/llvm/lib/Target/DirectX/MemAccessLowerHelper.cpp --- a/llvm/lib/Target/DirectX/MemAccessLowerHelper.cpp +++ b/llvm/lib/Target/DirectX/MemAccessLowerHelper.cpp @@ -19,43 +19,53 @@ using namespace llvm; using namespace llvm::dxil::MemAccessLowerHelper; -static void collectZeroOffsetAccess(User *U, Value *Addr, +static void collectZeroOffsetAccess(User *U, Value *Index, Value *Offset, std::vector &AccessList, const DataLayout &DL) { if (auto *GEP = dyn_cast(U)) { + Value *NewIndex = Index; + if (!NewIndex) + NewIndex = *GEP->idx_begin(); // Calculate new Addr. - Value *NewAddr = Addr; + Value *NewOffset = Offset; IRBuilder<> B(GEP->getContext()); if (auto *Inst = dyn_cast(GEP)) B.SetInsertPoint(Inst); + // Set first index to zero to calculate offset. + if (!Index) + GEP->setOperand(1, ConstantInt::get(NewIndex->getType(), 0)); if (GEP->hasAllConstantIndices()) { SmallVector IdxList(GEP->idx_begin(), GEP->idx_end()); - NewAddr = B.CreateAdd(Addr, B.getInt32(DL.getIndexedOffsetInType( - GEP->getSourceElementType(), IdxList))); + NewOffset = + B.CreateAdd(Offset, B.getInt32(DL.getIndexedOffsetInType( + GEP->getSourceElementType(), IdxList))); } else { - Value *Offset = EmitGEPOffset(&B, DL, GEP, /*NoAssumptions=*/true); - NewAddr = B.CreateAdd(Addr, Offset); + Value *GEPOffset = EmitGEPOffset(&B, DL, GEP, /*NoAssumptions=*/true); + NewOffset = B.CreateAdd(Offset, GEPOffset); } + // Recover first index. + if (!Index) + GEP->setOperand(1, NewIndex); for (User *GEPU : GEP->users()) - collectZeroOffsetAccess(GEPU, NewAddr, AccessList, DL); + collectZeroOffsetAccess(GEPU, NewIndex, NewOffset, AccessList, DL); } else if (isa(U) || isa(U)) { for (User *AU : U->users()) - collectZeroOffsetAccess(AU, Addr, AccessList, DL); + collectZeroOffsetAccess(AU, Index, Offset, AccessList, DL); } else if (auto *LI = dyn_cast(U)) { - MemAccess Access = {Addr, 0, LI}; + MemAccess Access = {Index, Offset, LI}; AccessList.emplace_back(Access); } else if (auto *SI = dyn_cast(U)) { - MemAccess Access = {Addr, 0, SI}; + MemAccess Access = {Index, Offset, SI}; AccessList.emplace_back(Access); } else llvm_unreachable("unsupported user"); } -void llvm::dxil::MemAccessLowerHelper::collectZeroOffsetMemAccess( +void llvm::dxil::MemAccessLowerHelper::collectMemAccess( Value *Ptr, std::vector &AccessList, const DataLayout &DL) { auto &Ctx = Ptr->getContext(); Value *OffsetZero = ConstantInt::get(Type::getInt32Ty(Ctx), 0); for (User *U : Ptr->users()) - collectZeroOffsetAccess(U, OffsetZero, AccessList, DL); + collectZeroOffsetAccess(U, nullptr, OffsetZero, AccessList, DL); } diff --git a/llvm/test/CodeGen/DirectX/buf_ld_st.ll b/llvm/test/CodeGen/DirectX/buf_ld_st.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/buf_ld_st.ll @@ -0,0 +1,54 @@ +; RUN: opt -S -dxil-typedbuf-lower < %s | FileCheck %s +target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" +target triple = "dxil-unknown-shadermodel6.7-compute" + +; Make sure generate create handle. +; CHECK-DAG:%[[HDL_IN:.+]] = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 2, i1 false) +; CHECK-DAG:%[[HDL_OUT:.+]] = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 1, i32 3, i1 false) + +; Make sure generate Out[TID] = In[TID]. +; CHECK:%[[TID:.+]] = tail call i32 @llvm.dx.flattened.thread.id.in.group() +; CHECK:%[[LD:.+]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle %[[HDL_IN]], i32 %[[TID]], i32 poison) +; CHECK:%[[LD_ELT:.+]] = extractvalue %dx.types.ResRet.f32 %[[LD]], 0 +; CHECK:call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %[[HDL_OUT]], i32 %[[TID]], i32 poison, float %[[LD_ELT]], float poison, float poison, float poison, i8 1) + +%"class.hlsl::RWBuffer" = type { ptr } + +@In = internal global %"class.hlsl::RWBuffer" zeroinitializer, align 4 +@"?Out@@3V?$RWBuffer@M@hlsl@@A" = local_unnamed_addr global %"class.hlsl::RWBuffer" zeroinitializer, align 4 + +; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn +declare ptr @llvm.invariant.start.p0(i64 immarg %0, ptr nocapture %1) #0 + +; Function Attrs: mustprogress norecurse nounwind willreturn +define void @main() local_unnamed_addr #1 { +entry: + %0 = tail call ptr @llvm.dx.create.handle(i8 1) + store ptr %0, ptr @In, align 4 + %1 = tail call ptr @llvm.invariant.start.p0(i64 4, ptr nonnull @In) + %2 = tail call ptr @llvm.dx.create.handle(i8 1) + store ptr %2, ptr @"?Out@@3V?$RWBuffer@M@hlsl@@A", align 4 + %3 = tail call i32 @llvm.dx.flattened.thread.id.in.group() + %4 = load ptr, ptr @In, align 4 + %arrayidx.i.i = getelementptr inbounds float, ptr %4, i32 %3 + %5 = load float, ptr %arrayidx.i.i, align 4 + %arrayidx.i3.i = getelementptr inbounds float, ptr %2, i32 %3 + store float %5, ptr %arrayidx.i3.i, align 4 + ret void +} + +; Function Attrs: mustprogress nofree nosync nounwind readnone willreturn +declare i32 @llvm.dx.flattened.thread.id.in.group() #2 + +; Function Attrs: mustprogress nounwind willreturn +declare ptr @llvm.dx.create.handle(i8 %0) #3 + +attributes #0 = { argmemonly mustprogress nocallback nofree nosync nounwind willreturn } +attributes #1 = { mustprogress norecurse nounwind willreturn "frame-pointer"="all" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #2 = { mustprogress nofree nosync nounwind readnone willreturn } +attributes #3 = { mustprogress nounwind willreturn } + +!hlsl.uavs = !{!0, !1} + +!0 = !{ptr @In, !"RWBuffer", i32 0, i32 2, i32 0} +!1 = !{ptr @"?Out@@3V?$RWBuffer@M@hlsl@@A", !"RWBuffer", i32 1, i32 3, i32 0} \ No newline at end of file