diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -508,12 +508,6 @@ Value c0 = b.create(0); Value c1 = b.create(1); - // Create an async.group to wait on all async tokens from the concurrent - // execution of multiple parallel compute function. First block will be - // executed synchronously in the caller thread. - Value groupSize = b.create(blockCount, c1); - Value group = b.create(GroupType::get(ctx), groupSize); - // Appends operands shared by async dispatch and parallel compute functions to // the given operands vector. auto appendBlockComputeOperands = [&](SmallVector &operands) { @@ -543,6 +537,12 @@ }; auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) { + // Create an async.group to wait on all async tokens from the concurrent + // execution of multiple parallel compute function. First block will be + // executed synchronously in the caller thread. + Value groupSize = b.create(blockCount, c1); + Value group = b.create(GroupType::get(ctx), groupSize); + ImplicitLocOpBuilder nb(loc, nestedBuilder); // Launch async dispatch function for [0, blockCount) range. @@ -551,14 +551,15 @@ nb.create(asyncDispatchFunction.sym_name(), asyncDispatchFunction.getCallableResults(), operands); + + // Wait for the completion of all parallel compute operations. + b.create(group); + nb.create(); }; // Dispatch either single block compute function, or launch async dispatch. b.create(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch); - - // Wait for the completion of all parallel compute operations. - b.create(group); } // Dispatch parallel compute functions by submitting all async compute tasks diff --git a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir @@ -12,13 +12,13 @@ // CHECK: scf.if %[[IS_NOOP]] { // CHECK-NEXT: } else { - // CHECK: %[[GROUP:.*]] = async.create_group - // CHECK: scf.if {{.*}} { + // CHECK: scf.if {{.*}} { // CHECK: call @parallel_compute_fn(%[[C0]] // CHECK: } else { + // CHECK: %[[GROUP:.*]] = async.create_group // CHECK: call @async_dispatch_fn + // CHECK: async.await_all %[[GROUP]] // CHECK: } - // CHECK: async.await_all %[[GROUP]] // CHECK: } scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) { %one = constant 1.0 : f32