diff --git a/llvm/include/llvm/Support/Parallel.h b/llvm/include/llvm/Support/Parallel.h --- a/llvm/include/llvm/Support/Parallel.h +++ b/llvm/include/llvm/Support/Parallel.h @@ -99,6 +99,8 @@ void spawn(std::function f, bool Sequential = false); void sync() const { L.sync(); } + + bool isParallel() { return Parallel; } }; namespace detail { diff --git a/llvm/lib/Support/Parallel.cpp b/llvm/lib/Support/Parallel.cpp --- a/llvm/lib/Support/Parallel.cpp +++ b/llvm/lib/Support/Parallel.cpp @@ -99,11 +99,6 @@ void add(std::function F, bool Sequential = false) override { { - if (parallel::strategy.ThreadsRequested == 1) { - F(); - return; - } - std::lock_guard Lock(Mutex); if (Sequential) WorkQueueSequential.emplace_front(std::move(F)); @@ -185,18 +180,17 @@ } // namespace detail #endif -static std::atomic TaskGroupInstances; - // Latch::sync() called by the dtor may cause one thread to block. If is a dead // lock if all threads in the default executor are blocked. To prevent the dead -// lock, only allow the first TaskGroup to run tasks parallelly. In the scenario +// lock, only allow the root TaskGroup to run tasks parallelly. In the scenario // of nested parallel_for_each(), only the outermost one runs parallelly. -TaskGroup::TaskGroup() : Parallel(TaskGroupInstances++ == 0) {} +TaskGroup::TaskGroup() + : Parallel((parallel::strategy.ThreadsRequested != 1) && + (threadIndex == UINT_MAX)) {} TaskGroup::~TaskGroup() { // We must ensure that all the workloads have finished before decrementing the // instances count. L.sync(); - --TaskGroupInstances; } void TaskGroup::spawn(std::function F, bool Sequential) { diff --git a/llvm/unittests/Support/ParallelTest.cpp b/llvm/unittests/Support/ParallelTest.cpp --- a/llvm/unittests/Support/ParallelTest.cpp +++ b/llvm/unittests/Support/ParallelTest.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Support/Parallel.h" +#include "llvm/Support/ThreadPool.h" #include "gtest/gtest.h" #include #include @@ -102,4 +103,76 @@ EXPECT_EQ(Count, 500ul); } +TEST(Parallel, NestedTaskGroup) { + // This test checks: + // 1. Root TaskGroup is in Parallel mode. + // 2. Nested TaskGroup is not in Parallel mode. + parallel::TaskGroup tg; + + tg.spawn([&]() { + EXPECT_TRUE(tg.isParallel() || (parallel::strategy.ThreadsRequested == 1)); + }); + + tg.spawn([&]() { + parallel::TaskGroup nestedTG; + EXPECT_FALSE(nestedTG.isParallel()); + + nestedTG.spawn([&]() { + // Check that root TaskGroup is in Parallel mode. + EXPECT_TRUE(tg.isParallel() || + (parallel::strategy.ThreadsRequested == 1)); + + // Check that nested TaskGroup is not in Parallel mode. + EXPECT_FALSE(nestedTG.isParallel()); + }); + }); +} + +#if LLVM_ENABLE_THREADS +TEST(Parallel, ParallelNestedTaskGroup) { + // This test checks that it is possible to have several TaskGroups + // run from different threads in Parallel mode. + std::atomic Count{0}; + + { + std::function Fn = [&]() { + parallel::TaskGroup tg; + + tg.spawn([&]() { + // Check that root TaskGroup is in Parallel mode. + EXPECT_TRUE(tg.isParallel() || + (parallel::strategy.ThreadsRequested == 1)); + + // Check that nested TaskGroup is not in Parallel mode. + parallel::TaskGroup nestedTG; + EXPECT_FALSE(nestedTG.isParallel()); + Count++; + + nestedTG.spawn([&]() { + // Check that root TaskGroup is in Parallel mode. + EXPECT_TRUE(tg.isParallel() || + (parallel::strategy.ThreadsRequested == 1)); + + // Check that nested TaskGroup is not in Parallel mode. + EXPECT_FALSE(nestedTG.isParallel()); + Count++; + }); + }); + }; + + ThreadPool Pool; + + Pool.async(Fn); + Pool.async(Fn); + Pool.async(Fn); + Pool.async(Fn); + Pool.async(Fn); + Pool.async(Fn); + + Pool.wait(); + } + EXPECT_EQ(Count, 12ul); +} +#endif + #endif