diff --git a/llvm/include/llvm/Support/ThreadPool.h b/llvm/include/llvm/Support/ThreadPool.h --- a/llvm/include/llvm/Support/ThreadPool.h +++ b/llvm/include/llvm/Support/ThreadPool.h @@ -70,6 +70,9 @@ unsigned getThreadCount() const { return ThreadCount; } + /// Returns true if the current thread is a worker thread of this thread pool. + bool isWorkerThread() const; + private: bool workCompletedUnlocked() { return !ActiveThreads && Tasks.empty(); } diff --git a/llvm/lib/Support/ThreadPool.cpp b/llvm/lib/Support/ThreadPool.cpp --- a/llvm/lib/Support/ThreadPool.cpp +++ b/llvm/lib/Support/ThreadPool.cpp @@ -72,6 +72,14 @@ CompletionCondition.wait(LockGuard, [&] { return workCompletedUnlocked(); }); } +bool ThreadPool::isWorkerThread() const { + std::thread::id CurrentThreadId = std::this_thread::get_id(); + for (const std::thread &Thread : Threads) + if (CurrentThreadId == Thread.get_id()) + return true; + return false; +} + std::shared_future ThreadPool::asyncImpl(TaskTy Task) { /// Wrap the Task in a packaged_task to return a future object. PackagedTaskTy PackagedTask(std::move(Task)); diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -15,6 +15,10 @@ #include #include +namespace llvm { +class ThreadPool; +} // end namespace llvm + namespace mlir { class AbstractOperation; class DebugActionManager; @@ -114,6 +118,12 @@ disableMultithreading(!enable); } + /// Return the thread pool owned by this context. This method requires that + /// multithreading be enabled within the context, and should generally not be + /// used directly. Users should instead prefer the threading utilities within + /// ThreadingUtilities.h. + llvm::ThreadPool &getThreadPool(); + /// Return true if we should attach the operation to diagnostics emitted via /// Operation::emit. bool shouldPrintOpOnDiagnostic(); diff --git a/mlir/include/mlir/IR/ThreadingUtilities.h b/mlir/include/mlir/IR/ThreadingUtilities.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/ThreadingUtilities.h @@ -0,0 +1,92 @@ +//===- ThreadingUtilities.h - MLIR Threading Utilities ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines various utilies for multithreaded processing within MLIR. +// These utilities automatically handle many of the necessary threading +// conditions, such as properly ordering diagnostics, observing if threading is +// disabled, etc. These utilities should be used over other threading utilities +// whenever feasible. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_THREADING_UTILITIES_H +#define MLIR_IR_THREADING_UTILITIES_H + +#include "mlir/IR/Diagnostics.h" +#include "llvm/Support/ThreadPool.h" +#include + +namespace mlir { + +/// Invoke the given function on the elements between [begin, end) +/// asynchronously. Diagnostics emitted during processing are ordered relative +/// to the element's position within [begin, end). If the provided context does +/// not have multi-threading enabled, this function always processes elements +/// synchronously. +template +void parallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, + FuncT &&func) { + unsigned numElements = static_cast(std::distance(begin, end)); + if (numElements == 0) + return; + + // If multithreading is disabled, there is a small number of elements + // process the elements directly on this thread. + if (!context->isMultithreadingEnabled() || numElements <= 1) + return std::for_each(begin, end, func); + llvm::ThreadPool &threadPool = context->getThreadPool(); + + // If this is a worker thread of the thread pool, don't execute in parallel to + // avoid potential deadlock. + // FIXME: ThreadPool should allow work stealing to avoid deadlocks when + // scheduling work within a worker thread. + if (threadPool.isWorkerThread()) + return std::for_each(begin, end, func); + + // Build a wrapper processing function that properly initializes a parallel + // diagnostic handler. + ParallelDiagnosticHandler handler(context); + std::atomic curIndex = 0; + auto processFn = [&](unsigned threadIndex) { + handler.setOrderIDForThread(threadIndex); + while (true) { + unsigned index = curIndex++; + if (index >= numElements) + break; + func(*std::next(begin, index)); + } + handler.eraseOrderIDForThread(); + }; + + // Otherwise, process the elements in parallel. + size_t numActions = std::min(numElements, threadPool.getThreadCount()); + SmallVector> threadFutures; + threadFutures.reserve(numActions - 1); + for (unsigned i = 1; i < numActions; ++i) + threadFutures.emplace_back(threadPool.async(processFn, i)); + processFn(/*threadIndex=*/0); + + // Wait for all of the threads to finish. + for (std::shared_future &future : threadFutures) + future.wait(); +} + +/// Invoke the given function on the elements in the provided range +/// asynchronously. Diagnostics emitted during processing are ordered relative +/// to the element's position within the range. If the provided context does not +/// have multi-threading enabled, this function always processes elements +/// synchronously. +template +void parallelForEach(MLIRContext *context, RangeT &&range, FuncT &&func) { + parallelForEach(context, std::begin(range), std::end(range), + std::forward(func)); +} + +} // end namespace mlir + +#endif // MLIR_IR_THREADING_UTILITIES_H diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -35,6 +35,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/RWMutex.h" +#include "llvm/Support/ThreadPool.h" #include "llvm/Support/raw_ostream.h" #include @@ -261,6 +262,9 @@ // Other //===--------------------------------------------------------------------===// + /// The thread pool to use when processing MLIR tasks in parallel. + llvm::ThreadPool threadPool; + /// This is a list of dialects that are created referring to this context. /// The MLIRContext owns the objects. DenseMap> loadedDialects; @@ -576,6 +580,12 @@ impl->typeUniquer.disableMultithreading(disable); } +llvm::ThreadPool &MLIRContext::getThreadPool() { + assert(isMultithreadingEnabled() && + "expected multi-threading to be enabled within the context"); + return impl->threadPool; +} + void MLIRContext::enterMultiThreadedExecution() { #ifndef NDEBUG ++impl->multiThreadedExecutionContext; diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -30,6 +30,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" #include "mlir/IR/RegionKindInterface.h" +#include "mlir/IR/ThreadingUtilities.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Parallel.h" @@ -43,9 +44,6 @@ /// This class encapsulates all the state used to verify an operation region. class OperationVerifier { public: - explicit OperationVerifier(MLIRContext *context) - : parallelismEnabled(context->isMultithreadingEnabled()) {} - /// Verify the given operation. LogicalResult verifyOpAndDominance(Operation &op); @@ -64,9 +62,6 @@ /// Operation. LogicalResult verifyDominanceOfContainedRegions(Operation &op, DominanceInfo &domInfo); - - /// This is true if parallelism is enabled on the MLIRContext. - const bool parallelismEnabled; }; } // end anonymous namespace @@ -89,28 +84,12 @@ // Check the dominance properties and invariants of any operations in the // regions contained by the 'opsWithIsolatedRegions' operations. - if (!parallelismEnabled || opsWithIsolatedRegions.size() <= 1) { - // If parallelism is disabled or if there is only 0/1 operation to do, use - // a simple non-parallel loop. - for (Operation *op : opsWithIsolatedRegions) { - if (failed(verifyOpAndDominance(*op))) - return failure(); - } - } else { - // Otherwise, verify the operations and their bodies in parallel. - ParallelDiagnosticHandler handler(op.getContext()); - std::atomic passFailed(false); - llvm::parallelForEachN(0, opsWithIsolatedRegions.size(), [&](size_t opIdx) { - handler.setOrderIDForThread(opIdx); - if (failed(verifyOpAndDominance(*opsWithIsolatedRegions[opIdx]))) - passFailed = true; - handler.eraseOrderIDForThread(); - }); - if (passFailed) - return failure(); - } - - return success(); + std::atomic passFailed(false); + parallelForEach(op.getContext(), opsWithIsolatedRegions, [&](Operation *op) { + if (!passFailed && failed(verifyOpAndDominance(*op))) + passFailed = true; + }); + return failure(passFailed); } /// Returns true if this block may be valid without terminator. That is if: @@ -376,5 +355,5 @@ /// compiler bugs. On error, this reports the error through the MLIRContext and /// returns failure. LogicalResult mlir::verify(Operation *op) { - return OperationVerifier(op->getContext()).verifyOpAndDominance(*op); + return OperationVerifier().verifyOpAndDominance(*op); } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -14,6 +14,7 @@ #include "PassDetail.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/ThreadingUtilities.h" #include "mlir/IR/Verifier.h" #include "mlir/Support/FileUtilities.h" #include "llvm/ADT/STLExtras.h" @@ -580,13 +581,6 @@ } } - // A parallel diagnostic handler that provides deterministic diagnostic - // ordering. - ParallelDiagnosticHandler diagHandler(&getContext()); - - // An index for the current operation/analysis manager pair. - std::atomic opIt(0); - // Get the current thread for this adaptor. PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(), this}; @@ -594,44 +588,36 @@ // An atomic failure variable for the async executors. std::atomic passFailed(false); - llvm::parallelForEach( - asyncExecutors.begin(), - std::next(asyncExecutors.begin(), - std::min(asyncExecutors.size(), opAMPairs.size())), - [&](MutableArrayRef pms) { - for (auto e = opAMPairs.size(); !passFailed && opIt < e;) { - // Get the next available operation index. - unsigned nextID = opIt++; - if (nextID >= e) - break; - - // Set the order id for this thread in the diagnostic handler. - diagHandler.setOrderIDForThread(nextID); - - // Get the pass manager for this operation and execute it. - auto &it = opAMPairs[nextID]; - auto *pm = findPassManagerFor( - pms, it.first->getName().getIdentifier(), getContext()); - assert(pm && "expected valid pass manager for operation"); - - unsigned initGeneration = pm->impl->initializationGeneration; - LogicalResult pipelineResult = - runPipeline(pm->getPasses(), it.first, it.second, verifyPasses, - initGeneration, instrumentor, &parentInfo); - - // Drop this thread from being tracked by the diagnostic handler. - // After this task has finished, the thread may be used outside of - // this pass manager context meaning that we don't want to track - // diagnostics from it anymore. - diagHandler.eraseOrderIDForThread(); - - // Handle a failed pipeline result. - if (failed(pipelineResult)) { - passFailed = true; - break; - } - } - }); + std::vector> activePMs(asyncExecutors.size()); + for (std::atomic &isActive : activePMs) + isActive = false; + parallelForEach(&getContext(), opAMPairs, [&](auto &opPMPair) { + if (passFailed) + return; + + // Find a pass manager for this operation. + auto it = llvm::find_if(activePMs, [](std::atomic &isActive) { + bool expectedInactive = false; + return isActive.compare_exchange_strong(expectedInactive, true); + }); + unsigned pmIndex = it - activePMs.begin(); + + // Get the pass manager for this operation and execute it. + auto *pm = findPassManagerFor(asyncExecutors[pmIndex], + opPMPair.first->getName().getIdentifier(), + getContext()); + assert(pm && "expected valid pass manager for operation"); + + unsigned initGeneration = pm->impl->initializationGeneration; + LogicalResult pipelineResult = + runPipeline(pm->getPasses(), opPMPair.first, opPMPair.second, + verifyPasses, initGeneration, instrumentor, &parentInfo); + if (failed(pipelineResult)) + passFailed = true; + + // Reset the active bit for this pass manager. + activePMs[pmIndex].store(false); + }); // Signal a failure if any of the executors failed. if (passFailed) diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -15,6 +15,7 @@ #include "PassDetail.h" #include "mlir/Analysis/CallGraph.h" +#include "mlir/IR/ThreadingUtilities.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/InliningUtils.h" @@ -703,34 +704,29 @@ for (CallGraphNode *node : nodesToVisit) getAnalysisManager().nest(node->getCallableRegion()->getParentOp()); - // An index for the current node to optimize. - std::atomic nodeIt(0); - - // Optimize the nodes of the SCC in parallel. - ParallelDiagnosticHandler optimizerHandler(context); + // An atomic failure variable for the async executors. std::atomic passFailed(false); - llvm::parallelForEach( - opPipelines.begin(), std::next(opPipelines.begin(), numThreads), - [&](llvm::StringMap &pipelines) { - for (auto e = nodesToVisit.size(); !passFailed && nodeIt < e;) { - // Get the next available operation index. - unsigned nextID = nodeIt++; - if (nextID >= e) - break; - - // Set the order for this thread so that diagnostics will be - // properly ordered, and reset after optimization has finished. - optimizerHandler.setOrderIDForThread(nextID); - LogicalResult pipelineResult = - optimizeCallable(nodesToVisit[nextID], pipelines); - optimizerHandler.eraseOrderIDForThread(); - - if (failed(pipelineResult)) { - passFailed = true; - break; - } - } - }); + std::vector> activePMs(opPipelines.size()); + for (std::atomic &isActive : activePMs) + isActive = false; + parallelForEach(context, nodesToVisit, [&](CallGraphNode *node) { + if (passFailed) + return; + + // Find a pass manager for this operation. + auto it = llvm::find_if(activePMs, [](std::atomic &isActive) { + bool expectedInactive = false; + return isActive.compare_exchange_strong(expectedInactive, true); + }); + unsigned pmIndex = it - activePMs.begin(); + + // Optimize this callable node. + if (failed(optimizeCallable(node, opPipelines[pmIndex]))) + passFailed = true; + + // Reset the active bit for this pass manager. + activePMs[pmIndex].store(false); + }); return failure(passFailed); } diff --git a/mlir/test/Dialect/Affine/SuperVectorize/compose_maps.mlir b/mlir/test/Dialect/Affine/SuperVectorize/compose_maps.mlir --- a/mlir/test/Dialect/Affine/SuperVectorize/compose_maps.mlir +++ b/mlir/test/Dialect/Affine/SuperVectorize/compose_maps.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-super-vectorizer-test -compose-maps 2>&1 | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -affine-super-vectorizer-test -compose-maps -mlir-disable-threading 2>&1 | FileCheck %s // For all these cases, the test traverses the `test_affine_map` ops and // composes them in order one-by-one. diff --git a/mlir/test/Dialect/Affine/slicing-utils.mlir b/mlir/test/Dialect/Affine/slicing-utils.mlir --- a/mlir/test/Dialect/Affine/slicing-utils.mlir +++ b/mlir/test/Dialect/Affine/slicing-utils.mlir @@ -1,6 +1,6 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-super-vectorizer-test -forward-slicing=true 2>&1 | FileCheck %s --check-prefix=FWD -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-super-vectorizer-test -backward-slicing=true 2>&1 | FileCheck %s --check-prefix=BWD -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-super-vectorizer-test -slicing=true 2>&1 | FileCheck %s --check-prefix=FWDBWD +// RUN: mlir-opt -allow-unregistered-dialect %s -mlir-disable-threading -affine-super-vectorizer-test -forward-slicing=true 2>&1 | FileCheck %s --check-prefix=FWD +// RUN: mlir-opt -allow-unregistered-dialect %s -mlir-disable-threading -affine-super-vectorizer-test -backward-slicing=true 2>&1 | FileCheck %s --check-prefix=BWD +// RUN: mlir-opt -allow-unregistered-dialect %s -mlir-disable-threading -affine-super-vectorizer-test -slicing=true 2>&1 | FileCheck %s --check-prefix=FWDBWD /// 1 2 3 4 /// |_______| |______|