diff --git a/openmp/libomptarget/DeviceRTL/include/State.h b/openmp/libomptarget/DeviceRTL/include/State.h --- a/openmp/libomptarget/DeviceRTL/include/State.h +++ b/openmp/libomptarget/DeviceRTL/include/State.h @@ -78,6 +78,7 @@ ///} uint32_t ParallelTeamSize; + uint32_t HasThreadState; ParallelRegionFnTy ParallelRegionFnVar; }; @@ -125,6 +126,7 @@ VK_RunSchedChunk, VK_ParallelRegionFn, VK_ParallelTeamSize, + VK_HasThreadState, }; /// TODO @@ -143,56 +145,66 @@ void resetStateForThread(uint32_t TId); inline uint32_t &lookupForModify32Impl(uint32_t state::ICVStateTy::*Var, - IdentTy *Ident) { - if (OMP_LIKELY(!config::mayUseThreadStates() || - TeamState.ICVState.LevelVar == 0)) + IdentTy *Ident, bool ForceTeamState) { + if (OMP_LIKELY(ForceTeamState || !config::mayUseThreadStates() || + !TeamState.HasThreadState)) return TeamState.ICVState.*Var; uint32_t TId = mapping::getThreadIdInBlock(); if (OMP_UNLIKELY(!ThreadStates[TId])) { ThreadStates[TId] = reinterpret_cast(memory::allocGlobal( sizeof(ThreadStateTy), "ICV modification outside data environment")); ASSERT(ThreadStates[TId] != nullptr && "Nullptr returned by malloc!"); + TeamState.HasThreadState = true; ThreadStates[TId]->init(); } return ThreadStates[TId]->ICVState.*Var; } -inline uint32_t &lookupImpl(uint32_t state::ICVStateTy::*Var) { +inline uint32_t &lookupImpl(uint32_t state::ICVStateTy::*Var, + bool ForceTeamState) { auto TId = mapping::getThreadIdInBlock(); - if (OMP_UNLIKELY(config::mayUseThreadStates() && ThreadStates[TId])) + if (OMP_UNLIKELY(!ForceTeamState && config::mayUseThreadStates() && + TeamState.HasThreadState && ThreadStates[TId])) return ThreadStates[TId]->ICVState.*Var; return TeamState.ICVState.*Var; } __attribute__((always_inline, flatten)) inline uint32_t & -lookup32(ValueKind Kind, bool IsReadonly, IdentTy *Ident) { +lookup32(ValueKind Kind, bool IsReadonly, IdentTy *Ident, bool ForceTeamState) { switch (Kind) { case state::VK_NThreads: if (IsReadonly) - return lookupImpl(&ICVStateTy::NThreadsVar); - return lookupForModify32Impl(&ICVStateTy::NThreadsVar, Ident); + return lookupImpl(&ICVStateTy::NThreadsVar, ForceTeamState); + return lookupForModify32Impl(&ICVStateTy::NThreadsVar, Ident, + ForceTeamState); case state::VK_Level: if (IsReadonly) - return lookupImpl(&ICVStateTy::LevelVar); - return lookupForModify32Impl(&ICVStateTy::LevelVar, Ident); + return lookupImpl(&ICVStateTy::LevelVar, ForceTeamState); + return lookupForModify32Impl(&ICVStateTy::LevelVar, Ident, ForceTeamState); case state::VK_ActiveLevel: if (IsReadonly) - return lookupImpl(&ICVStateTy::ActiveLevelVar); - return lookupForModify32Impl(&ICVStateTy::ActiveLevelVar, Ident); + return lookupImpl(&ICVStateTy::ActiveLevelVar, ForceTeamState); + return lookupForModify32Impl(&ICVStateTy::ActiveLevelVar, Ident, + ForceTeamState); case state::VK_MaxActiveLevels: if (IsReadonly) - return lookupImpl(&ICVStateTy::MaxActiveLevelsVar); - return lookupForModify32Impl(&ICVStateTy::MaxActiveLevelsVar, Ident); + return lookupImpl(&ICVStateTy::MaxActiveLevelsVar, ForceTeamState); + return lookupForModify32Impl(&ICVStateTy::MaxActiveLevelsVar, Ident, + ForceTeamState); case state::VK_RunSched: if (IsReadonly) - return lookupImpl(&ICVStateTy::RunSchedVar); - return lookupForModify32Impl(&ICVStateTy::RunSchedVar, Ident); + return lookupImpl(&ICVStateTy::RunSchedVar, ForceTeamState); + return lookupForModify32Impl(&ICVStateTy::RunSchedVar, Ident, + ForceTeamState); case state::VK_RunSchedChunk: if (IsReadonly) - return lookupImpl(&ICVStateTy::RunSchedChunkVar); - return lookupForModify32Impl(&ICVStateTy::RunSchedChunkVar, Ident); + return lookupImpl(&ICVStateTy::RunSchedChunkVar, ForceTeamState); + return lookupForModify32Impl(&ICVStateTy::RunSchedChunkVar, Ident, + ForceTeamState); case state::VK_ParallelTeamSize: return TeamState.ParallelTeamSize; + case state::VK_HasThreadState: + return TeamState.HasThreadState; default: break; } @@ -200,7 +212,7 @@ } __attribute__((always_inline, flatten)) inline void *& -lookupPtr(ValueKind Kind, bool IsReadonly) { +lookupPtr(ValueKind Kind, bool IsReadonly, bool ForceTeamState) { switch (Kind) { case state::VK_ParallelRegionFn: return TeamState.ParallelRegionFnVar; @@ -214,7 +226,8 @@ /// update ICV values we can declare in global scope. template struct Value { __attribute__((flatten, always_inline)) operator Ty() { - return lookup(/* IsReadonly */ true, /* IdentTy */ nullptr); + return lookup(/* IsReadonly */ true, /* IdentTy */ nullptr, + /* ForceTeamState */ false); } __attribute__((flatten, always_inline)) Value &operator=(const Ty &Other) { @@ -232,21 +245,29 @@ return *this; } + __attribute__((flatten, always_inline)) void + assert_eq(const Ty &V, IdentTy *Ident = nullptr, + bool ForceTeamState = false) { + ASSERT(lookup(/* IsReadonly */ true, Ident, ForceTeamState) == V); + } + private: - __attribute__((flatten, always_inline)) Ty &lookup(bool IsReadonly, - IdentTy *Ident) { - Ty &t = lookup32(Kind, IsReadonly, Ident); + __attribute__((flatten, always_inline)) Ty & + lookup(bool IsReadonly, IdentTy *Ident, bool ForceTeamState) { + Ty &t = lookup32(Kind, IsReadonly, Ident, ForceTeamState); return t; } __attribute__((flatten, always_inline)) Ty &inc(int UpdateVal, IdentTy *Ident) { - return (lookup(/* IsReadonly */ false, Ident) += UpdateVal); + return (lookup(/* IsReadonly */ false, Ident, /* ForceTeamState */ false) += + UpdateVal); } __attribute__((flatten, always_inline)) Ty &set(Ty UpdateVal, IdentTy *Ident) { - return (lookup(/* IsReadonly */ false, Ident) = UpdateVal); + return (lookup(/* IsReadonly */ false, Ident, /* ForceTeamState */ false) = + UpdateVal); } template friend struct ValueRAII; @@ -257,7 +278,8 @@ /// we can declare in global scope. template struct PtrValue { __attribute__((flatten, always_inline)) operator Ty() { - return lookup(/* IsReadonly */ true, /* IdentTy */ nullptr); + return lookup(/* IsReadonly */ true, /* IdentTy */ nullptr, + /* ForceTeamState */ false); } __attribute__((flatten, always_inline)) PtrValue &operator=(const Ty Other) { @@ -266,18 +288,22 @@ } private: - Ty &lookup(bool IsReadonly, IdentTy *) { return lookupPtr(Kind, IsReadonly); } + Ty &lookup(bool IsReadonly, IdentTy *, bool ForceTeamState) { + return lookupPtr(Kind, IsReadonly, ForceTeamState); + } Ty &set(Ty UpdateVal) { - return (lookup(/* IsReadonly */ false, /* IdentTy */ nullptr) = UpdateVal); + return (lookup(/* IsReadonly */ false, /* IdentTy */ nullptr, + /* ForceTeamState */ false) = UpdateVal); } template friend struct ValueRAII; }; template struct ValueRAII { - ValueRAII(VTy &V, Ty NewValue, Ty OldValue, bool Active, IdentTy *Ident) - : Ptr(Active ? &V.lookup(/* IsReadonly */ false, Ident) + ValueRAII(VTy &V, Ty NewValue, Ty OldValue, bool Active, IdentTy *Ident, + bool ForceTeamState = false) + : Ptr(Active ? &V.lookup(/* IsReadonly */ false, Ident, ForceTeamState) : (Ty *)utils::UndefPtr), Val(OldValue), Active(Active) { if (!Active) @@ -303,6 +329,9 @@ /// TODO inline state::Value ParallelTeamSize; +/// TODO +inline state::Value HasThreadState; + /// TODO inline state::PtrValue ParallelRegionFn; diff --git a/openmp/libomptarget/DeviceRTL/src/Parallelism.cpp b/openmp/libomptarget/DeviceRTL/src/Parallelism.cpp --- a/openmp/libomptarget/DeviceRTL/src/Parallelism.cpp +++ b/openmp/libomptarget/DeviceRTL/src/Parallelism.cpp @@ -85,14 +85,21 @@ FunctionTracingRAII(); uint32_t TId = mapping::getThreadIdInBlock(); - // Handle the serialized case first, same for SPMD/non-SPMD. - if (OMP_UNLIKELY(!if_expr || icv::Level)) { + + // Handle the serialized case first, same for SPMD/non-SPMD: + // 1) if-clause(0) + // 2) nested parallel regions + // 3) parallel in task or other thread state inducing construct + if (OMP_UNLIKELY(!if_expr || icv::Level || state::HasThreadState)) { state::DateEnvironmentRAII DERAII(ident); ++icv::Level; invokeMicrotask(TId, 0, fn, args, nargs); return; } + // From this point forward we know that there is no thread state used. + ASSERT(state::HasThreadState == false); + uint32_t NumThreads = determineNumberOfThreads(num_threads); if (mapping::isSPMDMode()) { // Avoid the race between the read of the `icv::Level` above and the write @@ -103,18 +110,21 @@ // last or the other updates will cause a thread specific state to be // created. state::ValueRAII ParallelTeamSizeRAII(state::ParallelTeamSize, NumThreads, - 1u, TId == 0, ident); + 1u, TId == 0, ident, + /* ForceTeamState */ true); state::ValueRAII ActiveLevelRAII(icv::ActiveLevel, 1u, 0u, TId == 0, - ident); - state::ValueRAII LevelRAII(icv::Level, 1u, 0u, TId == 0, ident); + ident, /* ForceTeamState */ true); + state::ValueRAII LevelRAII(icv::Level, 1u, 0u, TId == 0, ident, + /* ForceTeamState */ true); // Synchronize all threads after the main thread (TId == 0) set up the // team state properly. synchronize::threadsAligned(); - ASSERT(state::ParallelTeamSize == NumThreads); - ASSERT(icv::ActiveLevel == 1u); - ASSERT(icv::Level == 1u); + state::ParallelTeamSize.assert_eq(NumThreads, ident, + /* ForceTeamState */ true); + icv::ActiveLevel.assert_eq(1u, ident, /* ForceTeamState */ true); + icv::Level.assert_eq(1u, ident, /* ForceTeamState */ true); if (TId < NumThreads) invokeMicrotask(TId, 0, fn, args, nargs); @@ -128,9 +138,9 @@ // __kmpc_target_deinit may not hold. synchronize::threadsAligned(); - ASSERT(state::ParallelTeamSize == 1u); - ASSERT(icv::ActiveLevel == 0u); - ASSERT(icv::Level == 0u); + state::ParallelTeamSize.assert_eq(1u, ident, /* ForceTeamState */ true); + icv::ActiveLevel.assert_eq(0u, ident, /* ForceTeamState */ true); + icv::Level.assert_eq(0u, ident, /* ForceTeamState */ true); return; } @@ -213,11 +223,15 @@ // last or the other updates will cause a thread specific state to be // created. state::ValueRAII ParallelTeamSizeRAII(state::ParallelTeamSize, NumThreads, - 1u, true, ident); + 1u, true, ident, + /* ForceTeamState */ true); state::ValueRAII ParallelRegionFnRAII(state::ParallelRegionFn, wrapper_fn, - (void *)nullptr, true, ident); - state::ValueRAII ActiveLevelRAII(icv::ActiveLevel, 1u, 0u, true, ident); - state::ValueRAII LevelRAII(icv::Level, 1u, 0u, true, ident); + (void *)nullptr, true, ident, + /* ForceTeamState */ true); + state::ValueRAII ActiveLevelRAII(icv::ActiveLevel, 1u, 0u, true, ident, + /* ForceTeamState */ true); + state::ValueRAII LevelRAII(icv::Level, 1u, 0u, true, ident, + /* ForceTeamState */ true); // Master signals work to activate workers. synchronize::threads(); diff --git a/openmp/libomptarget/DeviceRTL/src/State.cpp b/openmp/libomptarget/DeviceRTL/src/State.cpp --- a/openmp/libomptarget/DeviceRTL/src/State.cpp +++ b/openmp/libomptarget/DeviceRTL/src/State.cpp @@ -203,17 +203,20 @@ ICVState.RunSchedVar = omp_sched_static; ICVState.RunSchedChunkVar = 1; ParallelTeamSize = 1; + HasThreadState = false; ParallelRegionFnVar = nullptr; } bool state::TeamStateTy::operator==(const TeamStateTy &Other) const { return (ICVState == Other.ICVState) & + (HasThreadState == Other.HasThreadState) & (ParallelTeamSize == Other.ParallelTeamSize); } void state::TeamStateTy::assertEqual(TeamStateTy &Other) const { ICVState.assertEqual(Other.ICVState); ASSERT(ParallelTeamSize == Other.ParallelTeamSize); + ASSERT(HasThreadState == Other.HasThreadState); } namespace { @@ -257,6 +260,7 @@ ThreadStateTy *NewThreadState = static_cast(__kmpc_alloc_shared(sizeof(ThreadStateTy))); NewThreadState->init(ThreadStates[TId]); + TeamState.HasThreadState = true; ThreadStates[TId] = NewThreadState; } @@ -269,7 +273,7 @@ } void state::resetStateForThread(uint32_t TId) { - if (OMP_LIKELY(!ThreadStates[TId])) + if (OMP_LIKELY(!TeamState.HasThreadState || !ThreadStates[TId])) return; ThreadStateTy *PreviousThreadState = ThreadStates[TId]->PreviousThreadState;