diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp @@ -2072,56 +2072,45 @@ emitOutlinedFunctionCall(CGF, Loc, OutlinedFn, OutlinedFnArgs); } -void CGOpenMPRuntimeGPU::emitParallelCall( - CodeGenFunction &CGF, SourceLocation Loc, llvm::Function *OutlinedFn, - ArrayRef CapturedVars, const Expr *IfCond) { +void CGOpenMPRuntimeGPU::emitParallelCall(CodeGenFunction &CGF, + SourceLocation Loc, + llvm::Function *OutlinedFn, + ArrayRef CapturedVars, + const Expr *IfCond) { if (!CGF.HaveInsertPoint()) return; - if (getExecutionMode() == CGOpenMPRuntimeGPU::EM_SPMD) - emitSPMDParallelCall(CGF, Loc, OutlinedFn, CapturedVars, IfCond); - else - emitNonSPMDParallelCall(CGF, Loc, OutlinedFn, CapturedVars, IfCond); -} - -void CGOpenMPRuntimeGPU::emitNonSPMDParallelCall( - CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn, - ArrayRef CapturedVars, const Expr *IfCond) { - llvm::Function *Fn = cast(OutlinedFn); - - // Force inline this outlined function at its call site. - Fn->setLinkage(llvm::GlobalValue::InternalLinkage); + auto &&CodeGen = [this, OutlinedFn, CapturedVars, + Loc](CodeGenFunction &CGF, PrePostActionTy &Action) { + Action.Enter(CGF); + llvm::Function *Fn = cast(OutlinedFn); - // Ensure we do not inline the function. This is trivially true for the ones - // passed to __kmpc_fork_call but the ones calles in serialized regions - // could be inlined. This is not a perfect but it is closer to the invariant - // we want, namely, every data environment starts with a new function. - // TODO: We should pass the if condition to the runtime function and do the - // handling there. Much cleaner code. - cast(OutlinedFn)->addFnAttr(llvm::Attribute::NoInline); + // Force inline this outlined function at its call site. + Fn->setLinkage(llvm::GlobalValue::InternalLinkage); - Address ZeroAddr = CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty, - /*Name=*/".zero.addr"); - CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0)); - // ThreadId for serialized parallels is 0. - Address ThreadIDAddr = ZeroAddr; - auto &&CodeGen = [this, Fn, CapturedVars, Loc, &ThreadIDAddr]( - CodeGenFunction &CGF, PrePostActionTy &Action) { - Action.Enter(CGF); + // Ensure we do not inline the function. This is trivially true for the ones + // passed to __kmpc_fork_call but the ones calles in serialized regions + // could be inlined. This is not a perfect but it is closer to the invariant + // we want, namely, every data environment starts with a new function. + // TODO: We should pass the if condition to the runtime function and do the + // handling there. Much cleaner code. + cast(OutlinedFn)->addFnAttr(llvm::Attribute::NoInline); - Address ZeroAddr = - CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty, - /*Name=*/".bound.zero.addr"); + Address ZeroAddr = CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty, + /*Name=*/".zero.addr"); CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0)); + // ThreadId for serialized parallels is 0. + Address ThreadIDAddr = ZeroAddr; + llvm::SmallVector OutlinedFnArgs; OutlinedFnArgs.push_back(ThreadIDAddr.getPointer()); OutlinedFnArgs.push_back(ZeroAddr.getPointer()); OutlinedFnArgs.append(CapturedVars.begin(), CapturedVars.end()); emitOutlinedFunctionCall(CGF, Loc, Fn, OutlinedFnArgs); }; + auto &&SeqGen = [this, &CodeGen, Loc](CodeGenFunction &CGF, PrePostActionTy &) { - RegionCodeGenTy RCG(CodeGen); llvm::Value *RTLoc = emitUpdateLocation(CGF, Loc); llvm::Value *ThreadID = getThreadID(CGF, Loc); @@ -2138,47 +2127,33 @@ RCG(CGF); }; - auto &&L0ParallelGen = [this, CapturedVars, Fn](CodeGenFunction &CGF, - PrePostActionTy &Action) { + auto &&ParallelGen = [this, Loc, OutlinedFn, CapturedVars, + IfCond](CodeGenFunction &CGF, PrePostActionTy &Action) { CGBuilderTy &Bld = CGF.Builder; - llvm::Function *WFn = WrapperFunctionsMap[Fn]; - assert(WFn && "Wrapper function does not exist!"); - llvm::Value *ID = Bld.CreateBitOrPointerCast(WFn, CGM.Int8PtrTy); - - // Prepare for parallel region. Indicate the outlined function. - llvm::Value *Args[] = {ID}; - CGF.EmitRuntimeCall( - OMPBuilder.getOrCreateRuntimeFunction( - CGM.getModule(), OMPRTL___kmpc_kernel_prepare_parallel), - Args); + llvm::Function *WFn = WrapperFunctionsMap[OutlinedFn]; + llvm::Value *ID = llvm::ConstantPointerNull::get(CGM.Int8PtrTy); + if (WFn) { + ID = Bld.CreateBitOrPointerCast(WFn, CGM.Int8PtrTy); + // Remember for post-processing in worker loop. + Work.emplace_back(WFn); + } + llvm::Value *FnPtr = Bld.CreateBitOrPointerCast(OutlinedFn, CGM.Int8PtrTy); // Create a private scope that will globalize the arguments // passed from the outside of the target region. + // TODO: Is that needed? CodeGenFunction::OMPPrivateScope PrivateArgScope(CGF); + Address CapturedVarsAddrs = CGF.CreateDefaultAlignTempAlloca( + llvm::ArrayType::get(CGM.VoidPtrTy, CapturedVars.size()), + "captured_vars_addrs"); // There's something to share. if (!CapturedVars.empty()) { // Prepare for parallel region. Indicate the outlined function. - Address SharedArgs = - CGF.CreateDefaultAlignTempAlloca(CGF.VoidPtrPtrTy, "shared_arg_refs"); - llvm::Value *SharedArgsPtr = SharedArgs.getPointer(); - - llvm::Value *DataSharingArgs[] = { - SharedArgsPtr, - llvm::ConstantInt::get(CGM.SizeTy, CapturedVars.size())}; - CGF.EmitRuntimeCall( - OMPBuilder.getOrCreateRuntimeFunction( - CGM.getModule(), OMPRTL___kmpc_begin_sharing_variables), - DataSharingArgs); - - // Store variable address in a list of references to pass to workers. - unsigned Idx = 0; ASTContext &Ctx = CGF.getContext(); - Address SharedArgListAddress = CGF.EmitLoadOfPointer( - SharedArgs, Ctx.getPointerType(Ctx.getPointerType(Ctx.VoidPtrTy)) - .castAs()); + unsigned Idx = 0; for (llvm::Value *V : CapturedVars) { - Address Dst = Bld.CreateConstInBoundsGEP(SharedArgListAddress, Idx); + Address Dst = Bld.CreateConstArrayGEP(CapturedVarsAddrs, Idx); llvm::Value *PtrV; if (V->getType()->isIntegerTy()) PtrV = Bld.CreateIntToPtr(V, CGF.VoidPtrTy); @@ -2190,139 +2165,36 @@ } } - // Activate workers. This barrier is used by the master to signal - // work for the workers. - syncCTAThreads(CGF); - - // OpenMP [2.5, Parallel Construct, p.49] - // There is an implied barrier at the end of a parallel region. After the - // end of a parallel region, only the master thread of the team resumes - // execution of the enclosing task region. - // - // The master waits at this barrier until all workers are done. - syncCTAThreads(CGF); - - if (!CapturedVars.empty()) - CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction( - CGM.getModule(), OMPRTL___kmpc_end_sharing_variables)); - - // Remember for post-processing in worker loop. - Work.emplace_back(WFn); - }; - - auto &&LNParallelGen = [this, Loc, &SeqGen, &L0ParallelGen]( - CodeGenFunction &CGF, PrePostActionTy &Action) { - if (IsInParallelRegion) { - SeqGen(CGF, Action); - } else if (IsInTargetMasterThreadRegion) { - L0ParallelGen(CGF, Action); - } else { - // Check for master and then parallelism: - // if (__kmpc_is_spmd_exec_mode() || __kmpc_parallel_level(loc, gtid)) { - // Serialized execution. - // } else { - // Worker call. - // } - CGBuilderTy &Bld = CGF.Builder; - llvm::BasicBlock *ExitBB = CGF.createBasicBlock(".exit"); - llvm::BasicBlock *SeqBB = CGF.createBasicBlock(".sequential"); - llvm::BasicBlock *ParallelCheckBB = CGF.createBasicBlock(".parcheck"); - llvm::BasicBlock *MasterBB = CGF.createBasicBlock(".master"); - llvm::Value *IsSPMD = Bld.CreateIsNotNull( - CGF.EmitNounwindRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction( - CGM.getModule(), OMPRTL___kmpc_is_spmd_exec_mode))); - Bld.CreateCondBr(IsSPMD, SeqBB, ParallelCheckBB); - // There is no need to emit line number for unconditional branch. - (void)ApplyDebugLocation::CreateEmpty(CGF); - CGF.EmitBlock(ParallelCheckBB); - llvm::Value *RTLoc = emitUpdateLocation(CGF, Loc); - llvm::Value *ThreadID = getThreadID(CGF, Loc); - llvm::Value *PL = CGF.EmitRuntimeCall( - OMPBuilder.getOrCreateRuntimeFunction(CGM.getModule(), - OMPRTL___kmpc_parallel_level), - {RTLoc, ThreadID}); - llvm::Value *Res = Bld.CreateIsNotNull(PL); - Bld.CreateCondBr(Res, SeqBB, MasterBB); - CGF.EmitBlock(SeqBB); - SeqGen(CGF, Action); - CGF.EmitBranch(ExitBB); - // There is no need to emit line number for unconditional branch. - (void)ApplyDebugLocation::CreateEmpty(CGF); - CGF.EmitBlock(MasterBB); - L0ParallelGen(CGF, Action); - CGF.EmitBranch(ExitBB); - // There is no need to emit line number for unconditional branch. - (void)ApplyDebugLocation::CreateEmpty(CGF); - // Emit the continuation block for code after the if. - CGF.EmitBlock(ExitBB, /*IsFinished=*/true); - } - }; - - if (IfCond) { - emitIfClause(CGF, IfCond, LNParallelGen, SeqGen); - } else { - CodeGenFunction::RunCleanupsScope Scope(CGF); - RegionCodeGenTy ThenRCG(LNParallelGen); - ThenRCG(CGF); - } -} - -void CGOpenMPRuntimeGPU::emitSPMDParallelCall( - CodeGenFunction &CGF, SourceLocation Loc, llvm::Function *OutlinedFn, - ArrayRef CapturedVars, const Expr *IfCond) { - // Just call the outlined function to execute the parallel region. - // OutlinedFn(>id, &zero, CapturedStruct); - // - llvm::SmallVector OutlinedFnArgs; - - Address ZeroAddr = CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty, - /*Name=*/".zero.addr"); - CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0)); - // ThreadId for serialized parallels is 0. - Address ThreadIDAddr = ZeroAddr; - auto &&CodeGen = [this, OutlinedFn, CapturedVars, Loc, &ThreadIDAddr]( - CodeGenFunction &CGF, PrePostActionTy &Action) { - Action.Enter(CGF); - - Address ZeroAddr = - CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty, - /*Name=*/".bound.zero.addr"); - CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0)); - llvm::SmallVector OutlinedFnArgs; - OutlinedFnArgs.push_back(ThreadIDAddr.getPointer()); - OutlinedFnArgs.push_back(ZeroAddr.getPointer()); - OutlinedFnArgs.append(CapturedVars.begin(), CapturedVars.end()); - emitOutlinedFunctionCall(CGF, Loc, OutlinedFn, OutlinedFnArgs); - }; - auto &&SeqGen = [this, &CodeGen, Loc](CodeGenFunction &CGF, - PrePostActionTy &) { + llvm::Value *IfCondVal = nullptr; + if (IfCond) + IfCondVal = Bld.CreateIntCast(CGF.EvaluateExprAsBool(IfCond), CGF.Int32Ty, + /* isSigned */ false); + else + IfCondVal = llvm::ConstantInt::get(CGF.Int32Ty, 1); - RegionCodeGenTy RCG(CodeGen); + assert(IfCondVal && "Expected a value"); llvm::Value *RTLoc = emitUpdateLocation(CGF, Loc); - llvm::Value *ThreadID = getThreadID(CGF, Loc); - llvm::Value *Args[] = {RTLoc, ThreadID}; - - NVPTXActionTy Action( - OMPBuilder.getOrCreateRuntimeFunction( - CGM.getModule(), OMPRTL___kmpc_serialized_parallel), - Args, - OMPBuilder.getOrCreateRuntimeFunction( - CGM.getModule(), OMPRTL___kmpc_end_serialized_parallel), - Args); - RCG.setAction(Action); - RCG(CGF); + llvm::Value *Args[] = { + RTLoc, + getThreadID(CGF, Loc), + IfCondVal, + llvm::ConstantInt::get(CGF.Int32Ty, -1), + llvm::ConstantInt::get(CGF.Int32Ty, -1), + FnPtr, + ID, + Bld.CreateBitOrPointerCast(CapturedVarsAddrs.getPointer(), + CGF.VoidPtrPtrTy), + llvm::ConstantInt::get(CGM.SizeTy, CapturedVars.size())}; + CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction( + CGM.getModule(), OMPRTL___kmpc_parallel_51), + Args); }; - if (IsInTargetMasterThreadRegion) { - // In the worker need to use the real thread id. - ThreadIDAddr = emitThreadIDAddress(CGF, Loc); - RegionCodeGenTy RCG(CodeGen); + if (IsInParallelRegion) { + RegionCodeGenTy RCG(SeqGen); RCG(CGF); } else { - // If we are not in the target region, it is definitely L2 parallelism or - // more, because for SPMD mode we always has L1 parallel level, sowe don't - // need to check for orphaned directives. - RegionCodeGenTy RCG(SeqGen); + RegionCodeGenTy RCG(ParallelGen); RCG(CGF); } } diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -412,6 +412,8 @@ __OMP_RTL(__kmpc_spmd_kernel_init, false, Void, Int32, Int16) __OMP_RTL(__kmpc_spmd_kernel_deinit_v2, false, Void, Int16) __OMP_RTL(__kmpc_kernel_prepare_parallel, false, Void, VoidPtr) +__OMP_RTL(__kmpc_parallel_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32, + VoidPtr, VoidPtr, VoidPtrPtr, SizeTy) __OMP_RTL(__kmpc_kernel_parallel, false, Int1, VoidPtrPtr) __OMP_RTL(__kmpc_kernel_end_parallel, false, Void, ) __OMP_RTL(__kmpc_serialized_parallel, false, Void, IdentPtr, Int32) diff --git a/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu b/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu --- a/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu +++ b/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu @@ -87,11 +87,6 @@ int threadId = GetThreadIdInBlock(); if (threadId == 0) { usedSlotIdx = __kmpc_impl_smid() % MAX_SM; - parallelLevel[0] = - 1 + (GetNumberOfThreadsInBlock() > 1 ? OMP_ACTIVE_PARALLEL_LEVEL : 0); - } else if (GetLaneId() == 0) { - parallelLevel[GetWarpId()] = - 1 + (GetNumberOfThreadsInBlock() > 1 ? OMP_ACTIVE_PARALLEL_LEVEL : 0); } if (!RequiresOMPRuntime) { // Runtime is not required - exit. diff --git a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu --- a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu +++ b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu @@ -154,16 +154,6 @@ (int)newTaskDescr->ThreadId(), (int)nThreads); isActive = true; - // Reconverge the threads at the end of the parallel region to correctly - // handle parallel levels. - // In Cuda9+ in non-SPMD mode we have either 1 worker thread or the whole - // warp. If only 1 thread is active, not need to reconverge the threads. - // If we have the whole warp, reconverge all the threads in the warp before - // actually trying to change the parallel level. Otherwise, parallel level - // can be changed incorrectly because of threads divergence. - bool IsActiveParallelRegion = threadsInTeam != 1; - IncParallelLevel(IsActiveParallelRegion, - IsActiveParallelRegion ? __kmpc_impl_all_lanes : 1u); } return isActive; @@ -180,17 +170,6 @@ omptarget_nvptx_TaskDescr *currTaskDescr = getMyTopTaskDescriptor(threadId); omptarget_nvptx_threadPrivateContext->SetTopLevelTaskDescr( threadId, currTaskDescr->GetPrevTaskDescr()); - - // Reconverge the threads at the end of the parallel region to correctly - // handle parallel levels. - // In Cuda9+ in non-SPMD mode we have either 1 worker thread or the whole - // warp. If only 1 thread is active, not need to reconverge the threads. - // If we have the whole warp, reconverge all the threads in the warp before - // actually trying to change the parallel level. Otherwise, parallel level can - // be changed incorrectly because of threads divergence. - bool IsActiveParallelRegion = threadsInTeam != 1; - DecParallelLevel(IsActiveParallelRegion, - IsActiveParallelRegion ? __kmpc_impl_all_lanes : 1u); } //////////////////////////////////////////////////////////////////////////////// @@ -302,4 +281,91 @@ PRINT(LD_IO, "call kmpc_push_proc_bind %d\n", (int)proc_bind); } +//////////////////////////////////////////////////////////////////////////////// +// parallel interface +//////////////////////////////////////////////////////////////////////////////// + +EXTERN void __kmpc_parallel_51(kmp_Ident *ident, kmp_int32 global_tid, + kmp_int32 if_expr, kmp_int32 num_threads, + int proc_bind, void *fn, void *wrapper_fn, + void **args, size_t nargs) { + + // Handle the serialized case first, same for SPMD/non-SPMD. + // TODO: Add UNLIKELY to optimize? + if (!if_expr) { + __kmpc_serialized_parallel(ident, global_tid); + __kmp_invoke_microtask(global_tid, 0, fn, args, nargs); + __kmpc_end_serialized_parallel(ident, global_tid); + + return; + } + + if (__kmpc_is_spmd_exec_mode()) { + // Increment parallel level for SPMD warps. + if (GetThreadIdInBlock() == 0) + parallelLevel[0] = + 1 + (GetNumberOfThreadsInBlock() > 1 ? OMP_ACTIVE_PARALLEL_LEVEL : 0); + else if (GetLaneId() == 0) + parallelLevel[GetWarpId()] = + 1 + (GetNumberOfThreadsInBlock() > 1 ? OMP_ACTIVE_PARALLEL_LEVEL : 0); + // TODO: Is that synchronization correct/needed? Can only using a memory + // fence ensure consistency? + __kmpc_impl_syncthreads(); + + __kmp_invoke_microtask(global_tid, 0, fn, args, nargs); + + // TODO: is decrementing parallel level needed? parallelLevel will reset to + // the next SPMD/non-SPMD parallel region execution, existing implementation + // does not decrement? + // parallelLevel[GetWarpId()] = 0; + return; + } + + // Handle the num_threads clause. + if (num_threads != -1) + __kmpc_push_num_threads(ident, global_tid, num_threads); + + __kmpc_kernel_prepare_parallel((void *)wrapper_fn); + + if (nargs) { + void **GlobalArgs; + __kmpc_begin_sharing_variables(&GlobalArgs, nargs); + // TODO: faster memcpy? + for (int I = 0; I < nargs; I++) + GlobalArgs[I] = args[I]; + } + + // TODO: what if that's a parallel region with a single thread? this is considered + // not active in the existing implementation. + bool IsActiveParallelRegion = threadsInTeam != 1; + // Increment parallel level for non-SPMD warps. + for (int I = 0; I < threadsInTeam / WARPSIZE; ++I) + parallelLevel[I] += + (1 + (IsActiveParallelRegion ? OMP_ACTIVE_PARALLEL_LEVEL : 0)); + + // Master signals work to activate workers. + __kmpc_barrier_simple_spmd(nullptr, 0); + + // OpenMP [2.5, Parallel Construct, p.49] + // There is an implied barrier at the end of a parallel region. After the + // end of a parallel region, only the master thread of the team resumes + // execution of the enclosing task region. + // + // The master waits at this barrier until all workers are done. + __kmpc_barrier_simple_spmd(nullptr, 0); + + // Decrement parallel level for non-SPMD warps. + for (int I = 0; I < threadsInTeam / WARPSIZE; ++I) + parallelLevel[I] -= + (1 + (IsActiveParallelRegion ? OMP_ACTIVE_PARALLEL_LEVEL : 0)); + // TODO: Is synchronization needed since out of parallel execution? + + if (nargs) + __kmpc_end_sharing_variables(); + + // TODO: proc_bind is a noop? + // if (proc_bind != proc_bind_default) + // __kmpc_push_proc_bind(ident, global_tid, proc_bind); +} + #pragma omp end declare target diff --git a/openmp/libomptarget/deviceRTLs/common/src/support.cu b/openmp/libomptarget/deviceRTLs/common/src/support.cu --- a/openmp/libomptarget/deviceRTLs/common/src/support.cu +++ b/openmp/libomptarget/deviceRTLs/common/src/support.cu @@ -265,4 +265,110 @@ return static_cast(ReductionScratchpadPtr) + 256; } +// Invoke an outlined parallel function unwrapping arguments (up +// to 16). +DEVICE void __kmp_invoke_microtask(kmp_int32 global_tid, kmp_int32 bound_tid, + void *fn, void **args, size_t nargs) { + switch (nargs) { + case 0: + ((void (*)(kmp_int32 *, kmp_int32 *))fn)(&global_tid, &bound_tid); + break; + case 1: + ((void (*)(kmp_int32 *, kmp_int32 *, void *))fn)(&global_tid, &bound_tid, + args[0]); + break; + case 2: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *))fn)( + &global_tid, &bound_tid, args[0], args[1]); + break; + case 3: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *))fn)( + &global_tid, &bound_tid, args[0], args[1], args[2]); + break; + case 4: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *))fn)( + &global_tid, &bound_tid, args[0], args[1], args[2], args[3]); + break; + case 5: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, + void *))fn)(&global_tid, &bound_tid, args[0], args[1], args[2], + args[3], args[4]); + break; + case 6: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *, + void *))fn)(&global_tid, &bound_tid, args[0], args[1], args[2], + args[3], args[4], args[5]); + break; + case 7: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *, + void *, void *))fn)(&global_tid, &bound_tid, args[0], args[1], + args[2], args[3], args[4], args[5], args[6]); + break; + case 8: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *, + void *, void *, void *))fn)(&global_tid, &bound_tid, args[0], + args[1], args[2], args[3], args[4], + args[5], args[6], args[7]); + break; + case 9: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *, + void *, void *, void *, void *))fn)( + &global_tid, &bound_tid, args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8]); + break; + case 10: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *, + void *, void *, void *, void *, void *))fn)( + &global_tid, &bound_tid, args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8], args[9]); + break; + case 11: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *, + void *, void *, void *, void *, void *, void *))fn)( + &global_tid, &bound_tid, args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8], args[9], args[10]); + break; + case 12: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *, + void *, void *, void *, void *, void *, void *, void *))fn)( + &global_tid, &bound_tid, args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8], args[9], args[10], args[11]); + break; + case 13: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *, + void *, void *, void *, void *, void *, void *, void *, + void *))fn)(&global_tid, &bound_tid, args[0], args[1], args[2], + args[3], args[4], args[5], args[6], args[7], args[8], + args[9], args[10], args[11], args[12]); + break; + case 14: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *, + void *, void *, void *, void *, void *, void *, void *, void *, + void *))fn)(&global_tid, &bound_tid, args[0], args[1], args[2], + args[3], args[4], args[5], args[6], args[7], args[8], + args[9], args[10], args[11], args[12], args[13]); + break; + case 15: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *, + void *, void *, void *, void *, void *, void *, void *, void *, + void *, void *))fn)(&global_tid, &bound_tid, args[0], args[1], + args[2], args[3], args[4], args[5], args[6], + args[7], args[8], args[9], args[10], + args[11], args[12], args[13], args[14]); + break; + case 16: + ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *, + void *, void *, void *, void *, void *, void *, void *, void *, + void *, void *, void *))fn)( + &global_tid, &bound_tid, args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8], args[9], args[10], args[11], + args[12], args[13], args[14], args[15]); + break; + default: + // TODO: assert + printf("Too many arguments in kmp_invoke_microtask, aborting execution.\n"); + return; + } +} + #pragma omp end declare target diff --git a/openmp/libomptarget/deviceRTLs/common/support.h b/openmp/libomptarget/deviceRTLs/common/support.h --- a/openmp/libomptarget/deviceRTLs/common/support.h +++ b/openmp/libomptarget/deviceRTLs/common/support.h @@ -95,4 +95,9 @@ DEVICE unsigned int *GetTeamsReductionTimestamp(); DEVICE char *GetTeamsReductionScratchpad(); +// Invoke an outlined parallel function unwrapping global, shared arguments (up +// to 16). +DEVICE void __kmp_invoke_microtask(kmp_int32 global_tid, kmp_int32 bound_tid, + void *fn, void **args, size_t nargs); + #endif diff --git a/openmp/libomptarget/deviceRTLs/interface.h b/openmp/libomptarget/deviceRTLs/interface.h --- a/openmp/libomptarget/deviceRTLs/interface.h +++ b/openmp/libomptarget/deviceRTLs/interface.h @@ -177,6 +177,7 @@ * The struct is identical to the one in the kmp.h file. * We maintain the same data structure for compatibility. */ +typedef short kmp_int16; typedef int kmp_int32; typedef struct ident { kmp_int32 reserved_1; /**< might be used in Fortran; see above */ @@ -437,6 +438,22 @@ EXTERN void __kmpc_end_sharing_variables(); EXTERN void __kmpc_get_shared_variables(void ***GlobalArgs); +/// Entry point to start a new parallel region. +/// +/// \param ident The source identifier. +/// \param global_tid The global thread ID. +/// \param if_expr The if(expr), or 1 if none given. +/// \param num_threads The num_threads(expr), or -1 if none given. +/// \param proc_bind The proc_bind, or `proc_bind_default` if none given. +/// \param fn The outlined parallel region function. +/// \param wrapper_fn The worker wrapper function of fn. +/// \param args The pointer array of arguments to fn. +/// \param nargs The number of arguments to fn. +EXTERN void __kmpc_parallel_51(ident_t *ident, kmp_int32 global_tid, + kmp_int32 if_expr, kmp_int32 num_threads, + int proc_bind, void *fn, void *wrapper_fn, + void **args, size_t nargs); + // SPMD execution mode interrogation function. EXTERN int8_t __kmpc_is_spmd_exec_mode();