diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td @@ -73,4 +73,8 @@ def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType, Async_TokenType]>; +def Async_AnyAsyncType : AnyTypeOf<[Async_AnyValueType, + Async_TokenType, + Async_GroupType]>; + #endif // ASYNC_BASE_TD diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -227,4 +227,62 @@ let assemblyFormat = "$operand attr-dict"; } +//===----------------------------------------------------------------------===// +// Async Dialect Automatic Reference Counting Operations. +//===----------------------------------------------------------------------===// + +// All async values (values, tokens, groups) are reference counted at runtime +// and automatically destructed when reference count drops to 0. +// +// All values are semantically created with a reference count of +1 and it is +// the responsibility of the last async value user to drop reference count. +// +// Async values created when: +// 1. Operation returns async result (e.g. the result of an `async.execute`). +// 2. Async value passed in as a block argument. +// +// It is the responsiblity of the async value user to extend the lifetime by +// adding a +1 reference, if the reference counted value captured by the +// asynchronously executed region (`async.execute` operation), and drop it after +// the last nested use. +// +// Reference counting operations could be added to the IR using automatic +// reference count pass, that relies on liveness analysis to find the last uses +// of all reference counted values and automatically inserts +// `drop_ref` operations. +// +// See `AsyncRefCountingPass` documentation for the implementation details. + +def Async_AddRefOp : Async_Op<"add_ref"> { + let summary = "adds a reference to async value"; + let description = [{ + The `async.add_ref` operation adds a reference(s) to async value (token, + value or group). + }]; + + let arguments = (ins Async_AnyAsyncType:$operand, + Confined:$count); + let results = (outs ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) + }]; +} + +def Async_DropRefOp : Async_Op<"drop_ref"> { + let summary = "drops a reference to async value"; + let description = [{ + The `async.drop_ref` operation drops a reference(s) to async value (token, + value or group). + }]; + + let arguments = (ins Async_AnyAsyncType:$operand, + Confined:$count); + let results = (outs ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) + }]; +} + #endif // ASYNC_OPS diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -19,6 +19,10 @@ std::unique_ptr> createAsyncParallelForPass(); +std::unique_ptr> createAsyncRefCountingPass(); + +std::unique_ptr> createAsyncRefCountingOptimizationPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -24,4 +24,18 @@ let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"]; } +def AsyncRefCounting : FunctionPass<"async-ref-counting"> { + let summary = "Automatic reference counting for Async dialect data types"; + let constructor = "mlir::createAsyncRefCountingPass()"; + let dependentDialects = ["async::AsyncDialect"]; +} + +def AsyncRefCountingOptimization : + FunctionPass<"async-ref-counting-optimization"> { + let summary = "Optimize automatic reference counting operations for the" + "Async dialect by removing redundant operations"; + let constructor = "mlir::createAsyncRefCountingOptimizationPass()"; + let dependentDialects = ["async::AsyncDialect"]; +} + #endif // MLIR_DIALECT_ASYNC_PASSES diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -48,6 +48,18 @@ using CoroHandle = void *; // coroutine handle using CoroResume = void (*)(void *); // coroutine resume function +// Async runtime uses reference counting to manage the lifetime of async values +// (values of async types like tokens, values and groups). +using RefCountedObjPtr = void *; + +// Adds references to reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void + mlirAsyncRuntimeAddRef(RefCountedObjPtr, int32_t); + +// Drops references from reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void + mlirAsyncRuntimeDropRef(RefCountedObjPtr, int32_t); + // Create a new `async.token` in not-ready state. extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken(); diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir --- a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir +++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-ref-counting \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir --- a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir +++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-ref-counting \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -33,6 +33,8 @@ // Async Runtime C API declaration. //===----------------------------------------------------------------------===// +static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; +static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; @@ -49,6 +51,12 @@ namespace { // Async Runtime API function types. struct AsyncAPI { + static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { + auto ref = LLVM::LLVMType::getInt8PtrTy(ctx); + auto count = IntegerType::get(32, ctx); + return FunctionType::get({ref, count}, {}, ctx); + } + static FunctionType createTokenFunctionType(MLIRContext *ctx) { return FunctionType::get({}, {TokenType::get(ctx)}, ctx); } @@ -113,6 +121,8 @@ }; MLIRContext *ctx = module.getContext(); + addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); + addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); @@ -587,6 +597,55 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Async reference counting ops lowering (`async.add_ref` and `async.drop_ref` +// to the corresponding API calls). +//===----------------------------------------------------------------------===// + +namespace { + +template +class RefCountingOpLowering : public ConversionPattern { +public: + explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName) + : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx), + apiFunctionName(apiFunctionName) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RefCountingOp refCountingOp = cast(op); + + auto count = rewriter.create( + op->getLoc(), rewriter.getI32Type(), + rewriter.getI32IntegerAttr(refCountingOp.count())); + + rewriter.replaceOpWithNewOp(op, Type(), apiFunctionName, + ValueRange({operands[0], count})); + + return success(); + } + +private: + StringRef apiFunctionName; +}; + +// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. +class AddRefOpLowering : public RefCountingOpLowering { +public: + explicit AddRefOpLowering(MLIRContext *ctx) + : RefCountingOpLowering(ctx, kAddRef) {} +}; + +// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. +class DropRefOpLowering : public RefCountingOpLowering { +public: + explicit DropRefOpLowering(MLIRContext *ctx) + : RefCountingOpLowering(ctx, kDropRef) {} +}; + +} // namespace + //===----------------------------------------------------------------------===// // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. //===----------------------------------------------------------------------===// @@ -694,8 +753,8 @@ // Call async runtime API to resume a coroutine in the managed thread when // the async await argument becomes ready. - SmallVector awaitAndExecuteArgs = { - await.getOperand(), coro.coroHandle, resumePtr.res()}; + SmallVector awaitAndExecuteArgs = {operands[0], coro.coroHandle, + resumePtr.res()}; builder.create(loc, Type(), coroAwaitFuncName, awaitAndExecuteArgs); @@ -793,10 +852,12 @@ populateFuncOpTypeConversionPattern(patterns, ctx, converter); patterns.insert(ctx); + patterns.insert(ctx); patterns.insert(ctx); patterns.insert(ctx, outlinedFunctions); ConversionTarget target(*ctx); + target.addLegalOp(); target.addLegalDialect(); target.addIllegalDialect(); target.addDynamicallyLegalOp( diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp @@ -0,0 +1,330 @@ +//===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements automatic reference counting for Async dialect data +// types. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; +using namespace mlir::async; + +#define DEBUG_TYPE "async-ref-counting" + +namespace { + +class AsyncRefCountingPass : public AsyncRefCountingBase { +public: + AsyncRefCountingPass() = default; + void runOnFunction() override; + +private: + /// Adds an automatic reference counting to the `value`. + /// + /// All values are semantically created with a reference count of +1 and it is + /// the responsibility of the last async value user to drop reference count. + /// + /// Async values created when: + /// 1. Operation returns async result (e.g. the result of an + /// `async.execute`). + /// 2. Async value passed in as a block argument. + /// + /// It is the responsiblity of the async value user to extend the lifetime by + /// adding a +1 reference, if the reference counted value captured by the + /// asynchronously executed region (`async.execute` operation), and drop it + /// after the last nested use. + /// + /// Automatic reference counting algorithm outline: + /// + /// 1. `ReturnLike` operations forward the reference counted values without + /// modifying the reference count. + /// + /// 2. Use liveness analysis to find blocks in the CFG where the lifetime of + /// reference counted values ends, and insert `drop_ref` operations after + /// the last use of the value. + /// + /// 3. Insert `add_ref` before the `async.execute` operation capturing the + /// value, and pairing `drop_ref` before the async body region terminator, + /// to release the captured reference counted value when execution + /// completes. + /// + /// 4. If the reference counted value is passed only to some of the block + /// successors, insert `drop_ref` operations in the beginning of the blocks + /// that do not have reference counted value uses. + /// + /// + /// Example: + /// + /// %token = ... + /// async.execute { + /// async.await %token : !async.token // await #1 + /// async.yied + /// } + /// async.await %token : !async.token // await #2 + /// + /// Based on the liveness analysis await #2 is the last use of the %token, + /// however the execution of the async region can be delayed, and to guarantee + /// that the %token is still alive when await #1 executes we need to + /// explicitly extend its lifetime using `add_ref` operation. + /// + /// After automatic reference counting: + /// + /// %token = ... + /// + /// // Make sure that %token is alive inside async.execute. + /// async.add_ref %token {count = 1 : i32} : !async.token + /// + /// async.execute { + /// async.await %token : !async.token // await #1 + /// + /// // Drop the extra reference added to keep %token alive. + /// async.drop_ref %token {count = 1 : i32} : !async.token + /// + /// async.yied + /// } + /// async.await %token : !async.token // await #2 + /// + /// // Drop the reference after the last use of %token. + /// async.drop_ref %token {count = 1 : i32} : !async.token + /// + LogicalResult addAutomaticRefCounting(Value value); +}; + +} // namespace + +/// Returns true if the type is reference counted. All async dialect types are +/// reference counted at runtime. +static bool isRefCounted(Type type) { + return type.isa(); +} + +LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) { + MLIRContext *ctx = value.getContext(); + OpBuilder builder(ctx); + + // Set inserton point after the operation producing a value, or at the + // beginning of the block if the value defined by the block argument. + if (Operation *op = value.getDefiningOp()) + builder.setInsertionPointAfter(op); + else + builder.setInsertionPointToStart(value.getParentBlock()); + + Location loc = value.getLoc(); + auto i32 = IntegerType::get(32, ctx); + + // Drop the reference count immediately if the value has no uses. + if (value.getUses().empty()) { + builder.create(loc, value, IntegerAttr::get(i32, 1)); + return success(); + } + + // Use liveness analysis to find the placement of `drop_ref`operation. + auto liveness = getAnalysis(); + + // We analyse only the blocks of the region that defines the `value`, and do + // not check nested blocks attached to operations. + // + // By analyzing only the `definingRegion` CFG we potentially loose an + // opportunity to drop the reference count earlier and can extend the lifetime + // of reference counted value longer then it is really required. + // + // We also assume that all nested regions finish their execution before the + // completion of the owner operation. The only exception to this rule is + // `async.execute` operation, which is handled explicitly below. + Region *definingRegion = value.getParentRegion(); + + // ------------------------------------------------------------------------ // + // Find blocks where the `value` dies: the value is in `liveIn` set and not + // in the `liveOut` set. We place `drop_ref` immediately after the last use + // of the `value` in such regions. + // ------------------------------------------------------------------------ // + + // Last users of the `value` inside all blocks where the value dies. + llvm::SmallSet lastUsers; + + for (Block &block : definingRegion->getBlocks()) { + const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); + + // Value in live input set or was defined in the block. + bool liveIn = blockLiveness->isLiveIn(value) || + blockLiveness->getBlock() == value.getParentBlock(); + if (!liveIn) + continue; + + // Value is in the live out set. + bool liveOut = blockLiveness->isLiveOut(value); + if (liveOut) + continue; + + // We proved that `value` dies in the `block`. Now find the last use of the + // `value` inside the `block`. + + // Find any user of the `value` inside the block (including uses in nested + // regions attached to the operations in the block). + Operation *userInTheBlock = nullptr; + for (Operation *user : value.getUsers()) { + userInTheBlock = block.findAncestorOpInBlock(*user); + if (userInTheBlock) + break; + } + + // Values with zero users handled explicitly in the beginning, if the value + // is in live out set it must have at least one use in the block. + assert(userInTheBlock && "value must have a user in the block"); + + // Find the last user of the `value` in the block; + Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); + assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); + lastUsers.insert(lastUser); + } + + // Process all the last users of the `value` inside each block where the value + // dies. + for (Operation *lastUser : lastUsers) { + // Return like operations forward reference count. + if (lastUser->hasTrait()) + continue; + + // We can't currently handle other types of terminators. + if (lastUser->hasTrait()) + return lastUser->emitError() << "async reference counting can't handle " + "terminators that are not ReturnLike"; + + // Add a drop_ref immediately after the last user. + builder.setInsertionPointAfter(lastUser); + builder.create(loc, value, IntegerAttr::get(i32, 1)); + } + + // ------------------------------------------------------------------------ // + // Find blocks where the `value` is in `liveOut` set, however it is not in + // the `liveIn` set of all successors. If the `value` is not in the successor + // `liveIn` set, we add a `drop_ref` to the beginning of it. + // ------------------------------------------------------------------------ // + + // Successors that we'll need a `drop_ref` for the `value`. + llvm::SmallSet dropRefSuccessors; + + for (Block &block : definingRegion->getBlocks()) { + const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); + + // Skip the block if value is not in the `liveOut` set. + if (!blockLiveness->isLiveOut(value)) + continue; + + // Find successors that do not have `value` in the `liveIn` set. + for (Block *successor : block.getSuccessors()) { + const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); + + if (!succLiveness->isLiveIn(value)) + dropRefSuccessors.insert(successor); + } + } + + // Drop reference in all successor blocks that do not have the `value` in + // their `liveIn` set. + for (Block *dropRefSuccessor : dropRefSuccessors) { + builder.setInsertionPointToStart(dropRefSuccessor); + builder.create(loc, value, IntegerAttr::get(i32, 1)); + } + + // ------------------------------------------------------------------------ // + // Find all `async.execute` operation that take `value` as an operand + // (dependency token or async value), or capture implicitly by the nested + // region. Each `async.execute` operation will require `add_ref` operation + // to keep all captured values alive until it will finish its execution. + // ------------------------------------------------------------------------ // + + llvm::SmallSet executeOperations; + + auto trackAsyncExecute = [&](Operation *op) { + if (auto execute = dyn_cast(op)) + executeOperations.insert(execute); + }; + + for (Operation *user : value.getUsers()) { + // Follow parent operations up until the operation in the `definingRegion`. + while (user->getParentRegion() != definingRegion) { + trackAsyncExecute(user); + user = user->getParentOp(); + assert(user != nullptr && "value user lies outside of the value region"); + } + + // Don't forget to process the parent in the `definingRegion` (can be the + // original user operation itself). + trackAsyncExecute(user); + } + + // Process all `async.execute` operations capturing `value`. + for (ExecuteOp execute : executeOperations) { + // Add a reference before the execute operation to keep the reference + // counted alive before the async region completes execution. + builder.setInsertionPoint(execute.getOperation()); + builder.create(loc, value, IntegerAttr::get(i32, 1)); + + // Drop the reference inside the async region before completion. + OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody()); + executeBuilder.create(loc, value, IntegerAttr::get(i32, 1)); + } + + return success(); +} + +void AsyncRefCountingPass::runOnFunction() { + FuncOp func = getFunction(); + + // Check that we do not have explicit `add_ref` or `drop_ref` in the IR + // because otherwise automatic reference counting will produce incorrect + // results. + WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult { + if (isa(op)) + return op->emitError() << "explicit reference counting is not supported"; + return WalkResult::advance(); + }); + + if (refCountingWalk.wasInterrupted()) + signalPassFailure(); + + // Add reference counting to block arguments. + WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { + for (BlockArgument arg : block->getArguments()) + if (isRefCounted(arg.getType())) + if (failed(addAutomaticRefCounting(arg))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (blockWalk.wasInterrupted()) + signalPassFailure(); + + // Add reference counting to operation results. + WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { + for (unsigned i = 0; i < op->getNumResults(); ++i) + if (isRefCounted(op->getResultTypes()[i])) + if (failed(addAutomaticRefCounting(op->getResult(i)))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (opWalk.wasInterrupted()) + signalPassFailure(); +} + +std::unique_ptr> mlir::createAsyncRefCountingPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp @@ -0,0 +1,224 @@ +//===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Optimize Async dialect reference counting operations. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Async/Passes.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; +using namespace mlir::async; + +#define DEBUG_TYPE "async-ref-counting" + +namespace { + +class AsyncRefCountingOptimizationPass + : public AsyncRefCountingOptimizationBase< + AsyncRefCountingOptimizationPass> { +public: + AsyncRefCountingOptimizationPass() = default; + void runOnFunction() override; + +private: + LogicalResult optimizeReferenceCounting(Value value); +}; + +} // namespace + +/// Returns true if the type is reference counted. All async dialect types are +/// reference counted at runtime. +static bool isRefCounted(Type type) { + return type.isa(); +} + +LogicalResult +AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) { + Region *definingRegion = value.getParentRegion(); + + // Find all users of the `value` inside each block, including operations that + // do not use `value` directly, but have a direct use inside nested region(s). + // + // Example: + // + // ^bb1: + // %token = ... + // scf.if %cond { + // ^bb2: + // async.await %token : !async.token + // } + // + // %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`). + // + // In addition to the operation that uses the `value` we also keep track if + // this user is an `async.execute` operation itself, or has `async.execute` + // operations in the nested regions that do use the `value`. + + struct ValueUser { + Operation *operation; + bool hasExecuteUser; + }; + + struct BlockUsersInfo { + llvm::SmallVector addRefs; + llvm::SmallVector dropRefs; + llvm::SmallVector users; + }; + + llvm::DenseMap blockUsers; + + auto updateBlockUsersInfo = [&](ValueUser user) { + BlockUsersInfo &info = blockUsers[user.operation->getBlock()]; + info.users.push_back(user); + + if (auto addRef = dyn_cast(user.operation)) + info.addRefs.push_back(addRef); + if (auto dropRef = dyn_cast(user.operation)) + info.dropRefs.push_back(dropRef); + }; + + for (Operation *user : value.getUsers()) { + bool isAsyncUser = isa(user); + + while (user->getParentRegion() != definingRegion) { + updateBlockUsersInfo({user, isAsyncUser}); + user = user->getParentOp(); + isAsyncUser |= isa(user); + assert(user != nullptr && "value user lies outside of the value region"); + } + + updateBlockUsersInfo({user, isAsyncUser}); + } + + // Sort all operations found in the block. + auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & { + auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool { + return a->isBeforeInBlock(b); + }; + llvm::sort(info.addRefs, isBeforeInBlock); + llvm::sort(info.dropRefs, isBeforeInBlock); + llvm::sort(info.users, [&](ValueUser a, ValueUser b) -> bool { + return isBeforeInBlock(a.operation, b.operation); + }); + + return info; + }; + + // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the + // blocks that modify the reference count of the `value`. + for (auto &kv : blockUsers) { + BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second); + + // Find all cancellable pairs first and erase them later to keep all + // pointers in the `info` valid until the end. + // + // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. + llvm::SmallDenseMap cancellable; + + for (AddRefOp addRef : info.addRefs) { + for (DropRefOp dropRef : info.dropRefs) { + // `drop_ref` operation after the `add_ref` with matching count. + if (dropRef.count() != addRef.count() || + dropRef.getOperation()->isBeforeInBlock(addRef.getOperation())) + continue; + + // `drop_ref` was already marked for removal. + if (cancellable.find(dropRef.getOperation()) != cancellable.end()) + continue; + + // Check `value` users between `addRef` and `dropRef` in the `block`. + Operation *addRefOp = addRef.getOperation(); + Operation *dropRefOp = dropRef.getOperation(); + + // If there is a "regular" user after the `async.execute` user it is + // unsafe to erase cancellable reference counting operations pair, + // because async region can complete before the "regular" user and + // destroy the reference counted value. + bool hasExecuteUser = false; + bool unsafeToCancel = false; + + for (ValueUser &user : info.users) { + Operation *op = user.operation; + + // `user` operation lies after `addRef` ... + if (op == addRefOp || op->isBeforeInBlock(addRefOp)) + continue; + // ... and before `dropRef`. + if (op == dropRefOp || dropRefOp->isBeforeInBlock(op)) + break; + + bool isRegularUser = !user.hasExecuteUser; + bool isExecuteUser = user.hasExecuteUser; + + // It is unsafe to cancel `addRef` / `dropRef` pair. + if (isRegularUser && hasExecuteUser) { + unsafeToCancel = true; + break; + } + + hasExecuteUser |= isExecuteUser; + } + + // Mark the pair of reference counting operations for removal. + if (!unsafeToCancel) + cancellable[dropRef.getOperation()] = addRef.getOperation(); + + // If it us unsafe to cancel `addRef <-> dropRef` pair at this point, + // all the following pairs will be also unsafe. + break; + } + } + + // Erase all cancellable `addRef <-> dropRef` operation pairs. + for (auto &kv : cancellable) { + kv.first->erase(); + kv.second->erase(); + } + } + + return success(); +} + +void AsyncRefCountingOptimizationPass::runOnFunction() { + FuncOp func = getFunction(); + + // Optimize reference counting for values defined by block arguments. + WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { + for (BlockArgument arg : block->getArguments()) + if (isRefCounted(arg.getType())) + if (failed(optimizeReferenceCounting(arg))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (blockWalk.wasInterrupted()) + signalPassFailure(); + + // Optimize reference counting for values defined by operation results. + WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { + for (unsigned i = 0; i < op->getNumResults(); ++i) + if (isRefCounted(op->getResultTypes()[i])) + if (failed(optimizeReferenceCounting(op->getResult(i)))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (opWalk.wasInterrupted()) + signalPassFailure(); +} + +std::unique_ptr> +mlir::createAsyncRefCountingOptimizationPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms AsyncParallelFor.cpp + AsyncRefCounting.cpp + AsyncRefCountingOptimization.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -16,6 +16,7 @@ #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS #include +#include #include #include #include @@ -27,30 +28,141 @@ // Async runtime API. //===----------------------------------------------------------------------===// -struct AsyncToken { - bool ready = false; +namespace { + +// Forward declare class defined below. +class RefCounted; + +// -------------------------------------------------------------------------- // +// AsyncRuntime orchestrates all async operations and Async runtime API is built +// on top of the default runtime instance. +// -------------------------------------------------------------------------- // + +class AsyncRuntime { +public: + AsyncRuntime() : numRefCountedObjects(0) {} + + ~AsyncRuntime() { + assert(getNumRefCountedObjects() == 0 && + "all ref counted objects must be destroyed"); + } + + int32_t getNumRefCountedObjects() { + return numRefCountedObjects.load(std::memory_order_relaxed); + } + +private: + friend class RefCounted; + + // Count the total number of reference counted objects in this instance + // of an AsyncRuntime. For debugging purposes only. + void addNumRefCountedObjects() { + numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); + } + void dropNumRefCountedObjects() { + numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); + } + + std::atomic numRefCountedObjects; +}; + +// Returns the default per-process instance of an async runtime. +AsyncRuntime *getDefaultAsyncRuntimeInstance() { + static auto runtime = std::make_unique(); + return runtime.get(); +} + +// -------------------------------------------------------------------------- // +// A base class for all reference counted objects created by the async runtime. +// -------------------------------------------------------------------------- // + +class RefCounted { +public: + RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) + : runtime(runtime), refCount(refCount) { + runtime->addNumRefCountedObjects(); + } + + virtual ~RefCounted() { + assert(refCount.load() == 0 && "reference count must be zero"); + runtime->dropNumRefCountedObjects(); + } + + RefCounted(const RefCounted &) = delete; + RefCounted &operator=(const RefCounted &) = delete; + + void addRef(int32_t count = 1) { refCount.fetch_add(count); } + + void dropRef(int32_t count = 1) { + int32_t previous = refCount.fetch_sub(count); + assert(previous >= count && "reference count should not go below zero"); + if (previous == count) + destroy(); + } + +protected: + virtual void destroy() { delete this; } + +private: + AsyncRuntime *runtime; + std::atomic refCount; +}; + +} // namespace + +struct AsyncToken : public RefCounted { + // AsyncToken created with a reference count of 2 because it will be returned + // to the `async.execute` caller and also will be later on emplaced by the + // asynchronously executed task. If the caller immediately will drop its + // reference we must ensure that the token will be alive until the + // asynchronous operation is completed. + AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {} + + // Internal state below guarded by a mutex. std::mutex mu; std::condition_variable cv; + + bool ready = false; std::vector> awaiters; }; -struct AsyncGroup { - std::atomic pendingTokens{0}; - std::atomic rank{0}; +struct AsyncGroup : public RefCounted { + AsyncGroup(AsyncRuntime *runtime) + : RefCounted(runtime), pendingTokens(0), rank(0) {} + + std::atomic pendingTokens; + std::atomic rank; + + // Internal state below guarded by a mutex. std::mutex mu; std::condition_variable cv; + std::vector> awaiters; }; +// Adds references to reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { + RefCounted *refCounted = static_cast(ptr); + refCounted->addRef(count); +} + +// Drops references from reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { + RefCounted *refCounted = static_cast(ptr); + refCounted->dropRef(count); +} + // Create a new `async.token` in not-ready state. extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { - AsyncToken *token = new AsyncToken; + AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance()); return token; } // Create a new `async.group` in empty state. extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() { - AsyncGroup *group = new AsyncGroup; + AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); return group; } @@ -59,23 +171,34 @@ std::unique_lock lockToken(token->mu); std::unique_lock lockGroup(group->mu); + // Get the rank of the token inside the group before we drop the reference. + int rank = group->rank.fetch_add(1); group->pendingTokens.fetch_add(1); - auto onTokenReady = [group]() { + auto onTokenReady = [group, token](bool dropRef) { // Run all group awaiters if it was the last token in the group. if (group->pendingTokens.fetch_sub(1) == 1) { group->cv.notify_all(); for (auto &awaiter : group->awaiters) awaiter(); } + + // We no longer need the token or the group, drop references on them. + if (dropRef) { + group->dropRef(); + token->dropRef(); + } }; - if (token->ready) - onTokenReady(); - else - token->awaiters.push_back([onTokenReady]() { onTokenReady(); }); + if (token->ready) { + onTokenReady(false); + } else { + group->addRef(); + token->addRef(); + token->awaiters.push_back([onTokenReady]() { onTokenReady(true); }); + } - return group->rank.fetch_add(1); + return rank; } // Switches `async.token` to ready state and runs all awaiters. @@ -85,6 +208,10 @@ token->cv.notify_all(); for (auto &awaiter : token->awaiters) awaiter(); + + // Async tokens created with a ref count `2` to keep token alive until the + // async task completes. Drop this reference explicitly when token emplaced. + token->dropRef(); } extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { @@ -114,14 +241,18 @@ CoroResume resume) { std::unique_lock lock(token->mu); - auto execute = [handle, resume]() { + auto execute = [handle, resume, token](bool dropRef) { + if (dropRef) + token->dropRef(); mlirAsyncRuntimeExecute(handle, resume); }; - if (token->ready) - execute(); - else - token->awaiters.push_back([execute]() { execute(); }); + if (token->ready) { + execute(false); + } else { + token->addRef(); + token->awaiters.push_back([execute]() { execute(true); }); + } } extern "C" MLIR_ASYNCRUNTIME_EXPORT void @@ -129,14 +260,18 @@ CoroResume resume) { std::unique_lock lock(group->mu); - auto execute = [handle, resume]() { + auto execute = [handle, resume, group](bool dropRef) { + if (dropRef) + group->dropRef(); mlirAsyncRuntimeExecute(handle, resume); }; - if (group->pendingTokens == 0) - execute(); - else - group->awaiters.push_back([execute]() { execute(); }); + if (group->pendingTokens == 0) { + execute(false); + } else { + group->addRef(); + group->awaiters.push_back([execute]() { execute(true); }); + } } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -1,5 +1,20 @@ // RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s +// CHECK-LABEL: reference_counting +func @reference_counting(%arg0: !async.token) { + // CHECK: %[[C2:.*]] = constant 2 : i32 + // CHECK: call @mlirAsyncRuntimeAddRef(%arg0, %[[C2]]) + async.add_ref %arg0 {count = 2 : i32} : !async.token + + // CHECK: %[[C1:.*]] = constant 1 : i32 + // CHECK: call @mlirAsyncRuntimeDropRef(%arg0, %[[C1]]) + async.drop_ref %arg0 {count = 1 : i32} : !async.token + + return +} + +// ----- + // CHECK-LABEL: execute_no_async_args func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) { // CHECK: %[[TOKEN:.*]] = call @async_execute_fn(%arg0, %arg1) diff --git a/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir b/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir @@ -0,0 +1,90 @@ +// RUN: mlir-opt %s -async-ref-counting-optimization | FileCheck %s + +// CHECK-LABEL: @cancellable_operations_0 +func @cancellable_operations_0(%arg0: !async.token) { + // CHECK-NOT: async.add_ref + // CHECK-NOT: async.drop_ref + async.add_ref %arg0 {count = 1 : i32} : !async.token + async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: return + return +} + +// CHECK-LABEL: @cancellable_operations_1 +func @cancellable_operations_1(%arg0: !async.token) { + // CHECK: async.await + // CHECK-NEXT: async.await + // CHECK-NEXT: async.await + // CHECK-NEXT: return + async.add_ref %arg0 {count = 1 : i32} : !async.token + async.await %arg0 : !async.token + async.drop_ref %arg0 {count = 1 : i32} : !async.token + async.await %arg0 : !async.token + async.add_ref %arg0 {count = 1 : i32} : !async.token + async.await %arg0 : !async.token + async.drop_ref %arg0 {count = 1 : i32} : !async.token + return +} + +// CHECK-LABEL: @cancellable_operations_2 +func @cancellable_operations_2(%arg0: !async.token) { + // CHECK-NOT: add_ref + async.add_ref %arg0 {count = 1 : i32} : !async.token + %token = async.execute { + async.await %arg0 : !async.token + // CHECK: async.drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + async.yield + } + // CHECK-NOT: drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: async.await + async.await %arg0 : !async.token + // CHECK: return + return +} + +// CHECK-LABEL: @not_cancellable_operations_0 +func @not_cancellable_operations_0(%arg0: !async.token, %arg1: i1) { + // It is unsafe to cancel `add_ref` / `drop_ref` pair because it is possible + // that the body of the `async.execute` operation will run before the await + // operation in the function body, and will destroy the `%arg0` token. + // CHECK: add_ref + async.add_ref %arg0 {count = 1 : i32} : !async.token + %token = async.execute { + scf.if %arg1 { + async.await %arg0 : !async.token + // CHECK: async.drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + } + async.yield + } + // CHECK: async.await + async.await %arg0 : !async.token + // CHECK: drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: return + return +} + +// CHECK-LABEL: @not_cancellable_operations_1 +func @not_cancellable_operations_1(%arg0: !async.token, %arg1: i1) { + // Same reason as above, although `async.execute` is inside the nested + // region or "regular" operation. + // CHECK: add_ref + async.add_ref %arg0 {count = 1 : i32} : !async.token + scf.if %arg1 { + %token = async.execute { + async.await %arg0 : !async.token + // CHECK: async.drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + async.yield + } + } + // CHECK: async.await + async.await %arg0 : !async.token + // CHECK: drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: return + return +} diff --git a/mlir/test/Dialect/Async/async-ref-counting.mlir b/mlir/test/Dialect/Async/async-ref-counting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-ref-counting.mlir @@ -0,0 +1,249 @@ +// RUN: mlir-opt %s -async-ref-counting | FileCheck %s + +// CHECK-LABEL: @cond +func @cond() -> i1 + +// CHECK-LABEL: @token_arg_no_uses +func @token_arg_no_uses(%arg0: !async.token) { + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + return +} + +// CHECK-LABEL: @token_arg_conditional_await +func @token_arg_conditional_await(%arg0: !async.token, %arg1: i1) { + cond_br %arg1, ^bb1, ^bb2 +^bb1: + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + return +^bb2: + // CHECK: async.await %arg0 + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + async.await %arg0 : !async.token + return +} + +// CHECK-LABEL: @token_no_uses +func @token_no_uses() { + // CHECK: %[[TOKEN:.*]] = async.execute + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + %token = async.execute { + async.yield + } + return +} + +// CHECK-LABEL: @token_return +func @token_return() -> !async.token { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: return %[[TOKEN]] + return %token : !async.token +} + +// CHECK-LABEL: @token_await +func @token_await() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: async.await %[[TOKEN]] + async.await %token : !async.token + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_await_and_return +func @token_await_and_return() -> !async.token { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: async.await %[[TOKEN]] + // CHECK-NOT: async.drop_ref + async.await %token : !async.token + // CHECK: return %[[TOKEN]] + return %token : !async.token +} + +// CHECK-LABEL: @token_await_inside_scf_if +func @token_await_inside_scf_if(%arg0: i1) { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: scf.if %arg0 { + scf.if %arg0 { + // CHECK: async.await %[[TOKEN]] + async.await %token : !async.token + } + // CHECK: } + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_conditional_await +func @token_conditional_await(%arg0: i1) { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + cond_br %arg0, ^bb1, ^bb2 +^bb1: + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + return +^bb2: + // CHECK: async.await %[[TOKEN]] + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + return +} + +// CHECK-LABEL: @token_await_in_the_loop +func @token_await_in_the_loop() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + br ^bb1 +^bb1: + // CHECK: async.await %[[TOKEN]] + async.await %token : !async.token + %0 = call @cond(): () -> (i1) + cond_br %0, ^bb1, ^bb2 +^bb2: + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + return +} + +// CHECK-LABEL: @token_defined_in_the_loop +func @token_defined_in_the_loop() { + br ^bb1 +^bb1: + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: async.await %[[TOKEN]] + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + %0 = call @cond(): () -> (i1) + cond_br %0, ^bb1, ^bb2 +^bb2: + return +} + +// CHECK-LABEL: @token_capture +func @token_capture() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute { + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK-NEXT: async.yield + async.await %token : !async.token + async.yield + } + // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_nested_capture +func @token_nested_capture() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute { + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: %[[TOKEN_1:.*]] = async.execute + %token_1 = async.execute { + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: %[[TOKEN_2:.*]] = async.execute + %token_2 = async.execute { + // CHECK: async.await %[[TOKEN]] + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + async.yield + } + // CHECK: async.drop_ref %[[TOKEN_2]] {count = 1 : i32} + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.yield + } + // CHECK: async.drop_ref %[[TOKEN_1]] {count = 1 : i32} + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.yield + } + // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_dependency +func @token_dependency() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute[%token] { + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK-NEXT: async.yield + async.yield + } + + // CHECK: async.await %[[TOKEN]] + // CHECK: async.await %[[TOKEN_0]] + // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + async.await %token : !async.token + async.await %token_0 : !async.token + + // CHECK: return + return +} + +// CHECK-LABEL: @value_operand +func @value_operand() -> f32 { + // CHECK: %[[TOKEN:.*]], %[[RESULTS:.*]] = async.execute + %token, %results = async.execute -> !async.value { + %0 = constant 0.0 : f32 + async.yield %0 : f32 + } + + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.add_ref %[[RESULTS]] {count = 1 : i32} + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute[%token](%results as %arg0 : !async.value) { + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32} + // CHECK: async.yield + async.yield + } + + // CHECK: async.await %[[TOKEN]] + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + + // CHECK: async.await %[[TOKEN_0]] + // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + async.await %token_0 : !async.token + + // CHECK: async.await %[[RESULTS]] + // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32} + %0 = async.await %results : !async.value + + // CHECK: return + return %0 : f32 +} diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -134,3 +134,17 @@ %3 = addi %1, %2 : index return %3 : index } + +// CHECK-LABEL: @add_ref +func @add_ref(%arg0: !async.token) { + // CHECK: async.add_ref %arg0 {count = 1 : i32} + async.add_ref %arg0 {count = 1 : i32} : !async.token + return +} + +// CHECK-LABEL: @drop_ref +func @drop_ref(%arg0: !async.token) { + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + async.drop_ref %arg0 {count = 1 : i32} : !async.token + return +} diff --git a/mlir/test/Dialect/Async/verify.mlir b/mlir/test/Dialect/Async/verify.mlir --- a/mlir/test/Dialect/Async/verify.mlir +++ b/mlir/test/Dialect/Async/verify.mlir @@ -19,3 +19,17 @@ // expected-error @+1 {{'async.await' op result type 'f64' does not match async value type 'f32'}} %0 = "async.await"(%arg0): (!async.value) -> f64 } + +// ----- + +func @wrong_add_ref_count(%arg0: !async.token) { + // expected-error @+1 {{'async.add_ref' op attribute 'count' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} + async.add_ref %arg0 {count = 0 : i32} : !async.token +} + +// ----- + +func @wrong_drop_ref_count(%arg0: !async.token) { + // expected-error @+1 {{'async.drop_ref' op attribute 'count' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} + async.drop_ref %arg0 {count = 0 : i32} : !async.token +} diff --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir --- a/mlir/test/mlir-cpu-runner/async-group.mlir +++ b/mlir/test/mlir-cpu-runner/async-group.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -convert-async-to-llvm \ +// RUN: mlir-opt %s -async-ref-counting \ +// RUN: -convert-async-to-llvm \ // RUN: -convert-std-to-llvm \ // RUN: | mlir-cpu-runner \ // RUN: -e main -entry-point-result=void -O0 \ diff --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir --- a/mlir/test/mlir-cpu-runner/async.mlir +++ b/mlir/test/mlir-cpu-runner/async.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -convert-async-to-llvm \ +// RUN: mlir-opt %s -async-ref-counting \ +// RUN: -convert-async-to-llvm \ // RUN: -convert-linalg-to-loops \ // RUN: -convert-linalg-to-llvm \ // RUN: -convert-std-to-llvm \