diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -17,6 +17,7 @@ DirectXRegisterInfo.cpp DirectXSubtarget.cpp DirectXTargetMachine.cpp + DXILOpBuilder.cpp DXILOpLowering.cpp DXILPointerType.cpp DXILPrepare.cpp diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -26,11 +26,14 @@ def ThreadIdInGroupClass : dxil_class<"ThreadIdInGroup">; def ThreadIdClass : dxil_class<"ThreadId">; def GroupIdClass : dxil_class<"GroupId">; +def BufferLoadClass : dxil_class<"BufferLoad">; +def BufferStoreClass : dxil_class<"BufferStore">; +def CreateHandleClass : dxil_class<"CreateHandle">; def binary_uint : dxil_category<"Binary uint">; def unary_float : dxil_category<"Unary float">; def ComputeID : dxil_category<"Compute/Mesh/Amplification shader">; - +def Resources : dxil_category<"Resources">; // The parameter description for a DXIL instruction class dxil_param ]>, dxil_map_intrinsic; + +def BufferLoad : dxil_op< "BufferLoad", 68, BufferLoadClass,Resources, "reads from a TypedBuffer", "half;float;i16;i32;", "ro", + [ + dxil_param<0, "dx.types.ResRet", "", "the loaded value">, + dxil_param<1, "i32", "opcode", "DXIL opcode">, + dxil_param<2, "dx.types.Handle", "srv", "handle of TypedBuffer SRV to sample">, + dxil_param<3, "i32", "index", "element index">, + dxil_param<4, "i32", "wot", "coordinate"> + ], + ["tex_load"]>; + +def BufferStore : dxil_op< "BufferStore", 69, BufferStoreClass, Resources, "writes to a RWTypedBuffer", "half;float;i16;i32;", "", + [ + dxil_param<0, "void", "", "">, + dxil_param<1, "i32", "opcode", "DXIL opcode">, + dxil_param<2, "dx.types.Handle", "uav", "handle of UAV to store to">, + dxil_param<3, "i32", "coord0", "coordinate in elements">, + dxil_param<4, "i32", "coord1", "coordinate (unused?)">, + dxil_param<5, "$o", "value0", "value">, + dxil_param<6, "$o", "value1", "value">, + dxil_param<7, "$o", "value2", "value">, + dxil_param<8, "$o", "value3", "value">, + dxil_param<9, "i8", "mask", "written value mask"> + ], + ["tex_store"]>; + +def CreateHandle : dxil_op< "CreateHandle", 57, CreateHandleClass, Resources, "creates the handle to a resource", + "void;", "ro", + [ + dxil_param<0, "dx.types.Handle", "", "the handle to the resource">, + dxil_param<1, "i32", "opcode", "DXIL opcode">, + dxil_param<2, "i8", "resourceClass", "the class of resource to create (SRV, UAV, CBuffer, Sampler)", 1>, // maps to DxilResourceBase::Class + dxil_param<3, "i32", "rangeId", "range identifier for resource", 1>, + dxil_param<4, "i32", "index", "zero-based index into range">, + dxil_param<5, "i1", "nonUniformIndex", "non-uniform resource index", 1> + ]>; diff --git a/llvm/lib/Target/DirectX/DXILConstants.h b/llvm/lib/Target/DirectX/DXILConstants.h --- a/llvm/lib/Target/DirectX/DXILConstants.h +++ b/llvm/lib/Target/DirectX/DXILConstants.h @@ -15,6 +15,8 @@ namespace llvm { namespace DXIL { +enum class ResourceClass { SRV = 0, UAV, CBuffer, Sampler, Invalid }; + #define DXIL_OP_ENUM #include "DXILOperation.inc" #undef DXIL_OP_ENUM diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h @@ -0,0 +1,53 @@ +//===- DXILOpBuilder.h - Helper class for build DIXLOp functions ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains class to help build DXIL op functions. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H +#define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H + +#include "DXILConstants.h" +#include "llvm/ADT/iterator_range.h" + +namespace llvm { +class Module; +class IRBuilderBase; +class CallInst; +class Value; +class Type; +class FunctionType; +class Use; + +namespace DXIL { + +class DXILOpBuilder { +public: + DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {} + CallInst *createCreateHandle(Type *OverloadTy, int8_t ResClass, int RangeID, + Value *Index, bool NonUniformIndex); + CallInst *createBufferLoad(Type *OverloadTy, Value *Hdl, Value *Coord0, + Value *Coord1); + CallInst *createBufferStore(Type *OverloadTy, Value *Hdl, Value *Coord0, + Value *Coord1, Value *V0, Value *V1, Value *V2, + Value *V3, Value *Mask); + CallInst *createDXILOpCall(DXIL::OpCode OpCode, Type *OverloadTy, + llvm::iterator_range Args); + Type *getOverloadTy(DXIL::OpCode OpCode, FunctionType *FT, + bool NoOpCodeParam); + static const char *getOpCodeName(DXIL::OpCode DXILOp); + +private: + Module &M; + IRBuilderBase &B; +}; + +} // namespace DXIL +} // namespace llvm + +#endif diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -0,0 +1,364 @@ +//===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains class to help build DXIL op functions. +//===----------------------------------------------------------------------===// + +#include "DXILOpBuilder.h" +#include "DXILConstants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace llvm; +using namespace llvm::DXIL; + +constexpr StringLiteral DXILOpNamePrefix = "dx.op."; + +namespace { + +enum OverloadKind : uint16_t { + VOID = 1, + HALF = 1 << 1, + FLOAT = 1 << 2, + DOUBLE = 1 << 3, + I1 = 1 << 4, + I8 = 1 << 5, + I16 = 1 << 6, + I32 = 1 << 7, + I64 = 1 << 8, + UserDefineType = 1 << 9, + ObjectType = 1 << 10, +}; +// NOTE: keep sync with ParameterKind in utils/TableGen/DXILEmitter.cpp. +enum class ParameterKind : uint8_t { + INVALID = 0, + VOID, + HALF, + FLOAT, + DOUBLE, + I1, + I8, + I16, + I32, + I64, + OVERLOAD, + CBUFFER_RET, + RESOURCE_RET, + DXIL_HANDLE, +}; +} // namespace + +static const char *getOverloadTypeName(OverloadKind Kind) { + switch (Kind) { + case OverloadKind::HALF: + return "f16"; + case OverloadKind::FLOAT: + return "f32"; + case OverloadKind::DOUBLE: + return "f64"; + case OverloadKind::I1: + return "i1"; + case OverloadKind::I8: + return "i8"; + case OverloadKind::I16: + return "i16"; + case OverloadKind::I32: + return "i32"; + case OverloadKind::I64: + return "i64"; + case OverloadKind::VOID: + case OverloadKind::ObjectType: + case OverloadKind::UserDefineType: + break; + } + llvm_unreachable("invalid overload type for name"); + return "void"; +} + +static OverloadKind getOverloadKind(Type *Ty) { + Type::TypeID T = Ty->getTypeID(); + switch (T) { + case Type::VoidTyID: + return OverloadKind::VOID; + case Type::HalfTyID: + return OverloadKind::HALF; + case Type::FloatTyID: + return OverloadKind::FLOAT; + case Type::DoubleTyID: + return OverloadKind::DOUBLE; + case Type::IntegerTyID: { + IntegerType *ITy = cast(Ty); + unsigned Bits = ITy->getBitWidth(); + switch (Bits) { + case 1: + return OverloadKind::I1; + case 8: + return OverloadKind::I8; + case 16: + return OverloadKind::I16; + case 32: + return OverloadKind::I32; + case 64: + return OverloadKind::I64; + default: + llvm_unreachable("invalid overload type"); + return OverloadKind::VOID; + } + } + case Type::PointerTyID: + return OverloadKind::UserDefineType; + case Type::StructTyID: + return OverloadKind::ObjectType; + default: + llvm_unreachable("invalid overload type"); + return OverloadKind::VOID; + } +} + +static std::string getTypeName(OverloadKind Kind, Type *Ty) { + if (Kind < OverloadKind::UserDefineType) { + return getOverloadTypeName(Kind); + } else if (Kind == OverloadKind::UserDefineType) { + StructType *ST = cast(Ty); + return ST->getStructName().str(); + } else if (Kind == OverloadKind::ObjectType) { + StructType *ST = cast(Ty); + return ST->getStructName().str(); + } else { + std::string Str; + raw_string_ostream OS(Str); + Ty->print(OS); + return OS.str(); + } +} + +// Static properties. +struct OpCodeProperty { + DXIL::OpCode OpCode; + // Offset in DXILOpCodeNameTable. + unsigned OpCodeNameOffset; + DXIL::OpCodeClass OpCodeClass; + // Offset in DXILOpCodeClassNameTable. + unsigned OpCodeClassNameOffset; + uint16_t OverloadTys; + llvm::Attribute::AttrKind FuncAttr; + int OverloadParamIndex; // parameter index which control the overload. + // When < 0, should be only 1 overload type. + unsigned NumOfParameters; // Number of parameters include return value. + unsigned ParameterTableOffset; // Offset in ParameterTable. +}; + +// Include getOpCodeClassName getOpCodeProperty, getOpCodeName and +// getOpCodeParameterKind which generated by tableGen. +#define DXIL_OP_OPERATION_TABLE +#include "DXILOperation.inc" +#undef DXIL_OP_OPERATION_TABLE + +static std::string constructOverloadName(OverloadKind Kind, Type *Ty, + const OpCodeProperty &Prop) { + if (Kind == OverloadKind::VOID) { + return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); + } + return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + + getTypeName(Kind, Ty)) + .str(); +} + +static std::string constructOverloadTypeName(OverloadKind Kind, + StringRef TypeName) { + if (Kind == OverloadKind::VOID) + return TypeName.str(); + + assert(Kind < OverloadKind::UserDefineType && "invalid overload kind"); + return (Twine(TypeName) + getOverloadTypeName(Kind)).str(); +} + +static StructType *getOrCreateStructType(StringRef Name, + ArrayRef EltTys, + LLVMContext &Ctx) { + StructType *ST = StructType::getTypeByName(Ctx, Name); + if (ST) + return ST; + + return StructType::create(Ctx, EltTys, Name); +} + +static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { + OverloadKind Kind = getOverloadKind(OverloadTy); + std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); + Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, + Type::getInt32Ty(Ctx)}; + return getOrCreateStructType(TypeName, FieldTypes, Ctx); +} + +static StructType *getHandleType(LLVMContext &Ctx) { + return getOrCreateStructType("dx.types.Handle", Type::getInt8PtrTy(Ctx), Ctx); +} + +static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) { + auto &Ctx = OverloadTy->getContext(); + switch (Kind) { + case ParameterKind::VOID: + return Type::getVoidTy(Ctx); + case ParameterKind::HALF: + return Type::getHalfTy(Ctx); + case ParameterKind::FLOAT: + return Type::getFloatTy(Ctx); + case ParameterKind::DOUBLE: + return Type::getDoubleTy(Ctx); + case ParameterKind::I1: + return Type::getInt1Ty(Ctx); + case ParameterKind::I8: + return Type::getInt8Ty(Ctx); + case ParameterKind::I16: + return Type::getInt16Ty(Ctx); + case ParameterKind::I32: + return Type::getInt32Ty(Ctx); + case ParameterKind::I64: + return Type::getInt64Ty(Ctx); + case ParameterKind::OVERLOAD: + return OverloadTy; + case ParameterKind::RESOURCE_RET: + return getResRetType(OverloadTy, Ctx); + case ParameterKind::DXIL_HANDLE: + return getHandleType(Ctx); + default: + break; + } + llvm_unreachable("Invalid parameter kind"); + return nullptr; +} + +static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, + Type *OverloadTy) { + SmallVector ArgTys; + + auto ParamKinds = getOpCodeParameterKind(*Prop); + + for (unsigned I = 0; I < Prop->NumOfParameters; ++I) { + ParameterKind Kind = ParamKinds[I]; + ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy)); + } + return FunctionType::get( + ArgTys[0], ArrayRef(&ArgTys[1], ArgTys.size() - 1), false); +} + +static FunctionCallee getOrCreateDXILOpFunction(DXIL::OpCode DXILOp, + Type *OverloadTy, Module &M) { + const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); + + OverloadKind Kind = getOverloadKind(OverloadTy); + // FIXME: find the issue and report error in clang instead of check it in + // backend. + if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { + llvm_unreachable("invalid overload"); + } + + std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); + // Dependent on name to dedup. + if (auto *Fn = M.getFunction(FnName)) + return FunctionCallee(Fn); + + FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy); + return M.getOrInsertFunction(FnName, DXILOpFT); +} + +namespace llvm { +namespace DXIL { + +CallInst *DXILOpBuilder::createCreateHandle(Type *OverloadTy, int8_t ResClass, + int RangeID, Value *Index, + bool NonUniformIndex) { + auto Fn = + getOrCreateDXILOpFunction(DXIL::OpCode::CreateHandle, OverloadTy, M); + return B.CreateCall(Fn, {B.getInt32((int32_t)DXIL::OpCode::CreateHandle), + B.getInt8(ResClass), B.getInt32(RangeID), Index, + B.getInt1(NonUniformIndex)}); +} + +CallInst *DXILOpBuilder::createBufferLoad(Type *OverloadTy, Value *Hdl, + Value *Coord0, Value *Coord1) { + auto Fn = getOrCreateDXILOpFunction(DXIL::OpCode::BufferLoad, OverloadTy, M); + return B.CreateCall( + Fn, {B.getInt32((int32_t)DXIL::OpCode::BufferLoad), Hdl, Coord0, Coord1}); +} +CallInst *DXILOpBuilder::createBufferStore(Type *OverloadTy, Value *Hdl, + Value *Coord0, Value *Coord1, + Value *V0, Value *V1, Value *V2, + Value *V3, Value *Mask) { + auto Fn = getOrCreateDXILOpFunction(DXIL::OpCode::BufferStore, OverloadTy, M); + return B.CreateCall(Fn, {B.getInt32((int32_t)DXIL::OpCode::BufferStore), Hdl, + Coord0, Coord1, V0, V1, V2, V3, Mask}); +} + +CallInst *DXILOpBuilder::createDXILOpCall(DXIL::OpCode OpCode, Type *OverloadTy, + llvm::iterator_range Args) { + auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M); + SmallVector FullArgs; + FullArgs.emplace_back(B.getInt32((int32_t)OpCode)); + FullArgs.append(Args.begin(), Args.end()); + return B.CreateCall(Fn, FullArgs); +} + +Type *DXILOpBuilder::getOverloadTy(DXIL::OpCode OpCode, FunctionType *FT, + bool NoOpCodeParam) { + + const OpCodeProperty *Prop = getOpCodeProperty(OpCode); + if (Prop->OverloadParamIndex < 0) { + auto &Ctx = FT->getContext(); + // When only has 1 overload type, just return it. + switch (Prop->OverloadTys) { + case OverloadKind::VOID: + return Type::getVoidTy(Ctx); + case OverloadKind::HALF: + return Type::getHalfTy(Ctx); + case OverloadKind::FLOAT: + return Type::getFloatTy(Ctx); + case OverloadKind::DOUBLE: + return Type::getDoubleTy(Ctx); + case OverloadKind::I1: + return Type::getInt1Ty(Ctx); + case OverloadKind::I8: + return Type::getInt8Ty(Ctx); + case OverloadKind::I16: + return Type::getInt16Ty(Ctx); + case OverloadKind::I32: + return Type::getInt32Ty(Ctx); + case OverloadKind::I64: + return Type::getInt64Ty(Ctx); + default: + llvm_unreachable("invalid overload type"); + return nullptr; + } + } + + // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType(). + Type *OverloadType = FT->getReturnType(); + if (Prop->OverloadParamIndex != 0) { + // Skip Return Type and Type for DXIL opcode. + const unsigned SkipedParam = NoOpCodeParam ? 2 : 1; + OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam); + } + + auto ParamKinds = getOpCodeParameterKind(*Prop); + auto Kind = ParamKinds[Prop->OverloadParamIndex]; + // For ResRet and CBufferRet, OverloadTy is in field of StructType. + if (Kind == ParameterKind::CBUFFER_RET || + Kind == ParameterKind::RESOURCE_RET) { + auto *ST = cast(OverloadType); + OverloadType = ST->getElementType(0); + } + return OverloadType; +} + +const char *DXILOpBuilder::getOpCodeName(DXIL::OpCode DXILOp) { + return ::getOpCodeName(DXILOp); +} +} // namespace DXIL +} // namespace llvm diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "DXILConstants.h" +#include "DXILOpBuilder.h" #include "DirectX.h" #include "llvm/ADT/SmallVector.h" #include "llvm/CodeGen/Passes.h" @@ -28,168 +29,12 @@ using namespace llvm; using namespace llvm::DXIL; -constexpr StringLiteral DXILOpNamePrefix = "dx.op."; - -enum OverloadKind : uint16_t { - VOID = 1, - HALF = 1 << 1, - FLOAT = 1 << 2, - DOUBLE = 1 << 3, - I1 = 1 << 4, - I8 = 1 << 5, - I16 = 1 << 6, - I32 = 1 << 7, - I64 = 1 << 8, - UserDefineType = 1 << 9, - ObjectType = 1 << 10, -}; - -static const char *getOverloadTypeName(OverloadKind Kind) { - switch (Kind) { - case OverloadKind::HALF: - return "f16"; - case OverloadKind::FLOAT: - return "f32"; - case OverloadKind::DOUBLE: - return "f64"; - case OverloadKind::I1: - return "i1"; - case OverloadKind::I8: - return "i8"; - case OverloadKind::I16: - return "i16"; - case OverloadKind::I32: - return "i32"; - case OverloadKind::I64: - return "i64"; - case OverloadKind::VOID: - case OverloadKind::ObjectType: - case OverloadKind::UserDefineType: - break; - } - llvm_unreachable("invalid overload type for name"); - return "void"; -} - -static OverloadKind getOverloadKind(Type *Ty) { - Type::TypeID T = Ty->getTypeID(); - switch (T) { - case Type::VoidTyID: - return OverloadKind::VOID; - case Type::HalfTyID: - return OverloadKind::HALF; - case Type::FloatTyID: - return OverloadKind::FLOAT; - case Type::DoubleTyID: - return OverloadKind::DOUBLE; - case Type::IntegerTyID: { - IntegerType *ITy = cast(Ty); - unsigned Bits = ITy->getBitWidth(); - switch (Bits) { - case 1: - return OverloadKind::I1; - case 8: - return OverloadKind::I8; - case 16: - return OverloadKind::I16; - case 32: - return OverloadKind::I32; - case 64: - return OverloadKind::I64; - default: - llvm_unreachable("invalid overload type"); - return OverloadKind::VOID; - } - } - case Type::PointerTyID: - return OverloadKind::UserDefineType; - case Type::StructTyID: - return OverloadKind::ObjectType; - default: - llvm_unreachable("invalid overload type"); - return OverloadKind::VOID; - } -} - -static std::string getTypeName(OverloadKind Kind, Type *Ty) { - if (Kind < OverloadKind::UserDefineType) { - return getOverloadTypeName(Kind); - } else if (Kind == OverloadKind::UserDefineType) { - StructType *ST = cast(Ty); - return ST->getStructName().str(); - } else if (Kind == OverloadKind::ObjectType) { - StructType *ST = cast(Ty); - return ST->getStructName().str(); - } else { - std::string Str; - raw_string_ostream OS(Str); - Ty->print(OS); - return OS.str(); - } -} - -// Static properties. -struct OpCodeProperty { - DXIL::OpCode OpCode; - // Offset in DXILOpCodeNameTable. - unsigned OpCodeNameOffset; - DXIL::OpCodeClass OpCodeClass; - // Offset in DXILOpCodeClassNameTable. - unsigned OpCodeClassNameOffset; - uint16_t OverloadTys; - llvm::Attribute::AttrKind FuncAttr; -}; - -// Include getOpCodeClassName getOpCodeProperty and getOpCodeName which -// generated by tableGen. -#define DXIL_OP_OPERATION_TABLE -#include "DXILOperation.inc" -#undef DXIL_OP_OPERATION_TABLE - -static std::string constructOverloadName(OverloadKind Kind, Type *Ty, - const OpCodeProperty &Prop) { - if (Kind == OverloadKind::VOID) { - return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); - } - return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + - getTypeName(Kind, Ty)) - .str(); -} - -static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F, - Module &M) { - const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); - - // Get return type as overload type for DXILOp. - // Only simple mapping case here, so return type is good enough. - Type *OverloadTy = F.getReturnType(); - - OverloadKind Kind = getOverloadKind(OverloadTy); - // FIXME: find the issue and report error in clang instead of check it in - // backend. - if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { - llvm_unreachable("invalid overload"); - } - - std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); - assert(!M.getFunction(FnName) && "Function already exists"); - - auto &Ctx = M.getContext(); - Type *OpCodeTy = Type::getInt32Ty(Ctx); - - SmallVector ArgTypes; - // DXIL has i32 opcode as first arg. - ArgTypes.emplace_back(OpCodeTy); - FunctionType *FT = F.getFunctionType(); - ArgTypes.append(FT->param_begin(), FT->param_end()); - FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false); - return M.getOrInsertFunction(FnName, DXILOpFT); -} - static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) { - auto DXILOpFn = createDXILOpFunction(DXILOp, F, M); IRBuilder<> B(M.getContext()); Value *DXILOpArg = B.getInt32(static_cast(DXILOp)); + DXILOpBuilder DXILB(M, B); + Type *OverloadTy = + DXILB.getOverloadTy(DXILOp, F.getFunctionType(), /*NoOpCodeParam*/ true); for (User *U : make_early_inc_range(F.users())) { CallInst *CI = dyn_cast(U); if (!CI) @@ -199,8 +44,8 @@ Args.emplace_back(DXILOpArg); Args.append(CI->arg_begin(), CI->arg_end()); B.SetInsertPoint(CI); - CallInst *DXILCI = B.CreateCall(DXILOpFn, Args); - LLVM_DEBUG(DXILCI->setName(getOpCodeName(DXILOp))); + CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, OverloadTy, CI->args()); + CI->replaceAllUsesWith(DXILCI); CI->eraseFromParent(); } diff --git a/llvm/unittests/Target/DirectX/CMakeLists.txt b/llvm/unittests/Target/DirectX/CMakeLists.txt --- a/llvm/unittests/Target/DirectX/CMakeLists.txt +++ b/llvm/unittests/Target/DirectX/CMakeLists.txt @@ -11,5 +11,6 @@ ) add_llvm_target_unittest(DirectXTests + DXILOpBuilderTest.cpp PointerTypeAnalysisTests.cpp ) diff --git a/llvm/unittests/Target/DirectX/DXILOpBuilderTest.cpp b/llvm/unittests/Target/DirectX/DXILOpBuilderTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Target/DirectX/DXILOpBuilderTest.cpp @@ -0,0 +1,144 @@ +//===- llvm/unittest/Target/DirectX/DXILOpBuilderTest.cpp -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "DXILOpBuilder.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsDirectX.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::DXIL; +using ::testing::UnorderedElementsAre; + +namespace { + +class DXILOpBuilderTest : public testing::Test { +protected: + void SetUp() override { + M.reset(new Module("MyModule", Ctx)); + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), + /*isVarArg=*/false); + F = Function::Create(FTy, Function::ExternalLinkage, "", M.get()); + BB = BasicBlock::Create(Ctx, "", F); + GV = new GlobalVariable(*M, Type::getFloatTy(Ctx), true, + GlobalValue::ExternalLinkage, nullptr); + } + + void TearDown() override { + BB = nullptr; + M.reset(); + } + + LLVMContext Ctx; + std::unique_ptr M; + Function *F; + BasicBlock *BB; + GlobalVariable *GV; +}; + +TEST_F(DXILOpBuilderTest, createDXILOpCall) { + IRBuilder<> Builder(BB); + Value *V; + Instruction *I; + CallInst *Call; + V = Builder.CreateLoad(GV->getValueType(), GV); + I = cast(Builder.CreateFAdd(V, V)); + Call = Builder.CreateUnaryIntrinsic(Intrinsic::sin, V, I); + + DXILOpBuilder B(*M, Builder); + auto *Sin = B.createDXILOpCall(DXIL::OpCode::Sin, Type::getFloatTy(Ctx), + Call->args()); + auto *SinFn = cast(Sin->getCalledFunction()); + EXPECT_STREQ(SinFn->getName().str().data(), "dx.op.unary.f32"); + auto *FT = SinFn->getFunctionType(); + EXPECT_EQ(FT->getNumParams(), (unsigned)2); + EXPECT_TRUE(FT->getReturnType()->isFloatTy()); + EXPECT_TRUE(FT->getParamType(0)->isIntegerTy(32)); + EXPECT_TRUE(FT->getParamType(1)->isFloatTy()); +} + +TEST_F(DXILOpBuilderTest, createBufferLoad) { + IRBuilder<> Builder(BB); + + DXILOpBuilder B(*M, Builder); + CallInst *Hdl = B.createCreateHandle(Type::getVoidTy(Ctx), + (int8_t)DXIL::ResourceClass::SRV, 0, + Builder.getInt32(0), + /*NonUniformIndex*/ false); + CallInst *BufLoad = + B.createBufferLoad(Type::getInt32Ty(Ctx), Hdl, Builder.getInt32(0), + PoisonValue::get(Builder.getInt32Ty())); + Type *HdlTy = Hdl->getType(); + EXPECT_TRUE(HdlTy->isStructTy()); + if (StructType *ST = dyn_cast(HdlTy)) + EXPECT_STREQ(ST->getName().str().data(), "dx.types.Handle"); + EXPECT_STREQ(Hdl->getCalledFunction()->getName().str().data(), + "dx.op.createHandle"); + + Type *ResRetTy = BufLoad->getType(); + EXPECT_TRUE(ResRetTy->isStructTy()); + if (StructType *ST = dyn_cast(ResRetTy)) + EXPECT_STREQ(ST->getName().str().data(), "dx.types.ResRet.i32"); + EXPECT_STREQ(BufLoad->getCalledFunction()->getName().str().data(), + "dx.op.bufferLoad.i32"); +} + +TEST_F(DXILOpBuilderTest, createBufferStore) { + IRBuilder<> Builder(BB); + + DXILOpBuilder B(*M, Builder); + CallInst *Hdl = B.createCreateHandle(Type::getVoidTy(Ctx), + (int8_t)DXIL::ResourceClass::UAV, 0, + Builder.getInt32(0), + /*NonUniformIndex*/ false); + Value *V = Builder.CreateLoad(GV->getValueType(), GV); + auto *UndefV = PoisonValue::get(V->getType()); + + CallInst *BufStore = + B.createBufferStore(Type::getFloatTy(Ctx), Hdl, Builder.getInt32(0), + PoisonValue::get(Builder.getInt32Ty()), V, UndefV, + UndefV, UndefV, Builder.getInt8(1)); + + Type *HdlTy = Hdl->getType(); + EXPECT_TRUE(HdlTy->isStructTy()); + if (StructType *ST = dyn_cast(HdlTy)) + EXPECT_STREQ(ST->getName().str().data(), "dx.types.Handle"); + EXPECT_STREQ(Hdl->getCalledFunction()->getName().str().data(), + "dx.op.createHandle"); + + EXPECT_TRUE(BufStore->getType()->isVoidTy()); + EXPECT_STREQ(BufStore->getCalledFunction()->getName().str().data(), + "dx.op.bufferStore.f32"); +} + +TEST_F(DXILOpBuilderTest, getOverloadTy) { + IRBuilder<> Builder(BB); + Function *ThreadID = + Intrinsic::getDeclaration(M.get(), Intrinsic::dxil_thread_id); + DXILOpBuilder B(*M, Builder); + Type *OverloadTy = + B.getOverloadTy(DXIL::OpCode::ThreadId, ThreadID->getFunctionType(), + /*NoOpCodeParam*/ false); + EXPECT_TRUE(OverloadTy->isIntegerTy(32)); +} + +TEST_F(DXILOpBuilderTest, getOpCodeName) { + IRBuilder<> Builder(BB); + DXILOpBuilder B(*M, Builder); + auto UMax = B.getOpCodeName(DXIL::OpCode::UMax); + EXPECT_STREQ(UMax, "UMax"); +} + +} // namespace diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp --- a/llvm/utils/TableGen/DXILEmitter.cpp +++ b/llvm/utils/TableGen/DXILEmitter.cpp @@ -27,25 +27,33 @@ int Major; int Minor; }; + +enum class ParameterKind : uint8_t { + INVALID = 0, + VOID, + HALF, + FLOAT, + DOUBLE, + I1, + I8, + I16, + I32, + I64, + OVERLOAD, + CBUFFER_RET, + RESOURCE_RET, + DXIL_HANDLE, +}; + struct DXILParam { - int Pos; // position in parameter list - StringRef Type; // llvm type name, $o for overload, $r for resource - // type, $cb for legacy cbuffer, $u4 for u4 struct + int Pos; // position in parameter list + ParameterKind Kind; StringRef Name; // short, unique name StringRef Doc; // the documentation description of this parameter bool IsConst; // whether this argument requires a constant value in the IR StringRef EnumName; // the name of the enum type if applicable int MaxValue; // the maximum value for this parameter if applicable - DXILParam(const Record *R) { - Name = R->getValueAsString("name"); - Pos = R->getValueAsInt("pos"); - Type = R->getValueAsString("llvm_type"); - if (R->getValue("doc")) - Doc = R->getValueAsString("doc"); - IsConst = R->getValueAsBit("is_const"); - EnumName = R->getValueAsString("enum_name"); - MaxValue = R->getValueAsInt("max_value"); - } + DXILParam(const Record *R); }; struct DXILOperationData { @@ -74,7 +82,9 @@ DXILShaderModel ShaderModel; // minimum shader model required DXILShaderModel ShaderModelTranslated; // minimum shader model required with // translation by linker - SmallVector counters; // counters for this inst. + int OverloadParamIndex; // parameter index which control the overload. + // When < 0, should be only 1 overload type. + SmallVector counters; // counters for this inst. DXILOperationData(const Record *R) { Name = R->getValueAsString("name"); DXILOp = R->getValueAsString("dxil_op"); @@ -93,9 +103,13 @@ Doc = R->getValueAsString("doc"); ListInit *ParamList = R->getValueAsListInit("ops"); - for (unsigned i = 0; i < ParamList->size(); ++i) { - Record *Param = ParamList->getElementAsRecord(i); + OverloadParamIndex = -1; + for (unsigned I = 0; I < ParamList->size(); ++I) { + Record *Param = ParamList->getElementAsRecord(I); Params.emplace_back(DXILParam(Param)); + auto &CurParam = Params.back(); + if (CurParam.Kind >= ParameterKind::OVERLOAD) + OverloadParamIndex = I; } OverloadTypes = R->getValueAsString("oload_types"); FnAttr = R->getValueAsString("fn_attr"); @@ -103,6 +117,68 @@ }; } // end anonymous namespace +static ParameterKind parameterTypeNameToKind(StringRef Name) { + return StringSwitch(Name) + .Case("void", ParameterKind::VOID) + .Case("half", ParameterKind::HALF) + .Case("float", ParameterKind::FLOAT) + .Case("double", ParameterKind::DOUBLE) + .Case("i1", ParameterKind::I1) + .Case("i8", ParameterKind::I8) + .Case("i16", ParameterKind::I16) + .Case("i32", ParameterKind::I32) + .Case("i64", ParameterKind::I64) + .Case("$o", ParameterKind::OVERLOAD) + .Case("dx.types.Handle", ParameterKind::DXIL_HANDLE) + .Case("dx.types.CBufRet", ParameterKind::CBUFFER_RET) + .Case("dx.types.ResRet", ParameterKind::RESOURCE_RET) + .Default(ParameterKind::INVALID); +} + +DXILParam::DXILParam(const Record *R) { + Name = R->getValueAsString("name"); + Pos = R->getValueAsInt("pos"); + Kind = parameterTypeNameToKind(R->getValueAsString("llvm_type")); + if (R->getValue("doc")) + Doc = R->getValueAsString("doc"); + IsConst = R->getValueAsBit("is_const"); + EnumName = R->getValueAsString("enum_name"); + MaxValue = R->getValueAsInt("max_value"); +} + +static std::string parameterKindToString(ParameterKind Kind) { + switch (Kind) { + case ParameterKind::INVALID: + return "INVALID"; + case ParameterKind::VOID: + return "VOID"; + case ParameterKind::HALF: + return "HALF"; + case ParameterKind::FLOAT: + return "FLOAT"; + case ParameterKind::DOUBLE: + return "DOUBLE"; + case ParameterKind::I1: + return "I1"; + case ParameterKind::I8: + return "I8"; + case ParameterKind::I16: + return "I16"; + case ParameterKind::I32: + return "I32"; + case ParameterKind::I64: + return "I64"; + case ParameterKind::OVERLOAD: + return "OVERLOAD"; + case ParameterKind::CBUFFER_RET: + return "CBUFFER_RET"; + case ParameterKind::RESOURCE_RET: + return "RESOURCE_RET"; + case ParameterKind::DXIL_HANDLE: + return "DXIL_HANDLE"; + } +} + static void emitDXILOpEnum(DXILOperationData &DXILOp, raw_ostream &OS) { // Name = ID, // Doc OS << DXILOp.Name << " = " << DXILOp.DXILOpID << ", // " << DXILOp.Doc @@ -274,7 +350,9 @@ // Collect Names. SequenceToOffsetTable OpClassStrings; SequenceToOffsetTable OpStrings; + SequenceToOffsetTable> Parameters; + StringMap> ParameterMap; StringSet<> ClassSet; for (auto &DXILOp : DXILOps) { OpStrings.add(DXILOp.DXILOp.str()); @@ -283,16 +361,24 @@ continue; ClassSet.insert(DXILOp.DXILClass); OpClassStrings.add(getDXILOpClassName(DXILOp.DXILClass)); + SmallVector ParamKindVec; + for (auto &Param : DXILOp.Params) { + ParamKindVec.emplace_back(Param.Kind); + } + ParameterMap[DXILOp.DXILClass] = ParamKindVec; + Parameters.add(ParamKindVec); } // Layout names. OpStrings.layout(); OpClassStrings.layout(); + Parameters.layout(); // Emit the DXIL operation table. //{DXIL::OpCode::Sin, OpCodeNameIndex, OpCodeClass::Unary, // OpCodeClassNameIndex, - // OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone}, + // OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone, 0, + // 3, ParameterTableOffset}, OS << "static const OpCodeProperty *getOpCodeProperty(DXIL::OpCode DXILOp) " "{\n"; @@ -303,7 +389,9 @@ << ", OpCodeClass::" << DXILOp.DXILClass << ", " << OpClassStrings.get(getDXILOpClassName(DXILOp.DXILClass)) << ", " << getDXILOperationOverload(DXILOp.OverloadTypes) << ", " - << emitDXILOperationFnAttr(DXILOp.FnAttr) << " },\n"; + << emitDXILOperationFnAttr(DXILOp.FnAttr) << ", " + << DXILOp.OverloadParamIndex << ", " << DXILOp.Params.size() << ", " + << Parameters.get(ParameterMap[DXILOp.DXILClass]) << " },\n"; } OS << " };\n"; @@ -341,6 +429,21 @@ OS << " unsigned Index = Prop.OpCodeClassNameOffset;\n"; OS << " return DXILOpCodeClassNameTable + Index;\n"; OS << "}\n "; + + OS << "static const ParameterKind *getOpCodeParameterKind(const " + "OpCodeProperty &Prop) " + "{\n\n"; + OS << " static const ParameterKind DXILOpParameterKindTable[] = {\n"; + Parameters.emit( + OS, + [](raw_ostream &ParamOS, ParameterKind Kind) { + ParamOS << "ParameterKind::" << parameterKindToString(Kind); + }, + "ParameterKind::INVALID"); + OS << " };\n\n"; + OS << " unsigned Index = Prop.ParameterTableOffset;\n"; + OS << " return DXILOpParameterKindTable + Index;\n"; + OS << "}\n "; } namespace llvm {