diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -19,6 +19,9 @@ std::unique_ptr> createAsyncParallelForPass(); +std::unique_ptr> +createAsyncParallelForPass(int numWorkerThreads); + std::unique_ptr> createAsyncRefCountingPass(); std::unique_ptr> createAsyncRefCountingOptimizationPass(); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -96,6 +96,10 @@ struct AsyncParallelForPass : public AsyncParallelForBase { AsyncParallelForPass() = default; + AsyncParallelForPass(int numWorkerThreads) { + assert(numWorkerThreads >= 1); + numConcurrentAsyncExecute = numWorkerThreads; + } void runOnFunction() override; }; @@ -276,3 +280,8 @@ std::unique_ptr> mlir::createAsyncParallelForPass() { return std::make_unique(); } + +std::unique_ptr> +mlir::createAsyncParallelForPass(int numWorkerThreads) { + return std::make_unique(numWorkerThreads); +}