diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -160,20 +160,24 @@ let summary = "creates an empty async group"; let description = [{ The `async.create_group` allocates an empty async group. Async tokens or - values can be added to this group later. + values can be added to this group later. The size of the group must be + specified at construction time, and `await_all` operation will first + wait until the number of added tokens or values reaches the group size. Example: ```mlir - %0 = async.create_group + %size = ... : index + %group = async.create_group %size : !async.group ... - async.await_all %0 + async.await_all %group ``` }]; + let arguments = (ins Index:$size); let results = (outs Async_GroupType:$result); - let assemblyFormat = "attr-dict"; + let assemblyFormat = "$size `:` type($result) attr-dict"; } def Async_AddToGroupOp : Async_Op<"add_to_group", []> { @@ -186,7 +190,7 @@ Example: ```mlir - %0 = async.create_group + %0 = async.create_group %size : !async.group %1 = ... : !async.token %2 = async.add_to_group %1, %0 : !async.token ``` @@ -209,7 +213,7 @@ Example: ```mlir - %0 = async.create_group + %0 = async.create_group %size : !async.group %1 = ... : !async.token %2 = async.add_to_group %1, %0 : !async.token @@ -331,17 +335,28 @@ // Runtime API defined in the `ExecutionEngine/AsyncRuntime.h`. def Async_RuntimeCreateOp : Async_Op<"runtime.create"> { - let summary = "creates an async runtime value (token, value or group)"; + let summary = "creates an async runtime token or value"; let description = [{ - The `async.runtime.create` operation creates an async dialect value - (token, value or group). Tokens and values are created in non-ready state. - Groups are created in empty state. + The `async.runtime.create` operation creates an async dialect token or + value. Tokens and values are created in the non-ready state. }]; - let results = (outs Async_AnyAsyncType:$result); + let results = (outs Async_AnyValueOrTokenType:$result); let assemblyFormat = "attr-dict `:` type($result)"; } +def Async_RuntimeCreateGroupOp : Async_Op<"runtime.create_group"> { + let summary = "creates an async runtime group"; + let description = [{ + The `async.runtime.create_group` operation creates an async dialect group + of the given size. Group created in the empty state. + }]; + + let arguments = (ins Index:$size); + let results = (outs Async_GroupType:$result); + let assemblyFormat = "$size `:` type($result) attr-dict "; +} + def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> { let summary = "switches token or value to available state"; let description = [{ diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -66,7 +66,7 @@ extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t); // Create a new `async.group` in empty state. -extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(); +extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size); extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -89,7 +89,8 @@ } static FunctionType createGroupFunctionType(MLIRContext *ctx) { - return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); + auto i64 = IntegerType::get(ctx, 64); + return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); } static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { @@ -543,11 +544,10 @@ TypeConverter *converter = getTypeConverter(); Type resultType = op->getResultTypes()[0]; - // Tokens and Groups lowered to function calls without arguments. - if (resultType.isa() || resultType.isa()) { - rewriter.replaceOpWithNewOp( - op, resultType.isa() ? kCreateToken : kCreateGroup, - converter->convertType(resultType)); + // Tokens creation maps to a simple function call. + if (resultType.isa()) { + rewriter.replaceOpWithNewOp(op, kCreateToken, + converter->convertType(resultType)); return success(); } @@ -582,6 +582,29 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Convert async.runtime.create_group to the corresponding runtime API call. +//===----------------------------------------------------------------------===// + +namespace { +class RuntimeCreateGroupOpLowering + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TypeConverter *converter = getTypeConverter(); + Type resultType = op.getResult().getType(); + + rewriter.replaceOpWithNewOp( + op, kCreateGroup, converter->convertType(resultType), operands); + return success(); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // Convert async.runtime.set_available to the corresponding runtime API call. //===----------------------------------------------------------------------===// @@ -967,8 +990,9 @@ // Lower async.runtime operations that rely on LLVM type converter to convert // from async value payload type to the LLVM type. - patterns.add(llvmConverter, ctx); + patterns.add(llvmConverter, + ctx); // Lower async coroutine operations to LLVM coroutine intrinsics. patterns diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -165,8 +165,14 @@ numBlocks[i] = divup(tripCounts[i], blockSize[i]); } + // Total number of async compute blocks. + Value totalBlocks = numBlocks[0]; + for (size_t i = 1; i < op.getNumLoops(); ++i) + totalBlocks = rewriter.create(loc, totalBlocks, numBlocks[i]); + // Create an async.group to wait on all async tokens from async execute ops. - auto group = rewriter.create(loc, GroupType::get(ctx)); + auto group = + rewriter.create(loc, GroupType::get(ctx), totalBlocks); // Build a scf.for loop nest from the parallel operation. diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -302,7 +302,7 @@ } //===----------------------------------------------------------------------===// -// Convert async.create_group operation to async.runtime.create +// Convert async.create_group operation to async.runtime.create_group //===----------------------------------------------------------------------===// namespace { @@ -313,8 +313,8 @@ LogicalResult matchAndRewrite(CreateGroupOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, GroupType::get(op->getContext())); + rewriter.replaceOpWithNewOp( + op, GroupType::get(op->getContext()), operands); return success(); } }; diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -211,8 +211,8 @@ // values to await on all of them together (wait for the completion of all // tokens or values added to the group). struct AsyncGroup : public RefCounted { - AsyncGroup(AsyncRuntime *runtime) - : RefCounted(runtime), pendingTokens(0), numErrors(0), rank(0) {} + AsyncGroup(AsyncRuntime *runtime, int64_t size) + : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {} std::atomic pendingTokens; std::atomic numErrors; @@ -249,8 +249,8 @@ } // Create a new `async.group` in empty state. -extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { - AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime()); +extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) { + AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size); return group; } @@ -261,13 +261,16 @@ // Get the rank of the token inside the group before we drop the reference. int rank = group->rank.fetch_add(1); - group->pendingTokens.fetch_add(1); auto onTokenReady = [group, token]() { // Increment the number of errors in the group. if (State(token->state).isError()) group->numErrors.fetch_add(1); + // If pending tokens go below zero it means that more tokens than the group + // size were added to this group. + assert(group->pendingTokens > 0 && "wrong group size"); + // Run all group awaiters if it was the last token in the group. if (group->pendingTokens.fetch_sub(1) == 1) { group->cv.notify_all(); diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s +// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s --dump-input=always // CHECK-LABEL: @create_token func @create_token() { @@ -20,8 +20,11 @@ // CHECK-LABEL: @create_group func @create_group() { - // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup - %0 = async.runtime.create : !async.group + // CHECK: %[[C:.*]] = constant 1 : index + // CHECK: %[[S:.*]] = llvm.mlir.cast %[[C]] : index to i64 + %c = constant 1 : index + // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup(%[[S]]) + %0 = async.runtime.create_group %c: !async.group return } @@ -81,8 +84,9 @@ // CHECK-LABEL: @await_group func @await_group() { + %c = constant 1 : index // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup - %0 = async.runtime.create : !async.group + %0 = async.runtime.create_group %c: !async.group // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]]) async.runtime.await %0 : !async.group return @@ -118,11 +122,12 @@ // CHECK-LABEL: @await_and_resume_group func @await_and_resume_group() { + %c = constant 1 : index %0 = async.coro.id // CHECK: %[[HDL:.*]] = llvm.intr.coro.begin %1 = async.coro.begin %0 // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateGroup - %2 = async.runtime.create : !async.group + %2 = async.runtime.create_group %c : !async.group // CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume // CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute // CHECK-SAME: (%[[TOKEN]], %[[HDL]], %[[RESUME]]) @@ -168,10 +173,11 @@ // CHECK-LABEL: @add_token_to_group func @add_token_to_group() { + %c = constant 1 : index // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken %0 = async.runtime.create : !async.token // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup - %1 = async.runtime.create : !async.group + %1 = async.runtime.create_group %c : !async.group // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]]) async.runtime.add_to_group %0, %1 : !async.token return diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -170,12 +170,13 @@ // CHECK-LABEL: async_group_await_all func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) { - // CHECK: %0 = call @mlirAsyncRuntimeCreateGroup() - %0 = async.create_group + %c = constant 1 : index + // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup + %0 = async.create_group %c : !async.group // CHECK: %[[TOKEN:.*]] = call @async_execute_fn %token = async.execute { async.yield } - // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %0) + // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]]) async.add_to_group %token, %0 : !async.token // CHECK: call @async_execute_fn_0 @@ -184,7 +185,7 @@ async.yield } - // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%0) + // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]]) async.await_all %0 return diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -179,8 +179,10 @@ // CHECK-LABEL: @async_group_await_all func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) { - // CHECK: %[[GROUP:.*]] = async.runtime.create : !async.group - %0 = async.create_group + // CHECK: %[[C:.*]] = constant 1 : index + %c = constant 1 : index + // CHECK: %[[GROUP:.*]] = async.runtime.create_group %[[C]] : !async.group + %0 = async.create_group %c : !async.group // CHECK: %[[TOKEN:.*]] = call @async_execute_fn %token = async.execute { async.yield } diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -122,8 +122,10 @@ } // CHECK-LABEL: @create_group_and_await_all -func @create_group_and_await_all(%arg0: !async.token, %arg1: !async.value) -> index { - %0 = async.create_group +func @create_group_and_await_all(%arg0: !async.token, + %arg1: !async.value) -> index { + %c = constant 2 : index + %0 = async.create_group %c : !async.group // CHECK: async.add_to_group %arg0 // CHECK: async.add_to_group %arg1 diff --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir --- a/mlir/test/Dialect/Async/runtime.mlir +++ b/mlir/test/Dialect/Async/runtime.mlir @@ -18,9 +18,11 @@ // CHECK-LABEL: @create_group func @create_group() -> !async.group { - // CHECK: %0 = async.runtime.create : !async.group - %0 = async.runtime.create : !async.group - // CHECK: return %0 : !async.group + // CHECK: %[[C:.*]] = constant 10 : index + %c = constant 10 : index + // CHECK: %[[V:.*]] = async.runtime.create_group %[[C]] : !async.group + %0 = async.runtime.create_group %c : !async.group + // CHECK: return %[[V]] : !async.group return %0 : !async.group } diff --git a/mlir/test/mlir-cpu-runner/async-error.mlir b/mlir/test/mlir-cpu-runner/async-error.mlir --- a/mlir/test/mlir-cpu-runner/async-error.mlir +++ b/mlir/test/mlir-cpu-runner/async-error.mlir @@ -85,7 +85,8 @@ // Check error propagation from a token to the group. // ------------------------------------------------------------------------ // - %group0 = async.create_group + %c2 = constant 2 : index + %group0 = async.create_group %c2 : !async.group %token4 = async.execute { async.yield diff --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir --- a/mlir/test/mlir-cpu-runner/async-group.mlir +++ b/mlir/test/mlir-cpu-runner/async-group.mlir @@ -11,7 +11,10 @@ // RUN: | FileCheck %s func @main() { - %group = async.create_group + %c1 = constant 1 : index + %c5 = constant 5 : index + + %group = async.create_group %c5 : !async.group %token0 = async.execute { async.yield } %token1 = async.execute { async.yield } @@ -30,7 +33,7 @@ async.yield } - %group0 = async.create_group + %group0 = async.create_group %c1 : !async.group %5 = async.add_to_group %token5, %group0 : !async.token async.await_all %group0