diff --git a/clang/lib/AST/Mangle.cpp b/clang/lib/AST/Mangle.cpp --- a/clang/lib/AST/Mangle.cpp +++ b/clang/lib/AST/Mangle.cpp @@ -133,10 +133,6 @@ if (isa(D)) return true; - // HLSL shader entry function never need to be mangled. - if (getASTContext().getLangOpts().HLSL && D->hasAttr()) - return false; - return shouldMangleCXXName(D); } diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -15,6 +15,8 @@ #ifndef LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H #define LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H +#include "llvm/IR/IRBuilder.h" + #include "clang/Basic/HLSLRuntime.h" namespace llvm { @@ -26,6 +28,7 @@ class CallExpr; class Type; class VarDecl; +class ParmVarDecl; class FunctionDecl; @@ -39,6 +42,8 @@ uint32_t ResourceCounters[static_cast( hlsl::ResourceClass::NumClasses)] = {0}; + llvm::Value *emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D); + public: CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {} virtual ~CGHLSLRuntime() {} @@ -48,6 +53,8 @@ void finishCodeGen(); void setHLSLFunctionAttributes(llvm::Function *, const FunctionDecl *); + + void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn); }; } // namespace CodeGen diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -14,7 +14,9 @@ #include "CGHLSLRuntime.h" #include "CodeGenModule.h" +#include "clang/AST/Decl.h" #include "clang/Basic/TargetOptions.h" +#include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" @@ -95,3 +97,36 @@ ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType())); } } + +llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, + const ParmVarDecl &D) { + assert(D.hasAttrs() && "Entry parameter missing annotation attribute!"); + if (D.hasAttr()) { + llvm::Function *DxGroupIndex = + CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group); + CallInst *CI = B.CreateCall(FunctionCallee(DxGroupIndex)); + return CI; + } + assert(false && "Unhandled parameter attribute"); + return nullptr; +} + +void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, + llvm::Function *Fn) { + llvm::Module &M = CGM.getModule(); + llvm::LLVMContext &Ctx = M.getContext(); + auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false); + Function *EntryFn = + Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M); + BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn); + IRBuilder<> B(BB); + llvm::SmallVector Args; + for (const auto Param : FD->parameters()) { + Args.push_back(emitInputSemantic(B, *Param)); + } + + CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args); + (void)CI; + // FIXME: Handle codegen for return type semantics + B.CreateRetVoid(); +} diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp --- a/clang/lib/CodeGen/CodeGenFunction.cpp +++ b/clang/lib/CodeGen/CodeGenFunction.cpp @@ -16,6 +16,7 @@ #include "CGCXXABI.h" #include "CGCleanup.h" #include "CGDebugInfo.h" +#include "CGHLSLRuntime.h" #include "CGOpenMPRuntime.h" #include "CodeGenModule.h" #include "CodeGenPGO.h" @@ -907,6 +908,9 @@ if (D && D->hasAttr()) Fn->addFnAttr(llvm::Attribute::NoProfile); + if (D && D->hasAttr()) + CGM.getHLSLRuntime().emitEntryFunction(FD, Fn); + if (D) { // Function attributes take precedence over command line flags. if (auto *A = D->getAttr()) { diff --git a/clang/test/CodeGenHLSL/semantics/GroupIndex-codegen.hlsl b/clang/test/CodeGenHLSL/semantics/GroupIndex-codegen.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenHLSL/semantics/GroupIndex-codegen.hlsl @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -emit-llvm -disable-llvm-passes -o - -hlsl-entry main %s + +[numthreads(1,1,1)] +void main(unsigned GI : SV_GroupIndex) { + main(GI - 1); +} + +// For HLSL entry functions, we are generating a C-export function that wraps +// the C++-mangled entry function. The wrapper function can be used to populate +// semantic parameters and provides the expected void(void) signature that +// drivers expect for entry points. + +//CHECK: define void @main() { +//CHECK-NEXT: entry: +//CHECK-NEXT: %0 = call i32 @llvm.dx.flattened.thread.id.in.group() +//CHECK-NEXT: call void @"?main@@YAXI@Z"(i32 %0) +//CHECK-NEXT: ret void +//CHECK-NEXT: }