Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
//===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===// | //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===// | ||||
// | // | ||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
// See https://llvm.org/LICENSE.txt for license information. | // See https://llvm.org/LICENSE.txt for license information. | ||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
// | // | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// | // | ||||
// This file implements scf.parallel to scf.for + async.execute conversion pass. | // This file implements scf.parallel to scf.for + async.execute conversion pass. | ||||
// | // | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
#include "PassDetail.h" | #include "PassDetail.h" | ||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" | #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" | ||||
#include "mlir/Dialect/Async/IR/Async.h" | #include "mlir/Dialect/Async/IR/Async.h" | ||||
#include "mlir/Dialect/Async/Passes.h" | #include "mlir/Dialect/Async/Passes.h" | ||||
#include "mlir/Dialect/Async/Transforms.h" | |||||
#include "mlir/Dialect/SCF/SCF.h" | #include "mlir/Dialect/SCF/SCF.h" | ||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" | #include "mlir/Dialect/StandardOps/IR/Ops.h" | ||||
#include "mlir/IR/BlockAndValueMapping.h" | #include "mlir/IR/BlockAndValueMapping.h" | ||||
#include "mlir/IR/ImplicitLocOpBuilder.h" | #include "mlir/IR/ImplicitLocOpBuilder.h" | ||||
#include "mlir/IR/Matchers.h" | #include "mlir/IR/Matchers.h" | ||||
#include "mlir/IR/PatternMatch.h" | #include "mlir/IR/PatternMatch.h" | ||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||||
#include "mlir/Transforms/RegionUtils.h" | #include "mlir/Transforms/RegionUtils.h" | ||||
▲ Show 20 Lines • Show All 75 Lines • ▼ Show 20 Lines | AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads, | ||||
this->minTaskSize = minTaskSize; | this->minTaskSize = minTaskSize; | ||||
} | } | ||||
void runOnOperation() override; | void runOnOperation() override; | ||||
}; | }; | ||||
struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> { | struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> { | ||||
public: | public: | ||||
AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch, | AsyncParallelForRewrite( | ||||
int32_t numWorkerThreads, int32_t minTaskSize) | MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads, | ||||
AsyncMinTaskSizeComputationFunction computeMinTaskSize) | |||||
: OpRewritePattern(ctx), asyncDispatch(asyncDispatch), | : OpRewritePattern(ctx), asyncDispatch(asyncDispatch), | ||||
numWorkerThreads(numWorkerThreads), minTaskSize(minTaskSize) {} | numWorkerThreads(numWorkerThreads), | ||||
computeMinTaskSize(computeMinTaskSize) {} | |||||
LogicalResult matchAndRewrite(scf::ParallelOp op, | LogicalResult matchAndRewrite(scf::ParallelOp op, | ||||
PatternRewriter &rewriter) const override; | PatternRewriter &rewriter) const override; | ||||
private: | private: | ||||
bool asyncDispatch; | bool asyncDispatch; | ||||
int32_t numWorkerThreads; | int32_t numWorkerThreads; | ||||
int32_t minTaskSize; | AsyncMinTaskSizeComputationFunction computeMinTaskSize; | ||||
}; | }; | ||||
struct ParallelComputeFunctionType { | struct ParallelComputeFunctionType { | ||||
FunctionType type; | FunctionType type; | ||||
SmallVector<Value> captures; | SmallVector<Value> captures; | ||||
}; | }; | ||||
// Helper struct to parse parallel compute function argument list. | // Helper struct to parse parallel compute function argument list. | ||||
▲ Show 20 Lines • Show All 119 Lines • ▼ Show 20 Lines | static ParallelComputeFunction createParallelComputeFunction( | ||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||||
ModuleOp module = op->getParentOfType<ModuleOp>(); | ModuleOp module = op->getParentOfType<ModuleOp>(); | ||||
ParallelComputeFunctionType computeFuncType = | ParallelComputeFunctionType computeFuncType = | ||||
getParallelComputeFunctionType(op, rewriter); | getParallelComputeFunctionType(op, rewriter); | ||||
FunctionType type = computeFuncType.type; | FunctionType type = computeFuncType.type; | ||||
FuncOp func = FuncOp::create(op.getLoc(), "parallel_compute_fn", type); | FuncOp func = FuncOp::create(op.getLoc(), | ||||
numBlockAlignedInnerLoops > 0 | |||||
? "parallel_compute_fn_with_aligned_loops" | |||||
: "parallel_compute_fn", | |||||
type); | |||||
func.setPrivate(); | func.setPrivate(); | ||||
// Insert function into the module symbol table and assign it unique name. | // Insert function into the module symbol table and assign it unique name. | ||||
SymbolTable symbolTable(module); | SymbolTable symbolTable(module); | ||||
symbolTable.insert(func); | symbolTable.insert(func); | ||||
rewriter.getListener()->notifyOperationInserted(func); | rewriter.getListener()->notifyOperationInserted(func); | ||||
// Create function entry block. | // Create function entry block. | ||||
▲ Show 20 Lines • Show All 483 Lines • ▼ Show 20 Lines | auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) { | ||||
ParallelComputeFunctionBounds staticBounds = { | ParallelComputeFunctionBounds staticBounds = { | ||||
integerConstants(tripCounts), | integerConstants(tripCounts), | ||||
integerConstants(op.lowerBound()), | integerConstants(op.lowerBound()), | ||||
integerConstants(op.upperBound()), | integerConstants(op.upperBound()), | ||||
integerConstants(op.step()), | integerConstants(op.step()), | ||||
}; | }; | ||||
// Find how many inner iteration dimensions are statically known, and their | // Find how many inner iteration dimensions are statically known, and their | ||||
// product is smaller than the `512`. We aling the parallel compute block | // product is smaller than the `512`. We align the parallel compute block | ||||
// size by the product of statically known dimensions, so that we can | // size by the product of statically known dimensions, so that we can | ||||
// guarantee that the inner loops executes from 0 to the loop trip counts | // guarantee that the inner loops executes from 0 to the loop trip counts | ||||
// and we can elide dynamic loop boundaries, and give LLVM an opportunity to | // and we can elide dynamic loop boundaries, and give LLVM an opportunity to | ||||
// unroll the loops. The constant `512` is arbitrary, it should depend on | // unroll the loops. The constant `512` is arbitrary, it should depend on | ||||
// how many iterations LLVM will typically decide to unroll. | // how many iterations LLVM will typically decide to unroll. | ||||
static constexpr int64_t maxIterations = 512; | static constexpr int64_t maxIterations = 512; | ||||
// The number of inner loops with statically known number of iterations less | // The number of inner loops with statically known number of iterations less | ||||
Show All 24 Lines | float overshardingFactor = numWorkerThreads <= 4 ? 8.0 | ||||
: numWorkerThreads <= 32 ? 1.0 | : numWorkerThreads <= 32 ? 1.0 | ||||
: numWorkerThreads <= 64 ? 0.8 | : numWorkerThreads <= 64 ? 0.8 | ||||
: 0.6; | : 0.6; | ||||
// Do not overload worker threads with too many compute blocks. | // Do not overload worker threads with too many compute blocks. | ||||
Value maxComputeBlocks = b.create<arith::ConstantIndexOp>( | Value maxComputeBlocks = b.create<arith::ConstantIndexOp>( | ||||
std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor))); | std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor))); | ||||
// Target block size from the pass parameters. | |||||
Value minTaskSizeCst = b.create<arith::ConstantIndexOp>(minTaskSize); | |||||
// Compute parallel block size from the parallel problem size: | // Compute parallel block size from the parallel problem size: | ||||
// blockSize = min(tripCount, | // blockSize = min(tripCount, | ||||
// max(ceil_div(tripCount, maxComputeBlocks), | // max(ceil_div(tripCount, maxComputeBlocks), | ||||
// ceil_div(minTaskSize, bodySize))) | // minTaskSize)) | ||||
Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks); | Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks); | ||||
Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSizeCst); | Value minTaskSize = computeMinTaskSize(b, op); | ||||
Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize); | |||||
Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1); | Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1); | ||||
// Align the block size to be a multiple of the statically known number | ParallelComputeFunction notUnrollableParallelComputeFunction = | ||||
// of iterations in the inner loops. | createParallelComputeFunction(op, staticBounds, 0, rewriter); | ||||
if (numUnrollableLoops > 0 && minTaskSize >= maxIterations) { | |||||
Value numIters = b.create<arith::ConstantIndexOp>( | |||||
numIterations[op.getNumLoops() - numUnrollableLoops]); | |||||
Value bs2 = b.create<arith::MulIOp>( | |||||
b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters); | |||||
blockSize = b.create<arith::MinSIOp>(tripCount, bs2); | |||||
} else { | |||||
// Reset the number of unrollable loops if we didn't align the block size. | |||||
numUnrollableLoops = 0; | |||||
} | |||||
// Compute the number of parallel compute blocks. | // Dispatch parallel compute function using async recursive work splitting, | ||||
Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize); | // or by submitting compute task sequentially from a caller thread. | ||||
auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch; | |||||
// Create a parallel compute function that takes a block id and computes | // Create a parallel compute function that takes a block id and computes | ||||
// the parallel operation body for a subset of iteration space. | // the parallel operation body for a subset of iteration space. | ||||
ParallelComputeFunction parallelComputeFunction = | |||||
ezhulenev: Computing `numUnrollableLoops` as `Value` (compute function argument) + runtime `scf.if` in the… | |||||
Done, PTAL. bakhtiyarneyman: Done, PTAL. | |||||
// Compute the number of parallel compute blocks. | |||||
Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize); | |||||
// Unroll when numUnrollableLoops > 0 && blockSize >= maxIterations. | |||||
bool staticShouldUnroll = numUnrollableLoops > 0; | |||||
auto dispatchNotUnrollable = [&](OpBuilder &nestedBuilder, Location loc) { | |||||
ImplicitLocOpBuilder nb(loc, nestedBuilder); | |||||
doDispatch(b, rewriter, notUnrollableParallelComputeFunction, op, | |||||
You're creating a value dynamicShouldUnroll here which isn't used in the else branch below, can you sink this in the then branch? mehdi_amini: You're creating a value `dynamicShouldUnroll` here which isn't used in the else branch below… | |||||
blockSize, blockCount, tripCounts); | |||||
return nb.create<scf::YieldOp>(); | |||||
}; | |||||
if (staticShouldUnroll) { | |||||
Value dynamicShouldUnroll = b.create<arith::CmpIOp>( | |||||
arith::CmpIPredicate::sge, blockSize, | |||||
b.create<arith::ConstantIndexOp>(maxIterations)); | |||||
ParallelComputeFunction unrollableParallelComputeFunction = | |||||
createParallelComputeFunction(op, staticBounds, numUnrollableLoops, | createParallelComputeFunction(op, staticBounds, numUnrollableLoops, | ||||
rewriter); | rewriter); | ||||
// Dispatch parallel compute function using async recursive work splitting, | auto dispatchUnrollable = [&](OpBuilder &nestedBuilder, Location loc) { | ||||
Not Done ReplyInline Actionsnit: I'd move this lambda close to the dispatchNotUnrollable to reduce the nesting, and put similar things together, although it's used only inside one branch of the if. I think it's ok to put createParallelComputeFunction inside lamdba, so you don't create aligned compute function if you'll not need it. ezhulenev: nit: I'd move this lambda close to the `dispatchNotUnrollable` to reduce the nesting, and put… | |||||
// or by submitting compute task sequentially from a caller thread. | ImplicitLocOpBuilder nb(loc, nestedBuilder); | ||||
if (asyncDispatch) { | // Align the block size to be a multiple of the statically known | ||||
doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize, | // number of iterations in the inner loops. | ||||
blockCount, tripCounts); | Value numIters = nb.create<arith::ConstantIndexOp>( | ||||
} else { | numIterations[op.getNumLoops() - numUnrollableLoops]); | ||||
doSequentialDispatch(b, rewriter, parallelComputeFunction, op, blockSize, | Value alignedBlockSize = nb.create<arith::MulIOp>( | ||||
blockCount, tripCounts); | nb.create<arith::CeilDivSIOp>(blockSize, numIters), numIters); | ||||
} | doDispatch(b, rewriter, unrollableParallelComputeFunction, op, | ||||
alignedBlockSize, blockCount, tripCounts); | |||||
return nb.create<scf::YieldOp>(); | |||||
}; | |||||
b.create<scf::IfOp>(TypeRange(), dynamicShouldUnroll, dispatchUnrollable, | |||||
dispatchNotUnrollable); | |||||
nb.create<scf::YieldOp>(); | nb.create<scf::YieldOp>(); | ||||
} else { | |||||
dispatchNotUnrollable(nb, loc); | |||||
} | |||||
}; | }; | ||||
// Replace the `scf.parallel` operation with the parallel compute function. | // Replace the `scf.parallel` operation with the parallel compute function. | ||||
b.create<scf::IfOp>(TypeRange(), isZeroIterations, noOp, dispatch); | b.create<scf::IfOp>(TypeRange(), isZeroIterations, noOp, dispatch); | ||||
// Parallel operation was replaced with a block iteration loop. | // Parallel operation was replaced with a block iteration loop. | ||||
rewriter.eraseOp(op); | rewriter.eraseOp(op); | ||||
return success(); | return success(); | ||||
} | } | ||||
void AsyncParallelForPass::runOnOperation() { | void AsyncParallelForPass::runOnOperation() { | ||||
MLIRContext *ctx = &getContext(); | MLIRContext *ctx = &getContext(); | ||||
RewritePatternSet patterns(ctx); | RewritePatternSet patterns(ctx); | ||||
patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads, | populateAsyncParallelForPatterns( | ||||
minTaskSize); | patterns, asyncDispatch, numWorkerThreads, | ||||
[&](ImplicitLocOpBuilder builder, scf::ParallelOp op) { | |||||
return builder.create<arith::ConstantIndexOp>(minTaskSize); | |||||
}); | |||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) | if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) | ||||
signalPassFailure(); | signalPassFailure(); | ||||
} | } | ||||
std::unique_ptr<Pass> mlir::createAsyncParallelForPass() { | std::unique_ptr<Pass> mlir::createAsyncParallelForPass() { | ||||
return std::make_unique<AsyncParallelForPass>(); | return std::make_unique<AsyncParallelForPass>(); | ||||
} | } | ||||
std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch, | std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch, | ||||
int32_t numWorkerThreads, | int32_t numWorkerThreads, | ||||
int32_t minTaskSize) { | int32_t minTaskSize) { | ||||
return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads, | return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads, | ||||
minTaskSize); | minTaskSize); | ||||
} | } | ||||
void mlir::async::populateAsyncParallelForPatterns( | |||||
RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, | |||||
AsyncMinTaskSizeComputationFunction computeMinTaskSize) { | |||||
MLIRContext *ctx = patterns.getContext(); | |||||
patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads, | |||||
computeMinTaskSize); | |||||
} |
Computing numUnrollableLoops as Value (compute function argument) + runtime scf.if in the loop nest prevents LLVM from loop unrolling and vectorization, and it leads to large regressions in: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfrt/benchmarks/compute_function_benchmark.cc
numUnrollableLoops should be known at compiled time, with dynamic minTaskSize the structure should look like this:
There is no real need of benefit of computing numUnrollableLoop as Value, because unrolling/vectorization can only happen in trip counts (loop bounds) are know at compile time.