diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -21,7 +21,7 @@ std::unique_ptr createAsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads, - int32_t targetBlockSize); + int32_t minTaskSize); std::unique_ptr> createAsyncToAsyncRuntimePass(); diff --git a/mlir/include/mlir/Dialect/Async/Transforms.h b/mlir/include/mlir/Dialect/Async/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Async/Transforms.h @@ -0,0 +1,40 @@ +//===- Transforms.h - Async dialect transformation utilities ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines transformations on Async operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ASYNC_TRANSFORMS_H_ +#define MLIR_DIALECT_ASYNC_TRANSFORMS_H_ + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" + +namespace mlir { +namespace async { + +/// Emit the IR to compute the minimum number of iterations of scf.parallel body +/// that would be viable for a single parallel task. Allows the user to avoid +/// incurring the overheads of spawning costly parallel tasks in absence of +/// sufficient amount of parallelizable work. +/// +/// Must return an index type. +using AsyncMinTaskSizeComputationFunction = + std::function; + +/// Add a pattern to the given pattern list to lower scf.parallel to async +/// operations. +void populateAsyncParallelForPatterns( + RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, + AsyncMinTaskSizeComputationFunction computeMinTaskSize); + +} // namespace async +} // namespace mlir + +#endif // MLIR_DIALECT_ASYNC_TRANSFORMS_H_ diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/Async/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -105,10 +106,12 @@ struct AsyncParallelForRewrite : public OpRewritePattern { public: - AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch, - int32_t numWorkerThreads, int32_t minTaskSize) + AsyncParallelForRewrite( + MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads, + AsyncMinTaskSizeComputationFunction computeMinTaskSize) : OpRewritePattern(ctx), asyncDispatch(asyncDispatch), - numWorkerThreads(numWorkerThreads), minTaskSize(minTaskSize) {} + numWorkerThreads(numWorkerThreads), + computeMinTaskSize(computeMinTaskSize) {} LogicalResult matchAndRewrite(scf::ParallelOp op, PatternRewriter &rewriter) const override; @@ -116,7 +119,7 @@ private: bool asyncDispatch; int32_t numWorkerThreads; - int32_t minTaskSize; + AsyncMinTaskSizeComputationFunction computeMinTaskSize; }; struct ParallelComputeFunctionType { @@ -252,7 +255,11 @@ getParallelComputeFunctionType(op, rewriter); FunctionType type = computeFuncType.type; - FuncOp func = FuncOp::create(op.getLoc(), "parallel_compute_fn", type); + FuncOp func = FuncOp::create(op.getLoc(), + numBlockAlignedInnerLoops > 0 + ? "parallel_compute_fn_with_aligned_loops" + : "parallel_compute_fn", + type); func.setPrivate(); // Insert function into the module symbol table and assign it unique name. @@ -752,7 +759,7 @@ }; // Find how many inner iteration dimensions are statically known, and their - // product is smaller than the `512`. We aling the parallel compute block + // product is smaller than the `512`. We align the parallel compute block // size by the product of statically known dimensions, so that we can // guarantee that the inner loops executes from 0 to the loop trip counts // and we can elide dynamic loop boundaries, and give LLVM an opportunity to @@ -793,50 +800,65 @@ Value maxComputeBlocks = b.create( std::max(1, static_cast(numWorkerThreads * overshardingFactor))); - // Target block size from the pass parameters. - Value minTaskSizeCst = b.create(minTaskSize); - // Compute parallel block size from the parallel problem size: // blockSize = min(tripCount, // max(ceil_div(tripCount, maxComputeBlocks), - // ceil_div(minTaskSize, bodySize))) + // minTaskSize)) Value bs0 = b.create(tripCount, maxComputeBlocks); - Value bs1 = b.create(bs0, minTaskSizeCst); + Value minTaskSize = computeMinTaskSize(b, op); + Value bs1 = b.create(bs0, minTaskSize); Value blockSize = b.create(tripCount, bs1); - // Align the block size to be a multiple of the statically known number - // of iterations in the inner loops. - if (numUnrollableLoops > 0 && minTaskSize >= maxIterations) { - Value numIters = b.create( - numIterations[op.getNumLoops() - numUnrollableLoops]); - Value bs2 = b.create( - b.create(blockSize, numIters), numIters); - blockSize = b.create(tripCount, bs2); - } else { - // Reset the number of unrollable loops if we didn't align the block size. - numUnrollableLoops = 0; - } + ParallelComputeFunction notUnrollableParallelComputeFunction = + createParallelComputeFunction(op, staticBounds, 0, rewriter); - // Compute the number of parallel compute blocks. - Value blockCount = b.create(tripCount, blockSize); + // Dispatch parallel compute function using async recursive work splitting, + // or by submitting compute task sequentially from a caller thread. + auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch; // Create a parallel compute function that takes a block id and computes // the parallel operation body for a subset of iteration space. - ParallelComputeFunction parallelComputeFunction = - createParallelComputeFunction(op, staticBounds, numUnrollableLoops, - rewriter); - // Dispatch parallel compute function using async recursive work splitting, - // or by submitting compute task sequentially from a caller thread. - if (asyncDispatch) { - doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize, - blockCount, tripCounts); + // Compute the number of parallel compute blocks. + Value blockCount = b.create(tripCount, blockSize); + + // Unroll when numUnrollableLoops > 0 && blockSize >= maxIterations. + bool staticShouldUnroll = numUnrollableLoops > 0; + auto dispatchNotUnrollable = [&](OpBuilder &nestedBuilder, Location loc) { + ImplicitLocOpBuilder nb(loc, nestedBuilder); + doDispatch(b, rewriter, notUnrollableParallelComputeFunction, op, + blockSize, blockCount, tripCounts); + return nb.create(); + }; + + if (staticShouldUnroll) { + Value dynamicShouldUnroll = b.create( + arith::CmpIPredicate::sge, blockSize, + b.create(maxIterations)); + + ParallelComputeFunction unrollableParallelComputeFunction = + createParallelComputeFunction(op, staticBounds, numUnrollableLoops, + rewriter); + + auto dispatchUnrollable = [&](OpBuilder &nestedBuilder, Location loc) { + ImplicitLocOpBuilder nb(loc, nestedBuilder); + // Align the block size to be a multiple of the statically known + // number of iterations in the inner loops. + Value numIters = nb.create( + numIterations[op.getNumLoops() - numUnrollableLoops]); + Value alignedBlockSize = nb.create( + nb.create(blockSize, numIters), numIters); + doDispatch(b, rewriter, unrollableParallelComputeFunction, op, + alignedBlockSize, blockCount, tripCounts); + return nb.create(); + }; + + b.create(TypeRange(), dynamicShouldUnroll, dispatchUnrollable, + dispatchNotUnrollable); + nb.create(); } else { - doSequentialDispatch(b, rewriter, parallelComputeFunction, op, blockSize, - blockCount, tripCounts); + dispatchNotUnrollable(nb, loc); } - - nb.create(); }; // Replace the `scf.parallel` operation with the parallel compute function. @@ -852,9 +874,11 @@ MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx, asyncDispatch, numWorkerThreads, - minTaskSize); - + populateAsyncParallelForPatterns( + patterns, asyncDispatch, numWorkerThreads, + [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) { + return builder.create(minTaskSize); + }); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } @@ -869,3 +893,11 @@ return std::make_unique(asyncDispatch, numWorkerThreads, minTaskSize); } + +void mlir::async::populateAsyncParallelForPatterns( + RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, + AsyncMinTaskSizeComputationFunction computeMinTaskSize) { + MLIRContext *ctx = patterns.getContext(); + patterns.add(ctx, asyncDispatch, numWorkerThreads, + computeMinTaskSize); +} diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir @@ -100,9 +100,27 @@ // CHECK-SAME: %[[STEP1:arg[0-9]+]]: index, // CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref // CHECK-SAME: ) { +// CHECK: scf.for %[[I:arg[0-9]+]] +// CHECK: select +// CHECK: scf.for %[[J:arg[0-9]+]] +// CHECK: memref.store + +// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops( +// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index, +// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index, +// CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index, +// CHECK-SAME: %[[TRIP_COUNT1:arg[0-9]+]]: index, +// CHECK-SAME: %[[LB0:arg[0-9]+]]: index, +// CHECK-SAME: %[[LB1:arg[0-9]+]]: index, +// CHECK-SAME: %[[UB0:arg[0-9]+]]: index, +// CHECK-SAME: %[[UB1:arg[0-9]+]]: index, +// CHECK-SAME: %[[STEP0:arg[0-9]+]]: index, +// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index, +// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref +// CHECK-SAME: ) { // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK: scf.for %[[I:arg[0-9]+]] // CHECK-NOT: select -// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1 \ No newline at end of file +// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1 diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2017,7 +2017,10 @@ "lib/Dialect/Async/Transforms/*.cpp", "lib/Dialect/Async/Transforms/*.h", ]), - hdrs = ["include/mlir/Dialect/Async/Passes.h"], + hdrs = [ + "include/mlir/Dialect/Async/Passes.h", + "include/mlir/Dialect/Async/Transforms.h", + ], includes = ["include"], deps = [ ":Analysis",