diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -779,10 +779,10 @@ // and we can elide dynamic loop boundaries, and give LLVM an opportunity to // unroll the loops. The constant `512` is arbitrary, it should depend on // how many iterations LLVM will typically decide to unroll. - static constexpr int64_t maxIterations = 512; + static constexpr int64_t maxUnrollableIterations = 512; // The number of inner loops with statically known number of iterations less - // than the `maxIterations` value. + // than the `maxUnrollableIterations` value. int numUnrollableLoops = 0; auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; }; @@ -796,7 +796,7 @@ numIterations[i] = tripCount * innerIterations; // Update the number of inner loops that we can potentially unroll. - if (innerIterations > 0 && innerIterations <= maxIterations) + if (innerIterations > 0 && innerIterations <= maxUnrollableIterations) numUnrollableLoops++; } @@ -856,9 +856,6 @@ Value bs1 = b.create(bs0, minTaskSize); Value blockSize = b.create(tripCount, bs1); - ParallelComputeFunction notUnrollableParallelComputeFunction = - createParallelComputeFunction(op, staticBounds, 0, rewriter); - // Dispatch parallel compute function using async recursive work splitting, // or by submitting compute task sequentially from a caller thread. auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch; @@ -869,42 +866,47 @@ // Compute the number of parallel compute blocks. Value blockCount = b.create(tripCount, blockSize); - // Unroll when numUnrollableLoops > 0 && blockSize >= maxIterations. - bool staticShouldUnroll = numUnrollableLoops > 0; - auto dispatchNotUnrollable = [&](OpBuilder &nestedBuilder, Location loc) { + // Dispatch parallel compute function without hints to unroll inner loops. + auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) { + ParallelComputeFunction compute = + createParallelComputeFunction(op, staticBounds, 0, rewriter); + + ImplicitLocOpBuilder b(loc, nestedBuilder); + doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts); + b.create(); + }; + + // Dispatch parallel compute function with hints for unrolling inner loops. + auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) { + ParallelComputeFunction compute = createParallelComputeFunction( + op, staticBounds, numUnrollableLoops, rewriter); + ImplicitLocOpBuilder b(loc, nestedBuilder); - doDispatch(b, rewriter, notUnrollableParallelComputeFunction, op, - blockSize, blockCount, tripCounts); + // Align the block size to be a multiple of the statically known + // number of iterations in the inner loops. + Value numIters = b.create( + numIterations[op.getNumLoops() - numUnrollableLoops]); + Value alignedBlockSize = b.create( + b.create(blockSize, numIters), numIters); + doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount, + tripCounts); b.create(); }; - if (staticShouldUnroll) { - Value dynamicShouldUnroll = b.create( - arith::CmpIPredicate::sge, blockSize, - b.create(maxIterations)); - - ParallelComputeFunction unrollableParallelComputeFunction = - createParallelComputeFunction(op, staticBounds, numUnrollableLoops, - rewriter); - - auto dispatchUnrollable = [&](OpBuilder &nestedBuilder, Location loc) { - ImplicitLocOpBuilder b(loc, nestedBuilder); - // Align the block size to be a multiple of the statically known - // number of iterations in the inner loops. - Value numIters = b.create( - numIterations[op.getNumLoops() - numUnrollableLoops]); - Value alignedBlockSize = b.create( - b.create(blockSize, numIters), numIters); - doDispatch(b, rewriter, unrollableParallelComputeFunction, op, - alignedBlockSize, blockCount, tripCounts); - b.create(); - }; - - b.create(TypeRange(), dynamicShouldUnroll, dispatchUnrollable, - dispatchNotUnrollable); + // Dispatch to block aligned compute function only if the computed block + // size is larger than the number of iterations in the unrollable inner + // loops, because otherwise it can reduce the available parallelism. + if (numUnrollableLoops > 0) { + Value numIters = b.create( + numIterations[op.getNumLoops() - numUnrollableLoops]); + Value useBlockAlignedComputeFn = b.create( + arith::CmpIPredicate::sge, blockSize, numIters); + + b.create(TypeRange(), useBlockAlignedComputeFn, + dispatchBlockAligned, dispatchDefault); b.create(); } else { - dispatchNotUnrollable(b, loc); + dispatchDefault(b, loc); } }; diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir @@ -87,7 +87,7 @@ return } -// CHECK-LABEL: func private @parallel_compute_fn( +// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops( // CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index, // CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index, // CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index, @@ -100,12 +100,14 @@ // CHECK-SAME: %[[STEP1:arg[0-9]+]]: index, // CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref // CHECK-SAME: ) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK: scf.for %[[I:arg[0-9]+]] -// CHECK: arith.select -// CHECK: scf.for %[[J:arg[0-9]+]] -// CHECK: memref.store +// CHECK-NOT: arith.select +// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1 -// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops( +// CHECK-LABEL: func private @parallel_compute_fn( // CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index, // CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index, // CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index, @@ -118,9 +120,7 @@ // CHECK-SAME: %[[STEP1:arg[0-9]+]]: index, // CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref // CHECK-SAME: ) { -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK: scf.for %[[I:arg[0-9]+]] -// CHECK-NOT: arith.select -// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1 +// CHECK: arith.select +// CHECK: scf.for %[[J:arg[0-9]+]] +// CHECK: memref.store