diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -650,52 +650,73 @@ for (size_t i = 1; i < tripCounts.size(); ++i) tripCount = b.create(tripCount, tripCounts[i]); - // With large number of threads the value of creating many compute blocks - // is reduced because the problem typically becomes memory bound. For small - // number of threads it helps with stragglers. - float overshardingFactor = numWorkerThreads <= 4 ? 8.0 - : numWorkerThreads <= 8 ? 4.0 - : numWorkerThreads <= 16 ? 2.0 - : numWorkerThreads <= 32 ? 1.0 - : numWorkerThreads <= 64 ? 0.8 - : 0.6; - - // Do not overload worker threads with too many compute blocks. - Value maxComputeBlocks = b.create( - std::max(1, static_cast(numWorkerThreads * overshardingFactor))); - - // Target block size from the pass parameters. - Value targetComputeBlockSize = b.create(targetBlockSize); - - // Compute parallel block size from the parallel problem size: - // blockSize = min(tripCount, - // max(ceil_div(tripCount, maxComputeBlocks), - // targetComputeBlockSize)) - Value bs0 = b.create(tripCount, maxComputeBlocks); - Value bs1 = b.create(CmpIPredicate::sge, bs0, targetComputeBlockSize); - Value bs2 = b.create(bs1, bs0, targetComputeBlockSize); - Value bs3 = b.create(CmpIPredicate::sle, tripCount, bs2); - Value blockSize0 = b.create(bs3, tripCount, bs2); - Value blockCount0 = b.create(tripCount, blockSize0); - - // Compute balanced block size for the estimated block count. - Value blockSize = b.create(tripCount, blockCount0); - Value blockCount = b.create(tripCount, blockSize); - - // Create a parallel compute function that takes a block id and computes the - // parallel operation body for a subset of iteration space. - ParallelComputeFunction parallelComputeFunction = - createParallelComputeFunction(op, rewriter); - - // Dispatch parallel compute function using async recursive work splitting, or - // by submitting compute task sequentially from a caller thread. - if (asyncDispatch) { - doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize, - blockCount, tripCounts); - } else { - doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize, - blockCount, tripCounts); - } + // Short circuit no-op parallel loops (zero iterations) that can arise from + // the memrefs with dynamic dimension(s) equal to zero. + Value c0 = b.create(0); + Value isZeroIterations = b.create(CmpIPredicate::eq, tripCount, c0); + + // Do absolutely nothing if the trip count is zero. + auto noOp = [&](OpBuilder &nestedBuilder, Location loc) { + nestedBuilder.create(loc); + }; + + // Compute the parallel block size and dispatch concurrent tasks computing + // results for each block. + auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) { + ImplicitLocOpBuilder nb(loc, nestedBuilder); + + // With large number of threads the value of creating many compute blocks + // is reduced because the problem typically becomes memory bound. For small + // number of threads it helps with stragglers. + float overshardingFactor = numWorkerThreads <= 4 ? 8.0 + : numWorkerThreads <= 8 ? 4.0 + : numWorkerThreads <= 16 ? 2.0 + : numWorkerThreads <= 32 ? 1.0 + : numWorkerThreads <= 64 ? 0.8 + : 0.6; + + // Do not overload worker threads with too many compute blocks. + Value maxComputeBlocks = b.create( + std::max(1, static_cast(numWorkerThreads * overshardingFactor))); + + // Target block size from the pass parameters. + Value targetComputeBlock = b.create(targetBlockSize); + + // Compute parallel block size from the parallel problem size: + // blockSize = min(tripCount, + // max(ceil_div(tripCount, maxComputeBlocks), + // targetComputeBlock)) + Value bs0 = b.create(tripCount, maxComputeBlocks); + Value bs1 = b.create(CmpIPredicate::sge, bs0, targetComputeBlock); + Value bs2 = b.create(bs1, bs0, targetComputeBlock); + Value bs3 = b.create(CmpIPredicate::sle, tripCount, bs2); + Value blockSize0 = b.create(bs3, tripCount, bs2); + Value blockCount0 = b.create(tripCount, blockSize0); + + // Compute balanced block size for the estimated block count. + Value blockSize = b.create(tripCount, blockCount0); + Value blockCount = b.create(tripCount, blockSize); + + // Create a parallel compute function that takes a block id and computes the + // parallel operation body for a subset of iteration space. + ParallelComputeFunction parallelComputeFunction = + createParallelComputeFunction(op, rewriter); + + // Dispatch parallel compute function using async recursive work splitting, + // or by submitting compute task sequentially from a caller thread. + if (asyncDispatch) { + doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize, + blockCount, tripCounts); + } else { + doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize, + blockCount, tripCounts); + } + + nb.create(); + }; + + // Replace the `scf.parallel` operation with the parallel compute function. + b.create(TypeRange(), isZeroIterations, noOp, dispatch); // Parallel operation was replaced with a block iteration loop. rewriter.eraseOp(op);