diff --git a/llvm/include/llvm/IR/IntrinsicsDXIL.td b/llvm/include/llvm/IR/IntrinsicsDXIL.td --- a/llvm/include/llvm/IR/IntrinsicsDXIL.td +++ b/llvm/include/llvm/IR/IntrinsicsDXIL.td @@ -17,4 +17,12 @@ def int_dxil_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>; def int_dxil_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMem, IntrWillReturn]>; +def int_dxil_create_handle : Intrinsic<[ llvm_i64_ty ], [llvm_i8_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty], [IntrNoMem, IntrWillReturn]>; + +def int_dxil_buffer_load : Intrinsic<[ llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>, llvm_i32_ty ], + [ llvm_i64_ty, llvm_i32_ty, llvm_i32_ty], [IntrReadMem, IntrWillReturn]>; +def int_dxil_buffer_store : Intrinsic<[ ], [ llvm_i64_ty, llvm_i32_ty, llvm_i32_ty, llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>, llvm_i8_ty], + [ IntrWriteMem, IntrWillReturn]>; + + } diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -26,11 +26,14 @@ def ThreadIdInGroupClass : dxil_class<"ThreadIdInGroup">; def ThreadIdClass : dxil_class<"ThreadId">; def GroupIdClass : dxil_class<"GroupId">; +def BufferLoadClass : dxil_class<"BufferLoad">; +def BufferStoreClass : dxil_class<"BufferStore">; +def CreateHandleClass : dxil_class<"CreateHandle">; def binary_uint : dxil_category<"Binary uint">; def unary_float : dxil_category<"Unary float">; def ComputeID : dxil_category<"Compute/Mesh/Amplification shader">; - +def Resources : dxil_category<"Resources">; // The parameter description for a DXIL instruction class dxil_param ]>, dxil_map_intrinsic; + +def BufferLoad : dxil_op< "BufferLoad", 68, BufferLoadClass,Resources, "reads from a TypedBuffer", "half;float;i16;i32;", "ro", + [ + dxil_param<0, "dx.types.ResRet", "", "the loaded value">, + dxil_param<1, "i32", "opcode", "DXIL opcode">, + dxil_param<2, "dx.types.Handle", "srv", "handle of TypedBuffer SRV to sample">, + dxil_param<3, "i32", "index", "element index">, + dxil_param<4, "i32", "wot", "coordinate"> + ], + ["tex_load"]>, + dxil_map_intrinsic; + +def BufferStore : dxil_op< "BufferStore", 69, BufferStoreClass, Resources, "writes to a RWTypedBuffer", "half;float;i16;i32;", "", + [ + dxil_param<0, "v", "", "">, + dxil_param<1, "i32", "opcode", "DXIL opcode">, + dxil_param<2, "dx.types.Handle", "uav", "handle of UAV to store to">, + dxil_param<3, "i32", "coord0", "coordinate in elements">, + dxil_param<4, "i32", "coord1", "coordinate (unused?)">, + dxil_param<5, "$o", "value0", "value">, + dxil_param<6, "$o", "value1", "value">, + dxil_param<7, "$o", "value2", "value">, + dxil_param<8, "$o", "value3", "value">, + dxil_param<9, "i8", "mask", "written value mask"> + ], + ["tex_store"]>, + dxil_map_intrinsic; + +def CreateHandle : dxil_op< "CreateHandle", 57, CreateHandleClass, Resources, "creates the handle to a resource", + "void;", "ro", + [ + dxil_param<0, "dx.types.Handle", "", "the handle to the resource">, + dxil_param<1, "i32", "opcode", "DXIL opcode">, + dxil_param<2, "i8", "resourceClass", "the class of resource to create (SRV, UAV, CBuffer, Sampler)", 1>, // maps to DxilResourceBase::Class + dxil_param<3, "i32", "rangeId", "range identifier for resource", 1>, + dxil_param<4, "i32", "index", "zero-based index into range">, + dxil_param<5, "i1", "nonUniformIndex", "non-uniform resource index", 1> + ]>, + dxil_map_intrinsic; diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -138,6 +138,15 @@ unsigned OpCodeClassNameOffset; uint16_t OverloadTys; llvm::Attribute::AttrKind FuncAttr; + bool HasOverload; // If only has one type in OverloadTypes, HasOverload will + // be false. + int OverloadParamIndex; // parameter index which control the overload + bool OverloadTypeInStruct; // The overload parameter type is struct, the + // overload type is first field type of the struct + // type. This happens for things like buffer load. + SmallVector> + StructParams; // Param which is struct type and need to mutate to DXIL + // type. }; // Include getOpCodeClassName getOpCodeProperty and getOpCodeName which @@ -156,13 +165,101 @@ .str(); } +static Type *getOverloadType(const OpCodeProperty *Prop, FunctionType *FT) { + if (!Prop->HasOverload) { + auto &Ctx = FT->getContext(); + // When only has 1 overload type, just return it. + switch (Prop->OverloadTys) { + case OverloadKind::VOID: + return Type::getVoidTy(Ctx); + case OverloadKind::HALF: + return Type::getHalfTy(Ctx); + case OverloadKind::FLOAT: + return Type::getFloatTy(Ctx); + case OverloadKind::DOUBLE: + return Type::getDoubleTy(Ctx); + case OverloadKind::I1: + return Type::getInt1Ty(Ctx); + case OverloadKind::I8: + return Type::getInt8Ty(Ctx); + case OverloadKind::I16: + return Type::getInt16Ty(Ctx); + case OverloadKind::I32: + return Type::getInt32Ty(Ctx); + case OverloadKind::I64: + return Type::getInt64Ty(Ctx); + default: + break; + } + } + + // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType(). + Type *OverloadType = FT->getReturnType(); + if (Prop->OverloadParamIndex != 0) { + // Skip Return Type and Type for DXIL opcode. + OverloadType = FT->getParamType(Prop->OverloadParamIndex - 2); + } + + if (Prop->OverloadTypeInStruct) { + auto *ST = cast(OverloadType); + OverloadType = ST->getElementType(0); + } + return OverloadType; +} + +static std::string constructOverloadTypeName(OverloadKind Kind, + StringRef TypeName) { + if (Kind == OverloadKind::VOID) + return TypeName.str(); + + assert(Kind < OverloadKind::UserDefineType && "invalid overload kind"); + return (Twine(TypeName) + getOverloadTypeName(Kind)).str(); +} + +static StructType *getOrCreateStructType(StringRef Name, + ArrayRef EltTys, + LLVMContext &Ctx) { + StructType *ST = StructType::getTypeByName(Ctx, Name); + if (ST) + return ST; + + return StructType::create(Ctx, EltTys, Name); +} + +static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { + OverloadKind Kind = getOverloadKind(OverloadTy); + std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); + Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, + Type::getInt32Ty(Ctx)}; + return getOrCreateStructType(TypeName, FieldTypes, Ctx); +} + +static StructType *getHandleType(LLVMContext &Ctx) { + return getOrCreateStructType("dx.types.Handle", Type::getInt8PtrTy(Ctx), Ctx); +} + +static StructType *getDXILStructType(StringRef Name, Type *CurTy, + LLVMContext &Ctx) { + StructType *ST = StructType::getTypeByName(Ctx, Name); + if (ST) + return ST; + + if (Name == "dx.types.Handle") + return getHandleType(Ctx); + + if (Name == "dx.types.ResRet") + return getResRetType(CurTy, Ctx); + + llvm_unreachable("invalid DXIL struct type"); + return nullptr; +} + static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F, Module &M) { const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); - // Get return type as overload type for DXILOp. - // Only simple mapping case here, so return type is good enough. - Type *OverloadTy = F.getReturnType(); + FunctionType *FT = F.getFunctionType(); + Type *OverloadTy = getOverloadType(Prop, FT); OverloadKind Kind = getOverloadKind(OverloadTy); // FIXME: find the issue and report error in clang instead of check it in @@ -177,30 +274,95 @@ auto &Ctx = M.getContext(); Type *OpCodeTy = Type::getInt32Ty(Ctx); + auto StructParamIt = Prop->StructParams.begin(); + auto StructParamEnd = Prop->StructParams.end(); + auto *RetTy = FT->getReturnType(); + // Change struct type to DXIL type. + if (StructParamIt != StructParamEnd && StructParamIt->first == 0) { + RetTy = getDXILStructType(StructParamIt->second, OverloadTy, Ctx); + ++StructParamIt; + } + SmallVector ArgTypes; // DXIL has i32 opcode as first arg. ArgTypes.emplace_back(OpCodeTy); - FunctionType *FT = F.getFunctionType(); - ArgTypes.append(FT->param_begin(), FT->param_end()); - FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false); + + for (unsigned I = 0; I < FT->getNumParams(); ++I) { + auto *ParamTy = FT->getParamType(I); + // i+2 to skip RetType and DXIL opcode. + if (StructParamIt != StructParamEnd && StructParamIt->first == (I + 2)) { + // Change struct type to DXIL type. + ArgTypes.emplace_back( + getDXILStructType(StructParamIt->second, OverloadTy, Ctx)); + ++StructParamIt; + } else { + ArgTypes.emplace_back(ParamTy); + } + } + + FunctionType *DXILOpFT = FunctionType::get(RetTy, ArgTypes, false); + return M.getOrInsertFunction(FnName, DXILOpFT); } -static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) { +static FunctionCallee getOrCreateCastFunction(Type *FromTy, Type *ToTy, + Module &M) { + std::string CastFnName = "Tmp.Cast."; + llvm::raw_string_ostream OS(CastFnName); + FromTy->print(OS); + OS << "."; + ToTy->print(OS); + OS.flush(); + if (auto Fn = M.getFunction(CastFnName)) + return Fn; + FunctionType *FT = FunctionType::get(ToTy, FromTy, false); + + return M.getOrInsertFunction(CastFnName, FT); +} + +static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M, + SmallDenseSet &TmpCastFnSet) { auto DXILOpFn = createDXILOpFunction(DXILOp, F, M); IRBuilder<> B(M.getContext()); Value *DXILOpArg = B.getInt32(static_cast(DXILOp)); + FunctionType *FT = DXILOpFn.getFunctionType(); + Type *RetTy = FT->getReturnType(); + + // When RetTy or ParamTy for DXILOpFn is DXIL struct type, it will be mismatch + // from the type in intrinsic. The size will be the same. Create bitcast to + // make it match. These bitcast will be removed later. for (User *U : make_early_inc_range(F.users())) { CallInst *CI = dyn_cast(U); if (!CI) continue; - + if (TmpCastFnSet.contains(CI->getCalledFunction())) + continue; SmallVector Args; Args.emplace_back(DXILOpArg); - Args.append(CI->arg_begin(), CI->arg_end()); + auto ParamTyIt = FT->param_begin(); + // Skip param for DXIL opcode. + ++ParamTyIt; B.SetInsertPoint(CI); - CallInst *DXILCI = B.CreateCall(DXILOpFn, Args); + for (auto &Arg : CI->args()) { + auto *ArgTy = Arg->getType(); + auto *ParamTy = *(ParamTyIt++); + if (ArgTy == ParamTy) { + Args.emplace_back(Arg); + continue; + } + auto CastFn = getOrCreateCastFunction(ArgTy, ParamTy, M); + TmpCastFnSet.insert(cast(CastFn.getCallee())); + auto *Cast = B.CreateCall(CastFn, {Arg}); + Args.emplace_back(Cast); + } + + Value *DXILCI = B.CreateCall(DXILOpFn, Args); LLVM_DEBUG(DXILCI->setName(getOpCodeName(DXILOp))); + if (CI->getType() != RetTy) { + auto CastFn = getOrCreateCastFunction(RetTy, CI->getType(), M); + TmpCastFnSet.insert(cast(CastFn.getCallee())); + DXILCI = B.CreateCall(CastFn, {DXILCI}); + } CI->replaceAllUsesWith(DXILCI); CI->eraseFromParent(); } @@ -215,6 +377,8 @@ #include "DXILOperation.inc" #undef DXIL_OP_INTRINSIC_MAP + SmallDenseSet TmpCastFnSet; + for (Function &F : make_early_inc_range(M.functions())) { if (!F.isDeclaration()) continue; @@ -224,9 +388,63 @@ auto LowerIt = LowerMap.find(ID); if (LowerIt == LowerMap.end()) continue; - lowerIntrinsic(LowerIt->second, F, M); + lowerIntrinsic(LowerIt->second, F, M, TmpCastFnSet); Updated = true; } + // Remove cast functions. + for (auto *CastFn : TmpCastFnSet) { + Type *RetTy = CastFn->getReturnType(); + // Skip cast to non dxil struct type. + StructType *ST = dyn_cast(RetTy); + if (!ST) + continue; + if (!ST->hasName()) { + // Replace extractvalue use. + for (User *U : make_early_inc_range(CastFn->users())) { + CallInst *CI = cast(U); + Value *Arg = CI->getArgOperand(0); + + // FIXME: support cases where the Arg is not CallInst. + CallInst *CIArg = dyn_cast(Arg); + if (!CIArg) { + llvm_unreachable("unsupported DXIL struct type cast"); + break; + } + for (User *CastU : make_early_inc_range(CI->users())) { + ExtractValueInst *EVI = dyn_cast(CastU); + if (EVI) + EVI->setOperand(0, CIArg); + } + if (CI->user_empty()) + CI->eraseFromParent(); + } + + continue; + } + + for (User *U : make_early_inc_range(CastFn->users())) { + CallInst *CI = cast(U); + Value *Arg = CI->getArgOperand(0); + // FIXME: support cases where the Arg is not CallInst. + CallInst *CIArg = dyn_cast(Arg); + if (!CIArg) { + llvm_unreachable("unsupported DXIL struct type cast"); + break; + } + assert(TmpCastFnSet.contains(CIArg->getCalledFunction())); + Value *InputArg = CIArg->getArgOperand(0); + assert(InputArg->getType() == RetTy); + CI->replaceAllUsesWith(InputArg); + CI->eraseFromParent(); + if (CIArg->user_empty()) + CIArg->eraseFromParent(); + } + } + // All CastFn should be user empty now. + for (auto *CastFn : TmpCastFnSet) { + CastFn->eraseFromParent(); + } + return Updated; } diff --git a/llvm/test/CodeGen/DirectX/resources.ll b/llvm/test/CodeGen/DirectX/resources.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/resources.ll @@ -0,0 +1,44 @@ +; RUN: opt -S -dxil-op-lower < %s | FileCheck %s + +; Make sure dxil operation function calls for createHandle bufferLoad/Store operations are generated. + +target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" +target triple = "dxil-pc-shadermodel6.3-library" + +%struct.anon = type { float, float, float, float, i32 } + + +; CHECK-LABEL:test_buffer_load_f32 +; CHECK: %[[HDL:.+]] = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 %res_class, i32 %range_id, i32 %index, i1 %non_uniform_index) +; CHECK: call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle %[[HDL]], i32 %idx, i32 undef) +define float @test_buffer_load_f32(i32 %idx, i8 %res_class, i32 %range_id, i32 %index, i1 %non_uniform_index) #0 { + %hdl = call i64 @llvm.dxil.create.handle(i8 %res_class, i32 %range_id, i32 %index, i1 %non_uniform_index) + %1 = call %struct.anon @llvm.dxil.buffer.load.f32(i64 %hdl, i32 %idx, i32 undef) + %2 = extractvalue %struct.anon %1, 0 + ret float %2 +} +; CHECK-LABEL:test_buffer_store_f32 +; CHECK: %[[HDL2:.+]] = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 %res_class, i32 %range_id, i32 %index, i1 %non_uniform_index) +; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %[[HDL2]], i32 %idx, i32 undef, float %v, float undef, float undef, float undef, i8 1) +define void @test_buffer_store_f32(i32 %idx, float %v, i8 %res_class, i32 %range_id, i32 %index, i1 %non_uniform_index) #0 { + %hdl = call i64 @llvm.dxil.create.handle(i8 %res_class, i32 %range_id, i32 %index, i1 %non_uniform_index) + call void @llvm.dxil.buffer.store.f32(i64 %hdl, i32 %idx, i32 undef, float %v, float undef, float undef, float undef, i8 1) + ret void +} + +; CHECK-DAG:declare %dx.types.Handle @dx.op.createHandle(i32, i8, i32, i32, i1) +declare i64 @llvm.dxil.create.handle(i8, i32, i32, i1) #1 + +; CHECK-DAG:declare %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32, %dx.types.Handle, i32, i32) +declare %struct.anon @llvm.dxil.buffer.load.f32(i64, i32, i32) #2 + +; CHECK-DAG:declare void @dx.op.bufferStore.f32(i32, %dx.types.Handle, i32, i32, float, float, float, float, i8) +declare void @llvm.dxil.buffer.store.f32(i64, i32, i32, float, float, float, float, i8) #3 + +; Make sure no other function declaration. +; CHECK-NOT:declare + +attributes #0 = { noinline nounwind } +attributes #1 = { nounwind readnone willreturn } +attributes #2 = { nounwind readonly willreturn } +attributes #3 = { nounwind willreturn } diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp --- a/llvm/utils/TableGen/DXILEmitter.cpp +++ b/llvm/utils/TableGen/DXILEmitter.cpp @@ -75,6 +75,15 @@ DXILShaderModel ShaderModelTranslated; // minimum shader model required with // translation by linker SmallVector counters; // counters for this inst. + + bool HasOverload; // If only has one type in OverloadTypes, HasOverload will + // be false. + int OverloadParamIndex; // parameter index which control the overload + bool OverloadTypeInStruct; // The overload parameter type is struct, the + // overload type is first field type of the struct + // type. This happens for things like buffer load. + SmallVector> StructParams; + DXILOperationData(const Record *R) { Name = R->getValueAsString("name"); DXILOp = R->getValueAsString("dxil_op"); @@ -92,15 +101,34 @@ Doc = R->getValueAsString("doc"); + OverloadParamIndex = -1; + OverloadTypeInStruct = false; ListInit *ParamList = R->getValueAsListInit("ops"); for (unsigned i = 0; i < ParamList->size(); ++i) { Record *Param = ParamList->getElementAsRecord(i); Params.emplace_back(DXILParam(Param)); + auto &CurParam = Params.back(); + if (CurParam.Type == "$o" || CurParam.Type == "udt" || + CurParam.Type == "obj") { + OverloadParamIndex = i; + } else if (CurParam.Type == "dx.types.CBufRet" || + CurParam.Type == "dx.types.ResRet") { + OverloadParamIndex = i; + OverloadTypeInStruct = true; + } + if (CurParam.Type.startswith("dx.types.")) + StructParams.emplace_back(std::make_pair(i, CurParam.Type)); } OverloadTypes = R->getValueAsString("oload_types"); FnAttr = R->getValueAsString("fn_attr"); + + SmallVector OverloadStrs; + OverloadTypes.split(OverloadStrs, ';', /*MaxSplit*/ -1, + /*KeepEmpty*/ false); + HasOverload = OverloadStrs.size() > 1; } }; + } // end anonymous namespace static void emitDXILOpEnum(DXILOperationData &DXILOp, raw_ostream &OS) { @@ -303,7 +331,16 @@ << ", OpCodeClass::" << DXILOp.DXILClass << ", " << OpClassStrings.get(getDXILOpClassName(DXILOp.DXILClass)) << ", " << getDXILOperationOverload(DXILOp.OverloadTypes) << ", " - << emitDXILOperationFnAttr(DXILOp.FnAttr) << " },\n"; + << emitDXILOperationFnAttr(DXILOp.FnAttr) << ", " << DXILOp.HasOverload + << ", " << DXILOp.OverloadParamIndex << ", " + << DXILOp.OverloadTypeInStruct << ", "; + OS << "{ "; + for (auto &StructParam : DXILOp.StructParams) { + OS << " { " << StructParam.first << ", \"" << StructParam.second + << "\" } ,"; + } + OS << "} "; + OS << " },\n"; } OS << " };\n";