diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeTarget.h b/clang/lib/CodeGen/CGOpenMPRuntimeTarget.h --- a/clang/lib/CodeGen/CGOpenMPRuntimeTarget.h +++ b/clang/lib/CodeGen/CGOpenMPRuntimeTarget.h @@ -122,6 +122,35 @@ /// Call to void __kmpc_barrier_simple_spmd(ident_t *loc, kmp_int32 /// global_tid); OMPRTL__kmpc_barrier_simple_spmd, + + /// Target Region (TREgion) Kernel interface + /// + ///{ + + /// char __kmpc_target_region_kernel_init(ident_t *Ident, + /// bool UseSPMDMode, + /// bool UseStateMachine, + /// bool RequiresOMPRuntime, + /// bool RequiresDataSharing); + OMPRTL__kmpc_target_region_kernel_init, + + /// void __kmpc_target_region_kernel_deinit(ident_t *Ident, + /// bool UseSPMDMode, + /// bool RequiredOMPRuntime); + OMPRTL__kmpc_target_region_kernel_deinit, + + /// void __kmpc_target_region_kernel_parallel(ident_t *Ident, + /// bool UseSPMDMode, + /// bool RequiredOMPRuntime, + /// ParallelWorkFnTy WorkFn, + /// void *SharedVars, + /// uint16_t SharedVarsBytes, + /// void *PrivateVars, + /// uint16_t PrivateVarsBytes, + /// bool SharedPointers); + OMPRTL__kmpc_target_region_kernel_parallel, + + ///} }; /// Returns the OpenMP runtime function identified by \p ID. diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeTarget.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeTarget.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntimeTarget.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeTarget.cpp @@ -52,6 +52,7 @@ llvm::FunctionCallee CGOpenMPRuntimeTarget::createTargetRuntimeFunction( OpenMPRTLTargetFunctions ID) { llvm::FunctionCallee RTLFn = nullptr; + auto *I1Ty = llvm::IntegerType::getInt1Ty(CGM.getLLVMContext()); switch (ID) { case OMPRTL_NVPTX__kmpc_kernel_init: { // Build void __kmpc_kernel_init(kmp_int32 thread_limit, int16_t @@ -343,7 +344,96 @@ ->addFnAttr(llvm::Attribute::Convergent); break; } + case OMPRTL__kmpc_target_region_kernel_init: { + // char __kmpc_target_region_kernel_init(ident_t *Ident, + // bool UseSPMDMode, + // bool UseStateMachine, + // bool RequiresOMPRuntime, + // bool RequiresDataSharing); + llvm::Type *TypeParams[] = {getIdentTyPointerTy(), I1Ty, I1Ty, I1Ty, I1Ty}; + auto *FnTy = + llvm::FunctionType::get(CGM.Int8Ty, TypeParams, /* isVarArg */ false); + RTLFn = + CGM.CreateRuntimeFunction(FnTy, "__kmpc_target_region_kernel_init"); + + llvm::Function *RTFn = cast(RTLFn.getCallee()); + RTFn->addParamAttr(0, llvm::Attribute::NoCapture); + break; + } + case OMPRTL__kmpc_target_region_kernel_deinit: { + // void __kmpc_target_region_kernel_deinit(ident_t *Ident, + // bool UseSPMDMode, + // bool RequiredOMPRuntime); + llvm::Type *TypeParams[] = {getIdentTyPointerTy(), I1Ty, I1Ty}; + auto *FnTy = + llvm::FunctionType::get(CGM.VoidTy, TypeParams, /* isVarArg */ false); + RTLFn = + CGM.CreateRuntimeFunction(FnTy, "__kmpc_target_region_kernel_deinit"); + + llvm::Function *RTFn = cast(RTLFn.getCallee()); + RTFn->addParamAttr(0, llvm::Attribute::NoCapture); + break; + } + case OMPRTL__kmpc_target_region_kernel_parallel: { + // typedef void (*ParallelWorkFnTy)(void *, void *); + auto *ParWorkFnTy = + llvm::FunctionType::get(CGM.VoidTy, {CGM.VoidPtrTy, CGM.VoidPtrTy}, + /* isVarArg */ false); + + // void __kmpc_target_region_kernel_parallel(ident_t *Ident, + // bool UseSPMDMode, + // bool RequiredOMPRuntime, + // ParallelWorkFnTy WorkFn, + // void *SharedVars, + // uint16_t SharedVarsBytes, + // void *PrivateVars, + // uint16_t PrivateVarsBytes, + // bool SharedPointers); + llvm::Type *TypeParams[] = {getIdentTyPointerTy(), + I1Ty, + I1Ty, + ParWorkFnTy->getPointerTo(), + CGM.VoidPtrTy, + CGM.Int16Ty, + CGM.VoidPtrTy, + CGM.Int16Ty, + I1Ty}; + auto *FnTy = + llvm::FunctionType::get(CGM.VoidTy, TypeParams, /* isVarArg */ false); + + RTLFn = + CGM.CreateRuntimeFunction(FnTy, "__kmpc_target_region_kernel_parallel"); + + llvm::Function *RTFn = cast(RTLFn.getCallee()); + RTFn->addParamAttr(0, llvm::Attribute::NoCapture); + RTFn->addParamAttr(3, llvm::Attribute::NoCapture); + RTFn->addParamAttr(4, llvm::Attribute::NoCapture); + RTFn->addParamAttr(6, llvm::Attribute::NoCapture); + RTFn->addParamAttr(6, llvm::Attribute::ReadOnly); + + // Add the callback metadata if it is not present already. + if (!RTFn->hasMetadata(llvm::LLVMContext::MD_callback)) { + llvm::LLVMContext &Ctx = RTFn->getContext(); + llvm::MDBuilder MDB(Ctx); + // Annotate the callback behavior of __kmpc_target_region_kernel_parallel: + // - The callback callee is WorkFn, argument 3 starting with 0. + // - The first callback payload is SharedVars. + // - The second callback payload is PrivateVars. + RTFn->addMetadata( + llvm::LLVMContext::MD_callback, + *llvm::MDNode::get( + Ctx, {MDB.createCallbackEncoding(3, {4, 6}, + /* VarArgsArePassed */ false)})); + } + break; } + } + + // TODO: Remove all globals and set this attribute. + // + // This is overwritten when the definition is linked in. + // RTFn->addFnAttr(llvm::Attribute::InaccessibleMemOrArgMemOnly); + return RTLFn; }