diff --git a/clang/lib/AST/Mangle.cpp b/clang/lib/AST/Mangle.cpp --- a/clang/lib/AST/Mangle.cpp +++ b/clang/lib/AST/Mangle.cpp @@ -133,10 +133,6 @@ if (isa(D)) return true; - // HLSL shader entry function never need to be mangled. - if (getASTContext().getLangOpts().HLSL && D->hasAttr()) - return false; - return shouldMangleCXXName(D); } diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -15,6 +15,8 @@ #ifndef LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H #define LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H +#include "llvm/IR/IRBuilder.h" + #include "clang/Basic/HLSLRuntime.h" namespace llvm { @@ -26,6 +28,7 @@ class CallExpr; class Type; class VarDecl; +class ParmVarDecl; class FunctionDecl; @@ -39,6 +42,8 @@ uint32_t ResourceCounters[static_cast( hlsl::ResourceClass::NumClasses)] = {0}; + llvm::Value *emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D); + public: CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {} virtual ~CGHLSLRuntime() {} @@ -47,7 +52,9 @@ void finishCodeGen(); - void setHLSLFunctionAttributes(llvm::Function *, const FunctionDecl *); + void setHLSLEntryAttributes(const FunctionDecl *FD, llvm::Function *Fn); + + void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn); }; } // namespace CodeGen diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -14,7 +14,9 @@ #include "CGHLSLRuntime.h" #include "CodeGenModule.h" +#include "clang/AST/Decl.h" #include "clang/Basic/TargetOptions.h" +#include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" @@ -87,11 +89,55 @@ ConstantAsMetadata::get(B.getInt32(Counter))})); } -void clang::CodeGen::CGHLSLRuntime::setHLSLFunctionAttributes( - llvm::Function *F, const FunctionDecl *FD) { - if (HLSLShaderAttr *ShaderAttr = FD->getAttr()) { - const StringRef ShaderAttrKindStr = "dx.shader"; - F->addFnAttr(ShaderAttrKindStr, - ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType())); +void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes( + const FunctionDecl *FD, llvm::Function *Fn) { + const auto *ShaderAttr = FD->getAttr(); + assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr"); + const StringRef ShaderAttrKindStr = "dx.shader"; + Fn->addFnAttr(ShaderAttrKindStr, + ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType())); +} + +llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, + const ParmVarDecl &D) { + assert(D.hasAttrs() && "Entry parameter missing annotation attribute!"); + if (D.hasAttr()) { + llvm::Function *DxGroupIndex = + CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group); + return B.CreateCall(FunctionCallee(DxGroupIndex)); + } + assert(false && "Unhandled parameter attribute"); + return nullptr; +} + +void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, + llvm::Function *Fn) { + llvm::Module &M = CGM.getModule(); + llvm::LLVMContext &Ctx = M.getContext(); + auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false); + Function *EntryFn = + Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M); + + // Copy function attributes over, we have no argument or return attributes + // that can be valid on the real entry. + AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex, + Fn->getAttributes().getFnAttrs()); + EntryFn->setAttributes(NewAttrs); + setHLSLEntryAttributes(FD, EntryFn); + + // Set the called function as internal linkage. + Fn->setLinkage(GlobalValue::InternalLinkage); + + BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn); + IRBuilder<> B(BB); + llvm::SmallVector Args; + // FIXME: support struct parameters where semantics are on members. + for (const auto *Param : FD->parameters()) { + Args.push_back(emitInputSemantic(B, *Param)); } + + CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args); + (void)CI; + // FIXME: Handle codegen for return type semantics. + B.CreateRetVoid(); } diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp --- a/clang/lib/CodeGen/CodeGenFunction.cpp +++ b/clang/lib/CodeGen/CodeGenFunction.cpp @@ -16,6 +16,7 @@ #include "CGCXXABI.h" #include "CGCleanup.h" #include "CGDebugInfo.h" +#include "CGHLSLRuntime.h" #include "CGOpenMPRuntime.h" #include "CodeGenModule.h" #include "CodeGenPGO.h" @@ -1137,6 +1138,10 @@ if (getLangOpts().OpenMP && CurCodeDecl) CGM.getOpenMPRuntime().emitFunctionProlog(*this, CurCodeDecl); + // Handle emitting HLSL entry functions. + if (D && D->hasAttr()) + CGM.getHLSLRuntime().emitEntryFunction(FD, Fn); + EmitFunctionProlog(*CurFnInfo, CurFn, Args); if (isa_and_nonnull(D) && diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp --- a/clang/lib/CodeGen/CodeGenModule.cpp +++ b/clang/lib/CodeGen/CodeGenModule.cpp @@ -1678,10 +1678,6 @@ /*AttrOnCallSite=*/false, IsThunk); F->setAttributes(PAL); F->setCallingConv(static_cast(CallingConv)); - if (getLangOpts().HLSL) { - if (const FunctionDecl *FD = dyn_cast_or_null(GD.getDecl())) - getHLSLRuntime().setHLSLFunctionAttributes(F, FD); - } } static void removeImageAccessQualifier(std::string& TyName) { diff --git a/clang/test/CodeGenHLSL/semantics/GroupIndex-codegen.hlsl b/clang/test/CodeGenHLSL/semantics/GroupIndex-codegen.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenHLSL/semantics/GroupIndex-codegen.hlsl @@ -0,0 +1,22 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -emit-llvm -disable-llvm-passes -o - -hlsl-entry main %s + +[numthreads(1,1,1)] +void main(unsigned GI : SV_GroupIndex) { + main(GI - 1); +} + +// For HLSL entry functions, we are generating a C-export function that wraps +// the C++-mangled entry function. The wrapper function can be used to populate +// semantic parameters and provides the expected void(void) signature that +// drivers expect for entry points. + +//CHECK: define void @main() #[[ENTRY_ATTR:#]]{ +//CHECK-NEXT: entry: +//CHECK-NEXT: %0 = call i32 @llvm.dx.flattened.thread.id.in.group() +//CHECK-NEXT: call void @"?main@@YAXI@Z"(i32 %0) +//CHECK-NEXT: ret void +//CHECK-NEXT: } + +// Verify that the entry had the expected dx.shader attribute + +//CHECK: attributes #[[ENTRY_ATTR]] = { {{.*}}"dx.shader"="compute"{{.*}} }