diff --git a/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu b/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu --- a/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu +++ b/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu @@ -88,6 +88,11 @@ int threadId = GetThreadIdInBlock(); if (threadId == 0) { usedSlotIdx = __kmpc_impl_smid() % MAX_SM; + parallelLevel[0] = + 1 + (GetNumberOfThreadsInBlock() > 1 ? OMP_ACTIVE_PARALLEL_LEVEL : 0); + } else if (GetLaneId() == 0) { + parallelLevel[GetWarpId()] = + 1 + (GetNumberOfThreadsInBlock() > 1 ? OMP_ACTIVE_PARALLEL_LEVEL : 0); } if (!RequiresOMPRuntime) { // Runtime is not required - exit. diff --git a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu --- a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu +++ b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu @@ -289,31 +289,20 @@ int proc_bind, void *fn, void *wrapper_fn, void **args, size_t nargs) { - // Handle the serialized case first, same for SPMD/non-SPMD. - // TODO: Add UNLIKELY to optimize? - bool InParallelRegion = (__kmpc_parallel_level(ident, global_tid) > 0); + // Handle the serialized case first, same for SPMD/non-SPMD except that in + // SPMD mode we already incremented the parallel level counter, account for + // that. + bool InParallelRegion = + (__kmpc_parallel_level(ident, global_tid) > __kmpc_is_spmd_exec_mode()); if (!if_expr || InParallelRegion) { __kmpc_serialized_parallel(ident, global_tid); __kmp_invoke_microtask(global_tid, 0, fn, args, nargs); __kmpc_end_serialized_parallel(ident, global_tid); - return; } if (__kmpc_is_spmd_exec_mode()) { - // Increment parallel level for SPMD warps. - if (GetLaneId() == 0) - parallelLevel[GetWarpId()] = - 1 + (GetNumberOfThreadsInBlock() > 1 ? OMP_ACTIVE_PARALLEL_LEVEL : 0); - // TODO: Is that synchronization correct/needed? Can only using a memory - // fence ensure consistency? - __kmpc_impl_syncthreads(); - __kmp_invoke_microtask(global_tid, 0, fn, args, nargs); - - // Decrement (zero out) parallel level for SPMD warps. - if (GetLaneId() == 0) - parallelLevel[GetWarpId()] = 0; return; }