diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -21,7 +21,7 @@ std::unique_ptr createAsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads, - int32_t targetBlockSize); + int32_t minTaskSize); std::unique_ptr> createAsyncToAsyncRuntimePass(); diff --git a/mlir/include/mlir/Dialect/Async/Transforms.h b/mlir/include/mlir/Dialect/Async/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Async/Transforms.h @@ -0,0 +1,40 @@ +//===- Transforms.h - Async dialect transformation utilities ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines transformations on Async operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ASYNC_TRANSFORMS_H_ +#define MLIR_DIALECT_ASYNC_TRANSFORMS_H_ + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" + +namespace mlir { +namespace async { + +/// Emit the IR to compute the minimum number of iterations of scf.parallel body +/// that would be viable for a single parallel task. Allows the user to avoid +/// incurring the overheads of spawning costly parallel tasks in absence of +/// sufficient amount of parallelizable work. +/// +/// Must return an index type. +using AsyncMinTaskSizeComputationFunction = + std::function; + +/// Add a pattern to the given pattern list to lower scf.parallel to async +/// operations. +void populateAsyncParallelForPatterns( + RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, + AsyncMinTaskSizeComputationFunction computeMinTaskSize); + +} // namespace async +} // namespace mlir + +#endif // MLIR_DIALECT_ASYNC_TRANSFORMS_H_ diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/Async/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -105,10 +106,12 @@ struct AsyncParallelForRewrite : public OpRewritePattern { public: - AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch, - int32_t numWorkerThreads, int32_t minTaskSize) + AsyncParallelForRewrite( + MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads, + AsyncMinTaskSizeComputationFunction computeMinTaskSize) : OpRewritePattern(ctx), asyncDispatch(asyncDispatch), - numWorkerThreads(numWorkerThreads), minTaskSize(minTaskSize) {} + numWorkerThreads(numWorkerThreads), + computeMinTaskSize(computeMinTaskSize) {} LogicalResult matchAndRewrite(scf::ParallelOp op, PatternRewriter &rewriter) const override; @@ -116,7 +119,7 @@ private: bool asyncDispatch; int32_t numWorkerThreads; - int32_t minTaskSize; + AsyncMinTaskSizeComputationFunction computeMinTaskSize; }; struct ParallelComputeFunctionType { @@ -128,6 +131,7 @@ struct ParallelComputeFunctionArgs { BlockArgument blockIndex(); BlockArgument blockSize(); + BlockArgument numBlockAlignedInnerLoops(); ArrayRef tripCounts(); ArrayRef lowerBounds(); ArrayRef upperBounds(); @@ -155,25 +159,28 @@ BlockArgument ParallelComputeFunctionArgs::blockIndex() { return args[0]; } BlockArgument ParallelComputeFunctionArgs::blockSize() { return args[1]; } +BlockArgument ParallelComputeFunctionArgs::numBlockAlignedInnerLoops() { + return args[2]; +} ArrayRef ParallelComputeFunctionArgs::tripCounts() { - return args.drop_front(2).take_front(numLoops); + return args.drop_front(3).take_front(numLoops); } ArrayRef ParallelComputeFunctionArgs::lowerBounds() { - return args.drop_front(2 + 1 * numLoops).take_front(numLoops); + return args.drop_front(3 + 1 * numLoops).take_front(numLoops); } ArrayRef ParallelComputeFunctionArgs::upperBounds() { - return args.drop_front(2 + 2 * numLoops).take_front(numLoops); + return args.drop_front(3 + 2 * numLoops).take_front(numLoops); } ArrayRef ParallelComputeFunctionArgs::steps() { - return args.drop_front(2 + 3 * numLoops).take_front(numLoops); + return args.drop_front(3 + 3 * numLoops).take_front(numLoops); } ArrayRef ParallelComputeFunctionArgs::captures() { - return args.drop_front(2 + 4 * numLoops); + return args.drop_front(3 + 4 * numLoops); } template @@ -216,6 +223,7 @@ // One-dimensional iteration space defined by the block index and size. inputs.push_back(indexTy); // blockIndex inputs.push_back(indexTy); // blockSize + inputs.push_back(indexTy); // numBlockAlignedInnerLoops // Multi-dimensional parallel iteration space defined by the loop trip counts. for (unsigned i = 0; i < op.getNumLoops(); ++i) @@ -240,9 +248,10 @@ } // Create a parallel compute fuction from the parallel operation. -static ParallelComputeFunction createParallelComputeFunction( - scf::ParallelOp op, ParallelComputeFunctionBounds bounds, - unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter) { +static ParallelComputeFunction +createParallelComputeFunction(scf::ParallelOp op, + ParallelComputeFunctionBounds bounds, + PatternRewriter &rewriter) { OpBuilder::InsertionGuard guard(rewriter); ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -269,6 +278,7 @@ // Block iteration position defined by the block index and size. BlockArgument blockIndex = args.blockIndex(); BlockArgument blockSize = args.blockSize(); + BlockArgument numBlockAlignedInnerLoops = args.numBlockAlignedInnerLoops(); // Constants used below. Value c0 = b.create(0); @@ -359,9 +369,9 @@ // Builds inner loop nest inside async.execute operation that does all the // work concurrently. LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder { - return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv, + return [&, loopIdx](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange args) { - ImplicitLocOpBuilder nb(loc, nestedBuilder); + ImplicitLocOpBuilder nb(nestedLoc, nestedBuilder); // Compute induction variable for `loopIdx`. computeBlockInductionVars[loopIdx] = nb.create( @@ -383,27 +393,37 @@ // Keep building loop nest. if (loopIdx < op.getNumLoops() - 1) { - if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) { - // For block aligned loops we always iterate starting from 0 up to - // the loop trip counts. - nb.create(c0, tripCounts[loopIdx + 1], c1, ValueRange(), - workLoopBuilder(loopIdx + 1)); - - } else { - // Select nested loop lower/upper bounds depending on our position in - // the multi-dimensional iteration space. - auto lb = nb.create(isBlockFirstCoord[loopIdx], - blockFirstCoord[loopIdx + 1], c0); - - auto ub = nb.create(isBlockLastCoord[loopIdx], - blockEndCoord[loopIdx + 1], - tripCounts[loopIdx + 1]); - - nb.create(lb, ub, c1, ValueRange(), - workLoopBuilder(loopIdx + 1)); - } - - nb.create(loc); + auto inner = nb.create( + arith::CmpIPredicate::sge, numBlockAlignedInnerLoops, + nb.create(op.getNumLoops() - loopIdx - 1)); + nb.create( + TypeRange(), inner, + [&](OpBuilder &doublyNestedBuilder, Location doublyNestedLoc) { + ImplicitLocOpBuilder dnb(doublyNestedLoc, doublyNestedBuilder); + // For block aligned loops we always iterate starting from 0 up to + // the loop trip counts. + dnb.create(c0, tripCounts[loopIdx + 1], c1, + ValueRange(), + workLoopBuilder(loopIdx + 1)); + return dnb.create(); + }, + [&](OpBuilder &doublyNestedBuilder, Location doublyNestedLoc) { + ImplicitLocOpBuilder dnb(doublyNestedLoc, doublyNestedBuilder); + // Select nested loop lower/upper bounds depending on our position + // in the multi-dimensional iteration space. + auto lb = dnb.create(isBlockFirstCoord[loopIdx], + blockFirstCoord[loopIdx + 1], c0); + + auto ub = dnb.create(isBlockLastCoord[loopIdx], + blockEndCoord[loopIdx + 1], + tripCounts[loopIdx + 1]); + + dnb.create(lb, ub, c1, ValueRange(), + workLoopBuilder(loopIdx + 1)); + return dnb.create(); + }); + + nb.create(nestedLoc); return; } @@ -559,7 +579,7 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, - Value blockCount, + Value blockCount, Value numBlockAlignedInnerLoops, const SmallVector &tripCounts) { MLIRContext *ctx = op->getContext(); @@ -591,7 +611,7 @@ ImplicitLocOpBuilder nb(loc, nestedBuilder); // Call parallel compute function for the single block. - SmallVector operands = {c0, blockSize}; + SmallVector operands = {c0, blockSize, numBlockAlignedInnerLoops}; appendBlockComputeOperands(operands); nb.create(parallelComputeFunction.func.sym_name(), @@ -610,7 +630,9 @@ ImplicitLocOpBuilder nb(loc, nestedBuilder); // Launch async dispatch function for [0, blockCount) range. - SmallVector operands = {group, c0, blockCount, blockSize}; + SmallVector operands = { + group, c0, blockCount, blockSize, numBlockAlignedInnerLoops, + }; appendBlockComputeOperands(operands); nb.create(asyncDispatchFunction.sym_name(), @@ -632,6 +654,7 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, + Value numBlockAlignedInnerLoops, const SmallVector &tripCounts) { MLIRContext *ctx = op->getContext(); @@ -652,7 +675,8 @@ // Returns parallel compute function operands to process the given block. auto computeFuncOperands = [&](Value blockIndex) -> SmallVector { - SmallVector computeFuncOperands = {blockIndex, blockSize}; + SmallVector computeFuncOperands = {blockIndex, blockSize, + numBlockAlignedInnerLoops}; computeFuncOperands.append(tripCounts); computeFuncOperands.append(op.lowerBound().begin(), op.lowerBound().end()); computeFuncOperands.append(op.upperBound().begin(), op.upperBound().end()); @@ -752,7 +776,7 @@ }; // Find how many inner iteration dimensions are statically known, and their - // product is smaller than the `512`. We aling the parallel compute block + // product is smaller than the `512`. We align the parallel compute block // size by the product of statically known dimensions, so that we can // guarantee that the inner loops executes from 0 to the loop trip counts // and we can elide dynamic loop boundaries, and give LLVM an opportunity to @@ -793,47 +817,67 @@ Value maxComputeBlocks = b.create( std::max(1, static_cast(numWorkerThreads * overshardingFactor))); - // Target block size from the pass parameters. - Value minTaskSizeCst = b.create(minTaskSize); - // Compute parallel block size from the parallel problem size: // blockSize = min(tripCount, // max(ceil_div(tripCount, maxComputeBlocks), - // ceil_div(minTaskSize, bodySize))) + // minTaskSize)) Value bs0 = b.create(tripCount, maxComputeBlocks); - Value bs1 = b.create(bs0, minTaskSizeCst); + Value minTaskSize = computeMinTaskSize(b, op); + Value bs1 = b.create(bs0, minTaskSize); Value blockSize = b.create(tripCount, bs1); - // Align the block size to be a multiple of the statically known number - // of iterations in the inner loops. - if (numUnrollableLoops > 0 && minTaskSize >= maxIterations) { - Value numIters = b.create( - numIterations[op.getNumLoops() - numUnrollableLoops]); - Value bs2 = b.create( - b.create(blockSize, numIters), numIters); - blockSize = b.create(tripCount, bs2); - } else { - // Reset the number of unrollable loops if we didn't align the block size. - numUnrollableLoops = 0; - } - - // Compute the number of parallel compute blocks. - Value blockCount = b.create(tripCount, blockSize); + // Unroll when numUnrollableLoops > 0 && minTaskSize >= maxIterations. + // Value shouldUnroll = b.create( + // b.create( + // arith::CmpIPredicate::sgt, + // b.create(numUnrollableLoops), c0), + // b.create( + // arith::CmpIPredicate::sge, minTaskSize, + // b.create(maxIterations))); + Value shouldUnroll = + b.create(b.getI1Type(), b.getBoolAttr(false)); + MLIRContext *ctx = op->getContext(); + auto numUnrollableLoopsAndBlockSize = b.create( + TypeRange{IndexType::get(ctx), IndexType::get(ctx)}, shouldUnroll, + [&](OpBuilder &nestedBuilder, Location loc) { + ImplicitLocOpBuilder nb(loc, nestedBuilder); + // Align the block size to be a multiple of the statically known + // number of iterations in the inner loops. + Value numIters = nb.create( + numIterations[op.getNumLoops() - numUnrollableLoops]); + Value bs2 = nb.create( + nb.create(blockSize, numIters), numIters); + return nb.create( + ValueRange{nb.create(numUnrollableLoops), + nb.create(tripCount, bs2)}); + }, + [&](OpBuilder &nestedBuilder, Location loc) { + ImplicitLocOpBuilder nb(loc, nestedBuilder); + // Reset the number of unrollable loops if we didn't align the block + // size. + return nb.create(ValueRange{c0, blockSize}); + }); + + blockSize = numUnrollableLoopsAndBlockSize.getResult(1); // Create a parallel compute function that takes a block id and computes // the parallel operation body for a subset of iteration space. ParallelComputeFunction parallelComputeFunction = - createParallelComputeFunction(op, staticBounds, numUnrollableLoops, - rewriter); + createParallelComputeFunction(op, staticBounds, rewriter); + + // Compute the number of parallel compute blocks. + Value blockCount = b.create(tripCount, blockSize); // Dispatch parallel compute function using async recursive work splitting, // or by submitting compute task sequentially from a caller thread. if (asyncDispatch) { doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize, - blockCount, tripCounts); + blockCount, numUnrollableLoopsAndBlockSize.getResult(0), + tripCounts); } else { - doSequentialDispatch(b, rewriter, parallelComputeFunction, op, blockSize, - blockCount, tripCounts); + doSequentialDispatch( + b, rewriter, parallelComputeFunction, op, blockSize, blockCount, + numUnrollableLoopsAndBlockSize.getResult(0), tripCounts); } nb.create(); @@ -852,9 +896,11 @@ MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx, asyncDispatch, numWorkerThreads, - minTaskSize); - + populateAsyncParallelForPatterns( + patterns, asyncDispatch, numWorkerThreads, + [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) { + return builder.create(minTaskSize); + }); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } @@ -869,3 +915,11 @@ return std::make_unique(asyncDispatch, numWorkerThreads, minTaskSize); } + +void mlir::async::populateAsyncParallelForPatterns( + RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, + AsyncMinTaskSizeComputationFunction computeMinTaskSize) { + MLIRContext *ctx = patterns.getContext(); + patterns.add(ctx, asyncDispatch, numWorkerThreads, + computeMinTaskSize); +} diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir @@ -55,6 +55,7 @@ // CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index, // CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index, // CHECK-SAME: %[[TRIP_COUNT:arg[0-9]+]]: index, +// CHECK-SAME: %[[NUM_BLOCK_ALIGNED_INNER_LOOPS:arg[0-9]+]]: index, // CHECK-SAME: %[[LB:arg[0-9]+]]: index, // CHECK-SAME: %[[UB:arg[0-9]+]]: index, // CHECK-SAME: %[[STEP:arg[0-9]+]]: index, @@ -105,4 +106,4 @@ // CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK: scf.for %[[I:arg[0-9]+]] // CHECK-NOT: select -// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1 \ No newline at end of file +// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1 diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2018,7 +2018,10 @@ "lib/Dialect/Async/Transforms/*.cpp", "lib/Dialect/Async/Transforms/*.h", ]), - hdrs = ["include/mlir/Dialect/Async/Passes.h"], + hdrs = [ + "include/mlir/Dialect/Async/Passes.h", + "include/mlir/Dialect/Async/Transforms.h", + ], includes = ["include"], deps = [ ":Analysis",