diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td @@ -73,4 +73,8 @@ def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType, Async_TokenType]>; +def Async_AnyAsyncType : AnyTypeOf<[Async_AnyValueType, + Async_TokenType, + Async_GroupType]>; + #endif // ASYNC_BASE_TD diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -227,4 +227,50 @@ let assemblyFormat = "$operand attr-dict"; } +//===----------------------------------------------------------------------===// +// Async Dialect Automatic Reference Counting Operations. +//===----------------------------------------------------------------------===// + +// All async values (values, tokens, groups) are reference counted at runtime +// and automatically destructed when reference count drops to 0. +// +// All values semantically created with a reference count of +1 and it is +// the responsibility of the async value owner to add/drop reference count +// based on the number of uses. +// +// See `AsyncRefCountingPass` for the automatic reference counting +// implementation details. + +def Async_AddRefOp : Async_Op<"add_ref"> { + let summary = "adds a reference to async value"; + let description = [{ + The `async.add_ref` operation adds a reference(s) to async value (token, + value or group). + }]; + + let arguments = (ins Async_AnyAsyncType:$operand, + Confined:$count); + let results = (outs ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) + }]; +} + +def Async_DropRefOp : Async_Op<"drop_ref"> { + let summary = "drops a reference to async value"; + let description = [{ + The `async.drop_ref` operation drops a reference(s) to async value (token, + value or group). + }]; + + let arguments = (ins Async_AnyAsyncType:$operand, + Confined:$count); + let results = (outs ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) + }]; +} + #endif // ASYNC_OPS diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -19,6 +19,8 @@ std::unique_ptr> createAsyncParallelForPass(); +std::unique_ptr> createAsyncRefCountingPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -24,4 +24,10 @@ let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"]; } +def AsyncRefCounting : FunctionPass<"async-ref-counting"> { + let summary = "Automatic reference counting for Async dialect data types"; + let constructor = "mlir::createAsyncRefCountingPass()"; + let dependentDialects = ["async::AsyncDialect"]; +} + #endif // MLIR_DIALECT_ASYNC_PASSES diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -48,6 +48,18 @@ using CoroHandle = void *; // coroutine handle using CoroResume = void (*)(void *); // coroutine resume function +// Async runtime uses reference counting to manage the lifetime of async values +// (values of async types like tokens, values and groups). +using RefCountedObjPtr = void *; + +// Adds references to reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void + mlirAsyncRuntimeAddRef(RefCountedObjPtr, int32_t); + +// Drops references from reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void + mlirAsyncRuntimeDropRef(RefCountedObjPtr, int32_t); + // Create a new `async.token` in not-ready state. extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken(); diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir --- a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir +++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-ref-counting \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir --- a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir +++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-ref-counting \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -33,6 +33,8 @@ // Async Runtime C API declaration. //===----------------------------------------------------------------------===// +static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; +static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; @@ -49,6 +51,12 @@ namespace { // Async Runtime API function types. struct AsyncAPI { + static FunctionType refCountingFunctionType(MLIRContext *ctx) { + auto ref = LLVM::LLVMType::getInt8PtrTy(ctx); + auto count = IntegerType::get(32, ctx); + return FunctionType::get({ref, count}, {}, ctx); + } + static FunctionType createTokenFunctionType(MLIRContext *ctx) { return FunctionType::get({}, {TokenType::get(ctx)}, ctx); } @@ -109,6 +117,14 @@ MLIRContext *ctx = module.getContext(); Location loc = module.getLoc(); + if (!module.lookupSymbol(kAddRef)) + builder.create(loc, kAddRef, + AsyncAPI::refCountingFunctionType(ctx)); + + if (!module.lookupSymbol(kDropRef)) + builder.create(loc, kDropRef, + AsyncAPI::refCountingFunctionType(ctx)); + if (!module.lookupSymbol(kCreateToken)) builder.create(loc, kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); @@ -631,6 +647,56 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Async reference counting ops lowering (`async.add_ref` and `async.drop_ref` +// to the corresponding API calls). +//===----------------------------------------------------------------------===// + +namespace { + +template +class RefCountingOpLowering : public ConversionPattern { +public: + explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName) + : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx), + apiFunctionName(apiFunctionName) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RefCountingOp refCountingOp = cast(op); + + auto i32 = IntegerType::get(32, op->getContext()); + auto count = IntegerAttr::get(i32, refCountingOp.count()); + auto countCst = rewriter.create(op->getLoc(), i32, count); + + rewriter.replaceOpWithNewOp( + op, Type(), apiFunctionName, + ValueRange({refCountingOp.operand(), countCst})); + + return success(); + } + +private: + StringRef apiFunctionName; +}; + +// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. +class AddRefOpLowering : public RefCountingOpLowering { +public: + explicit AddRefOpLowering(MLIRContext *ctx) + : RefCountingOpLowering(ctx, kAddRef) {} +}; + +// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. +class DropRefOpLowering : public RefCountingOpLowering { +public: + explicit DropRefOpLowering(MLIRContext *ctx) + : RefCountingOpLowering(ctx, kDropRef) {} +}; + +} // namespace + //===----------------------------------------------------------------------===// // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. //===----------------------------------------------------------------------===// @@ -837,10 +903,12 @@ populateFuncOpTypeConversionPattern(patterns, ctx, converter); patterns.insert(ctx); + patterns.insert(ctx); patterns.insert(ctx); patterns.insert(ctx, outlinedFunctions); ConversionTarget target(*ctx); + target.addLegalOp(); target.addLegalDialect(); target.addIllegalDialect(); target.addDynamicallyLegalOp( diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp @@ -0,0 +1,351 @@ +//===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements automatic reference counting for Async dialect data +// types. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Analysis/NumberOfExecutions.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; +using namespace mlir::async; + +#define DEBUG_TYPE "async-ref-counting" + +namespace { + +class AsyncRefCountingPass : public AsyncRefCountingBase { +public: + AsyncRefCountingPass() = default; + void runOnFunction() override; + +private: + /// Adds an automatic reference counting to the `value`. + /// + /// Automatic reference counting algorithm outline: + /// + /// 1. `ReturnLike` operations forward the reference counted values without + /// modifying the reference count. + /// + /// 2. Use liveness analysis to find blocks in the CFG where the lifetime of + /// reference counted values ends, and insert `drop_ref` operations after + /// the last use of the value. + /// + /// 3. Insert `add_ref` before the `async.execute` operation capturing the + /// value, and pairing `drop_ref` before the async body region terminator, + /// to release the captured reference counted value when execution + /// completes. + /// + /// 4. If the reference counted value is passed only to some of the block + /// successors, insert `drop_ref` operations in the beginning of the blocks + /// that do not have reference coutned value uses. + /// + LogicalResult addAutomaticRefCounting(Value value); +}; + +} // namespace + +/// Returns true if the type is reference counted. All async dialect types a +/// reference counted at runtime. +static bool isRefCounted(Type type) { + return type.isa(); +} + +LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) { + MLIRContext *ctx = value.getContext(); + OpBuilder builder(ctx); + + // Set inserton point after the operation producing a value, or at the + // beginning of the block if the value defined by the block argument. + if (Operation *op = value.getDefiningOp()) + builder.setInsertionPointAfter(op); + else + builder.setInsertionPointToStart(value.getParentBlock()); + + Location loc = value.getLoc(); + auto i32 = IntegerType::get(32, ctx); + + // Drop the reference count immediately if the value has no uses. + if (value.getUses().empty()) { + builder.create(loc, value, IntegerAttr::get(i32, 1)); + return success(); + } + + // Use liveness analysis to find the placement of `drop_ref`operation. + auto liveness = getAnalysis(); + + // We analyse only the blocks of the region that defines the `value`, and do + // not check nested blocks attached to operations. + Region *definingRegion = value.getParentRegion(); + + // ------------------------------------------------------------------------ // + // Find blocks where the `value` dies: the value is in `liveIn` set and not + // in the `liveOut` set. We place `drop_ref` immediately after the last use + // of the `value` in such regions. + // ------------------------------------------------------------------------ // + + // Last users of the `value` inside all blocks where the value dies. + llvm::SmallSet lastUsers; + + for (Block &block : definingRegion->getBlocks()) { + const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); + + // Value in in live input set or was defined in the block. + bool liveIn = blockLiveness->isLiveIn(value) || + blockLiveness->getBlock() == value.getParentBlock(); + if (!liveIn) + continue; + + // Value is in the live out set. + bool liveOut = blockLiveness->isLiveOut(value); + if (liveOut) + continue; + + // We proved that `value` dies in the `block`. Now find the last use of the + // `value` inside the `block`. + + // Find any user of the `value` inside the block. + Operation *userInTheBlock = nullptr; + for (Operation *user : value.getUsers()) + if (Operation *ancestor = block.findAncestorOpInBlock(*user)) { + userInTheBlock = ancestor; + break; + } + assert(userInTheBlock && "value must have a user in the block"); + + // Find the last user of the `value` in the block; + Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); + assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); + lastUsers.insert(lastUser); + } + + // Process all the last users of the `value` inside each block where the value + // dies. + for (Operation *lastUser : lastUsers) { + // Return like operations forward reference count. + if (lastUser->hasTrait()) + continue; + + // We can't currently handle other types of terminators. + if (lastUser->hasTrait()) + return lastUser->emitError() << "async reference counting can't handle " + "terminators that are not ReturnLike"; + + // Add a drop_ref immediately after the last user. + builder.setInsertionPointAfter(lastUser); + builder.create(loc, value, IntegerAttr::get(i32, 1)); + } + + // ------------------------------------------------------------------------ // + // Find blocks where the `value` is in `liveOut` set, however it is not in + // the `liveIn` set of all successors. If the `value` is not in the successor + // `liveIn` set, we add a `drop_ref` to the beginning of it. + // ------------------------------------------------------------------------ // + + // Successors that we'll need a `drop_ref` for the `value`. + llvm::SmallSet dropRefSuccessors; + + for (Block &block : definingRegion->getBlocks()) { + const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); + + // Skip the block if value is not in the `liveOut` set. + if (!blockLiveness->isLiveOut(value)) + continue; + + // Find successors that do not have `value` in the `liveIn` set. + for (Block *successor : block.getSuccessors()) { + const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); + + if (!succLiveness->isLiveIn(value)) + dropRefSuccessors.insert(successor); + } + } + + // Drop reference in all successor blocks that do not have the `value` in + // their `liveIn` set. + for (Block *dropRefSuccessor : dropRefSuccessors) { + builder.setInsertionPointToStart(dropRefSuccessor); + builder.create(loc, value, IntegerAttr::get(i32, 1)); + } + + // ------------------------------------------------------------------------ // + // Find all `async.execute` operation that take `value` as an operand + // (dependency token or async value), or capture implicitly by the nested + // region. Each `async.execute` operation will require `add_ref` operation + // to keep all captured values alive until it will finish its execution. + // ------------------------------------------------------------------------ // + + llvm::SmallSet executeOperations; + + for (Operation *user : value.getUsers()) { + Block *block = user->getBlock(); + + Block *ancestorBlock = definingRegion->findAncestorBlockInRegion(*block); + assert(ancestorBlock && "user lies outside of the defining region"); + + Operation *ancestorOp = ancestorBlock->findAncestorOpInBlock(*user); + + auto trackAsyncExecute = [&](Operation *op) { + if (auto execute = dyn_cast(op)) + executeOperations.insert(execute); + }; + + // Check the operation in the same region as `value` definition. + trackAsyncExecute(ancestorOp); + + // Follow all parent operations starting from the user. + while (user != ancestorOp) { + trackAsyncExecute(user); + user = user->getParentOp(); + } + } + + // Process all `async.execute` operations capturing `value`. + for (ExecuteOp execute : executeOperations) { + // Add a reference before the execute operation to keep the reference + // counted alive before the async region completes execution. + builder.setInsertionPoint(execute.getOperation()); + builder.create(loc, value, IntegerAttr::get(i32, 1)); + + // Drop the reference inside the async region before completion. + OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody()); + executeBuilder.create(loc, value, IntegerAttr::get(i32, 1)); + } + + // ------------------------------------------------------------------------ // + // Remove all cancelling pairs of `add_ref` and `drop_ref` operations. + // ------------------------------------------------------------------------ // + + // Returns true if the operation itself is the user of the `value`, or it has + // nested operations that are users of the `value`. + auto isValueUser = [&](Operation *op) -> bool { + Block *block = op->getBlock(); + for (Operation *user : value.getUsers()) + if (block->findAncestorOpInBlock(*user) == op) + return true; + return false; + }; + + // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the + // `block` that modify the reference count of the `value`. + auto eraseCancellableRefCounting = [&](Block *block) -> void { + for (auto addRef : llvm::make_early_inc_range(block->getOps())) { + // `add_ref` operation adding a reference to the `value`. + if (addRef.operand() != value) + continue; + + for (auto dropRef : block->getOps()) { + // `drop_ref` operation after the `add_ref` with matching count. + if (dropRef.operand() != addRef.operand() || + dropRef.count() != addRef.count() || + dropRef.getOperation()->isBeforeInBlock(addRef.getOperation())) + continue; + + // Check `value` users between `addRef` and `dropRef` in the `block`. + Operation *addRefOp = addRef.getOperation(); + Operation *dropRefOp = dropRef.getOperation(); + + // If there is a "regular" user after the `async.execute` user it is + // unsafe to erase cancellable reference counting operations pair, + // because async region can complete before the "regular" user and + // destroy the reference counted value. + bool hasExecuteUser = false; + bool unsafeToCancel = false; + + for (auto *op = addRefOp; op != dropRefOp; op = op->getNextNode()) { + // Check if the operation is a user of the `value`. + bool user = isValueUser(op); + bool isRegularUser = user && !isa(op); + bool isExecuteUser = user && !isRegularUser; + + // It is unsafe to cancel `addRef` / `dropRef` pair. + if (isRegularUser && hasExecuteUser) { + unsafeToCancel = true; + break; + } + + hasExecuteUser |= isExecuteUser; + } + + // Erase the pair of reference counting operations if it is safe. + if (!unsafeToCancel) { + addRef.erase(); + dropRef.erase(); + } + + // If it us unsafe to cancel `addRef <-> dropRef` pair at this point, + // all the following pairs will be also unsafe. + break; + } + } + }; + + // Walk blocks of the value defining region. + for (Block &block : definingRegion->getBlocks()) + eraseCancellableRefCounting(&block); + + // Walk all the nested operations in the same region. + definingRegion->walk(eraseCancellableRefCounting); + + return success(); +} + +void AsyncRefCountingPass::runOnFunction() { + FuncOp func = getFunction(); + + // Check that we do not have explicit `add_ref` or `drop_ref` in the IR + // because otherwise automatic reference counting will produce incorrect + // results. + WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult { + if (isa(op)) + return op->emitError() << "explicit reference counting is not supported"; + return WalkResult::advance(); + }); + + if (refCountingWalk.wasInterrupted()) + signalPassFailure(); + + // Add reference counting to block arguments. + WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { + for (BlockArgument arg : block->getArguments()) + if (isRefCounted(arg.getType())) + if (failed(addAutomaticRefCounting(arg))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (blockWalk.wasInterrupted()) + signalPassFailure(); + + // Add reference counting to operation results. + WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { + for (unsigned i = 0; i < op->getNumResults(); ++i) + if (isRefCounted(op->getResultTypes()[i])) + if (failed(addAutomaticRefCounting(op->getResult(i)))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (opWalk.wasInterrupted()) + signalPassFailure(); +} + +std::unique_ptr> mlir::createAsyncRefCountingPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRAsyncTransforms AsyncParallelFor.cpp + AsyncRefCounting.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -16,6 +16,7 @@ #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS #include +#include #include #include #include @@ -27,30 +28,141 @@ // Async runtime API. //===----------------------------------------------------------------------===// -struct AsyncToken { - bool ready = false; +namespace { + +// Forward declare class defined below. +class RefCounted; + +// -------------------------------------------------------------------------- // +// AsyncRuntime orchestrates all async operations and Async runtime API is built +// on top of the default runtime instance. +// -------------------------------------------------------------------------- // + +class AsyncRuntime { +public: + AsyncRuntime() : numRefCountedObjects(0) {} + + ~AsyncRuntime() { + assert(getNumRefCountedObjects() == 0 && + "all ref counted objects must be destroyed"); + } + + int32_t getNumRefCountedObjects() { + return numRefCountedObjects.load(std::memory_order_relaxed); + } + +private: + friend class RefCounted; + + // Count the total number of reference counted objects in this instance + // of an AsyncRuntime. For debugging purposes only. + void addNumRefCountedObjects() { + numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); + } + void dropNumRefCountedObjects() { + numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); + } + + std::atomic numRefCountedObjects; +}; + +// Returns the default per-process instance of an async runtime. +AsyncRuntime *getDefaultAsyncRuntimeInstance() { + static auto runtime = std::make_unique(); + return runtime.get(); +} + +// -------------------------------------------------------------------------- // +// A base class for all reference counted objects created by the async runtime. +// -------------------------------------------------------------------------- // + +class RefCounted { +public: + RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) + : runtime(runtime), refCount(refCount) { + runtime->addNumRefCountedObjects(); + } + + virtual ~RefCounted() { + assert(refCount.load() == 0 && "reference count must be zero"); + runtime->dropNumRefCountedObjects(); + } + + RefCounted(const RefCounted &) = delete; + RefCounted &operator=(const RefCounted &) = delete; + + void addRef(int32_t count = 1) { refCount.fetch_add(count); } + + void dropRef(int32_t count = 1) { + int32_t previous = refCount.fetch_sub(count); + assert(previous >= count && "reference count should not go below zero"); + if (previous == count) + destroy(); + } + +protected: + virtual void destroy() { delete this; } + +private: + AsyncRuntime *runtime; + std::atomic refCount; +}; + +} // namespace + +struct AsyncToken : public RefCounted { + // AsyncToken created with a reference count of 2 because it will be returned + // to the `async.execute` caller and also will be later on emplaced by the + // asynchronously executed task. If the caller immediately will drop its + // reference we must ensure that the token will be alive until the + // asynchronous operation is completed. + AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {} + + // Internal state below guarded by a mutex. std::mutex mu; std::condition_variable cv; + + bool ready = false; std::vector> awaiters; }; -struct AsyncGroup { - std::atomic pendingTokens{0}; - std::atomic rank{0}; +struct AsyncGroup : public RefCounted { + AsyncGroup(AsyncRuntime *runtime) + : RefCounted(runtime), pendingTokens(0), rank(0) {} + + std::atomic pendingTokens; + std::atomic rank; + + // Internal state below guarded by a mutex. std::mutex mu; std::condition_variable cv; + std::vector> awaiters; }; +// Adds references to reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { + RefCounted *refCounted = static_cast(ptr); + refCounted->addRef(count); +} + +// Drops references from reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { + RefCounted *refCounted = static_cast(ptr); + refCounted->dropRef(count); +} + // Create a new `async.token` in not-ready state. extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { - AsyncToken *token = new AsyncToken; + AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance()); return token; } // Create a new `async.group` in empty state. extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() { - AsyncGroup *group = new AsyncGroup; + AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); return group; } @@ -59,23 +171,34 @@ std::unique_lock lockToken(token->mu); std::unique_lock lockGroup(group->mu); + // Get the rank of the token inside the group before we drop the reference. + int rank = group->rank.fetch_add(1); group->pendingTokens.fetch_add(1); - auto onTokenReady = [group]() { + auto onTokenReady = [group, token](bool dropRef) { // Run all group awaiters if it was the last token in the group. if (group->pendingTokens.fetch_sub(1) == 1) { group->cv.notify_all(); for (auto &awaiter : group->awaiters) awaiter(); } + + // We no longer need the token or the group, drop references on them. + if (dropRef) { + group->dropRef(); + token->dropRef(); + } }; - if (token->ready) - onTokenReady(); - else - token->awaiters.push_back([onTokenReady]() { onTokenReady(); }); + if (token->ready) { + onTokenReady(false); + } else { + group->addRef(); + token->addRef(); + token->awaiters.push_back([onTokenReady]() { onTokenReady(true); }); + } - return group->rank.fetch_add(1); + return rank; } // Switches `async.token` to ready state and runs all awaiters. @@ -85,6 +208,10 @@ token->cv.notify_all(); for (auto &awaiter : token->awaiters) awaiter(); + + // Async tokens created with a ref count `2` to keep token alive until the + // async task completes. Drop this reference explicitly when token emplaced. + token->dropRef(); } extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { @@ -114,14 +241,18 @@ CoroResume resume) { std::unique_lock lock(token->mu); - auto execute = [handle, resume]() { + auto execute = [handle, resume, token](bool dropRef) { + if (dropRef) + token->dropRef(); mlirAsyncRuntimeExecute(handle, resume); }; - if (token->ready) - execute(); - else - token->awaiters.push_back([execute]() { execute(); }); + if (token->ready) { + execute(false); + } else { + token->addRef(); + token->awaiters.push_back([execute]() { execute(true); }); + } } extern "C" MLIR_ASYNCRUNTIME_EXPORT void @@ -129,14 +260,18 @@ CoroResume resume) { std::unique_lock lock(group->mu); - auto execute = [handle, resume]() { + auto execute = [handle, resume, group](bool dropRef) { + if (dropRef) + group->dropRef(); mlirAsyncRuntimeExecute(handle, resume); }; - if (group->pendingTokens == 0) - execute(); - else - group->awaiters.push_back([execute]() { execute(); }); + if (group->pendingTokens == 0) { + execute(false); + } else { + group->addRef(); + group->awaiters.push_back([execute]() { execute(true); }); + } } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -1,5 +1,20 @@ // RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s +// CHECK-LABEL: reference_counting +func @reference_counting(%arg0: !async.token) { + // CHECK: %[[C2:.*]] = constant 2 : i32 + // CHECK: call @mlirAsyncRuntimeAddRef(%arg0, %[[C2]]) + async.add_ref %arg0 {count = 2 : i32} : !async.token + + // CHECK: %[[C1:.*]] = constant 1 : i32 + // CHECK: call @mlirAsyncRuntimeDropRef(%arg0, %[[C1]]) + async.drop_ref %arg0 {count = 1 : i32} : !async.token + + return +} + +// ----- + // CHECK-LABEL: execute_no_async_args func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) { // CHECK: %[[TOKEN:.*]] = call @async_execute_fn(%arg0, %arg1) diff --git a/mlir/test/Dialect/Async/async-ref-counting.mlir b/mlir/test/Dialect/Async/async-ref-counting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-ref-counting.mlir @@ -0,0 +1,242 @@ +// RUN: mlir-opt %s -async-ref-counting | FileCheck %s + +// CHECK-LABEL: @cond +func @cond() -> i1 + +// CHECK-LABEL: @token_arg_no_uses +func @token_arg_no_uses(%arg0: !async.token) { + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + return +} + +// CHECK-LABEL: @token_arg_conditional_await +func @token_arg_conditional_await(%arg0: !async.token, %arg1: i1) { + cond_br %arg1, ^bb1, ^bb2 +^bb1: + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + return +^bb2: + // CHECK: async.await %arg0 + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + async.await %arg0 : !async.token + return +} + +// CHECK-LABEL: @token_no_uses +func @token_no_uses() { + // CHECK: %[[TOKEN:.*]] = async.execute + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + %token = async.execute { + async.yield + } + return +} + +// CHECK-LABEL: @token_return +func @token_return() -> !async.token { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: return %[[TOKEN]] + return %token : !async.token +} + +// CHECK-LABEL: @token_await +func @token_await() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: async.await %[[TOKEN]] + async.await %token : !async.token + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_await_and_return +func @token_await_and_return() -> !async.token { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: async.await %[[TOKEN]] + // CHECK-NOT: async.drop_ref + async.await %token : !async.token + // CHECK: return %[[TOKEN]] + return %token : !async.token +} + +// CHECK-LABEL: @token_await_inside_scf_if +func @token_await_inside_scf_if(%arg0: i1) { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: scf.if %arg0 { + scf.if %arg0 { + // CHECK: async.await %[[TOKEN]] + async.await %token : !async.token + } + // CHECK: } + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_conditional_await +func @token_conditional_await(%arg0: i1) { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + cond_br %arg0, ^bb1, ^bb2 +^bb1: + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + return +^bb2: + // CHECK: async.await %[[TOKEN]] + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + return +} + +// CHECK-LABEL: @token_await_in_the_loop +func @token_await_in_the_loop() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + br ^bb1 +^bb1: + // CHECK: async.await %[[TOKEN]] + async.await %token : !async.token + %0 = call @cond(): () -> (i1) + cond_br %0, ^bb1, ^bb2 +^bb2: + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + return +} + +// CHECK-LABEL: @token_defined_in_the_loop +func @token_defined_in_the_loop() { + br ^bb1 +^bb1: + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: async.await %[[TOKEN]] + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + %0 = call @cond(): () -> (i1) + cond_br %0, ^bb1, ^bb2 +^bb2: + return +} + +// CHECK-LABEL: @token_capture +func @token_capture() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute { + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK-NEXT: async.yield + async.await %token : !async.token + async.yield + } + // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_nested_capture +func @token_nested_capture() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute { + // CHECK-NEXT: %[[TOKEN_1:.*]] = async.execute + %token_1 = async.execute { + // CHECK-NEXT: %[[TOKEN_2:.*]] = async.execute + %token_2 = async.execute { + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + async.yield + } + // CHECK: async.drop_ref %[[TOKEN_2]] {count = 1 : i32} + async.yield + } + // CHECK: async.drop_ref %[[TOKEN_1]] {count = 1 : i32} + async.yield + } + // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_dependency +func @token_dependency() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute[%token] { + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK-NEXT: async.yield + async.yield + } + + // CHECK: async.await %[[TOKEN]] + // CHECK: async.await %[[TOKEN_0]] + // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + async.await %token : !async.token + async.await %token_0 : !async.token + + // CHECK: return + return +} + +// CHECK-LABEL: @value_operand +func @value_operand() -> f32 { + // CHECK: %[[TOKEN:.*]], %[[RESULTS:.*]] = async.execute + %token, %results = async.execute -> !async.value { + %0 = constant 0.0 : f32 + async.yield %0 : f32 + } + + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.add_ref %[[RESULTS]] {count = 1 : i32} + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute[%token](%results as %arg0 : !async.value) { + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32} + // CHECK: async.yield + async.yield + } + + // CHECK: async.await %[[TOKEN]] + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + + // CHECK: async.await %[[TOKEN_0]] + // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + async.await %token_0 : !async.token + + // CHECK: async.await %[[RESULTS]] + // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32} + %0 = async.await %results : !async.value + + // CHECK: return + return %0 : f32 +} diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -134,3 +134,17 @@ %3 = addi %1, %2 : index return %3 : index } + +// CHECK-LABEL: @add_ref +func @add_ref(%arg0: !async.token) { + // CHECK: async.add_ref %arg0 {count = 1 : i32} + async.add_ref %arg0 {count = 1 : i32} : !async.token + return +} + +// CHECK-LABEL: @drop_ref +func @drop_ref(%arg0: !async.token) { + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + async.drop_ref %arg0 {count = 1 : i32} : !async.token + return +} diff --git a/mlir/test/Dialect/Async/verify.mlir b/mlir/test/Dialect/Async/verify.mlir --- a/mlir/test/Dialect/Async/verify.mlir +++ b/mlir/test/Dialect/Async/verify.mlir @@ -19,3 +19,17 @@ // expected-error @+1 {{'async.await' op result type 'f64' does not match async value type 'f32'}} %0 = "async.await"(%arg0): (!async.value) -> f64 } + +// ----- + +func @wrong_add_ref_count(%arg0: !async.token) { + // expected-error @+1 {{'async.add_ref' op attribute 'count' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} + async.add_ref %arg0 {count = 0 : i32} : !async.token +} + +// ----- + +func @wrong_drop_ref_count(%arg0: !async.token) { + // expected-error @+1 {{'async.drop_ref' op attribute 'count' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} + async.drop_ref %arg0 {count = 0 : i32} : !async.token +} diff --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir --- a/mlir/test/mlir-cpu-runner/async-group.mlir +++ b/mlir/test/mlir-cpu-runner/async-group.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -convert-async-to-llvm \ +// RUN: mlir-opt %s -async-ref-counting \ +// RUN: -convert-async-to-llvm \ // RUN: -convert-std-to-llvm \ // RUN: | mlir-cpu-runner \ // RUN: -e main -entry-point-result=void -O0 \ diff --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir --- a/mlir/test/mlir-cpu-runner/async.mlir +++ b/mlir/test/mlir-cpu-runner/async.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -convert-async-to-llvm \ +// RUN: mlir-opt %s -async-ref-counting \ +// RUN: -convert-async-to-llvm \ // RUN: -convert-linalg-to-loops \ // RUN: -convert-linalg-to-llvm \ // RUN: -convert-std-to-llvm \