diff --git a/mlir/include/mlir/Dialect/Async/CMakeLists.txt b/mlir/include/mlir/Dialect/Async/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Async/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Async/CMakeLists.txt @@ -1 +1,7 @@ add_subdirectory(IR) + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Async) +add_public_tablegen_target(MLIRAsyncPassIncGen) + +add_mlir_doc(Passes -gen-pass-doc AsyncPasses ./) diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -47,6 +47,12 @@ Type getValueType(); }; +/// The group type to represent async tokens or values grouped together. +class GroupType : public Type::TypeBase { +public: + using Base::Base; +}; + } // namespace async } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td @@ -56,6 +56,16 @@ Type valueType = type; } +def Async_GroupType : DialectType()">, "group type">, + BuildableType<"$_builder.getType<::mlir::async::GroupType>()"> { + let typeDescription = [{ + `async.group` represent a set of async tokens or values and allows to + execute async operations on all of them together (e.g. wait for the + completion of all/any of them). + }]; +} + def Async_AnyValueType : DialectType()">, "async value type">; diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -81,6 +81,20 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; let verifier = [{ return ::verify(*this); }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilderDAG<(ins "TypeRange":$resultTypes, "ValueRange":$dependencies, + "ValueRange":$operands, + CArg<"function_ref", + "nullptr">:$bodyBuilder)>, + ]; + + let extraClassDeclaration = [{ + using BodyBuilderFn = + function_ref; + + }]; } def Async_YieldOp : @@ -93,12 +107,12 @@ let arguments = (ins Variadic:$operands); - let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let assemblyFormat = "($operands^ `:` type($operands))? attr-dict"; let verifier = [{ return ::verify(*this); }]; } -def Async_AwaitOp : Async_Op<"await", [NoSideEffect]> { +def Async_AwaitOp : Async_Op<"await"> { let summary = "waits for the argument to become ready"; let description = [{ The `async.await` operation waits until the argument becomes ready, and for @@ -133,12 +147,84 @@ }]; let assemblyFormat = [{ - attr-dict $operand `:` custom( + $operand `:` custom( type($operand), type($result) - ) + ) attr-dict }]; let verifier = [{ return ::verify(*this); }]; } +def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> { + let summary = "creates an empty async group"; + let description = [{ + The `async.create_group` allocates an empty async group. Async tokens or + values can be added to this group later. + + Example: + + ```mlir + %0 = async.create_group + ... + async.await_all %0 + ``` + }]; + + let arguments = (ins ); + let results = (outs Async_GroupType:$result); + + let assemblyFormat = "attr-dict"; +} + +def Async_AddToGroupOp : Async_Op<"add_to_group", []> { + let summary = "adds and async token or value to the group"; + let description = [{ + The `async.add_to_group` adds an async token or value to the async group. + Returns the rank of the added element in the group. This rank is fixed + for the group lifetime. + + Example: + + ```mlir + %0 = async.create_group + %1 = ... : !async.token + %2 = async.add_to_group %1, %0 : !async.token + ``` + }]; + + let arguments = (ins Async_AnyValueOrTokenType:$operand, + Async_GroupType:$group); + let results = (outs Index:$rank); + + let assemblyFormat = "$operand `,` $group `:` type($operand) attr-dict"; +} + +def Async_AwaitAllOp : Async_Op<"await_all", []> { + let summary = "waits for the all async tokens or values in the group to " + "become ready"; + let description = [{ + The `async.await_all` operation waits until all the tokens or values in the + group become ready. + + Example: + + ```mlir + %0 = async.create_group + + %1 = ... : !async.token + %2 = async.add_to_group %1, %0 : !async.token + + %3 = ... : !async.token + %4 = async.add_to_group %2, %0 : !async.token + + async.await_all %0 + ``` + }]; + + let arguments = (ins Async_GroupType:$operand); + let results = (outs); + + let assemblyFormat = "$operand attr-dict"; +} + #endif // ASYNC_OPS diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -0,0 +1,32 @@ +//===- Passes.h - Async pass entry points -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ASYNC_PASSES_H_ +#define MLIR_DIALECT_ASYNC_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +std::unique_ptr> createAsyncParallelForPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Async/Passes.h.inc" + +} // namespace mlir + +#endif // MLIR_DIALECT_ASYNC_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -0,0 +1,27 @@ +//===-- Passes.td - Async pass definition file -------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ASYNC_PASSES +#define MLIR_DIALECT_ASYNC_PASSES + +include "mlir/Pass/PassBase.td" + +def AsyncParallelFor : FunctionPass<"async-parallel-for"> { + let summary = "Convert scf.parallel operations to multiple async regions " + "executed concurrently for non-overlapping iteration ranges"; + let constructor = "mlir::createAsyncParallelForPass()"; + let options = [ + Option<"numConcurrentAsyncExecute", "num-concurrent-async-execute", + "int32_t", /*default=*/"4", + "The number of async.execute operations that will be used for concurrent " + "loop execution."> + ]; + let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"]; +} + +#endif // MLIR_DIALECT_ASYNC_PASSES diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -14,6 +14,8 @@ #ifndef MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_ #define MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_ +#include + #ifdef _WIN32 #ifndef MLIR_ASYNCRUNTIME_EXPORT #ifdef mlir_async_runtime_EXPORTS @@ -37,6 +39,9 @@ // Runtime implementation of `async.token` data type. typedef struct AsyncToken MLIR_AsyncToken; +// Runtime implementation of `async.group` data type. +typedef struct AsyncGroup MLIR_AsyncGroup; + // Async runtime uses LLVM coroutines to represent asynchronous tasks. Task // function is a coroutine handle and a resume function that continue coroutine // execution from a suspension point. @@ -46,6 +51,12 @@ // Create a new `async.token` in not-ready state. extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken(); +// Create a new `async.group` in empty state. +extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup(); + +extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t +mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *); + // Switches `async.token` to ready state and runs all awaiters. extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeEmplaceToken(AsyncToken *); @@ -54,6 +65,10 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeAwaitToken(AsyncToken *); +// Blocks the caller thread until the elements in the group become ready. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *); + // Executes the task (coro handle + resume function) in one of the threads // managed by the runtime. extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle, @@ -64,6 +79,11 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume); +// Executes the task (coro handle + resume function) in one of the threads +// managed by the runtime after the all members of the group become ready. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *, CoroHandle, CoroResume); + //===----------------------------------------------------------------------===// // Small async runtime support library for testing. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -16,6 +16,7 @@ #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Async/Passes.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -47,6 +48,7 @@ // Dialect passes registerAffinePasses(); + registerAsyncPasses(); registerGPUPasses(); registerLinalgPasses(); LLVM::registerLLVMPasses(); diff --git a/mlir/integration_test/Dialect/Async/CPU/lit.local.cfg b/mlir/integration_test/Dialect/Async/CPU/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Async/CPU/lit.local.cfg @@ -0,0 +1,5 @@ +import sys + +# No JIT on win32. +if sys.platform == 'win32': + config.unsupported = True diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt %s -async-parallel-for \ +// RUN: -convert-async-to-llvm \ +// RUN: -convert-scf-to-std \ +// RUN: -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void -O0 \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ +// RUN: | FileCheck %s --dump-input=always + +func @entry() { + %c0 = constant 0.0 : f32 + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + + %lb = constant 0 : index + %ub = constant 9 : index + + %A = alloc() : memref<9xf32> + %U = memref_cast %A : memref<9xf32> to memref<*xf32> + + // 1. %i = (0) to (9) step (1) + scf.parallel (%i) = (%lb) to (%ub) step (%c1) { + %0 = index_cast %i : index to i32 + %1 = sitofp %0 : i32 to f32 + store %1, %A[%i] : memref<9xf32> + } + // CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8] + call @print_memref_f32(%U): (memref<*xf32>) -> () + + scf.parallel (%i) = (%lb) to (%ub) step (%c1) { + store %c0, %A[%i] : memref<9xf32> + } + + // 2. %i = (0) to (9) step (2) + scf.parallel (%i) = (%lb) to (%ub) step (%c2) { + %0 = index_cast %i : index to i32 + %1 = sitofp %0 : i32 to f32 + store %1, %A[%i] : memref<9xf32> + } + // CHECK: [0, 0, 2, 0, 4, 0, 6, 0, 8] + call @print_memref_f32(%U): (memref<*xf32>) -> () + + scf.parallel (%i) = (%lb) to (%ub) step (%c1) { + store %c0, %A[%i] : memref<9xf32> + } + + // 3. %i = (-20) to (-11) step (3) + %lb0 = constant -20 : index + %ub0 = constant -11 : index + scf.parallel (%i) = (%lb0) to (%ub0) step (%c3) { + %0 = index_cast %i : index to i32 + %1 = sitofp %0 : i32 to f32 + %2 = constant 20 : index + %3 = addi %i, %2 : index + store %1, %A[%3] : memref<9xf32> + } + // CHECK: [-20, 0, 0, -17, 0, 0, -14, 0, 0] + call @print_memref_f32(%U): (memref<*xf32>) -> () + + dealloc %A : memref<9xf32> + return +} + +func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-opt %s -async-parallel-for \ +// RUN: -convert-async-to-llvm \ +// RUN: -convert-scf-to-std \ +// RUN: -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void -O0 \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ +// RUN: | FileCheck %s --dump-input=always + +func @entry() { + %c0 = constant 0.0 : f32 + %c1 = constant 1 : index + %c2 = constant 2 : index + %c8 = constant 8 : index + + %lb = constant 0 : index + %ub = constant 8 : index + + %A = alloc() : memref<8x8xf32> + %U = memref_cast %A : memref<8x8xf32> to memref<*xf32> + + // 1. (%i, %i) = (0, 8) to (8, 8) step (1, 1) + scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c1, %c1) { + %0 = muli %i, %c8 : index + %1 = addi %j, %0 : index + %2 = index_cast %1 : index to i32 + %3 = sitofp %2 : i32 to f32 + store %3, %A[%i, %j] : memref<8x8xf32> + } + + // CHECK: [0, 1, 2, 3, 4, 5, 6, 7] + // CHECK-NEXT: [8, 9, 10, 11, 12, 13, 14, 15] + // CHECK-NEXT: [16, 17, 18, 19, 20, 21, 22, 23] + // CHECK-NEXT: [24, 25, 26, 27, 28, 29, 30, 31] + // CHECK-NEXT: [32, 33, 34, 35, 36, 37, 38, 39] + // CHECK-NEXT: [40, 41, 42, 43, 44, 45, 46, 47] + // CHECK-NEXT: [48, 49, 50, 51, 52, 53, 54, 55] + // CHECK-NEXT: [56, 57, 58, 59, 60, 61, 62, 63] + call @print_memref_f32(%U): (memref<*xf32>) -> () + + scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c1, %c1) { + store %c0, %A[%i, %j] : memref<8x8xf32> + } + + // 2. (%i, %i) = (0, 8) to (8, 8) step (2, 1) + scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c2, %c1) { + %0 = muli %i, %c8 : index + %1 = addi %j, %0 : index + %2 = index_cast %1 : index to i32 + %3 = sitofp %2 : i32 to f32 + store %3, %A[%i, %j] : memref<8x8xf32> + } + + // CHECK: [0, 1, 2, 3, 4, 5, 6, 7] + // CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0] + // CHECK-NEXT: [16, 17, 18, 19, 20, 21, 22, 23] + // CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0] + // CHECK-NEXT: [32, 33, 34, 35, 36, 37, 38, 39] + // CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0] + // CHECK-NEXT: [48, 49, 50, 51, 52, 53, 54, 55] + // CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0] + call @print_memref_f32(%U): (memref<*xf32>) -> () + + scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c1, %c1) { + store %c0, %A[%i, %j] : memref<8x8xf32> + } + + // 3. (%i, %i) = (0, 8) to (8, 8) step (1, 2) + scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c1, %c2) { + %0 = muli %i, %c8 : index + %1 = addi %j, %0 : index + %2 = index_cast %1 : index to i32 + %3 = sitofp %2 : i32 to f32 + store %3, %A[%i, %j] : memref<8x8xf32> + } + + // CHECK: [0, 0, 2, 0, 4, 0, 6, 0] + // CHECK-NEXT: [8, 0, 10, 0, 12, 0, 14, 0] + // CHECK-NEXT: [16, 0, 18, 0, 20, 0, 22, 0] + // CHECK-NEXT: [24, 0, 26, 0, 28, 0, 30, 0] + // CHECK-NEXT: [32, 0, 34, 0, 36, 0, 38, 0] + // CHECK-NEXT: [40, 0, 42, 0, 44, 0, 46, 0] + // CHECK-NEXT: [48, 0, 50, 0, 52, 0, 54, 0] + // CHECK-NEXT: [56, 0, 58, 0, 60, 0, 62, 0] + call @print_memref_f32(%U): (memref<*xf32>) -> () + + dealloc %A : memref<8x8xf32> + + return +} + +func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -34,11 +34,17 @@ //===----------------------------------------------------------------------===// static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; +static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; +static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; +static constexpr const char *kAddTokenToGroup = + "mlirAsyncRuntimeAddTokenToGroup"; static constexpr const char *kAwaitAndExecute = "mlirAsyncRuntimeAwaitTokenAndExecute"; +static constexpr const char *kAwaitAllAndExecute = + "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; namespace { // Async Runtime API function types. @@ -47,6 +53,10 @@ return FunctionType::get({}, {TokenType::get(ctx)}, ctx); } + static FunctionType createGroupFunctionType(MLIRContext *ctx) { + return FunctionType::get({}, {GroupType::get(ctx)}, ctx); + } + static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { return FunctionType::get({TokenType::get(ctx)}, {}, ctx); } @@ -55,18 +65,34 @@ return FunctionType::get({TokenType::get(ctx)}, {}, ctx); } + static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { + return FunctionType::get({GroupType::get(ctx)}, {}, ctx); + } + static FunctionType executeFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto resume = resumeFunctionType(ctx).getPointerTo(); return FunctionType::get({hdl, resume}, {}, ctx); } + static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { + auto i64 = IntegerType::get(64, ctx); + return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64}, + ctx); + } + static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto resume = resumeFunctionType(ctx).getPointerTo(); return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx); } + static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { + auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); + auto resume = resumeFunctionType(ctx).getPointerTo(); + return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx); + } + // Auxiliary coroutine resume intrinsic wrapper. static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) { auto voidTy = LLVM::LLVMType::getVoidTy(ctx); @@ -87,6 +113,10 @@ builder.create(loc, kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); + if (!module.lookupSymbol(kCreateGroup)) + builder.create(loc, kCreateGroup, + AsyncAPI::createGroupFunctionType(ctx)); + if (!module.lookupSymbol(kEmplaceToken)) builder.create(loc, kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); @@ -95,12 +125,24 @@ builder.create(loc, kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); + if (!module.lookupSymbol(kAwaitGroup)) + builder.create(loc, kAwaitGroup, + AsyncAPI::awaitGroupFunctionType(ctx)); + if (!module.lookupSymbol(kExecute)) builder.create(loc, kExecute, AsyncAPI::executeFunctionType(ctx)); + if (!module.lookupSymbol(kAddTokenToGroup)) + builder.create(loc, kAddTokenToGroup, + AsyncAPI::addTokenToGroupFunctionType(ctx)); + if (!module.lookupSymbol(kAwaitAndExecute)) builder.create(loc, kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx)); + + if (!module.lookupSymbol(kAwaitAllAndExecute)) + builder.create(loc, kAwaitAllAndExecute, + AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); } //===----------------------------------------------------------------------===// @@ -554,8 +596,8 @@ static Type convertType(Type type) { MLIRContext *ctx = type.getContext(); - // Convert async tokens to opaque pointers. - if (type.isa()) + // Convert async tokens and groups to opaque pointers. + if (type.isa()) return LLVM::LLVMType::getInt8PtrTy(ctx); return type; } @@ -590,28 +632,81 @@ } // namespace //===----------------------------------------------------------------------===// -// async.await op lowering to mlirAsyncRuntimeAwaitToken function call. +// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. //===----------------------------------------------------------------------===// namespace { -class AwaitOpLowering : public ConversionPattern { +class CreateGroupOpLowering : public ConversionPattern { public: - explicit AwaitOpLowering( + explicit CreateGroupOpLowering(MLIRContext *ctx) + : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto retTy = GroupType::get(op->getContext()); + rewriter.replaceOpWithNewOp(op, kCreateGroup, retTy); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// async.add_to_group op lowering to runtime function call. +//===----------------------------------------------------------------------===// + +namespace { +class AddToGroupOpLowering : public ConversionPattern { +public: + explicit AddToGroupOpLowering(MLIRContext *ctx) + : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Currently we can only add tokens to the group. + auto addToGroup = cast(op); + if (!addToGroup.operand().getType().isa()) + return failure(); + + auto i64 = IntegerType::get(64, op->getContext()); + rewriter.replaceOpWithNewOp(op, kAddTokenToGroup, i64, operands); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// async.await and async.await_all op lowerings to the corresponding async +// runtime function calls. +//===----------------------------------------------------------------------===// + +namespace { + +template +class AwaitOpLoweringBase : public ConversionPattern { +protected: + explicit AwaitOpLoweringBase( MLIRContext *ctx, - const llvm::DenseMap &outlinedFunctions) - : ConversionPattern(AwaitOp::getOperationName(), 1, ctx), - outlinedFunctions(outlinedFunctions) {} + const llvm::DenseMap &outlinedFunctions, + StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName) + : ConversionPattern(AwaitType::getOperationName(), 1, ctx), + outlinedFunctions(outlinedFunctions), + blockingAwaitFuncName(blockingAwaitFuncName), + coroAwaitFuncName(coroAwaitFuncName) {} +public: LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - // We can only await on the token operand. Async valus are not supported. - auto await = cast(op); - if (!await.operand().getType().isa()) + // We can only await on one the `AwaitableType` (for `await` it can be + // only a `token`, for `await_all` it is a `group`). + auto await = cast(op); + if (!await.operand().getType().template isa()) return failure(); - // Check if `async.await` is inside the outlined coroutine function. - auto func = await.getParentOfType(); + // Check if await operation is inside the outlined coroutine function. + auto func = await.template getParentOfType(); auto outlined = outlinedFunctions.find(func); const bool isInCoroutine = outlined != outlinedFunctions.end(); @@ -620,7 +715,7 @@ // Inside regular function we convert await operation to the blocking // async API await function call. if (!isInCoroutine) - rewriter.create(loc, Type(), kAwaitToken, + rewriter.create(loc, Type(), blockingAwaitFuncName, ValueRange(op->getOperand(0))); // Inside the coroutine we convert await operation into coroutine suspension @@ -645,7 +740,7 @@ // the async await argument becomes ready. SmallVector awaitAndExecuteArgs = { await.getOperand(), coro.coroHandle, resumePtr.res()}; - builder.create(loc, Type(), kAwaitAndExecute, + builder.create(loc, Type(), coroAwaitFuncName, awaitAndExecuteArgs); // Split the entry block before the await operation. @@ -660,7 +755,32 @@ private: const llvm::DenseMap &outlinedFunctions; + StringRef blockingAwaitFuncName; + StringRef coroAwaitFuncName; +}; + +// Lowering for `async.await` operation (only token operands are supported). +class AwaitOpLowering : public AwaitOpLoweringBase { + using Base = AwaitOpLoweringBase; + +public: + explicit AwaitOpLowering( + MLIRContext *ctx, + const llvm::DenseMap &outlinedFunctions) + : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {} }; + +// Lowering for `async.await_all` operation. +class AwaitAllOpLowering : public AwaitOpLoweringBase { + using Base = AwaitOpLoweringBase; + +public: + explicit AwaitAllOpLowering( + MLIRContext *ctx, + const llvm::DenseMap &outlinedFunctions) + : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {} +}; + } // namespace //===----------------------------------------------------------------------===// @@ -717,7 +837,8 @@ populateFuncOpTypeConversionPattern(patterns, ctx, converter); patterns.insert(ctx); - patterns.insert(ctx, outlinedFunctions); + patterns.insert(ctx); + patterns.insert(ctx, outlinedFunctions); ConversionTarget target(*ctx); target.addLegalDialect(); diff --git a/mlir/lib/Dialect/Async/CMakeLists.txt b/mlir/lib/Dialect/Async/CMakeLists.txt --- a/mlir/lib/Dialect/Async/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -21,6 +21,7 @@ >(); addTypes(); addTypes(); + addTypes(); } /// Parse a type registered to this dialect. @@ -54,6 +55,7 @@ os.printType(valueTy.getValueType()); os << '>'; }) + .Case([&](GroupType) { os << "group"; }) .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); }); } @@ -139,6 +141,51 @@ regions.push_back(RegionSuccessor(&body())); } +void ExecuteOp::build(OpBuilder &builder, OperationState &result, + TypeRange resultTypes, ValueRange dependencies, + ValueRange operands, BodyBuilderFn bodyBuilder) { + + result.addOperands(dependencies); + result.addOperands(operands); + + // Add derived `operand_segment_sizes` attribute based on parsed operands. + int32_t numDependencies = dependencies.size(); + int32_t numOperands = operands.size(); + auto operandSegmentSizes = DenseIntElementsAttr::get( + VectorType::get({2}, IntegerType::get(32, result.getContext())), + {numDependencies, numOperands}); + result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); + + // First result is always a token, and then `resultTypes` wrapped into + // `async.value`. + result.addTypes({TokenType::get(result.getContext())}); + for (Type type : resultTypes) + result.addTypes(ValueType::get(type)); + + // Add a body region with block arguments as unwrapped async value operands. + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block); + Block &bodyBlock = bodyRegion->front(); + for (Value operand : operands) { + auto valueType = operand.getType().dyn_cast(); + bodyBlock.addArgument(valueType ? valueType.getValueType() + : operand.getType()); + } + + // Create the default terminator if the builder is not provided and if the + // expected result is empty. Otherwise, leave this to the caller + // because we don't know which values to return from the execute op. + if (resultTypes.empty() && !bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + builder.create(result.location, ValueRange()); + } else if (bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + bodyBuilder(builder, result.location, bodyBlock.getArguments()); + } +} + static void print(OpAsmPrinter &p, ExecuteOp op) { p << op.getOperationName(); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -0,0 +1,278 @@ +//===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements scf.parallel to src.for + async.execute conversion pass. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::async; + +#define DEBUG_TYPE "async-parallel-for" + +namespace { + +// Rewrite scf.parallel operation into multiple concurrent async.execute +// operations over non overlapping subranges of the original loop. +// +// Example: +// +// scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) { +// "do_some_compute"(%i, %j): () -> () +// } +// +// Converted to: +// +// %c0 = constant 0 : index +// %c1 = constant 1 : index +// +// // Compute blocks sizes for each induction variable. +// %num_blocks_i = ... : index +// %num_blocks_j = ... : index +// %block_size_i = ... : index +// %block_size_j = ... : index +// +// // Create an async group to track async execute ops. +// %group = async.create_group +// +// scf.for %bi = %c0 to %num_blocks_i step %c1 { +// %block_start_i = ... : index +// %block_end_i = ... : index +// +// scf.for %bj = %c0 to %num_blocks_j step %c1 { +// %block_start_j = ... : index +// %block_end_j = ... : index +// +// // Execute the body of original parallel operation for the current +// // block. +// %token = async.execute { +// scf.for %i = %block_start_i to %block_end_i step %si { +// scf.for %j = %block_start_j to %block_end_j step %sj { +// "do_some_compute"(%i, %j): () -> () +// } +// } +// } +// +// // Add produced async token to the group. +// async.add_to_group %token, %group +// } +// } +// +// // Await completion of all async.execute operations. +// async.await_all %group +// +// In this example outer loop launches inner block level loops as separate async +// execute operations which will be executed concurrently. +// +// At the end it waits for the completiom of all async execute operations. +// +struct AsyncParallelForRewrite : public OpRewritePattern { +public: + AsyncParallelForRewrite(MLIRContext *ctx, int numConcurrentAsyncExecute) + : OpRewritePattern(ctx), + numConcurrentAsyncExecute(numConcurrentAsyncExecute) {} + + LogicalResult matchAndRewrite(scf::ParallelOp op, + PatternRewriter &rewriter) const override; + +private: + int numConcurrentAsyncExecute; +}; + +struct AsyncParallelForPass + : public AsyncParallelForBase { + AsyncParallelForPass() = default; + void runOnFunction() override; +}; + +} // namespace + +LogicalResult +AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, + PatternRewriter &rewriter) const { + // We do not currently support rewrite for parallel op with reductions. + if (op.getNumReductions() != 0) + return failure(); + + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + + // Index constants used below. + auto indexTy = IndexType::get(ctx); + auto zero = IntegerAttr::get(indexTy, 0); + auto one = IntegerAttr::get(indexTy, 1); + auto c0 = rewriter.create(loc, indexTy, zero); + auto c1 = rewriter.create(loc, indexTy, one); + + // Shorthand for signed integer ceil division operation. + auto divup = [&](Value x, Value y) -> Value { + return rewriter.create(loc, x, y); + }; + + // Compute trip count for each loop induction variable: + // tripCount = divUp(upperBound - lowerBound, step); + SmallVector tripCounts(op.getNumLoops()); + for (size_t i = 0; i < op.getNumLoops(); ++i) { + auto lb = op.lowerBound()[i]; + auto ub = op.upperBound()[i]; + auto step = op.step()[i]; + auto range = rewriter.create(loc, ub, lb); + tripCounts[i] = divup(range, step); + } + + // The target number of concurrent async.execute ops. + auto numExecuteOps = rewriter.create( + loc, indexTy, IntegerAttr::get(indexTy, numConcurrentAsyncExecute)); + + // Blocks sizes configuration for each induction variable. + + // We try to use maximum available concurrency in outer dimensions first + // (assuming that parallel induction variables are corresponding to some + // multidimensional access, e.g. in (%d0, %d1, ..., %dn) = () to () + // we will try to parallelize iteration along the %d0. If %d0 is too small, + // we'll parallelize iteration over %d1, and so on. + SmallVector targetNumBlocks(op.getNumLoops()); + SmallVector blockSize(op.getNumLoops()); + SmallVector numBlocks(op.getNumLoops()); + + // Compute block size and number of blocks along the first induction variable. + targetNumBlocks[0] = numExecuteOps; + blockSize[0] = divup(tripCounts[0], targetNumBlocks[0]); + numBlocks[0] = divup(tripCounts[0], blockSize[0]); + + // Assign remaining available concurrency to other induction variables. + for (size_t i = 1; i < op.getNumLoops(); ++i) { + targetNumBlocks[i] = divup(targetNumBlocks[i - 1], numBlocks[i - 1]); + blockSize[i] = divup(tripCounts[i], targetNumBlocks[i]); + numBlocks[i] = divup(tripCounts[i], blockSize[i]); + } + + // Create an async.group to wait on all async tokens from async execute ops. + auto group = rewriter.create(loc, GroupType::get(ctx)); + + // Build a scf.for loop nest from the parallel operation. + + // Lower/upper bounds for nest block level computations. + SmallVector blockLowerBounds(op.getNumLoops()); + SmallVector blockUpperBounds(op.getNumLoops()); + SmallVector blockInductionVars(op.getNumLoops()); + + using LoopBodyBuilder = + std::function; + using LoopBuilder = std::function; + + // Builds inner loop nest inside async.execute operation that does all the + // work concurrently. + LoopBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder { + return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) { + blockInductionVars[loopIdx] = iv; + + // Continute building async loop nest. + if (loopIdx < op.getNumLoops() - 1) { + b.create( + loc, blockLowerBounds[loopIdx + 1], blockUpperBounds[loopIdx + 1], + op.step()[loopIdx + 1], ValueRange(), workLoopBuilder(loopIdx + 1)); + b.create(loc); + return; + } + + // Copy the body of the parallel op with new loop bounds. + BlockAndValueMapping mapping; + mapping.map(op.getInductionVars(), blockInductionVars); + + for (auto &bodyOp : op.getLoopBody().getOps()) + b.clone(bodyOp, mapping); + }; + }; + + // Builds a loop nest that does async execute op dispatching. + LoopBuilder asyncLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder { + return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) { + auto lb = op.lowerBound()[loopIdx]; + auto ub = op.upperBound()[loopIdx]; + auto step = op.step()[loopIdx]; + + // Compute lower bound for the current block: + // blockLowerBound = iv * blockSize * step + lowerBound + auto s0 = b.create(loc, iv, blockSize[loopIdx]); + auto s1 = b.create(loc, s0, step); + auto s2 = b.create(loc, s1, lb); + blockLowerBounds[loopIdx] = s2; + + // Compute upper bound for the current block: + // blockUpperBound = min(upperBound, + // blockLowerBound + blockSize * step) + auto e0 = b.create(loc, blockSize[loopIdx], step); + auto e1 = b.create(loc, e0, s2); + auto e2 = b.create(loc, CmpIPredicate::slt, e1, ub); + auto e3 = b.create(loc, e2, e1, ub); + blockUpperBounds[loopIdx] = e3; + + // Continue building async dispatch loop nest. + if (loopIdx < op.getNumLoops() - 1) { + b.create(loc, c0, numBlocks[loopIdx + 1], c1, ValueRange(), + asyncLoopBuilder(loopIdx + 1)); + b.create(loc); + return; + } + + // Build the inner loop nest that will do the actual work inside the + // `async.execute` body region. + auto executeBodyBuilder = [&](OpBuilder &executeBuilder, + Location executeLoc, + ValueRange executeArgs) { + executeBuilder.create(executeLoc, blockLowerBounds[0], + blockUpperBounds[0], op.step()[0], + ValueRange(), workLoopBuilder(0)); + executeBuilder.create(executeLoc, ValueRange()); + }; + + auto execute = b.create( + loc, /*resultTypes=*/TypeRange(), /*dependencies=*/ValueRange(), + /*operands=*/ValueRange(), executeBodyBuilder); + auto rankType = IndexType::get(ctx); + b.create(loc, rankType, execute.token(), group.result()); + b.create(loc); + }; + }; + + // Start building a loop nest from the first induction variable. + rewriter.create(loc, c0, numBlocks[0], c1, ValueRange(), + asyncLoopBuilder(0)); + + // Wait for the completion of all subtasks. + rewriter.create(loc, group.result()); + + // Erase the original parallel operation. + rewriter.eraseOp(op); + + return success(); +} + +void AsyncParallelForPass::runOnFunction() { + MLIRContext *ctx = &getContext(); + + OwningRewritePatternList patterns; + patterns.insert(ctx, numConcurrentAsyncExecute); + + if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr> mlir::createAsyncParallelForPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRAsyncTransforms + AsyncParallelFor.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async + + DEPENDS + MLIRAsyncPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRAsync + MLIRSCF + MLIRPass + MLIRTransforms + MLIRTransformUtils +) diff --git a/mlir/lib/Dialect/Async/Transforms/PassDetail.h b/mlir/lib/Dialect/Async/Transforms/PassDetail.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/PassDetail.h @@ -0,0 +1,30 @@ +//===- PassDetail.h - Async Pass class details ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_ASYNC_TRANSFORMS_PASSDETAIL_H_ +#define DIALECT_ASYNC_TRANSFORMS_PASSDETAIL_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { + +namespace async { +class AsyncDialect; +} // namespace async + +namespace scf { +class SCFDialect; +} // namespace scf + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/Async/Passes.h.inc" + +} // namespace mlir + +#endif // DIALECT_ASYNC_TRANSFORMS_PASSDETAIL_H_ diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -15,6 +15,7 @@ #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS +#include #include #include #include @@ -33,12 +34,50 @@ std::vector> awaiters; }; +struct AsyncGroup { + std::atomic pendingTokens{0}; + std::atomic rank{0}; + std::mutex mu; + std::condition_variable cv; + std::vector> awaiters; +}; + // Create a new `async.token` in not-ready state. extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { AsyncToken *token = new AsyncToken; return token; } +// Create a new `async.group` in empty state. +extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() { + AsyncGroup *group = new AsyncGroup; + return group; +} + +extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t +mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) { + std::unique_lock lockToken(token->mu); + std::unique_lock lockGroup(group->mu); + + group->pendingTokens.fetch_add(1); + + auto onTokenReady = [group]() { + // Run all group awaiters if it was the last token in the group. + if (group->pendingTokens.fetch_sub(1) == 1) { + group->cv.notify_all(); + for (auto &awaiter : group->awaiters) + awaiter(); + } + }; + + if (token->ready) + onTokenReady(); + else + token->awaiters.push_back([onTokenReady]() { onTokenReady(); }); + + return group->rank.fetch_add(1); +} + // Switches `async.token` to ready state and runs all awaiters. extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { std::unique_lock lock(token->mu); @@ -52,7 +91,13 @@ std::unique_lock lock(token->mu); if (!token->ready) token->cv.wait(lock, [token] { return token->ready; }); - delete token; +} + +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { + std::unique_lock lock(group->mu); + if (group->pendingTokens != 0) + group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); } extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { @@ -69,9 +114,8 @@ CoroResume resume) { std::unique_lock lock(token->mu); - auto execute = [token, handle, resume]() { + auto execute = [handle, resume]() { mlirAsyncRuntimeExecute(handle, resume); - delete token; }; if (token->ready) @@ -80,6 +124,21 @@ token->awaiters.push_back([execute]() { execute(); }); } +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle, + CoroResume resume) { + std::unique_lock lock(group->mu); + + auto execute = [handle, resume]() { + mlirAsyncRuntimeExecute(handle, resume); + }; + + if (group->pendingTokens == 0) + execute(); + else + group->awaiters.push_back([execute]() { execute(); }); +} + //===----------------------------------------------------------------------===// // Small async runtime support library for testing. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -156,4 +156,43 @@ // CHECK: store %arg1, %arg2[%c0] : memref<1xf32> // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]]) +// ----- + +// CHECK-LABEL: async_group_await_all +func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) { + // CHECK: %0 = call @mlirAsyncRuntimeCreateGroup() + %0 = async.create_group + + // CHECK: %[[TOKEN:.*]] = call @async_execute_fn + %token = async.execute { async.yield } + // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %0) + async.add_to_group %token, %0 : !async.token + + // CHECK: call @async_execute_fn_0 + async.execute { + async.await_all %0 + async.yield + } + + // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%0) + async.await_all %0 + + return +} + +// Function outlined from the async.execute operation. +// CHECK: func private @async_execute_fn_0(%arg0: !llvm.ptr) +// CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken() +// CHECK: %[[HDL_1:.*]] = llvm.call @llvm.coro.begin +// Suspend coroutine in the beginning. +// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL_1]], +// CHECK: llvm.call @llvm.coro.suspend + +// Suspend coroutine second time waiting for the group. +// CHECK: llvm.call @llvm.coro.save +// CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute(%arg0, %[[HDL_1]], +// CHECK: llvm.call @llvm.coro.suspend + +// Emplace result token. +// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]]) diff --git a/mlir/test/Dialect/Async/async-parallel-for.mlir b/mlir/test/Dialect/Async/async-parallel-for.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-parallel-for.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s -async-parallel-for | FileCheck %s + +// CHECK-LABEL: @loop_1d +func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref) { + // CHECK: %[[GROUP:.*]] = async.create_group + // CHECK: scf.for + // CHECK: %[[TOKEN:.*]] = async.execute { + // CHECK: scf.for + // CHECK: store + // CHECK: async.yield + // CHECK: } + // CHECK: async.add_to_group %[[TOKEN]], %[[GROUP]] + // CHECK: async.await_all %[[GROUP]] + scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) { + %one = constant 1.0 : f32 + store %one, %arg3[%i] : memref + } + + return +} + +// CHECK-LABEL: @loop_2d +func @loop_2d(%arg0: index, %arg1: index, %arg2: index, // lb, ub, step + %arg3: index, %arg4: index, %arg5: index, // lb, ub, step + %arg6: memref) { + // CHECK: %[[GROUP:.*]] = async.create_group + // CHECK: scf.for + // CHECK: scf.for + // CHECK: %[[TOKEN:.*]] = async.execute { + // CHECK: scf.for + // CHECK: scf.for + // CHECK: store + // CHECK: async.yield + // CHECK: } + // CHECK: async.add_to_group %[[TOKEN]], %[[GROUP]] + // CHECK: async.await_all %[[GROUP]] + scf.parallel (%i0, %i1) = (%arg0, %arg3) to (%arg1, %arg4) + step (%arg2, %arg5) { + %one = constant 1.0 : f32 + store %one, %arg6[%i0, %i1] : memref + } + + return +} diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -120,3 +120,17 @@ %0 = async.await %arg0 : !async.value return %0 : f32 } + +// CHECK-LABEL: @create_group_and_await_all +func @create_group_and_await_all(%arg0: !async.token, %arg1: !async.value) -> index { + %0 = async.create_group + + // CHECK: async.add_to_group %arg0 + // CHECK: async.add_to_group %arg1 + %1 = async.add_to_group %arg0, %0 : !async.token + %2 = async.add_to_group %arg1, %0 : !async.value + async.await_all %0 + + %3 = addi %1, %2 : index + return %3 : index +} diff --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/async-group.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt %s -convert-async-to-llvm \ +// RUN: -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \ +// RUN: | FileCheck %s + +func @main() { + %group = async.create_group + + %token0 = async.execute { async.yield } + %token1 = async.execute { async.yield } + %token2 = async.execute { async.yield } + %token3 = async.execute { async.yield } + %token4 = async.execute { async.yield } + + %0 = async.add_to_group %token0, %group : !async.token + %1 = async.add_to_group %token1, %group : !async.token + %2 = async.add_to_group %token2, %group : !async.token + %3 = async.add_to_group %token3, %group : !async.token + %4 = async.add_to_group %token4, %group : !async.token + + %token5 = async.execute { + async.await_all %group + async.yield + } + + %group0 = async.create_group + %5 = async.add_to_group %token5, %group0 : !async.token + async.await_all %group0 + + // CHECK: Current thread id: [[THREAD:.*]] + call @mlirAsyncRuntimePrintCurrentThreadId(): () -> () + + return +} + +func @mlirAsyncRuntimePrintCurrentThreadId() -> ()