diff --git a/llvm/include/llvm/Support/Parallel.h b/llvm/include/llvm/Support/Parallel.h --- a/llvm/include/llvm/Support/Parallel.h +++ b/llvm/include/llvm/Support/Parallel.h @@ -11,6 +11,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Config/llvm-config.h" +#include "llvm/Support/Error.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/Threading.h" @@ -156,6 +157,44 @@ Fn(J); } +template +ResultTy parallel_transform_reduce(IterTy Begin, IterTy End, ResultTy Init, + ReduceFuncTy Reduce, + TransformFuncTy Transform) { + // TaskGroup has a relatively high overhead, so we want to reduce + // the number of spawn() calls. We'll create up to 1024 tasks here. + // (Note that 1024 is an arbitrary number. This code probably needs + // improving to take the number of available cores into account.) + size_t NumInputs = std::distance(Begin, End); + size_t NumTasks = std::min(static_cast(1024), NumInputs); + size_t TaskSize = NumInputs / NumTasks; + assert(TaskSize > 0); + + std::vector Results(NumTasks, Init); + { + TaskGroup TG; + for (size_t TaskId = 0; TaskId < NumTasks; ++TaskId) { + IterTy TBegin = Begin + TaskId * TaskSize; + IterTy TEnd = std::min(TBegin + TaskSize, End); + TG.spawn([=, &Transform, &Reduce, &Results] { + // Reduce the result of transformation eagerly within each task. + ResultTy R = Init; + for (IterTy It = TBegin; It != TEnd; ++It) + R = Reduce(R, Transform(*It)); + Results[TaskId] = R; + }); + } + } + + // Do a final reduction. There are at most 1024 tasks, so this only adds + // constant single-threaded overhead for large inputs. Hopefully most + // reductions are cheaper than the transformation. + for (auto &R : Results) + Init = Reduce(Init, std::move(R)); + return std::move(Init); +} + #endif } // namespace detail @@ -198,6 +237,22 @@ Fn(I); } +template +ResultTy parallelTransformReduce(IterTy Begin, IterTy End, ResultTy Init, + ReduceFuncTy Reduce, + TransformFuncTy Transform) { +#if LLVM_ENABLE_THREADS + if (parallel::strategy.ThreadsRequested != 1) { + return parallel::detail::parallel_transform_reduce(Begin, End, Init, Reduce, + Transform); + } +#endif + for (IterTy I = Begin; I != End; ++I) + Init = Reduce(std::move(Init), Transform(*I)); + return std::move(Init); +} + // Range wrappers. template > @@ -210,6 +265,31 @@ parallelForEach(std::begin(R), std::end(R), Fn); } +template +ResultTy parallelTransformReduce(RangeTy &&R, ResultTy Init, + ReduceFuncTy Reduce, + TransformFuncTy Transform) { + return parallelTransformReduce(std::begin(R), std::end(R), Init, Reduce, + Transform); +} + +// Parallel for-each, but with error handling. +template +Error parallelForEachError(RangeTy &&R, FuncTy Fn) { + // The transform_reduce algorithm requires that the initial value be copyable. + // Error objects are uncopyable. We only need to copy initial success values, + // so work around this mismatch via the C API. The C API represents success + // values with a null pointer. The joinErrors discards null values and joins + // multiple errors into an ErrorList. + return unwrap(parallelTransformReduce( + std::begin(R), std::end(R), wrap(Error::success()), + [](LLVMErrorRef Lhs, LLVMErrorRef Rhs) { + return wrap(joinErrors(unwrap(Lhs), unwrap(Rhs))); + }, + [&Fn](auto &&V) { return wrap(Fn(V)); })); +} + } // namespace llvm #endif // LLVM_SUPPORT_PARALLEL_H diff --git a/llvm/unittests/Support/ParallelTest.cpp b/llvm/unittests/Support/ParallelTest.cpp --- a/llvm/unittests/Support/ParallelTest.cpp +++ b/llvm/unittests/Support/ParallelTest.cpp @@ -49,4 +49,25 @@ ASSERT_EQ(range[2049], 1u); } +TEST(Parallel, TransformReduce) { + // Sum the lengths of these strings in parallel. + const char *strs[] = {"a", "ab", "abc", "abcd", "abcde", "abcdef"}; + size_t lenSum = + parallelTransformReduce(strs, static_cast(0), std::plus(), + [](const char *s) { return strlen(s); }); + ASSERT_EQ(lenSum, static_cast(21)); +} + +TEST(Parallel, ForEachError) { + int nums[] = {1, 2, 3, 4, 5, 6}; + Error e = parallelForEachError(nums, [](int v) -> Error { + if ((v & 1) == 0) + return createStringError(std::errc::invalid_argument, "asdf"); + return Error::success(); + }); + EXPECT_TRUE(e.isA()); + std::string errText = toString(std::move(e)); + EXPECT_EQ(errText, std::string("asdf\nasdf\nasdf")); +} + #endif