diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -21,7 +21,7 @@ std::unique_ptr createAsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads, - int32_t targetBlockSize); + int32_t minTaskSize); std::unique_ptr> createAsyncToAsyncRuntimePass(); diff --git a/mlir/include/mlir/Dialect/Async/Transforms.h b/mlir/include/mlir/Dialect/Async/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Async/Transforms.h @@ -0,0 +1,40 @@ +//===- Transforms.h - Async dialect transformation utilities ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines transformations on Async operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ASYNC_TRANSFORMS_H_ +#define MLIR_DIALECT_ASYNC_TRANSFORMS_H_ + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" + +namespace mlir { +namespace async { + +/// Emit the IR to compute the minimum number of iterations of scf.parallel body +/// that would be viable for a single parallel task. Allows the user to avoid +/// incurring the overheads of spawning costly parallel tasks in absence of +/// sufficient amount of parallelizable work. +/// +/// Must return an index type. +using AsyncMinTaskSizeComputationFunction = + std::function; + +/// Add a pattern to the given pattern list to lower scf.parallel to async +/// operations. +void populateAsyncParallelForPatterns( + RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, + AsyncMinTaskSizeComputationFunction computeMinTaskSize); + +} // namespace async +} // namespace mlir + +#endif // MLIR_DIALECT_ASYNC_TRANSFORMS_H_ diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/Async/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -104,10 +105,12 @@ struct AsyncParallelForRewrite : public OpRewritePattern { public: - AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch, - int32_t numWorkerThreads, int32_t minTaskSize) + AsyncParallelForRewrite( + MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads, + AsyncMinTaskSizeComputationFunction computeMinTaskSize) : OpRewritePattern(ctx), asyncDispatch(asyncDispatch), - numWorkerThreads(numWorkerThreads), minTaskSize(minTaskSize) {} + numWorkerThreads(numWorkerThreads), + computeMinTaskSize(computeMinTaskSize) {} LogicalResult matchAndRewrite(scf::ParallelOp op, PatternRewriter &rewriter) const override; @@ -115,7 +118,7 @@ private: bool asyncDispatch; int32_t numWorkerThreads; - int32_t minTaskSize; + AsyncMinTaskSizeComputationFunction computeMinTaskSize; }; struct ParallelComputeFunctionType { @@ -673,6 +676,11 @@ auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) { ImplicitLocOpBuilder nb(loc, nestedBuilder); + // Create a parallel compute function that takes a block id and computes the + // parallel operation body for a subset of iteration space. + ParallelComputeFunction parallelComputeFunction = + createParallelComputeFunction(op, rewriter); + // With large number of threads the value of creating many compute blocks // is reduced because the problem typically becomes memory bound. For small // number of threads it helps with stragglers. @@ -687,23 +695,15 @@ Value maxComputeBlocks = b.create( std::max(1, static_cast(numWorkerThreads * overshardingFactor))); - // Target block size from the pass parameters. - Value minTaskSizeCst = b.create(minTaskSize); - // Compute parallel block size from the parallel problem size: // blockSize = min(tripCount, // max(ceil_div(tripCount, maxComputeBlocks), - // ceil_div(minTaskSize, bodySize))) + // minTaskSize)) Value bs0 = b.create(tripCount, maxComputeBlocks); - Value bs1 = b.create(bs0, minTaskSizeCst); + Value bs1 = b.create(bs0, computeMinTaskSize(b, op)); Value blockSize = b.create(tripCount, bs1); Value blockCount = b.create(tripCount, blockSize); - // Create a parallel compute function that takes a block id and computes the - // parallel operation body for a subset of iteration space. - ParallelComputeFunction parallelComputeFunction = - createParallelComputeFunction(op, rewriter); - // Dispatch parallel compute function using async recursive work splitting, // or by submitting compute task sequentially from a caller thread. if (asyncDispatch) { @@ -730,9 +730,11 @@ MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx, asyncDispatch, numWorkerThreads, - minTaskSize); - + populateAsyncParallelForPatterns( + patterns, asyncDispatch, numWorkerThreads, + [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) { + return builder.create(minTaskSize); + }); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } @@ -747,3 +749,11 @@ return std::make_unique(asyncDispatch, numWorkerThreads, minTaskSize); } + +void mlir::async::populateAsyncParallelForPatterns( + RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, + AsyncMinTaskSizeComputationFunction computeMinTaskSize) { + MLIRContext *ctx = patterns.getContext(); + patterns.add(ctx, asyncDispatch, numWorkerThreads, + computeMinTaskSize); +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2010,7 +2010,10 @@ "lib/Dialect/Async/Transforms/*.cpp", "lib/Dialect/Async/Transforms/*.h", ]), - hdrs = ["include/mlir/Dialect/Async/Passes.h"], + hdrs = [ + "include/mlir/Dialect/Async/Passes.h", + "include/mlir/Dialect/Async/Transforms.h", + ], includes = ["include"], deps = [ ":Analysis",