diff --git a/llvm/include/llvm/Support/DXILOperationCommon.h b/llvm/include/llvm/Support/DXILOperationCommon.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Support/DXILOperationCommon.h @@ -0,0 +1,63 @@ +//===-- DXILOperationCommon.h - DXIL Operation ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is created to share common definitions used by both the +// DXILOpBuilder and the table +// generator. +// Documentation for DXIL can be found in +// https://github.com/Microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_SUPPORT_DXILOPERATIONCOMMON_H +#define LLVM_SUPPORT_DXILOPERATIONCOMMON_H + +#include "llvm/ADT/StringSwitch.h" + +namespace llvm { +namespace DXIL { + +enum class ParameterKind : uint8_t { + INVALID = 0, + VOID, + HALF, + FLOAT, + DOUBLE, + I1, + I8, + I16, + I32, + I64, + OVERLOAD, + CBUFFER_RET, + RESOURCE_RET, + DXIL_HANDLE, +}; + +inline ParameterKind parameterTypeNameToKind(StringRef Name) { + return StringSwitch(Name) + .Case("void", ParameterKind::VOID) + .Case("half", ParameterKind::HALF) + .Case("float", ParameterKind::FLOAT) + .Case("double", ParameterKind::DOUBLE) + .Case("i1", ParameterKind::I1) + .Case("i8", ParameterKind::I8) + .Case("i16", ParameterKind::I16) + .Case("i32", ParameterKind::I32) + .Case("i64", ParameterKind::I64) + .Case("$o", ParameterKind::OVERLOAD) + .Case("dx.types.Handle", ParameterKind::DXIL_HANDLE) + .Case("dx.types.CBufRet", ParameterKind::CBUFFER_RET) + .Case("dx.types.ResRet", ParameterKind::RESOURCE_RET) + .Default(ParameterKind::INVALID); +} + +} // namespace DXIL +} // namespace llvm + +#endif diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -17,6 +17,7 @@ DirectXRegisterInfo.cpp DirectXSubtarget.cpp DirectXTargetMachine.cpp + DXILOpBuilder.cpp DXILOpLowering.cpp DXILPointerType.cpp DXILPrepare.cpp diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h @@ -0,0 +1,46 @@ +//===- DXILOpBuilder.h - Helper class for build DIXLOp functions ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains class to help build DXIL op functions. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H +#define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H + +#include "DXILConstants.h" +#include "llvm/ADT/iterator_range.h" + +namespace llvm { +class Module; +class IRBuilderBase; +class CallInst; +class Value; +class Type; +class FunctionType; +class Use; + +namespace DXIL { + +class DXILOpBuilder { +public: + DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {} + CallInst *createDXILOpCall(DXIL::OpCode OpCode, Type *OverloadTy, + llvm::iterator_range Args); + Type *getOverloadTy(DXIL::OpCode OpCode, FunctionType *FT, + bool NoOpCodeParam); + static const char *getOpCodeName(DXIL::OpCode DXILOp); + +private: + Module &M; + IRBuilderBase &B; +}; + +} // namespace DXIL +} // namespace llvm + +#endif diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -0,0 +1,324 @@ +//===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains class to help build DXIL op functions. +//===----------------------------------------------------------------------===// + +#include "DXILOpBuilder.h" +#include "DXILConstants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/DXILOperationCommon.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace llvm; +using namespace llvm::DXIL; + +constexpr StringLiteral DXILOpNamePrefix = "dx.op."; + +namespace { + +enum OverloadKind : uint16_t { + VOID = 1, + HALF = 1 << 1, + FLOAT = 1 << 2, + DOUBLE = 1 << 3, + I1 = 1 << 4, + I8 = 1 << 5, + I16 = 1 << 6, + I32 = 1 << 7, + I64 = 1 << 8, + UserDefineType = 1 << 9, + ObjectType = 1 << 10, +}; + +} // namespace + +static const char *getOverloadTypeName(OverloadKind Kind) { + switch (Kind) { + case OverloadKind::HALF: + return "f16"; + case OverloadKind::FLOAT: + return "f32"; + case OverloadKind::DOUBLE: + return "f64"; + case OverloadKind::I1: + return "i1"; + case OverloadKind::I8: + return "i8"; + case OverloadKind::I16: + return "i16"; + case OverloadKind::I32: + return "i32"; + case OverloadKind::I64: + return "i64"; + case OverloadKind::VOID: + case OverloadKind::ObjectType: + case OverloadKind::UserDefineType: + break; + } + llvm_unreachable("invalid overload type for name"); + return "void"; +} + +static OverloadKind getOverloadKind(Type *Ty) { + Type::TypeID T = Ty->getTypeID(); + switch (T) { + case Type::VoidTyID: + return OverloadKind::VOID; + case Type::HalfTyID: + return OverloadKind::HALF; + case Type::FloatTyID: + return OverloadKind::FLOAT; + case Type::DoubleTyID: + return OverloadKind::DOUBLE; + case Type::IntegerTyID: { + IntegerType *ITy = cast(Ty); + unsigned Bits = ITy->getBitWidth(); + switch (Bits) { + case 1: + return OverloadKind::I1; + case 8: + return OverloadKind::I8; + case 16: + return OverloadKind::I16; + case 32: + return OverloadKind::I32; + case 64: + return OverloadKind::I64; + default: + llvm_unreachable("invalid overload type"); + return OverloadKind::VOID; + } + } + case Type::PointerTyID: + return OverloadKind::UserDefineType; + case Type::StructTyID: + return OverloadKind::ObjectType; + default: + llvm_unreachable("invalid overload type"); + return OverloadKind::VOID; + } +} + +static std::string getTypeName(OverloadKind Kind, Type *Ty) { + if (Kind < OverloadKind::UserDefineType) { + return getOverloadTypeName(Kind); + } else if (Kind == OverloadKind::UserDefineType) { + StructType *ST = cast(Ty); + return ST->getStructName().str(); + } else if (Kind == OverloadKind::ObjectType) { + StructType *ST = cast(Ty); + return ST->getStructName().str(); + } else { + std::string Str; + raw_string_ostream OS(Str); + Ty->print(OS); + return OS.str(); + } +} + +// Static properties. +struct OpCodeProperty { + DXIL::OpCode OpCode; + // Offset in DXILOpCodeNameTable. + unsigned OpCodeNameOffset; + DXIL::OpCodeClass OpCodeClass; + // Offset in DXILOpCodeClassNameTable. + unsigned OpCodeClassNameOffset; + uint16_t OverloadTys; + llvm::Attribute::AttrKind FuncAttr; + int OverloadParamIndex; // parameter index which control the overload. + // When < 0, should be only 1 overload type. + unsigned NumOfParameters; // Number of parameters include return value. + unsigned ParameterTableOffset; // Offset in ParameterTable. +}; + +// Include getOpCodeClassName getOpCodeProperty, getOpCodeName and +// getOpCodeParameterKind which generated by tableGen. +#define DXIL_OP_OPERATION_TABLE +#include "DXILOperation.inc" +#undef DXIL_OP_OPERATION_TABLE + +static std::string constructOverloadName(OverloadKind Kind, Type *Ty, + const OpCodeProperty &Prop) { + if (Kind == OverloadKind::VOID) { + return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); + } + return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + + getTypeName(Kind, Ty)) + .str(); +} + +static std::string constructOverloadTypeName(OverloadKind Kind, + StringRef TypeName) { + if (Kind == OverloadKind::VOID) + return TypeName.str(); + + assert(Kind < OverloadKind::UserDefineType && "invalid overload kind"); + return (Twine(TypeName) + getOverloadTypeName(Kind)).str(); +} + +static StructType *getOrCreateStructType(StringRef Name, + ArrayRef EltTys, + LLVMContext &Ctx) { + StructType *ST = StructType::getTypeByName(Ctx, Name); + if (ST) + return ST; + + return StructType::create(Ctx, EltTys, Name); +} + +static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { + OverloadKind Kind = getOverloadKind(OverloadTy); + std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); + Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, + Type::getInt32Ty(Ctx)}; + return getOrCreateStructType(TypeName, FieldTypes, Ctx); +} + +static StructType *getHandleType(LLVMContext &Ctx) { + return getOrCreateStructType("dx.types.Handle", Type::getInt8PtrTy(Ctx), Ctx); +} + +static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) { + auto &Ctx = OverloadTy->getContext(); + switch (Kind) { + case ParameterKind::VOID: + return Type::getVoidTy(Ctx); + case ParameterKind::HALF: + return Type::getHalfTy(Ctx); + case ParameterKind::FLOAT: + return Type::getFloatTy(Ctx); + case ParameterKind::DOUBLE: + return Type::getDoubleTy(Ctx); + case ParameterKind::I1: + return Type::getInt1Ty(Ctx); + case ParameterKind::I8: + return Type::getInt8Ty(Ctx); + case ParameterKind::I16: + return Type::getInt16Ty(Ctx); + case ParameterKind::I32: + return Type::getInt32Ty(Ctx); + case ParameterKind::I64: + return Type::getInt64Ty(Ctx); + case ParameterKind::OVERLOAD: + return OverloadTy; + case ParameterKind::RESOURCE_RET: + return getResRetType(OverloadTy, Ctx); + case ParameterKind::DXIL_HANDLE: + return getHandleType(Ctx); + default: + break; + } + llvm_unreachable("Invalid parameter kind"); + return nullptr; +} + +static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, + Type *OverloadTy) { + SmallVector ArgTys; + + auto ParamKinds = getOpCodeParameterKind(*Prop); + + for (unsigned I = 0; I < Prop->NumOfParameters; ++I) { + ParameterKind Kind = ParamKinds[I]; + ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy)); + } + return FunctionType::get( + ArgTys[0], ArrayRef(&ArgTys[1], ArgTys.size() - 1), false); +} + +static FunctionCallee getOrCreateDXILOpFunction(DXIL::OpCode DXILOp, + Type *OverloadTy, Module &M) { + const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); + + OverloadKind Kind = getOverloadKind(OverloadTy); + // FIXME: find the issue and report error in clang instead of check it in + // backend. + if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { + llvm_unreachable("invalid overload"); + } + + std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); + // Dependent on name to dedup. + if (auto *Fn = M.getFunction(FnName)) + return FunctionCallee(Fn); + + FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy); + return M.getOrInsertFunction(FnName, DXILOpFT); +} + +namespace llvm { +namespace DXIL { + +CallInst *DXILOpBuilder::createDXILOpCall(DXIL::OpCode OpCode, Type *OverloadTy, + llvm::iterator_range Args) { + auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M); + SmallVector FullArgs; + FullArgs.emplace_back(B.getInt32((int32_t)OpCode)); + FullArgs.append(Args.begin(), Args.end()); + return B.CreateCall(Fn, FullArgs); +} + +Type *DXILOpBuilder::getOverloadTy(DXIL::OpCode OpCode, FunctionType *FT, + bool NoOpCodeParam) { + + const OpCodeProperty *Prop = getOpCodeProperty(OpCode); + if (Prop->OverloadParamIndex < 0) { + auto &Ctx = FT->getContext(); + // When only has 1 overload type, just return it. + switch (Prop->OverloadTys) { + case OverloadKind::VOID: + return Type::getVoidTy(Ctx); + case OverloadKind::HALF: + return Type::getHalfTy(Ctx); + case OverloadKind::FLOAT: + return Type::getFloatTy(Ctx); + case OverloadKind::DOUBLE: + return Type::getDoubleTy(Ctx); + case OverloadKind::I1: + return Type::getInt1Ty(Ctx); + case OverloadKind::I8: + return Type::getInt8Ty(Ctx); + case OverloadKind::I16: + return Type::getInt16Ty(Ctx); + case OverloadKind::I32: + return Type::getInt32Ty(Ctx); + case OverloadKind::I64: + return Type::getInt64Ty(Ctx); + default: + llvm_unreachable("invalid overload type"); + return nullptr; + } + } + + // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType(). + Type *OverloadType = FT->getReturnType(); + if (Prop->OverloadParamIndex != 0) { + // Skip Return Type and Type for DXIL opcode. + const unsigned SkipedParam = NoOpCodeParam ? 2 : 1; + OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam); + } + + auto ParamKinds = getOpCodeParameterKind(*Prop); + auto Kind = ParamKinds[Prop->OverloadParamIndex]; + // For ResRet and CBufferRet, OverloadTy is in field of StructType. + if (Kind == ParameterKind::CBUFFER_RET || + Kind == ParameterKind::RESOURCE_RET) { + auto *ST = cast(OverloadType); + OverloadType = ST->getElementType(0); + } + return OverloadType; +} + +const char *DXILOpBuilder::getOpCodeName(DXIL::OpCode DXILOp) { + return ::getOpCodeName(DXILOp); +} +} // namespace DXIL +} // namespace llvm diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "DXILConstants.h" +#include "DXILOpBuilder.h" #include "DirectX.h" #include "llvm/ADT/SmallVector.h" #include "llvm/CodeGen/Passes.h" @@ -28,168 +29,12 @@ using namespace llvm; using namespace llvm::DXIL; -constexpr StringLiteral DXILOpNamePrefix = "dx.op."; - -enum OverloadKind : uint16_t { - VOID = 1, - HALF = 1 << 1, - FLOAT = 1 << 2, - DOUBLE = 1 << 3, - I1 = 1 << 4, - I8 = 1 << 5, - I16 = 1 << 6, - I32 = 1 << 7, - I64 = 1 << 8, - UserDefineType = 1 << 9, - ObjectType = 1 << 10, -}; - -static const char *getOverloadTypeName(OverloadKind Kind) { - switch (Kind) { - case OverloadKind::HALF: - return "f16"; - case OverloadKind::FLOAT: - return "f32"; - case OverloadKind::DOUBLE: - return "f64"; - case OverloadKind::I1: - return "i1"; - case OverloadKind::I8: - return "i8"; - case OverloadKind::I16: - return "i16"; - case OverloadKind::I32: - return "i32"; - case OverloadKind::I64: - return "i64"; - case OverloadKind::VOID: - case OverloadKind::ObjectType: - case OverloadKind::UserDefineType: - break; - } - llvm_unreachable("invalid overload type for name"); - return "void"; -} - -static OverloadKind getOverloadKind(Type *Ty) { - Type::TypeID T = Ty->getTypeID(); - switch (T) { - case Type::VoidTyID: - return OverloadKind::VOID; - case Type::HalfTyID: - return OverloadKind::HALF; - case Type::FloatTyID: - return OverloadKind::FLOAT; - case Type::DoubleTyID: - return OverloadKind::DOUBLE; - case Type::IntegerTyID: { - IntegerType *ITy = cast(Ty); - unsigned Bits = ITy->getBitWidth(); - switch (Bits) { - case 1: - return OverloadKind::I1; - case 8: - return OverloadKind::I8; - case 16: - return OverloadKind::I16; - case 32: - return OverloadKind::I32; - case 64: - return OverloadKind::I64; - default: - llvm_unreachable("invalid overload type"); - return OverloadKind::VOID; - } - } - case Type::PointerTyID: - return OverloadKind::UserDefineType; - case Type::StructTyID: - return OverloadKind::ObjectType; - default: - llvm_unreachable("invalid overload type"); - return OverloadKind::VOID; - } -} - -static std::string getTypeName(OverloadKind Kind, Type *Ty) { - if (Kind < OverloadKind::UserDefineType) { - return getOverloadTypeName(Kind); - } else if (Kind == OverloadKind::UserDefineType) { - StructType *ST = cast(Ty); - return ST->getStructName().str(); - } else if (Kind == OverloadKind::ObjectType) { - StructType *ST = cast(Ty); - return ST->getStructName().str(); - } else { - std::string Str; - raw_string_ostream OS(Str); - Ty->print(OS); - return OS.str(); - } -} - -// Static properties. -struct OpCodeProperty { - DXIL::OpCode OpCode; - // Offset in DXILOpCodeNameTable. - unsigned OpCodeNameOffset; - DXIL::OpCodeClass OpCodeClass; - // Offset in DXILOpCodeClassNameTable. - unsigned OpCodeClassNameOffset; - uint16_t OverloadTys; - llvm::Attribute::AttrKind FuncAttr; -}; - -// Include getOpCodeClassName getOpCodeProperty and getOpCodeName which -// generated by tableGen. -#define DXIL_OP_OPERATION_TABLE -#include "DXILOperation.inc" -#undef DXIL_OP_OPERATION_TABLE - -static std::string constructOverloadName(OverloadKind Kind, Type *Ty, - const OpCodeProperty &Prop) { - if (Kind == OverloadKind::VOID) { - return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); - } - return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + - getTypeName(Kind, Ty)) - .str(); -} - -static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F, - Module &M) { - const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); - - // Get return type as overload type for DXILOp. - // Only simple mapping case here, so return type is good enough. - Type *OverloadTy = F.getReturnType(); - - OverloadKind Kind = getOverloadKind(OverloadTy); - // FIXME: find the issue and report error in clang instead of check it in - // backend. - if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { - llvm_unreachable("invalid overload"); - } - - std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); - assert(!M.getFunction(FnName) && "Function already exists"); - - auto &Ctx = M.getContext(); - Type *OpCodeTy = Type::getInt32Ty(Ctx); - - SmallVector ArgTypes; - // DXIL has i32 opcode as first arg. - ArgTypes.emplace_back(OpCodeTy); - FunctionType *FT = F.getFunctionType(); - ArgTypes.append(FT->param_begin(), FT->param_end()); - FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false); - return M.getOrInsertFunction(FnName, DXILOpFT); -} - static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) { - auto DXILOpFn = createDXILOpFunction(DXILOp, F, M); IRBuilder<> B(M.getContext()); Value *DXILOpArg = B.getInt32(static_cast(DXILOp)); + DXILOpBuilder DXILB(M, B); + Type *OverloadTy = + DXILB.getOverloadTy(DXILOp, F.getFunctionType(), /*NoOpCodeParam*/ true); for (User *U : make_early_inc_range(F.users())) { CallInst *CI = dyn_cast(U); if (!CI) @@ -199,8 +44,8 @@ Args.emplace_back(DXILOpArg); Args.append(CI->arg_begin(), CI->arg_end()); B.SetInsertPoint(CI); - CallInst *DXILCI = B.CreateCall(DXILOpFn, Args); - LLVM_DEBUG(DXILCI->setName(getOpCodeName(DXILOp))); + CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, OverloadTy, CI->args()); + CI->replaceAllUsesWith(DXILCI); CI->eraseFromParent(); } diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp --- a/llvm/utils/TableGen/DXILEmitter.cpp +++ b/llvm/utils/TableGen/DXILEmitter.cpp @@ -16,10 +16,12 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/DXILOperationCommon.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" using namespace llvm; +using namespace llvm::DXIL; namespace { @@ -27,25 +29,16 @@ int Major; int Minor; }; + struct DXILParam { - int Pos; // position in parameter list - StringRef Type; // llvm type name, $o for overload, $r for resource - // type, $cb for legacy cbuffer, $u4 for u4 struct + int Pos; // position in parameter list + ParameterKind Kind; StringRef Name; // short, unique name StringRef Doc; // the documentation description of this parameter bool IsConst; // whether this argument requires a constant value in the IR StringRef EnumName; // the name of the enum type if applicable int MaxValue; // the maximum value for this parameter if applicable - DXILParam(const Record *R) { - Name = R->getValueAsString("name"); - Pos = R->getValueAsInt("pos"); - Type = R->getValueAsString("llvm_type"); - if (R->getValue("doc")) - Doc = R->getValueAsString("doc"); - IsConst = R->getValueAsBit("is_const"); - EnumName = R->getValueAsString("enum_name"); - MaxValue = R->getValueAsInt("max_value"); - } + DXILParam(const Record *R); }; struct DXILOperationData { @@ -74,7 +67,9 @@ DXILShaderModel ShaderModel; // minimum shader model required DXILShaderModel ShaderModelTranslated; // minimum shader model required with // translation by linker - SmallVector counters; // counters for this inst. + int OverloadParamIndex; // parameter index which control the overload. + // When < 0, should be only 1 overload type. + SmallVector counters; // counters for this inst. DXILOperationData(const Record *R) { Name = R->getValueAsString("name"); DXILOp = R->getValueAsString("dxil_op"); @@ -93,9 +88,13 @@ Doc = R->getValueAsString("doc"); ListInit *ParamList = R->getValueAsListInit("ops"); - for (unsigned i = 0; i < ParamList->size(); ++i) { - Record *Param = ParamList->getElementAsRecord(i); + OverloadParamIndex = -1; + for (unsigned I = 0; I < ParamList->size(); ++I) { + Record *Param = ParamList->getElementAsRecord(I); Params.emplace_back(DXILParam(Param)); + auto &CurParam = Params.back(); + if (CurParam.Kind >= ParameterKind::OVERLOAD) + OverloadParamIndex = I; } OverloadTypes = R->getValueAsString("oload_types"); FnAttr = R->getValueAsString("fn_attr"); @@ -103,6 +102,50 @@ }; } // end anonymous namespace +DXILParam::DXILParam(const Record *R) { + Name = R->getValueAsString("name"); + Pos = R->getValueAsInt("pos"); + Kind = parameterTypeNameToKind(R->getValueAsString("llvm_type")); + if (R->getValue("doc")) + Doc = R->getValueAsString("doc"); + IsConst = R->getValueAsBit("is_const"); + EnumName = R->getValueAsString("enum_name"); + MaxValue = R->getValueAsInt("max_value"); +} + +static std::string parameterKindToString(ParameterKind Kind) { + switch (Kind) { + case ParameterKind::INVALID: + return "INVALID"; + case ParameterKind::VOID: + return "VOID"; + case ParameterKind::HALF: + return "HALF"; + case ParameterKind::FLOAT: + return "FLOAT"; + case ParameterKind::DOUBLE: + return "DOUBLE"; + case ParameterKind::I1: + return "I1"; + case ParameterKind::I8: + return "I8"; + case ParameterKind::I16: + return "I16"; + case ParameterKind::I32: + return "I32"; + case ParameterKind::I64: + return "I64"; + case ParameterKind::OVERLOAD: + return "OVERLOAD"; + case ParameterKind::CBUFFER_RET: + return "CBUFFER_RET"; + case ParameterKind::RESOURCE_RET: + return "RESOURCE_RET"; + case ParameterKind::DXIL_HANDLE: + return "DXIL_HANDLE"; + } +} + static void emitDXILOpEnum(DXILOperationData &DXILOp, raw_ostream &OS) { // Name = ID, // Doc OS << DXILOp.Name << " = " << DXILOp.DXILOpID << ", // " << DXILOp.Doc @@ -271,7 +314,9 @@ // Collect Names. SequenceToOffsetTable OpClassStrings; SequenceToOffsetTable OpStrings; + SequenceToOffsetTable> Parameters; + StringMap> ParameterMap; StringSet<> ClassSet; for (auto &DXILOp : DXILOps) { OpStrings.add(DXILOp.DXILOp.str()); @@ -280,16 +325,24 @@ continue; ClassSet.insert(DXILOp.DXILClass); OpClassStrings.add(getDXILOpClassName(DXILOp.DXILClass)); + SmallVector ParamKindVec; + for (auto &Param : DXILOp.Params) { + ParamKindVec.emplace_back(Param.Kind); + } + ParameterMap[DXILOp.DXILClass] = ParamKindVec; + Parameters.add(ParamKindVec); } // Layout names. OpStrings.layout(); OpClassStrings.layout(); + Parameters.layout(); // Emit the DXIL operation table. //{DXIL::OpCode::Sin, OpCodeNameIndex, OpCodeClass::Unary, // OpCodeClassNameIndex, - // OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone}, + // OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone, 0, + // 3, ParameterTableOffset}, OS << "static const OpCodeProperty *getOpCodeProperty(DXIL::OpCode DXILOp) " "{\n"; @@ -300,7 +353,9 @@ << ", OpCodeClass::" << DXILOp.DXILClass << ", " << OpClassStrings.get(getDXILOpClassName(DXILOp.DXILClass)) << ", " << getDXILOperationOverload(DXILOp.OverloadTypes) << ", " - << emitDXILOperationFnAttr(DXILOp.FnAttr) << " },\n"; + << emitDXILOperationFnAttr(DXILOp.FnAttr) << ", " + << DXILOp.OverloadParamIndex << ", " << DXILOp.Params.size() << ", " + << Parameters.get(ParameterMap[DXILOp.DXILClass]) << " },\n"; } OS << " };\n"; @@ -338,6 +393,21 @@ OS << " unsigned Index = Prop.OpCodeClassNameOffset;\n"; OS << " return DXILOpCodeClassNameTable + Index;\n"; OS << "}\n "; + + OS << "static const ParameterKind *getOpCodeParameterKind(const " + "OpCodeProperty &Prop) " + "{\n\n"; + OS << " static const ParameterKind DXILOpParameterKindTable[] = {\n"; + Parameters.emit( + OS, + [](raw_ostream &ParamOS, ParameterKind Kind) { + ParamOS << "ParameterKind::" << parameterKindToString(Kind); + }, + "ParameterKind::INVALID"); + OS << " };\n\n"; + OS << " unsigned Index = Prop.ParameterTableOffset;\n"; + OS << " return DXILOpParameterKindTable + Index;\n"; + OS << "}\n "; } namespace llvm {