diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td @@ -73,4 +73,8 @@ def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType, Async_TokenType]>; +def Async_AnyAsyncType : AnyTypeOf<[Async_AnyValueType, + Async_TokenType, + Async_GroupType]>; + #endif // ASYNC_BASE_TD diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -224,4 +224,50 @@ let assemblyFormat = "$operand attr-dict"; } +//===----------------------------------------------------------------------===// +// Async Dialect Automatic Reference Counting Operations. +//===----------------------------------------------------------------------===// + +// All async values (values, tokens, groups) are reference counted at runtime +// and automatically destructed when reference count drops to 0. +// +// All values semantically created with a reference count of +1 and it is +// the responsibility of the async value owner to add/drop reference count +// based on the number of uses. +// +// See `AsyncRefCountingPass` for the automatic reference counting +// implementation details. + +def Async_AddRefOp : Async_Op<"add_ref"> { + let summary = "adds a reference to async value"; + let description = [{ + The `async.add_ref` operation adds a reference(s) to async value (token, + value or group). + }]; + + let arguments = (ins Async_AnyAsyncType:$operand, + Confined:$count); + let results = (outs ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) + }]; +} + +def Async_DropRefOp : Async_Op<"drop_ref"> { + let summary = "drops a reference to async value"; + let description = [{ + The `async.drop_ref` operation drops a reference(s) to async value (token, + value or group). + }]; + + let arguments = (ins Async_AnyAsyncType:$operand, + Confined:$count); + let results = (outs ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) + }]; +} + #endif // ASYNC_OPS diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -19,6 +19,8 @@ std::unique_ptr> createAsyncParallelForPass(); +std::unique_ptr> createAsyncRefCountingPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -24,4 +24,10 @@ let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"]; } +def AsyncRefCounting : FunctionPass<"async-ref-counting"> { + let summary = "Automatic reference counting for Async dialect data types"; + let constructor = "mlir::createAsyncRefCountingPass()"; + let dependentDialects = ["async::AsyncDialect"]; +} + #endif // MLIR_DIALECT_ASYNC_PASSES diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -48,6 +48,18 @@ using CoroHandle = void *; // coroutine handle using CoroResume = void (*)(void *); // coroutine resume function +// Async runtime uses reference counting to manage the lifetime of async values +// (values of async types like tokens, values and groups). +using RefCountedObjPtr = void *; + +// Adds references to reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void + mlirAsyncRuntimeAddRef(RefCountedObjPtr, int32_t); + +// Drops references from reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void + mlirAsyncRuntimeDropRef(RefCountedObjPtr, int32_t); + // Create a new `async.token` in not-ready state. extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken(); diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir --- a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir +++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-ref-counting \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir --- a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir +++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-ref-counting \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -33,6 +33,8 @@ // Async Runtime C API declaration. //===----------------------------------------------------------------------===// +static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; +static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; @@ -49,6 +51,12 @@ namespace { // Async Runtime API function types. struct AsyncAPI { + static FunctionType refCountingFunctionType(MLIRContext *ctx) { + auto ref = LLVM::LLVMType::getInt8PtrTy(ctx); + auto count = IntegerType::get(32, ctx); + return FunctionType::get({ref, count}, {}, ctx); + } + static FunctionType createTokenFunctionType(MLIRContext *ctx) { return FunctionType::get({}, {TokenType::get(ctx)}, ctx); } @@ -109,6 +117,14 @@ MLIRContext *ctx = module.getContext(); Location loc = module.getLoc(); + if (!module.lookupSymbol(kAddRef)) + builder.create(loc, kAddRef, + AsyncAPI::refCountingFunctionType(ctx)); + + if (!module.lookupSymbol(kDropRef)) + builder.create(loc, kDropRef, + AsyncAPI::refCountingFunctionType(ctx)); + if (!module.lookupSymbol(kCreateToken)) builder.create(loc, kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); @@ -631,6 +647,55 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Async reference counting ops lowering (`async.add_ref` and `async.drop_ref` +// to the corresponding API calls). +//===----------------------------------------------------------------------===// + +namespace { + +template +class RefCountingOpLowering : public ConversionPattern { +public: + explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName) + : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx), + apiFunctionName(apiFunctionName) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RefCountingOp refCountingOp = cast(op); + + auto i32 = IntegerType::get(32, op->getContext()); + auto count = IntegerAttr::get(i32, refCountingOp.count()); + auto countCst = rewriter.create(op->getLoc(), i32, count); + + SmallVector args = {refCountingOp.operand(), countCst}; + rewriter.replaceOpWithNewOp(op, Type(), apiFunctionName, + ValueRange(args)); + return success(); + } + +private: + StringRef apiFunctionName; +}; + +// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. +class AddRefOpLowering : public RefCountingOpLowering { +public: + explicit AddRefOpLowering(MLIRContext *ctx) + : RefCountingOpLowering(ctx, kAddRef) {} +}; + +// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. +class DropRefOpLowering : public RefCountingOpLowering { +public: + explicit DropRefOpLowering(MLIRContext *ctx) + : RefCountingOpLowering(ctx, kDropRef) {} +}; + +} // namespace + //===----------------------------------------------------------------------===// // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. //===----------------------------------------------------------------------===// @@ -837,10 +902,12 @@ populateFuncOpTypeConversionPattern(patterns, ctx, converter); patterns.insert(ctx); + patterns.insert(ctx); patterns.insert(ctx); patterns.insert(ctx, outlinedFunctions); ConversionTarget target(*ctx); + target.addLegalOp(); target.addLegalDialect(); target.addIllegalDialect(); target.addDynamicallyLegalOp( diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp @@ -0,0 +1,360 @@ +//===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements automatic reference counting for Async dialect data +// types. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; +using namespace mlir::async; + +#define DEBUG_TYPE "async-ref-counting" + +namespace { + +struct AsyncRefCountingPass + : public AsyncRefCountingBase { + AsyncRefCountingPass() = default; + void runOnFunction() override; +}; + +} // namespace + +/// Returns true if the type is reference counted. All async dialect types a +/// reference counted at runtime. +bool isRefCounted(Type type) { + return type.isa(); +} + +/// Returns true if the operation `op` supports async reference counting. +/// +/// It is the async value consumer responsibility to drop the reference count +/// when the value is no longer needed. If the async value passed to the +/// consumer that is not aware of reference counting, this async value will leak +/// at runtime. +bool isSupportedConsumer(Operation *op) { + // Return operation transfers ownership to the caller. + if (isa(op)) + return true; + + // Async dialect operations correctly handle reference counted values. + if (isa(op)) + return true; + + return false; +} + +/// Returns the statically know number of instances for all operations in the +/// attached region (the number of times each operation inside the attached +/// region will be executed). +Optional getStaticNumberOfInstances(Operation *op) { + assert(!op->getRegions().empty() && "operation must have attached regions"); + + // `async.execute` will execute all operations exactly once. + if (isa(op)) + return 1; + + // TODO: Loops with statically known bounds have statically know number of + // operation instances in the loop body. + return None; +} + +/// Returns the statically known number of instances of the `user` operation +/// that consumes async values produced by the `owner` operation. Returns empty +/// optional if the number of instances is dynamic. +/// +/// Examples: +/// +/// 1. `owner` and `user` are in the same region. +/// +/// %token = ... +/// "use"(%token): (!async.token) -> () +/// +/// Number of instances: 1 +/// +/// 2. `user` is inside the region with statically known execution. +/// +/// %token = ... +/// async.execute { +/// "use"(%token): (!async.token) -> () +/// } +/// +/// Number of instances: 1 (async.execute will execute all operations in the +/// attached body region) +/// +/// 3. `user` is inside the dynamic control flow operation (e.g. `scf.if`, +/// `scf.for` or `scf.parallel`). +/// +/// %token = ... +/// scf.if %condition { +/// "use"(%token): (!async.token) -> () +/// } else { +/// "some_other_operation"(): () -> () +/// } +/// +/// Number of instances: (it is not statically known if the +/// execution will go into the first region). +/// +/// If we know the number of `user` instances statically, we can increment the +/// reference count for the async value produced by the `owner`: +/// +/// %token = ... +/// async.add_ref %token {count = } +/// +/// For dynamic instances we can safely add a reference only in the same region +/// as the `user` parent region. See details below. +Optional getStaticNumberOfInstances(Operation *owner, + Operation *user) { + int32_t result = 1; + + Operation *ownerParent = owner->getParentOp(); + Operation *userParent = user->getParentOp(); + + while (ownerParent != userParent) { + if (auto num = getStaticNumberOfInstances(userParent)) + result *= *num; + else + return None; + + userParent = userParent->getParentOp(); + } + + return result; +} + +namespace { +struct DynamicInstanceProperties { + /// The static number instances of the `user` operation inside the dynamic + /// operation. + int32_t staticNumberOfInstances; + + /// The block owned by the region attached to the dynamic operation. + Block *dynamicBlock; + + /// Operation in the same region as async value `owner` that contains the + /// dynamic operation (can be the dynamic operation itself). We'll use this as + /// an anchor to add explicit `async.drop_ref` operation after it. + Operation *refCountAnchor; +}; +} // namespace + +/// Returns the dynamic instance properties of the `user` operation that +/// consumes async value produced by the `owner` operation. +/// +/// Example: +/// +/// %token = ... +/// scf.for %i = %c0 to %c2 step %c1 { +/// scf.if %cond { +/// async.execute { +/// async.await %token : !async.token +/// } +/// } +/// } +/// +/// Innermost dynamic operation that contains the async value user `async.await` +/// is `scf.for`. Inside this operation statically known number of uses is 1. +/// +/// Dynamic reference counting must be added to the block that owns the +/// operation that has statically known number of instances of the async uses. +/// +/// With automatic reference counting this should become: +/// +/// %token = ... +/// +/// // Add a reference count statically because we know that we have one +/// dynamic +/// // use inside the `scf.for` operation. +/// async.add_ref %token {count = 1 : i32} : !async.token +/// +/// scf.for %i = %c0 to %c2 step %c1 { +/// scf.if %cond { +/// // Add a reference count dynamically. +/// async.add_ref %token {count = 1 : i32} : !async.token +/// async.execute { +/// async.await %token : !async.token +/// } +/// } +/// } +/// +/// // Explicitly drop the static reference that we added on behalf of +/// `scf.for` operation. +/// async.drop_ref %token {count = 1 : i32} : !async.token +DynamicInstanceProperties getDynamicInstanceProperties(Operation *owner, + Operation *user) { + assert(!getStaticNumberOfInstances(owner, user).hasValue() && + "user must have dynamic number of instance"); + + // Compute the number of static instances before we reach the first dynamic + // parent. + int32_t numberOfStaticInstances = 1; + + // Operation with statically known number of instances of the `user` + // operation (can be `user` operation itself). + Operation *staticUser = user; + + Operation *ownerParent = owner->getParentOp(); + Operation *userParent = user->getParentOp(); + + // Find the parent with statically known number of instances. + while (ownerParent != userParent) { + if (auto n = getStaticNumberOfInstances(userParent)) + numberOfStaticInstances *= *n; + else + break; + + staticUser = userParent; + userParent = userParent->getParentOp(); + } + + assert(ownerParent != userParent && "did not find dynamic operation"); + + // Block that owns operation with statically known number of user instances, + // but the parent has dynamic nubmer of instances. + Block *dynamicBlock = staticUser->getBlock(); + Operation *dynamicOperation = userParent; + + assert(dynamicBlock->getParentOp() == dynamicOperation); + + // Find the operation that owns the operation with dynamic nubmer of + // instances and has the same parent as the `owner`. + Operation *dynamicParent = dynamicOperation->getParentOp(); + while (ownerParent != dynamicParent) { + dynamicOperation = dynamicParent; + dynamicParent = dynamicOperation->getParentOp(); + } + + return {numberOfStaticInstances, dynamicBlock, dynamicOperation}; +} + +LogicalResult addAutomaticRefCounting(OpResult result) { + Operation *op = result.getOwner(); + MLIRContext *ctx = op->getContext(); + + Location loc = result.getLoc(); + + OpBuilder builder(op); + builder.setInsertionPointAfter(op); + + auto i32 = IntegerType::get(32, ctx); + + // Drop ref count -1 if the result has no users. + if (result.getUsers().empty()) { + builder.create(loc, result, IntegerAttr::get(i32, 1)); + return success(); + } + + // Verify that all users support automatic reference counting. + for (Operation *user : result.getUsers()) { + if (!isSupportedConsumer(user)) { + op->emitError() << "result #" << result.getResultNumber() + << " passed to the operation that does not support " + "automatic async reference counting: " + << user->getName(); + return failure(); + } + } + + // The number of statically known uses of the `result`. + int32_t staticInstances = 0; + + // Collect properties of the dynamic uses of the `result`. + SmallVector dynamicInstances; + + for (const OpOperand &use : result.getUses()) { + Operation *user = use.getOwner(); + + // Check if we know the number of user instances statically. + if (auto knownStatically = getStaticNumberOfInstances(op, user)) { + staticInstances += *knownStatically; + continue; + } + + // Collect dynamic instance properties otherwise. + DynamicInstanceProperties props = getDynamicInstanceProperties(op, user); + dynamicInstances.push_back(props); + } + + // Remove redundant reference counting from the same anchors ... + llvm::SmallSet anchors; + // ... and aggregate the number of static instances per dynamic block. + llvm::DenseMap blockStaticCounts; + + for (DynamicInstanceProperties &props : dynamicInstances) { + blockStaticCounts[props.dynamicBlock] += props.staticNumberOfInstances; + anchors.insert(props.refCountAnchor); + } + + // We'll add +1 reference for each static instance of the user operation, and + // also +1 for every dynamic instance anchor operation. Adding references for + // dynamic instance anchors is required to keep reference counted objects + // alive until the control flow reaches `async.add_ref` operation inside the + // dynamic region. + int32_t useCount = staticInstances + anchors.size(); + + // Add +1 reference for each result use to eventually drop the reference count + // to zero. + if (useCount > 1) + builder.create(loc, result, IntegerAttr::get(i32, useCount - 1)); + + // Drop reference count immediately after the anchor operation. + for (Operation *anchor : anchors) { + builder.setInsertionPointAfter(anchor); + builder.create(loc, result, IntegerAttr::get(i32, 1)); + } + + // Add statically know references at the beginning of the dynamic block. + for (auto &kv : blockStaticCounts) { + Block *block = kv.first; + int32_t count = kv.second; + builder.setInsertionPointToStart(block); + builder.create(loc, result, IntegerAttr::get(i32, count)); + } + + return success(); +} + +LogicalResult addAutomaticRefCounting(Operation *op) { + for (unsigned i = 0; i < op->getNumResults(); ++i) { + Type resultType = op->getResultTypes()[i]; + if (!isRefCounted(resultType)) + continue; + + if (failed(addAutomaticRefCounting(op->getResult(i)))) + return failure(); + } + return success(); +} + +void AsyncRefCountingPass::runOnFunction() { + FuncOp func = getFunction(); + + // Add `async.add_ref` operations to match the number of uses for each async + // value. + WalkResult walkResult = func.walk([](Operation *op) -> WalkResult { + if (failed(addAutomaticRefCounting(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) + signalPassFailure(); +} + +std::unique_ptr> mlir::createAsyncRefCountingPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRAsyncTransforms AsyncParallelFor.cpp + AsyncRefCounting.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -16,6 +16,7 @@ #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS #include +#include #include #include #include @@ -27,30 +28,141 @@ // Async runtime API. //===----------------------------------------------------------------------===// -struct AsyncToken { - bool ready = false; +namespace { + +// Forward declare class defined below. +class RefCounted; + +// -------------------------------------------------------------------------- // +// AsyncRuntime orchestrates all async operations and Async runtime API is built +// on top of the default runtime instance. +// -------------------------------------------------------------------------- // + +class AsyncRuntime { +public: + AsyncRuntime() : numRefCountedObjects(0) {} + + ~AsyncRuntime() { + assert(getNumRefCountedObjects() == 0 && + "all ref counted objects must be destroyed"); + } + + int32_t getNumRefCountedObjects() { + return numRefCountedObjects.load(std::memory_order_relaxed); + } + +private: + friend class RefCounted; + + // Count the total number of reference counted objects in this instance + // of an AsyncRuntime. For debugging purposes only. + void addNumRefCountedObjects() { + numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); + } + void dropNumRefCountedObjects() { + numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); + } + + std::atomic numRefCountedObjects; +}; + +// Returns the default per-process instance of an async runtime. +AsyncRuntime *getDefaultAsyncRuntimeInstance() { + static auto runtime = std::make_unique(); + return runtime.get(); +} + +// -------------------------------------------------------------------------- // +// A base class for all reference counted objects created by the async runtime. +// -------------------------------------------------------------------------- // + +class RefCounted { +public: + RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) + : runtime(runtime), refCount(refCount) { + runtime->addNumRefCountedObjects(); + } + + virtual ~RefCounted() { + assert(refCount.load() == 0 && "reference count must be zero"); + runtime->dropNumRefCountedObjects(); + } + + RefCounted(const RefCounted &) = delete; + RefCounted &operator=(const RefCounted &) = delete; + + void addRef(int32_t count = 1) { refCount.fetch_add(count); } + + void dropRef(int32_t count = 1) { + int32_t previous = refCount.fetch_sub(count); + assert(previous >= count && "reference count should not go below zero"); + if (previous == count) + destroy(); + } + +protected: + virtual void destroy() { delete this; } + +private: + AsyncRuntime *runtime; + std::atomic refCount; +}; + +} // namespace + +struct AsyncToken : public RefCounted { + // AsyncToken created with a reference count of 2 because it will be returned + // to the `async.execute` caller and also will be later on emplaced by the + // asynchronously executed task. If the caller immediately will drop its + // reference we must ensure that the token will be alive until the + // asynchronous operation is completed. + AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {} + + // Internal state below guarded by a mutex. std::mutex mu; std::condition_variable cv; + + bool ready = false; std::vector> awaiters; }; -struct AsyncGroup { - std::atomic pendingTokens{0}; - std::atomic rank{0}; +struct AsyncGroup : public RefCounted { + AsyncGroup(AsyncRuntime *runtime) + : RefCounted(runtime), pendingTokens(0), rank(0) {} + + std::atomic pendingTokens; + std::atomic rank; + + // Internal state below guarded by a mutex. std::mutex mu; std::condition_variable cv; + std::vector> awaiters; }; +// Adds references to reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { + RefCounted *refCounted = static_cast(ptr); + refCounted->addRef(count); +} + +// Drops references from reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { + RefCounted *refCounted = static_cast(ptr); + refCounted->dropRef(count); +} + // Create a new `async.token` in not-ready state. extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken() { - AsyncToken *token = new AsyncToken; + AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance()); return token; } // Create a new `async.group` in empty state. extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() { - AsyncGroup *group = new AsyncGroup; + AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); return group; } @@ -59,15 +171,21 @@ std::unique_lock lockToken(token->mu); std::unique_lock lockGroup(group->mu); + // Get the rank of the token inside the group before we drop the reference. + int rank = group->rank.fetch_add(1); group->pendingTokens.fetch_add(1); - auto onTokenReady = [group]() { + auto onTokenReady = [group, token]() { // Run all group awaiters if it was the last token in the group. if (group->pendingTokens.fetch_sub(1) == 1) { group->cv.notify_all(); for (auto &awaiter : group->awaiters) awaiter(); } + + // We no longer need the token or the group, drop references on them. + group->dropRef(); + token->dropRef(); }; if (token->ready) @@ -75,7 +193,7 @@ else token->awaiters.push_back([onTokenReady]() { onTokenReady(); }); - return group->rank.fetch_add(1); + return rank; } // Switches `async.token` to ready state and runs all awaiters. @@ -86,6 +204,8 @@ token->cv.notify_all(); for (auto &awaiter : token->awaiters) awaiter(); + + token->dropRef(); } extern "C" MLIR_ASYNCRUNTIME_EXPORT void @@ -93,6 +213,8 @@ std::unique_lock lock(token->mu); if (!token->ready) token->cv.wait(lock, [token] { return token->ready; }); + + token->dropRef(); } extern "C" MLIR_ASYNCRUNTIME_EXPORT void @@ -100,6 +222,8 @@ std::unique_lock lock(group->mu); if (group->pendingTokens != 0) group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); + + group->dropRef(); } extern "C" MLIR_ASYNCRUNTIME_EXPORT void @@ -117,7 +241,8 @@ CoroResume resume) { std::unique_lock lock(token->mu); - auto execute = [handle, resume]() { + auto execute = [handle, resume, token]() { + token->dropRef(); mlirAsyncRuntimeExecute(handle, resume); }; @@ -132,7 +257,8 @@ CoroResume resume) { std::unique_lock lock(group->mu); - auto execute = [handle, resume]() { + auto execute = [handle, resume, group]() { + group->dropRef(); mlirAsyncRuntimeExecute(handle, resume); }; diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -1,5 +1,20 @@ // RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s +// CHECK-LABEL: reference_counting +func @reference_counting(%arg0: !async.token) { + // CHECK: %[[C2:.*]] = constant 2 : i32 + // CHECK: call @mlirAsyncRuntimeAddRef(%arg0, %[[C2]]) + async.add_ref %arg0 {count = 2 : i32} : !async.token + + // CHECK: %[[C1:.*]] = constant 1 : i32 + // CHECK: call @mlirAsyncRuntimeDropRef(%arg0, %[[C1]]) + async.drop_ref %arg0 {count = 1 : i32} : !async.token + + return +} + +// ----- + // CHECK-LABEL: execute_no_async_args func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) { // CHECK: %[[TOKEN:.*]] = call @async_execute_fn(%arg0, %arg1) diff --git a/mlir/test/Dialect/Async/async-ref-counting.mlir b/mlir/test/Dialect/Async/async-ref-counting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-ref-counting.mlir @@ -0,0 +1,181 @@ +// RUN: mlir-opt %s -async-ref-counting | FileCheck %s + +// CHECK-LABEL: @token_no_uses +func @token_no_uses() { + // CHECK: %[[TOKEN:.*]] = async.execute + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + %token = async.execute { + async.yield + } + return +} + +// CHECK-LABEL: @token_return +func @token_return() -> !async.token { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: return %token + return %token : !async.token +} + +// CHECK-LABEL: @token_with_await +func @token_with_await() -> !async.token { + // CHECK: %[[TOKEN:.*]] = async.execute + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + %token = async.execute { + async.yield + } + async.await %token : !async.token + // CHECK: return %token + return %token : !async.token +} + +// CHECK-LABEL: @token_capture +func @token_capture() { + // CHECK: %[[TOKEN:.*]] = async.execute + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + %token = async.execute { + async.yield + } + + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute { + async.await %token : !async.token + async.yield + } + + // CHECK: async.await %[[TOKEN]] + // CHECK: async.await %[[TOKEN_0]] + async.await %token : !async.token + async.await %token_0 : !async.token + + // CHECK: return + return +} + +// CHECK-LABEL: @token_dependency +func @token_dependency() { + // CHECK: %[[TOKEN:.*]] = async.execute + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + %token = async.execute { + async.yield + } + + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute[%token] { + async.yield + } + + // CHECK: async.await %[[TOKEN]] + // CHECK: async.await %[[TOKEN_0]] + async.await %token : !async.token + async.await %token_0 : !async.token + + // CHECK: return + return +} + +// CHECK-LABEL: @value_operand +func @value_operand() -> f32 { + // CHECK: %[[TOKEN:.*]], %[[RESULTS:.*]] = async.execute + // CHECK: async.add_ref %[[RESULTS]] {count = 1 : i32} + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + %token, %results = async.execute -> !async.value { + %0 = constant 0.0 : f32 + async.yield %0 : f32 + } + + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute[%token](%results as %arg0 : !async.value) { + async.yield + } + + // CHECK: async.await %[[TOKEN]] + // CHECK: async.await %[[TOKEN_0]] + async.await %token : !async.token + async.await %token_0 : !async.token + + // CHECK: async.await %[[RESULTS]] + %0 = async.await %results : !async.value + + // CHECK: return + return %0 : f32 +} + +// CHECK-LABEL: @async_group +func @async_group() { + // CHECK: %[[GROUP:.*]] = async.create_group + // CHECK: async.add_ref %[[GROUP]] {count = 1 : i32} : !async.group + %0 = async.create_group + + // CHECK: %[[TOKEN:.*]] = async.execute + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token = async.execute { async.yield } + %token_0 = async.execute { async.yield } + + // CHECK: async.add_to_group %[[TOKEN]], %[[GROUP]] + // CHECK: async.add_to_group %[[TOKEN_0]], %[[GROUP]] + %1 = async.add_to_group %token, %0 : !async.token + %2 = async.add_to_group %token_0, %0 : !async.token + + // CHECK: return + return +} + +// CHECK-LABEL: @capture_by_scf_if +func @capture_by_scf_if(%arg0 : i1) { + %token = async.execute { async.yield } + + scf.if %arg0 { + // CHECK: async.add_ref %token {count = 2 : i32} + async.await %token : !async.token + async.await %token : !async.token + } else { + // CHECK: async.add_ref %token {count = 1 : i32} + async.await %token : !async.token + } + // CHECK: async.drop_ref %token {count = 1 : i32} + + return +} + +// CHECK-LABEL: @capture_by_scf_if_with_async_execute +func @capture_by_scf_if_with_async_execute(%arg0 : i1) { + %token = async.execute { async.yield } + + // `async.await` from the `async.execute` rolled up to the first + // operation with dynamic number of instances. + scf.if %arg0 { + // CHECK: async.add_ref %token {count = 2 : i32} + async.execute { + async.await %token : !async.token + async.await %token : !async.token + async.yield + } + } else { + // CHECK: async.add_ref %token {count = 1 : i32} + async.await %token : !async.token + } + // CHECK: async.drop_ref %token {count = 1 : i32} + + return +} + +// CHECK-LABEL: @capture_by_scf_for +func @capture_by_scf_for() { + %token = async.execute { async.yield } + + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 0 : index + + scf.for %i = %c0 to %c2 step %c1 { + // CHECK: async.add_ref %token {count = 1 : i32} + async.await %token : !async.token + } + // CHECK: async.drop_ref %token {count = 1 : i32} + + return +} diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -134,3 +134,17 @@ %3 = addi %1, %2 : index return %3 : index } + +// CHECK-LABEL: @add_ref +func @add_ref(%arg0: !async.token) { + // CHECK: async.add_ref %arg0 {count = 1 : i32} + async.add_ref %arg0 {count = 1 : i32} : !async.token + return +} + +// CHECK-LABEL: @drop_ref +func @drop_ref(%arg0: !async.token) { + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + async.drop_ref %arg0 {count = 1 : i32} : !async.token + return +} diff --git a/mlir/test/Dialect/Async/verify.mlir b/mlir/test/Dialect/Async/verify.mlir --- a/mlir/test/Dialect/Async/verify.mlir +++ b/mlir/test/Dialect/Async/verify.mlir @@ -19,3 +19,17 @@ // expected-error @+1 {{'async.await' op result type 'f64' does not match async value type 'f32'}} %0 = "async.await"(%arg0): (!async.value) -> f64 } + +// ----- + +func @wrong_add_ref_count(%arg0: !async.token) { + // expected-error @+1 {{'async.add_ref' op reference count must be greater than 0}} + async.add_ref %arg0 {count = 0 : i32} : !async.token +} + +// ----- + +func @wrong_drop_ref_count(%arg0: !async.token) { + // expected-error @+1 {{'async.drop_ref' op reference count must be greater than 0}} + async.drop_ref %arg0 {count = 0 : i32} : !async.token +} diff --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir --- a/mlir/test/mlir-cpu-runner/async-group.mlir +++ b/mlir/test/mlir-cpu-runner/async-group.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -convert-async-to-llvm \ +// RUN: mlir-opt %s -async-ref-counting \ +// RUN: -convert-async-to-llvm \ // RUN: -convert-std-to-llvm \ // RUN: | mlir-cpu-runner \ // RUN: -e main -entry-point-result=void -O0 \