diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -653,9 +653,19 @@ for (size_t i = 1; i < tripCounts.size(); ++i) tripCount = b.create(tripCount, tripCounts[i]); + // With large number of threads the value of creating many compute blocks + // is reduced because the problem typically becomes memory bound. For small + // number of threads it helps with stragglers. + float overshardingFactor = numWorkerThreads <= 4 ? 8.0 + : numWorkerThreads <= 8 ? 4.0 + : numWorkerThreads <= 16 ? 2.0 + : numWorkerThreads <= 32 ? 1.0 + : numWorkerThreads <= 64 ? 0.8 + : 0.6; + // Do not overload worker threads with too many compute blocks. - Value maxComputeBlocks = - b.create(numWorkerThreads * kMaxOversharding); + Value maxComputeBlocks = b.create( + std::max(1, static_cast(numWorkerThreads * overshardingFactor))); // Target block size from the pass parameters. Value targetComputeBlockSize = b.create(targetBlockSize); @@ -668,7 +678,11 @@ Value bs1 = b.create(CmpIPredicate::sge, bs0, targetComputeBlockSize); Value bs2 = b.create(bs1, bs0, targetComputeBlockSize); Value bs3 = b.create(CmpIPredicate::sle, tripCount, bs2); - Value blockSize = b.create(bs3, tripCount, bs2); + Value blockSize0 = b.create(bs3, tripCount, bs2); + Value blockCount0 = b.create(tripCount, blockSize0); + + // Compute balanced block size for the estimated block count. + Value blockSize = b.create(tripCount, blockCount0); Value blockCount = b.create(tripCount, blockSize); // Create a parallel compute function that takes a block id and computes the