Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
//===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===// | //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===// | ||||
// | // | ||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
// See https://llvm.org/LICENSE.txt for license information. | // See https://llvm.org/LICENSE.txt for license information. | ||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
// | // | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// | // | ||||
// This file implements scf.parallel to scf.for + async.execute conversion pass. | // This file implements scf.parallel to scf.for + async.execute conversion pass. | ||||
// | // | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
#include "PassDetail.h" | #include "PassDetail.h" | ||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" | #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" | ||||
#include "mlir/Dialect/Async/IR/Async.h" | #include "mlir/Dialect/Async/IR/Async.h" | ||||
#include "mlir/Dialect/Async/Passes.h" | #include "mlir/Dialect/Async/Passes.h" | ||||
#include "mlir/Dialect/Async/Transforms.h" | |||||
#include "mlir/Dialect/SCF/SCF.h" | #include "mlir/Dialect/SCF/SCF.h" | ||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" | #include "mlir/Dialect/StandardOps/IR/Ops.h" | ||||
#include "mlir/IR/BlockAndValueMapping.h" | #include "mlir/IR/BlockAndValueMapping.h" | ||||
#include "mlir/IR/ImplicitLocOpBuilder.h" | #include "mlir/IR/ImplicitLocOpBuilder.h" | ||||
#include "mlir/IR/PatternMatch.h" | #include "mlir/IR/PatternMatch.h" | ||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||||
#include "mlir/Transforms/RegionUtils.h" | #include "mlir/Transforms/RegionUtils.h" | ||||
▲ Show 20 Lines • Show All 74 Lines • ▼ Show 20 Lines | AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads, | ||||
this->minTaskSize = minTaskSize; | this->minTaskSize = minTaskSize; | ||||
} | } | ||||
void runOnOperation() override; | void runOnOperation() override; | ||||
}; | }; | ||||
struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> { | struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> { | ||||
public: | public: | ||||
AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch, | AsyncParallelForRewrite( | ||||
int32_t numWorkerThreads, int32_t minTaskSize) | MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads, | ||||
AsyncMinTaskSizeComputationFunction computeMinTaskSize) | |||||
: OpRewritePattern(ctx), asyncDispatch(asyncDispatch), | : OpRewritePattern(ctx), asyncDispatch(asyncDispatch), | ||||
numWorkerThreads(numWorkerThreads), minTaskSize(minTaskSize) {} | numWorkerThreads(numWorkerThreads), | ||||
computeMinTaskSize(computeMinTaskSize) {} | |||||
LogicalResult matchAndRewrite(scf::ParallelOp op, | LogicalResult matchAndRewrite(scf::ParallelOp op, | ||||
PatternRewriter &rewriter) const override; | PatternRewriter &rewriter) const override; | ||||
private: | private: | ||||
bool asyncDispatch; | bool asyncDispatch; | ||||
int32_t numWorkerThreads; | int32_t numWorkerThreads; | ||||
int32_t minTaskSize; | AsyncMinTaskSizeComputationFunction computeMinTaskSize; | ||||
}; | }; | ||||
struct ParallelComputeFunctionType { | struct ParallelComputeFunctionType { | ||||
FunctionType type; | FunctionType type; | ||||
llvm::SmallVector<Value> captures; | llvm::SmallVector<Value> captures; | ||||
}; | }; | ||||
struct ParallelComputeFunction { | struct ParallelComputeFunction { | ||||
▲ Show 20 Lines • Show All 541 Lines • ▼ Show 20 Lines | auto noOp = [&](OpBuilder &nestedBuilder, Location loc) { | ||||
nestedBuilder.create<scf::YieldOp>(loc); | nestedBuilder.create<scf::YieldOp>(loc); | ||||
}; | }; | ||||
// Compute the parallel block size and dispatch concurrent tasks computing | // Compute the parallel block size and dispatch concurrent tasks computing | ||||
// results for each block. | // results for each block. | ||||
auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) { | auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) { | ||||
ImplicitLocOpBuilder nb(loc, nestedBuilder); | ImplicitLocOpBuilder nb(loc, nestedBuilder); | ||||
// Create a parallel compute function that takes a block id and computes the | |||||
// parallel operation body for a subset of iteration space. | |||||
ParallelComputeFunction parallelComputeFunction = | |||||
createParallelComputeFunction(op, rewriter); | |||||
// With large number of threads the value of creating many compute blocks | // With large number of threads the value of creating many compute blocks | ||||
// is reduced because the problem typically becomes memory bound. For small | // is reduced because the problem typically becomes memory bound. For small | ||||
// number of threads it helps with stragglers. | // number of threads it helps with stragglers. | ||||
float overshardingFactor = numWorkerThreads <= 4 ? 8.0 | float overshardingFactor = numWorkerThreads <= 4 ? 8.0 | ||||
: numWorkerThreads <= 8 ? 4.0 | : numWorkerThreads <= 8 ? 4.0 | ||||
: numWorkerThreads <= 16 ? 2.0 | : numWorkerThreads <= 16 ? 2.0 | ||||
: numWorkerThreads <= 32 ? 1.0 | : numWorkerThreads <= 32 ? 1.0 | ||||
: numWorkerThreads <= 64 ? 0.8 | : numWorkerThreads <= 64 ? 0.8 | ||||
: 0.6; | : 0.6; | ||||
// Do not overload worker threads with too many compute blocks. | // Do not overload worker threads with too many compute blocks. | ||||
Value maxComputeBlocks = b.create<arith::ConstantIndexOp>( | Value maxComputeBlocks = b.create<arith::ConstantIndexOp>( | ||||
std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor))); | std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor))); | ||||
// Target block size from the pass parameters. | |||||
Value minTaskSizeCst = b.create<arith::ConstantIndexOp>(minTaskSize); | |||||
// Compute parallel block size from the parallel problem size: | // Compute parallel block size from the parallel problem size: | ||||
// blockSize = min(tripCount, | // blockSize = min(tripCount, | ||||
// max(ceil_div(tripCount, maxComputeBlocks), | // max(ceil_div(tripCount, maxComputeBlocks), | ||||
// ceil_div(minTaskSize, bodySize))) | // minTaskSize)) | ||||
Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks); | Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks); | ||||
Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSizeCst); | Value bs1 = b.create<arith::MaxSIOp>(bs0, computeMinTaskSize(b, op)); | ||||
Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1); | Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1); | ||||
Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize); | Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize); | ||||
// Create a parallel compute function that takes a block id and computes the | |||||
// parallel operation body for a subset of iteration space. | |||||
ParallelComputeFunction parallelComputeFunction = | |||||
createParallelComputeFunction(op, rewriter); | |||||
// Dispatch parallel compute function using async recursive work splitting, | // Dispatch parallel compute function using async recursive work splitting, | ||||
ezhulenev: Computing `numUnrollableLoops` as `Value` (compute function argument) + runtime `scf.if` in the… | |||||
Done, PTAL. bakhtiyarneyman: Done, PTAL. | |||||
// or by submitting compute task sequentially from a caller thread. | // or by submitting compute task sequentially from a caller thread. | ||||
if (asyncDispatch) { | if (asyncDispatch) { | ||||
doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize, | doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize, | ||||
blockCount, tripCounts); | blockCount, tripCounts); | ||||
} else { | } else { | ||||
doSequentialDispatch(b, rewriter, parallelComputeFunction, op, blockSize, | doSequentialDispatch(b, rewriter, parallelComputeFunction, op, blockSize, | ||||
blockCount, tripCounts); | blockCount, tripCounts); | ||||
} | } | ||||
nb.create<scf::YieldOp>(); | nb.create<scf::YieldOp>(); | ||||
}; | }; | ||||
You're creating a value dynamicShouldUnroll here which isn't used in the else branch below, can you sink this in the then branch? mehdi_amini: You're creating a value `dynamicShouldUnroll` here which isn't used in the else branch below… | |||||
// Replace the `scf.parallel` operation with the parallel compute function. | // Replace the `scf.parallel` operation with the parallel compute function. | ||||
b.create<scf::IfOp>(TypeRange(), isZeroIterations, noOp, dispatch); | b.create<scf::IfOp>(TypeRange(), isZeroIterations, noOp, dispatch); | ||||
// Parallel operation was replaced with a block iteration loop. | // Parallel operation was replaced with a block iteration loop. | ||||
rewriter.eraseOp(op); | rewriter.eraseOp(op); | ||||
return success(); | return success(); | ||||
} | } | ||||
void AsyncParallelForPass::runOnOperation() { | void AsyncParallelForPass::runOnOperation() { | ||||
MLIRContext *ctx = &getContext(); | MLIRContext *ctx = &getContext(); | ||||
RewritePatternSet patterns(ctx); | RewritePatternSet patterns(ctx); | ||||
Not Done ReplyInline Actionsnit: I'd move this lambda close to the dispatchNotUnrollable to reduce the nesting, and put similar things together, although it's used only inside one branch of the if. I think it's ok to put createParallelComputeFunction inside lamdba, so you don't create aligned compute function if you'll not need it. ezhulenev: nit: I'd move this lambda close to the `dispatchNotUnrollable` to reduce the nesting, and put… | |||||
patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads, | populateAsyncParallelForPatterns( | ||||
minTaskSize); | patterns, asyncDispatch, numWorkerThreads, | ||||
[&](ImplicitLocOpBuilder builder, scf::ParallelOp op) { | |||||
return builder.create<arith::ConstantIndexOp>(minTaskSize); | |||||
}); | |||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) | if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) | ||||
signalPassFailure(); | signalPassFailure(); | ||||
} | } | ||||
std::unique_ptr<Pass> mlir::createAsyncParallelForPass() { | std::unique_ptr<Pass> mlir::createAsyncParallelForPass() { | ||||
return std::make_unique<AsyncParallelForPass>(); | return std::make_unique<AsyncParallelForPass>(); | ||||
} | } | ||||
std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch, | std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch, | ||||
int32_t numWorkerThreads, | int32_t numWorkerThreads, | ||||
int32_t minTaskSize) { | int32_t minTaskSize) { | ||||
return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads, | return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads, | ||||
minTaskSize); | minTaskSize); | ||||
} | } | ||||
void mlir::async::populateAsyncParallelForPatterns( | |||||
RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, | |||||
AsyncMinTaskSizeComputationFunction computeMinTaskSize) { | |||||
MLIRContext *ctx = patterns.getContext(); | |||||
patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads, | |||||
computeMinTaskSize); | |||||
} |
Computing numUnrollableLoops as Value (compute function argument) + runtime scf.if in the loop nest prevents LLVM from loop unrolling and vectorization, and it leads to large regressions in: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfrt/benchmarks/compute_function_benchmark.cc
numUnrollableLoops should be known at compiled time, with dynamic minTaskSize the structure should look like this:
There is no real need of benefit of computing numUnrollableLoop as Value, because unrolling/vectorization can only happen in trip counts (loop bounds) are know at compile time.