diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -22,11 +22,11 @@ std::unique_ptr> createAsyncParallelForPass(int numWorkerThreads); -std::unique_ptr> createAsyncRefCountingPass(); +std::unique_ptr> createAsyncToAsyncRuntimePass(); -std::unique_ptr> createAsyncRefCountingOptimizationPass(); +std::unique_ptr> createAsyncRuntimeRefCountingPass(); -std::unique_ptr> createAsyncToAsyncRuntimePass(); +std::unique_ptr> createAsyncRuntimeRefCountingOptPass(); //===----------------------------------------------------------------------===// // Registration diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -24,24 +24,35 @@ let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"]; } -def AsyncRefCounting : FunctionPass<"async-ref-counting"> { - let summary = "Automatic reference counting for Async dialect data types"; - let constructor = "mlir::createAsyncRefCountingPass()"; +def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> { + let summary = "Lower high level async operations (e.g. async.execute) to the" + "explicit async.runtime and async.coro operations"; + let constructor = "mlir::createAsyncToAsyncRuntimePass()"; let dependentDialects = ["async::AsyncDialect"]; } -def AsyncRefCountingOptimization : - FunctionPass<"async-ref-counting-optimization"> { - let summary = "Optimize automatic reference counting operations for the" - "Async dialect by removing redundant operations"; - let constructor = "mlir::createAsyncRefCountingOptimizationPass()"; +def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> { + let summary = "Automatic reference counting for Async runtime operations"; + let description = [{ + This pass works at the async runtime abtraction level, after all + `async.execute` and `async.await` operations are lowered to the async + runtime API calls, and async coroutine operations. + + It relies on the LLVM coroutines switched-resume lowering semantics for + the correct placing of the reference counting operations. + + See: https://llvm.org/docs/Coroutines.html#switched-resume-lowering + }]; + + let constructor = "mlir::createAsyncRuntimeRefCountingPass()"; let dependentDialects = ["async::AsyncDialect"]; } -def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> { - let summary = "Lower high level async operations (e.g. async.execute) to the" - "explicit async.rutime and async.coro operations"; - let constructor = "mlir::createAsyncToAsyncRuntimePass()"; +def AsyncRuntimeRefCountingOpt : + FunctionPass<"async-runtime-ref-counting-opt"> { + let summary = "Optimize automatic reference counting operations for the" + "Async runtime by removing redundant operations"; + let constructor = "mlir::createAsyncRuntimeRefCountingOptPass()"; let dependentDialects = ["async::AsyncDialect"]; } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp +++ /dev/null @@ -1,325 +0,0 @@ -//===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements automatic reference counting for Async dialect data -// types. -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" -#include "mlir/Analysis/Liveness.h" -#include "mlir/Dialect/Async/IR/Async.h" -#include "mlir/Dialect/Async/Passes.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/SmallSet.h" - -using namespace mlir; -using namespace mlir::async; - -#define DEBUG_TYPE "async-ref-counting" - -namespace { - -class AsyncRefCountingPass : public AsyncRefCountingBase { -public: - AsyncRefCountingPass() = default; - void runOnFunction() override; - -private: - /// Adds an automatic reference counting to the `value`. - /// - /// All values are semantically created with a reference count of +1 and it is - /// the responsibility of the last async value user to drop reference count. - /// - /// Async values created when: - /// 1. Operation returns async result (e.g. the result of an - /// `async.execute`). - /// 2. Async value passed in as a block argument. - /// - /// To implement automatic reference counting, we must insert a +1 reference - /// before each `async.execute` operation using the value, and drop it after - /// the last use inside the async body region (we currently drop the reference - /// before the `async.yield` terminator). - /// - /// Automatic reference counting algorithm outline: - /// - /// 1. `ReturnLike` operations forward the reference counted values without - /// modifying the reference count. - /// - /// 2. Use liveness analysis to find blocks in the CFG where the lifetime of - /// reference counted values ends, and insert `drop_ref` operations after - /// the last use of the value. - /// - /// 3. Insert `add_ref` before the `async.execute` operation capturing the - /// value, and pairing `drop_ref` before the async body region terminator, - /// to release the captured reference counted value when execution - /// completes. - /// - /// 4. If the reference counted value is passed only to some of the block - /// successors, insert `drop_ref` operations in the beginning of the blocks - /// that do not have reference counted value uses. - /// - /// - /// Example: - /// - /// %token = ... - /// async.execute { - /// async.await %token : !async.token // await #1 - /// async.yield - /// } - /// async.await %token : !async.token // await #2 - /// - /// Based on the liveness analysis await #2 is the last use of the %token, - /// however the execution of the async region can be delayed, and to guarantee - /// that the %token is still alive when await #1 executes we need to - /// explicitly extend its lifetime using `add_ref` operation. - /// - /// After automatic reference counting: - /// - /// %token = ... - /// - /// // Make sure that %token is alive inside async.execute. - /// async.add_ref %token {count = 1 : i32} : !async.token - /// - /// async.execute { - /// async.await %token : !async.token // await #1 - /// - /// // Drop the extra reference added to keep %token alive. - /// async.drop_ref %token {count = 1 : i32} : !async.token - /// - /// async.yied - /// } - /// async.await %token : !async.token // await #2 - /// - /// // Drop the reference after the last use of %token. - /// async.drop_ref %token {count = 1 : i32} : !async.token - /// - LogicalResult addAutomaticRefCounting(Value value); -}; - -} // namespace - -LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) { - MLIRContext *ctx = value.getContext(); - OpBuilder builder(ctx); - - // Set inserton point after the operation producing a value, or at the - // beginning of the block if the value defined by the block argument. - if (Operation *op = value.getDefiningOp()) - builder.setInsertionPointAfter(op); - else - builder.setInsertionPointToStart(value.getParentBlock()); - - Location loc = value.getLoc(); - auto i32 = IntegerType::get(ctx, 32); - - // Drop the reference count immediately if the value has no uses. - if (value.getUses().empty()) { - builder.create(loc, value, IntegerAttr::get(i32, 1)); - return success(); - } - - // Use liveness analysis to find the placement of `drop_ref`operation. - auto liveness = getAnalysis(); - - // We analyse only the blocks of the region that defines the `value`, and do - // not check nested blocks attached to operations. - // - // By analyzing only the `definingRegion` CFG we potentially loose an - // opportunity to drop the reference count earlier and can extend the lifetime - // of reference counted value longer then it is really required. - // - // We also assume that all nested regions finish their execution before the - // completion of the owner operation. The only exception to this rule is - // `async.execute` operation, which is handled explicitly below. - Region *definingRegion = value.getParentRegion(); - - // ------------------------------------------------------------------------ // - // Find blocks where the `value` dies: the value is in `liveIn` set and not - // in the `liveOut` set. We place `drop_ref` immediately after the last use - // of the `value` in such regions. - // ------------------------------------------------------------------------ // - - // Last users of the `value` inside all blocks where the value dies. - llvm::SmallSet lastUsers; - - for (Block &block : definingRegion->getBlocks()) { - const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); - - // Value in live input set or was defined in the block. - bool liveIn = blockLiveness->isLiveIn(value) || - blockLiveness->getBlock() == value.getParentBlock(); - if (!liveIn) - continue; - - // Value is in the live out set. - bool liveOut = blockLiveness->isLiveOut(value); - if (liveOut) - continue; - - // We proved that `value` dies in the `block`. Now find the last use of the - // `value` inside the `block`. - - // Find any user of the `value` inside the block (including uses in nested - // regions attached to the operations in the block). - Operation *userInTheBlock = nullptr; - for (Operation *user : value.getUsers()) { - userInTheBlock = block.findAncestorOpInBlock(*user); - if (userInTheBlock) - break; - } - - // Values with zero users handled explicitly in the beginning, if the value - // is in live out set it must have at least one use in the block. - assert(userInTheBlock && "value must have a user in the block"); - - // Find the last user of the `value` in the block; - Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); - assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); - lastUsers.insert(lastUser); - } - - // Process all the last users of the `value` inside each block where the value - // dies. - for (Operation *lastUser : lastUsers) { - // Return like operations forward reference count. - if (lastUser->hasTrait()) - continue; - - // We can't currently handle other types of terminators. - if (lastUser->hasTrait()) - return lastUser->emitError() << "async reference counting can't handle " - "terminators that are not ReturnLike"; - - // Add a drop_ref immediately after the last user. - builder.setInsertionPointAfter(lastUser); - builder.create(loc, value, IntegerAttr::get(i32, 1)); - } - - // ------------------------------------------------------------------------ // - // Find blocks where the `value` is in `liveOut` set, however it is not in - // the `liveIn` set of all successors. If the `value` is not in the successor - // `liveIn` set, we add a `drop_ref` to the beginning of it. - // ------------------------------------------------------------------------ // - - // Successors that we'll need a `drop_ref` for the `value`. - llvm::SmallSet dropRefSuccessors; - - for (Block &block : definingRegion->getBlocks()) { - const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); - - // Skip the block if value is not in the `liveOut` set. - if (!blockLiveness->isLiveOut(value)) - continue; - - // Find successors that do not have `value` in the `liveIn` set. - for (Block *successor : block.getSuccessors()) { - const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); - - if (!succLiveness->isLiveIn(value)) - dropRefSuccessors.insert(successor); - } - } - - // Drop reference in all successor blocks that do not have the `value` in - // their `liveIn` set. - for (Block *dropRefSuccessor : dropRefSuccessors) { - builder.setInsertionPointToStart(dropRefSuccessor); - builder.create(loc, value, IntegerAttr::get(i32, 1)); - } - - // ------------------------------------------------------------------------ // - // Find all `async.execute` operation that take `value` as an operand - // (dependency token or async value), or capture implicitly by the nested - // region. Each `async.execute` operation will require `add_ref` operation - // to keep all captured values alive until it will finish its execution. - // ------------------------------------------------------------------------ // - - llvm::SmallSet executeOperations; - - auto trackAsyncExecute = [&](Operation *op) { - if (auto execute = dyn_cast(op)) - executeOperations.insert(execute); - }; - - for (Operation *user : value.getUsers()) { - // Follow parent operations up until the operation in the `definingRegion`. - while (user->getParentRegion() != definingRegion) { - trackAsyncExecute(user); - user = user->getParentOp(); - assert(user != nullptr && "value user lies outside of the value region"); - } - - // Don't forget to process the parent in the `definingRegion` (can be the - // original user operation itself). - trackAsyncExecute(user); - } - - // Process all `async.execute` operations capturing `value`. - for (ExecuteOp execute : executeOperations) { - // Add a reference before the execute operation to keep the reference - // counted alive before the async region completes execution. - builder.setInsertionPoint(execute.getOperation()); - builder.create(loc, value, IntegerAttr::get(i32, 1)); - - // Drop the reference inside the async region before completion. - OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody()); - executeBuilder.create(loc, value, - IntegerAttr::get(i32, 1)); - } - - return success(); -} - -void AsyncRefCountingPass::runOnFunction() { - FuncOp func = getFunction(); - - // Check that we do not have explicit `add_ref` or `drop_ref` in the IR - // because otherwise automatic reference counting will produce incorrect - // results. - WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult { - if (isa(op)) - return op->emitError() << "explicit reference counting is not supported"; - return WalkResult::advance(); - }); - - if (refCountingWalk.wasInterrupted()) - signalPassFailure(); - - // Add reference counting to block arguments. - WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { - for (BlockArgument arg : block->getArguments()) - if (isRefCounted(arg.getType())) - if (failed(addAutomaticRefCounting(arg))) - return WalkResult::interrupt(); - - return WalkResult::advance(); - }); - - if (blockWalk.wasInterrupted()) - signalPassFailure(); - - // Add reference counting to operation results. - WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { - for (unsigned i = 0; i < op->getNumResults(); ++i) - if (isRefCounted(op->getResultTypes()[i])) - if (failed(addAutomaticRefCounting(op->getResult(i)))) - return WalkResult::interrupt(); - - return WalkResult::advance(); - }); - - if (opWalk.wasInterrupted()) - signalPassFailure(); -} - -std::unique_ptr> mlir::createAsyncRefCountingPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp +++ /dev/null @@ -1,218 +0,0 @@ -//===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Optimize Async dialect reference counting operations. -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" -#include "mlir/Dialect/Async/IR/Async.h" -#include "mlir/Dialect/Async/Passes.h" -#include "llvm/ADT/SmallSet.h" - -using namespace mlir; -using namespace mlir::async; - -#define DEBUG_TYPE "async-ref-counting" - -namespace { - -class AsyncRefCountingOptimizationPass - : public AsyncRefCountingOptimizationBase< - AsyncRefCountingOptimizationPass> { -public: - AsyncRefCountingOptimizationPass() = default; - void runOnFunction() override; - -private: - LogicalResult optimizeReferenceCounting(Value value); -}; - -} // namespace - -LogicalResult -AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) { - Region *definingRegion = value.getParentRegion(); - - // Find all users of the `value` inside each block, including operations that - // do not use `value` directly, but have a direct use inside nested region(s). - // - // Example: - // - // ^bb1: - // %token = ... - // scf.if %cond { - // ^bb2: - // async.await %token : !async.token - // } - // - // %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`). - // - // In addition to the operation that uses the `value` we also keep track if - // this user is an `async.execute` operation itself, or has `async.execute` - // operations in the nested regions that do use the `value`. - - struct UserInfo { - Operation *operation; - bool hasExecuteUser; - }; - - struct BlockUsersInfo { - llvm::SmallVector addRefs; - llvm::SmallVector dropRefs; - llvm::SmallVector users; - }; - - llvm::DenseMap blockUsers; - - auto updateBlockUsersInfo = [&](UserInfo user) { - BlockUsersInfo &info = blockUsers[user.operation->getBlock()]; - info.users.push_back(user); - - if (auto addRef = dyn_cast(user.operation)) - info.addRefs.push_back(addRef); - if (auto dropRef = dyn_cast(user.operation)) - info.dropRefs.push_back(dropRef); - }; - - for (Operation *user : value.getUsers()) { - bool isAsyncUser = isa(user); - - while (user->getParentRegion() != definingRegion) { - updateBlockUsersInfo({user, isAsyncUser}); - user = user->getParentOp(); - isAsyncUser |= isa(user); - assert(user != nullptr && "value user lies outside of the value region"); - } - - updateBlockUsersInfo({user, isAsyncUser}); - } - - // Sort all operations found in the block. - auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & { - auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool { - return a->isBeforeInBlock(b); - }; - llvm::sort(info.addRefs, isBeforeInBlock); - llvm::sort(info.dropRefs, isBeforeInBlock); - llvm::sort(info.users, [&](UserInfo a, UserInfo b) -> bool { - return isBeforeInBlock(a.operation, b.operation); - }); - - return info; - }; - - // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the - // blocks that modify the reference count of the `value`. - for (auto &kv : blockUsers) { - BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second); - - // Find all cancellable pairs first and erase them later to keep all - // pointers in the `info` valid until the end. - // - // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. - llvm::SmallDenseMap cancellable; - - for (RuntimeAddRefOp addRef : info.addRefs) { - for (RuntimeDropRefOp dropRef : info.dropRefs) { - // `drop_ref` operation after the `add_ref` with matching count. - if (dropRef.count() != addRef.count() || - dropRef->isBeforeInBlock(addRef.getOperation())) - continue; - - // `drop_ref` was already marked for removal. - if (cancellable.find(dropRef.getOperation()) != cancellable.end()) - continue; - - // Check `value` users between `addRef` and `dropRef` in the `block`. - Operation *addRefOp = addRef.getOperation(); - Operation *dropRefOp = dropRef.getOperation(); - - // If there is a "regular" user after the `async.execute` user it is - // unsafe to erase cancellable reference counting operations pair, - // because async region can complete before the "regular" user and - // destroy the reference counted value. - bool hasExecuteUser = false; - bool unsafeToCancel = false; - - for (UserInfo &user : info.users) { - Operation *op = user.operation; - - // `user` operation lies after `addRef` ... - if (op == addRefOp || op->isBeforeInBlock(addRefOp)) - continue; - // ... and before `dropRef`. - if (op == dropRefOp || dropRefOp->isBeforeInBlock(op)) - break; - - bool isRegularUser = !user.hasExecuteUser; - bool isExecuteUser = user.hasExecuteUser; - - // It is unsafe to cancel `addRef` / `dropRef` pair. - if (isRegularUser && hasExecuteUser) { - unsafeToCancel = true; - break; - } - - hasExecuteUser |= isExecuteUser; - } - - // Mark the pair of reference counting operations for removal. - if (!unsafeToCancel) - cancellable[dropRef.getOperation()] = addRef.getOperation(); - - // If it us unsafe to cancel `addRef <-> dropRef` pair at this point, - // all the following pairs will be also unsafe. - break; - } - } - - // Erase all cancellable `addRef <-> dropRef` operation pairs. - for (auto &kv : cancellable) { - kv.first->erase(); - kv.second->erase(); - } - } - - return success(); -} - -void AsyncRefCountingOptimizationPass::runOnFunction() { - FuncOp func = getFunction(); - - // Optimize reference counting for values defined by block arguments. - WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { - for (BlockArgument arg : block->getArguments()) - if (isRefCounted(arg.getType())) - if (failed(optimizeReferenceCounting(arg))) - return WalkResult::interrupt(); - - return WalkResult::advance(); - }); - - if (blockWalk.wasInterrupted()) - signalPassFailure(); - - // Optimize reference counting for values defined by operation results. - WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { - for (unsigned i = 0; i < op->getNumResults(); ++i) - if (isRefCounted(op->getResultTypes()[i])) - if (failed(optimizeReferenceCounting(op->getResult(i)))) - return WalkResult::interrupt(); - - return WalkResult::advance(); - }); - - if (opWalk.wasInterrupted()) - signalPassFailure(); -} - -std::unique_ptr> -mlir::createAsyncRefCountingOptimizationPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp @@ -0,0 +1,377 @@ +//===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements automatic reference counting for Async runtime +// operations and types. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; +using namespace mlir::async; + +#define DEBUG_TYPE "async-runtime-ref-counting" + +namespace { + +class AsyncRuntimeRefCountingPass + : public AsyncRuntimeRefCountingBase { +public: + AsyncRuntimeRefCountingPass() = default; + void runOnFunction() override; + +private: + /// Adds an automatic reference counting to the `value`. + /// + /// All values (token, group or value) are semantically created with a + /// reference count of +1 and it is the responsibility of the async value user + /// to place the `add_ref` and `drop_ref` operations to ensure that the value + /// is destroyed after the last use. + /// + /// The function returns failure if it can't deduce the locations where + /// to place the reference counting operations. + /// + /// Async values "semantically created" when: + /// 1. Operation returns async result (e.g. `async.runtime.create`) + /// 2. Async value passed in as a block argument (or function argument, + /// because function arguments are just entry block arguments) + /// + /// Passing async value as a function argument (or block argument) does not + /// really mean that a new async value is created, it only means that the + /// caller of a function transfered ownership of `+1` reference to the callee. + /// It is convenient to think that from the callee perspective async value was + /// "created" with `+1` reference by the block argument. + /// + /// Automatic reference counting algorithm outline: + /// + /// #1 Insert `drop_ref` operations after last use of the `value`. + /// #2 Insert `add_ref` operations before functions calls with reference + /// counted `value` operand (newly created `+1` reference will be + /// transferred to the callee). + /// #3 Verify that divergent control flow does not lead to leaked reference + /// counted objects. + /// + /// Async runtime reference counting optimization pass will optimize away + /// some of the redundant `add_ref` and `drop_ref` operations inserted by this + /// strategy (see `async-runtime-ref-counting-opt`). + LogicalResult addAutomaticRefCounting(Value value); + + /// (#1) Adds the `drop_ref` operation after the last use of the `value` + /// relying on the liveness analysis. + /// + /// If the `value` is in the block `liveIn` set and it is not in the block + /// `liveOut` set, it means that it "dies" in the block. We find the last + /// use of the value in such block and: + /// + /// 1. If the last user is a `ReturnLike` operation we do nothing, because + /// it forwards the ownership to the caller. + /// 2. Otherwise we add a `drop_ref` operation immediately after the last + /// use. + LogicalResult addDropRefAfterLastUse(Value value); + + /// (#2) Adds the `add_ref` operation before the function call taking `value` + /// operand to ensure that the value passed to the function entry block + /// has a `+1` reference count. + LogicalResult addAddRefBeforeFunctionCall(Value value); + + /// (#3) Verifies that if a block has a value in the `liveOut` set, then the + /// value is in `liveIn` set in all successors. + /// + /// Example: + /// + /// ^entry: + /// %token = async.runtime.create : !async.token + /// cond_br %cond, ^bb1, ^bb2 + /// ^bb1: + /// async.runtime.await %token + /// return + /// ^bb2: + /// return + /// + /// This CFG will be rejected because ^bb2 does not have `value` in the + /// `liveIn` set, and it will leak a reference counted object. + /// + /// An exception to this rule are blocks with `async.coro.suspend` terminator, + /// because in Async to LLVM lowering it is guaranteed that the control flow + /// will jump into the resume block, and then follow into the cleanup and + /// suspend blocks. + /// + /// Example: + /// + /// ^entry(%value: !async.value): + /// async.runtime.await_and_resume %value, %hdl : !async.value + /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup + /// ^resume: + /// %0 = async.runtime.load %value + /// br ^cleanup + /// ^cleanup: + /// ... + /// ^suspend: + /// ... + /// + /// Although cleanup and suspend blocks do not have the `value` in the + /// `liveIn` set, it is guaranteed that execution will eventually continue in + /// the resume block (we never explicitly destroy coroutines). + LogicalResult verifySuccessors(Value value); +}; + +} // namespace + +LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { + OpBuilder builder(value.getContext()); + Location loc = value.getLoc(); + + // Use liveness analysis to find the placement of `drop_ref`operation. + auto &liveness = getAnalysis(); + + // We analyse only the blocks of the region that defines the `value`, and do + // not check nested blocks attached to operations. + // + // By analyzing only the `definingRegion` CFG we potentially loose an + // opportunity to drop the reference count earlier and can extend the lifetime + // of reference counted value longer then it is really required. + // + // We also assume that all nested regions finish their execution before the + // completion of the owner operation. The only exception to this rule is + // `async.execute` operation, and we verify that they are lowered to the + // `async.runtime` operations before adding automatic reference counting. + Region *definingRegion = value.getParentRegion(); + + // Last users of the `value` inside all blocks where the value dies. + llvm::SmallSet lastUsers; + + // Find blocks in the `definingRegion` that have users of the `value` (if + // there are multiple users in the block, which one will be selected is + // undefined). User operation might be not the actual user of the value, but + // the operation in the block that has a "real user" in one of the attached + // regions. + llvm::DenseMap usersInTheBlocks; + + for (Operation *user : value.getUsers()) { + Block *userBlock = user->getBlock(); + Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock); + usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user); + assert(ancestor && "ancestor block must be not null"); + assert(usersInTheBlocks[ancestor] && "ancestor op must be not null"); + } + + // Find blocks where the `value` dies: the value is in `liveIn` set and not + // in the `liveOut` set. We place `drop_ref` immediately after the last use + // of the `value` in such regions (after handling few special cases). + // + // We do not traverse all the blocks in the `definingRegion`, because the + // `value` can be in the live in set only if it has users in the block, or it + // is defined in the block. + // + // Values with zero users (only definition) handled explicitly above. + for (auto &blockAndUser : usersInTheBlocks) { + Block *block = blockAndUser.getFirst(); + Operation *userInTheBlock = blockAndUser.getSecond(); + + const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block); + + // Value must be in the live input set or defined in the block. + assert(blockLiveness->isLiveIn(value) || + blockLiveness->getBlock() == value.getParentBlock()); + + // If value is in the live out set, it means it doesn't "die" in the block. + if (blockLiveness->isLiveOut(value)) + continue; + + // At this point we proved that `value` dies in the `block`. Find the last + // use of the `value` inside the `block`, this is where it "dies". + Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); + assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); + lastUsers.insert(lastUser); + } + + // Process all the last users of the `value` inside each block where the value + // dies. + for (Operation *lastUser : lastUsers) { + // Return like operations forward reference count. + if (lastUser->hasTrait()) + continue; + + // We can't currently handle other types of terminators. + if (lastUser->hasTrait()) + return lastUser->emitError() << "async reference counting can't handle " + "terminators that are not ReturnLike"; + + // Add a drop_ref immediately after the last user. + builder.setInsertionPointAfter(lastUser); + builder.create(loc, value, builder.getI32IntegerAttr(1)); + } + + return success(); +} + +LogicalResult +AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { + OpBuilder builder(value.getContext()); + Location loc = value.getLoc(); + + for (Operation *user : value.getUsers()) { + if (!isa(user)) + continue; + + // Add a reference before the function call to pass the value at `+1` + // reference to the function entry block. + builder.setInsertionPoint(user); + builder.create(loc, value, builder.getI32IntegerAttr(1)); + } + + return success(); +} + +LogicalResult AsyncRuntimeRefCountingPass::verifySuccessors(Value value) { + OpBuilder builder(value.getContext()); + + // Blocks with successfors with different `liveIn` properties of the `value`. + llvm::SmallSet divergentLivenessBlocks; + + // Use liveness analysis to find the placement of `drop_ref`operation. + auto &liveness = getAnalysis(); + + // Because we only add `drop_ref` operations to the region that defines the + // `value` we can only process CFG for the same region. + Region *definingRegion = value.getParentRegion(); + + // Collect blocks with successors with mismatching `liveIn` sets. + for (Block &block : definingRegion->getBlocks()) { + const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); + + // Skip the block if value is not in the `liveOut` set. + if (!blockLiveness->isLiveOut(value)) + continue; + + // Sucessors with value in `liveIn` set and not value in `liveIn` set. + llvm::SmallSet liveInSuccessors; + llvm::SmallSet noLiveInSuccessors; + + // Collect successors that do not have `value` in the `liveIn` set. + for (Block *successor : block.getSuccessors()) { + const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); + if (succLiveness->isLiveIn(value)) + liveInSuccessors.insert(successor); + else + noLiveInSuccessors.insert(successor); + } + + // Block has successors with different `liveIn` property of the `value`. + if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty()) + divergentLivenessBlocks.insert(&block); + } + + // Verify that divergent `liveIn` property only present in blocks with + // async.coro.suspend terminator. + for (Block *block : divergentLivenessBlocks) { + Operation *terminator = block->getTerminator(); + if (isa(terminator)) + continue; + + return terminator->emitOpError("successor have different `liveIn` property " + "of the reference counted value: "); + } + + return success(); +} + +LogicalResult +AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) { + OpBuilder builder(value.getContext()); + Location loc = value.getLoc(); + + // Set inserton point after the operation producing a value, or at the + // beginning of the block if the value defined by the block argument. + if (Operation *op = value.getDefiningOp()) + builder.setInsertionPointAfter(op); + else + builder.setInsertionPointToStart(value.getParentBlock()); + + // Drop the reference count immediately if the value has no uses. + if (value.getUses().empty()) { + builder.create(loc, value, builder.getI32IntegerAttr(1)); + return success(); + } + + // Add `drop_ref` operations based on the liveness analysis. + if (failed(addDropRefAfterLastUse(value))) + return failure(); + + // Add `add_ref` operations before function calls. + if (failed(addAddRefBeforeFunctionCall(value))) + return failure(); + + // Verify that the `value` is in `liveIn` set of all successors. + if (failed(verifySuccessors(value))) + return failure(); + + return success(); +} + +void AsyncRuntimeRefCountingPass::runOnFunction() { + FuncOp func = getFunction(); + + // Check that we do not have high level async operations in the IR because + // otherwise automatic reference counting will produce incorrect results after + // execute operations will be lowered to `async.runtime` + WalkResult executeOpWalk = func.walk([&](Operation *op) -> WalkResult { + if (!isa(op)) + return WalkResult::advance(); + + return op->emitError() + << "async operations must be lowered to async runtime operations"; + }); + + if (executeOpWalk.wasInterrupted()) { + signalPassFailure(); + return; + } + + // Add reference counting to block arguments. + WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { + for (BlockArgument arg : block->getArguments()) + if (isRefCounted(arg.getType())) + if (failed(addAutomaticRefCounting(arg))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (blockWalk.wasInterrupted()) { + signalPassFailure(); + return; + } + + // Add reference counting to operation results. + WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { + for (unsigned i = 0; i < op->getNumResults(); ++i) + if (isRefCounted(op->getResultTypes()[i])) + if (failed(addAutomaticRefCounting(op->getResult(i)))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (opWalk.wasInterrupted()) + signalPassFailure(); +} + +std::unique_ptr> +mlir::createAsyncRuntimeRefCountingPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp @@ -0,0 +1,177 @@ +//===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Optimize Async dialect reference counting operations. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Async/Passes.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; +using namespace mlir::async; + +#define DEBUG_TYPE "async-ref-counting" + +namespace { + +class AsyncRuntimeRefCountingOptPass + : public AsyncRuntimeRefCountingOptBase { +public: + AsyncRuntimeRefCountingOptPass() = default; + void runOnFunction() override; + +private: + LogicalResult optimizeReferenceCounting( + Value value, llvm::SmallDenseMap &cancellable); +}; + +} // namespace + +LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting( + Value value, llvm::SmallDenseMap &cancellable) { + Region *definingRegion = value.getParentRegion(); + + // Find all users of the `value` inside each block, including operations that + // do not use `value` directly, but have a direct use inside nested region(s). + // + // Example: + // + // ^bb1: + // %token = ... + // scf.if %cond { + // ^bb2: + // async.runtime.await %token : !async.token + // } + // + // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1 + // (`scf.if`). + + struct BlockUsersInfo { + llvm::SmallVector addRefs; + llvm::SmallVector dropRefs; + llvm::SmallVector users; + }; + + llvm::DenseMap blockUsers; + + auto updateBlockUsersInfo = [&](Operation *user) { + BlockUsersInfo &info = blockUsers[user->getBlock()]; + info.users.push_back(user); + + if (auto addRef = dyn_cast(user)) + info.addRefs.push_back(addRef); + if (auto dropRef = dyn_cast(user)) + info.dropRefs.push_back(dropRef); + }; + + for (Operation *user : value.getUsers()) { + while (user->getParentRegion() != definingRegion) { + updateBlockUsersInfo(user); + user = user->getParentOp(); + assert(user != nullptr && "value user lies outside of the value region"); + } + + updateBlockUsersInfo(user); + } + + // Sort all operations found in the block. + auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & { + auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool { + return a->isBeforeInBlock(b); + }; + llvm::sort(info.addRefs, isBeforeInBlock); + llvm::sort(info.dropRefs, isBeforeInBlock); + llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool { + return isBeforeInBlock(a, b); + }); + + return info; + }; + + // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the + // blocks that modify the reference count of the `value`. + for (auto &kv : blockUsers) { + BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second); + + for (RuntimeAddRefOp addRef : info.addRefs) { + for (RuntimeDropRefOp dropRef : info.dropRefs) { + // `drop_ref` operation after the `add_ref` with matching count. + if (dropRef.count() != addRef.count() || + dropRef->isBeforeInBlock(addRef.getOperation())) + continue; + + // Try to cancel the pair of `add_ref` and `drop_ref` operations. + auto emplaced = cancellable.try_emplace(dropRef.getOperation(), + addRef.getOperation()); + + if (!emplaced.second) // `drop_ref` was already marked for removal + continue; // go to the next `drop_ref` + + if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref` + break; // go to the next `add_ref` + } + } + } + + return success(); +} + +void AsyncRuntimeRefCountingOptPass::runOnFunction() { + FuncOp func = getFunction(); + + // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. + // + // Find all cancellable pairs of operation and erase them in the end to keep + // all iterators valid while we are walking the function operations. + llvm::SmallDenseMap cancellable; + + // Optimize reference counting for values defined by block arguments. + WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { + for (BlockArgument arg : block->getArguments()) + if (isRefCounted(arg.getType())) + if (failed(optimizeReferenceCounting(arg, cancellable))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (blockWalk.wasInterrupted()) + signalPassFailure(); + + // Optimize reference counting for values defined by operation results. + WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { + for (unsigned i = 0; i < op->getNumResults(); ++i) + if (isRefCounted(op->getResultTypes()[i])) + if (failed(optimizeReferenceCounting(op->getResult(i), cancellable))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (opWalk.wasInterrupted()) + signalPassFailure(); + + LLVM_DEBUG({ + llvm::dbgs() << "Found " << cancellable.size() + << " cancellable reference counting operations\n"; + }); + + // Erase all cancellable `add_ref <-> drop_ref` operation pairs. + for (auto &kv : cancellable) { + kv.first->erase(); + kv.second->erase(); + } +} + +std::unique_ptr> +mlir::createAsyncRuntimeRefCountingOptPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -1,7 +1,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms AsyncParallelFor.cpp - AsyncRefCounting.cpp - AsyncRefCountingOptimization.cpp + AsyncRuntimeRefCounting.cpp + AsyncRuntimeRefCountingOpt.cpp AsyncToAsyncRuntime.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir b/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir +++ /dev/null @@ -1,114 +0,0 @@ -// RUN: mlir-opt %s -async-ref-counting-optimization | FileCheck %s - -// CHECK-LABEL: @cancellable_operations_0 -func @cancellable_operations_0(%arg0: !async.token) { - // CHECK-NOT: async.runtime.add_ref - // CHECK-NOT: async.runtime.drop_ref - async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token - async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token - // CHECK: return - return -} - -// CHECK-LABEL: @cancellable_operations_1 -func @cancellable_operations_1(%arg0: !async.token) { - // CHECK-NOT: async.runtime.add_ref - // CHECK: async.execute - async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token - async.execute [%arg0] { - // CHECK: async.runtime.drop_ref - async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token - // CHECK-NEXT: async.yield - async.yield - } - // CHECK-NOT: async.runtime.drop_ref - async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token - // CHECK: return - return -} - -// CHECK-LABEL: @cancellable_operations_2 -func @cancellable_operations_2(%arg0: !async.token) { - // CHECK: async.await - // CHECK-NEXT: async.await - // CHECK-NEXT: async.await - // CHECK-NEXT: return - async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token - async.await %arg0 : !async.token - async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token - async.await %arg0 : !async.token - async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token - async.await %arg0 : !async.token - async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token - return -} - -// CHECK-LABEL: @cancellable_operations_3 -func @cancellable_operations_3(%arg0: !async.token) { - // CHECK-NOT: add_ref - async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token - %token = async.execute { - async.await %arg0 : !async.token - // CHECK: async.runtime.drop_ref - async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token - async.yield - } - // CHECK-NOT: async.runtime.drop_ref - async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token - // CHECK: async.await - async.await %arg0 : !async.token - // CHECK: return - return -} - -// CHECK-LABEL: @not_cancellable_operations_0 -func @not_cancellable_operations_0(%arg0: !async.token, %arg1: i1) { - // It is unsafe to cancel `add_ref` / `drop_ref` pair because it is possible - // that the body of the `async.execute` operation will run before the await - // operation in the function body, and will destroy the `%arg0` token. - // CHECK: add_ref - async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token - %token = async.execute { - // CHECK: async.await - async.await %arg0 : !async.token - // CHECK: async.runtime.drop_ref - async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token - // CHECK: async.yield - async.yield - } - // CHECK: async.await - async.await %arg0 : !async.token - // CHECK: drop_ref - async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token - // CHECK: return - return -} - -// CHECK-LABEL: @not_cancellable_operations_1 -func @not_cancellable_operations_1(%arg0: !async.token, %arg1: i1) { - // Same reason as above, although `async.execute` is inside the nested - // region or "regular" operation. - // - // NOTE: This test is not correct w.r.t. reference counting, and at runtime - // would leak %arg0 value if %arg1 is false. IR like this will not be - // constructed by automatic reference counting pass, because it would - // place `async.runtime.add_ref` right before the `async.execute` - // inside `scf.if`. - - // CHECK: async.runtime.add_ref - async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token - scf.if %arg1 { - %token = async.execute { - async.await %arg0 : !async.token - // CHECK: async.runtime.drop_ref - async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token - async.yield - } - } - // CHECK: async.await - async.await %arg0 : !async.token - // CHECK: async.runtime.drop_ref - async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token - // CHECK: return - return -} diff --git a/mlir/test/Dialect/Async/async-ref-counting.mlir b/mlir/test/Dialect/Async/async-ref-counting.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Async/async-ref-counting.mlir +++ /dev/null @@ -1,253 +0,0 @@ -// RUN: mlir-opt %s -async-ref-counting | FileCheck %s - -// CHECK-LABEL: @cond -func private @cond() -> i1 - -// CHECK-LABEL: @token_arg_no_uses -func @token_arg_no_uses(%arg0: !async.token) { - // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32} - return -} - -// CHECK-LABEL: @token_arg_conditional_await -func @token_arg_conditional_await(%arg0: !async.token, %arg1: i1) { - cond_br %arg1, ^bb1, ^bb2 -^bb1: - // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32} - return -^bb2: - // CHECK: async.await %arg0 - // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32} - async.await %arg0 : !async.token - return -} - -// CHECK-LABEL: @token_no_uses -func @token_no_uses() { - // CHECK: %[[TOKEN:.*]] = async.execute - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - %token = async.execute { - async.yield - } - return -} - -// CHECK-LABEL: @token_return -func @token_return() -> !async.token { - // CHECK: %[[TOKEN:.*]] = async.execute - %token = async.execute { - async.yield - } - // CHECK: return %[[TOKEN]] - return %token : !async.token -} - -// CHECK-LABEL: @token_await -func @token_await() { - // CHECK: %[[TOKEN:.*]] = async.execute - %token = async.execute { - async.yield - } - // CHECK: async.await %[[TOKEN]] - async.await %token : !async.token - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: return - return -} - -// CHECK-LABEL: @token_await_and_return -func @token_await_and_return() -> !async.token { - // CHECK: %[[TOKEN:.*]] = async.execute - %token = async.execute { - async.yield - } - // CHECK: async.await %[[TOKEN]] - // CHECK-NOT: async.runtime.drop_ref - async.await %token : !async.token - // CHECK: return %[[TOKEN]] - return %token : !async.token -} - -// CHECK-LABEL: @token_await_inside_scf_if -func @token_await_inside_scf_if(%arg0: i1) { - // CHECK: %[[TOKEN:.*]] = async.execute - %token = async.execute { - async.yield - } - // CHECK: scf.if %arg0 { - scf.if %arg0 { - // CHECK: async.await %[[TOKEN]] - async.await %token : !async.token - } - // CHECK: } - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: return - return -} - -// CHECK-LABEL: @token_conditional_await -func @token_conditional_await(%arg0: i1) { - // CHECK: %[[TOKEN:.*]] = async.execute - %token = async.execute { - async.yield - } - cond_br %arg0, ^bb1, ^bb2 -^bb1: - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - return -^bb2: - // CHECK: async.await %[[TOKEN]] - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - async.await %token : !async.token - return -} - -// CHECK-LABEL: @token_await_in_the_loop -func @token_await_in_the_loop() { - // CHECK: %[[TOKEN:.*]] = async.execute - %token = async.execute { - async.yield - } - br ^bb1 -^bb1: - // CHECK: async.await %[[TOKEN]] - async.await %token : !async.token - %0 = call @cond(): () -> (i1) - cond_br %0, ^bb1, ^bb2 -^bb2: - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - return -} - -// CHECK-LABEL: @token_defined_in_the_loop -func @token_defined_in_the_loop() { - br ^bb1 -^bb1: - // CHECK: %[[TOKEN:.*]] = async.execute - %token = async.execute { - async.yield - } - // CHECK: async.await %[[TOKEN]] - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - async.await %token : !async.token - %0 = call @cond(): () -> (i1) - cond_br %0, ^bb1, ^bb2 -^bb2: - return -} - -// CHECK-LABEL: @token_capture -func @token_capture() { - // CHECK: %[[TOKEN:.*]] = async.execute - %token = async.execute { - async.yield - } - - // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: %[[TOKEN_0:.*]] = async.execute - %token_0 = async.execute { - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - // CHECK-NEXT: async.yield - async.await %token : !async.token - async.yield - } - // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32} - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: return - return -} - -// CHECK-LABEL: @token_nested_capture -func @token_nested_capture() { - // CHECK: %[[TOKEN:.*]] = async.execute - %token = async.execute { - async.yield - } - - // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: %[[TOKEN_0:.*]] = async.execute - %token_0 = async.execute { - // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: %[[TOKEN_1:.*]] = async.execute - %token_1 = async.execute { - // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: %[[TOKEN_2:.*]] = async.execute - %token_2 = async.execute { - // CHECK: async.await %[[TOKEN]] - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - async.await %token : !async.token - async.yield - } - // CHECK: async.runtime.drop_ref %[[TOKEN_2]] {count = 1 : i32} - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - async.yield - } - // CHECK: async.runtime.drop_ref %[[TOKEN_1]] {count = 1 : i32} - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - async.yield - } - // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32} - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: return - return -} - -// CHECK-LABEL: @token_dependency -func @token_dependency() { - // CHECK: %[[TOKEN:.*]] = async.execute - %token = async.execute { - async.yield - } - - // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: %[[TOKEN_0:.*]] = async.execute - %token_0 = async.execute[%token] { - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - // CHECK-NEXT: async.yield - async.yield - } - - // CHECK: async.await %[[TOKEN]] - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - async.await %token : !async.token - // CHECK: async.await %[[TOKEN_0]] - // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32} - async.await %token_0 : !async.token - - // CHECK: return - return -} - -// CHECK-LABEL: @value_operand -func @value_operand() -> f32 { - // CHECK: %[[TOKEN:.*]], %[[RESULTS:.*]] = async.execute - %token, %results = async.execute -> !async.value { - %0 = constant 0.0 : f32 - async.yield %0 : f32 - } - - // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: async.runtime.add_ref %[[RESULTS]] {count = 1 : i32} - // CHECK: %[[TOKEN_0:.*]] = async.execute - %token_0 = async.execute[%token](%results as %arg0 : !async.value) { - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - // CHECK: async.runtime.drop_ref %[[RESULTS]] {count = 1 : i32} - // CHECK: async.yield - async.yield - } - - // CHECK: async.await %[[TOKEN]] - // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} - async.await %token : !async.token - - // CHECK: async.await %[[TOKEN_0]] - // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32} - async.await %token_0 : !async.token - - // CHECK: async.await %[[RESULTS]] - // CHECK: async.runtime.drop_ref %[[RESULTS]] {count = 1 : i32} - %0 = async.await %results : !async.value - - // CHECK: return - return %0 : f32 -} diff --git a/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir b/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir @@ -0,0 +1,55 @@ +// RUN: mlir-opt %s -async-runtime-ref-counting-opt | FileCheck %s + +func private @consume_token(%arg0: !async.token) + +// CHECK-LABEL: @cancellable_operations_0 +func @cancellable_operations_0(%arg0: !async.token) { + // CHECK-NOT: async.runtime.add_ref + // CHECK-NOT: async.runtime.drop_ref + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: return + return +} + +// CHECK-LABEL: @cancellable_operations_1 +func @cancellable_operations_1(%arg0: !async.token) { + // CHECK-NOT: async.runtime.add_ref + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: call @consume_toke + call @consume_token(%arg0): (!async.token) -> () + // CHECK-NOT: async.runtime.drop_ref + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: return + return +} + +// CHECK-LABEL: @cancellable_operations_2 +func @cancellable_operations_2(%arg0: !async.token) { + // CHECK: async.runtime.await + // CHECK-NEXT: async.runtime.await + // CHECK-NEXT: async.runtime.await + // CHECK-NEXT: return + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.await %arg0 : !async.token + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.await %arg0 : !async.token + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token + async.runtime.await %arg0 : !async.token + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token + return +} + +// CHECK-LABEL: @cancellable_operations_3 +func @cancellable_operations_3(%arg0: !async.token) { + // CHECK-NOT: add_ref + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: call @consume_toke + call @consume_token(%arg0): (!async.token) -> () + // CHECK-NOT: async.runtime.drop_ref + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: async.runtime.await + async.runtime.await %arg0 : !async.token + // CHECK: return + return +} diff --git a/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir b/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir @@ -0,0 +1,215 @@ +// RUN: mlir-opt %s -async-runtime-ref-counting | FileCheck %s + +// CHECK-LABEL: @token +func private @token() -> !async.token + +// CHECK-LABEL: @cond +func private @cond() -> i1 + +// CHECK-LABEL: @take_token +func private @take_token(%arg0: !async.token) + +// CHECK-LABEL: @token_arg_no_uses +// CHECK: %[[TOKEN:.*]]: !async.token +func @token_arg_no_uses(%arg0: !async.token) { + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + return +} + +// CHECK-LABEL: @token_value_no_uses +func @token_value_no_uses() { + // CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + %0 = async.runtime.create : !async.token + return +} + +// CHECK-LABEL: @token_returned_no_uses +func @token_returned_no_uses() { + // CHECK: %[[TOKEN:.*]] = call @token + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + %0 = call @token() : () -> !async.token + return +} + +// CHECK-LABEL: @token_arg_to_func +// CHECK: %[[TOKEN:.*]]: !async.token +func @token_arg_to_func(%arg0: !async.token) { + // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} : !async.token + call @take_token(%arg0): (!async.token) -> () + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} : !async.token + return +} + +// CHECK-LABEL: @token_value_to_func +func @token_value_to_func() { + // CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token + %0 = async.runtime.create : !async.token + // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} : !async.token + call @take_token(%0): (!async.token) -> () + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + return +} + +// CHECK-LABEL: @token_arg_cond_br_await_with_fallthough +// CHECK: %[[TOKEN:.*]]: !async.token +func @token_arg_cond_br_await_with_fallthough(%arg0: !async.token, %arg1: i1) { + // CHECK: cond_br + // CHECK-SAME: ^[[BB1:.*]], ^[[BB2:.*]] + cond_br %arg1, ^bb1, ^bb2 +^bb1: + // CHECK: ^[[BB1]]: + // CHECK: br ^[[BB2]] + br ^bb2 +^bb2: + // CHECK: ^[[BB2]]: + // CHECK: async.runtime.await %[[TOKEN]] + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + async.runtime.await %arg0 : !async.token + return +} + +// CHECK-LABEL: @token_simple_return +func @token_simple_return() -> !async.token { + // CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token + %token = async.runtime.create : !async.token + // CHECK: return %[[TOKEN]] + return %token : !async.token +} + +// CHECK-LABEL: @token_coro_return +// CHECK-NOT: async.runtime.drop_ref +// CHECK-NOT: async.runtime.add_ref +func @token_coro_return() -> !async.token { + %token = async.runtime.create : !async.token + %id = async.coro.id + %hdl = async.coro.begin %id + %saved = async.coro.save %hdl + async.runtime.resume %hdl + async.coro.suspend %saved, ^suspend, ^resume, ^cleanup +^resume: + br ^cleanup +^cleanup: + async.coro.free %id, %hdl + br ^suspend +^suspend: + async.coro.end %hdl + return %token : !async.token +} + +// CHECK-LABEL: @token_coro_await_and_resume +// CHECK: %[[TOKEN:.*]]: !async.token +func @token_coro_await_and_resume(%arg0: !async.token) -> !async.token { + %token = async.runtime.create : !async.token + %id = async.coro.id + %hdl = async.coro.begin %id + %saved = async.coro.save %hdl + // CHECK: async.runtime.await_and_resume %[[TOKEN]] + async.runtime.await_and_resume %arg0, %hdl : !async.token + // CHECK-NEXT: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + async.coro.suspend %saved, ^suspend, ^resume, ^cleanup +^resume: + br ^cleanup +^cleanup: + async.coro.free %id, %hdl + br ^suspend +^suspend: + async.coro.end %hdl + return %token : !async.token +} + +// CHECK-LABEL: @value_coro_await_and_resume +// CHECK: %[[VALUE:.*]]: !async.value +func @value_coro_await_and_resume(%arg0: !async.value) -> !async.token { + %token = async.runtime.create : !async.token + %id = async.coro.id + %hdl = async.coro.begin %id + %saved = async.coro.save %hdl + // CHECK: async.runtime.await_and_resume %[[VALUE]] + async.runtime.await_and_resume %arg0, %hdl : !async.value + // CHECK: async.coro.suspend + // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] + async.coro.suspend %saved, ^suspend, ^resume, ^cleanup +^resume: + // CHECK: ^[[RESUME]]: + // CHECK: %[[LOADED:.*]] = async.runtime.load %[[VALUE]] + // CHECK: async.runtime.drop_ref %[[VALUE]] {count = 1 : i32} + %0 = async.runtime.load %arg0 : !async.value + // CHECK: addf %[[LOADED]], %[[LOADED]] + %1 = addf %0, %0 : f32 + br ^cleanup +^cleanup: + async.coro.free %id, %hdl + br ^suspend +^suspend: + async.coro.end %hdl + return %token : !async.token +} + +// CHECK-LABEL: @outlined_async_execute +// CHECK: %[[TOKEN:.*]]: !async.token +func private @outlined_async_execute(%arg0: !async.token) -> !async.token { + %0 = async.runtime.create : !async.token + %1 = async.coro.id + %2 = async.coro.begin %1 + %3 = async.coro.save %2 + async.runtime.resume %2 + // CHECK: async.coro.suspend + async.coro.suspend %3, ^suspend, ^resume, ^cleanup +^resume: + // CHECK: ^[[RESUME:.*]]: + %4 = async.coro.save %2 + async.runtime.await_and_resume %arg0, %2 : !async.token + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.coro.suspend + async.coro.suspend %4, ^suspend, ^resume_1, ^cleanup +^resume_1: + // CHECK: ^[[RESUME_1:.*]]: + // CHECK: async.runtime.set_available + async.runtime.set_available %0 : !async.token + br ^cleanup +^cleanup: + // CHECK: ^[[CLEANUP:.*]]: + // CHECK: async.coro.free + async.coro.free %1, %2 + br ^suspend +^suspend: + // CHECK: ^[[SUSPEND:.*]]: + // CHECK: async.coro.end + async.coro.end %2 + return %0 : !async.token +} + +// CHECK-LABEL: @token_await_inside_nested_region +// CHECK: %[[ARG:.*]]: i1 +func @token_await_inside_nested_region(%arg0: i1) { + // CHECK: %[[TOKEN:.*]] = call @token() + %token = call @token() : () -> !async.token + // CHECK: scf.if %[[ARG]] { + scf.if %arg0 { + // CHECK: async.runtime.await %[[TOKEN]] + async.runtime.await %token : !async.token + } + // CHECK: } + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_defined_in_the_loop +func @token_defined_in_the_loop() { + br ^bb1 +^bb1: + // CHECK: ^[[BB1:.*]]: + // CHECK: %[[TOKEN:.*]] = call @token() + %token = call @token() : () -> !async.token + // CHECK: async.runtime.await %[[TOKEN]] + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + async.runtime.await %token : !async.token + %0 = call @cond(): () -> (i1) + cond_br %0, ^bb1, ^bb2 +^bb2: + // CHECK: ^[[BB2:.*]]: + // CHECK: return + return +} diff --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir @@ -1,8 +1,9 @@ // RUN: mlir-opt %s \ // RUN: -linalg-tile-to-parallel-loops="linalg-tile-sizes=256" \ // RUN: -async-parallel-for="num-concurrent-async-execute=4" \ -// RUN: -async-ref-counting \ // RUN: -async-to-async-runtime \ +// RUN: -async-runtime-ref-counting \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -lower-affine \ // RUN: -convert-linalg-to-loops \ diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -async-parallel-for \ -// RUN: -async-ref-counting \ // RUN: -async-to-async-runtime \ +// RUN: -async-runtime-ref-counting \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -async-parallel-for \ // RUN: -async-to-async-runtime \ -// RUN: -async-ref-counting \ +// RUN: -async-runtime-ref-counting \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir --- a/mlir/test/mlir-cpu-runner/async-group.mlir +++ b/mlir/test/mlir-cpu-runner/async-group.mlir @@ -1,5 +1,6 @@ -// RUN: mlir-opt %s -async-ref-counting \ -// RUN: -async-to-async-runtime \ +// RUN: mlir-opt %s -async-to-async-runtime \ +// RUN: -async-runtime-ref-counting \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-std-to-llvm \ // RUN: | mlir-cpu-runner \ diff --git a/mlir/test/mlir-cpu-runner/async-value.mlir b/mlir/test/mlir-cpu-runner/async-value.mlir --- a/mlir/test/mlir-cpu-runner/async-value.mlir +++ b/mlir/test/mlir-cpu-runner/async-value.mlir @@ -1,5 +1,6 @@ -// RUN: mlir-opt %s -async-ref-counting \ -// RUN: -async-to-async-runtime \ +// RUN: mlir-opt %s -async-to-async-runtime \ +// RUN: -async-runtime-ref-counting \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-vector-to-llvm \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir --- a/mlir/test/mlir-cpu-runner/async.mlir +++ b/mlir/test/mlir-cpu-runner/async.mlir @@ -1,5 +1,6 @@ -// RUN: mlir-opt %s -async-ref-counting \ -// RUN: -async-to-async-runtime \ +// RUN: mlir-opt %s -async-to-async-runtime \ +// RUN: -async-runtime-ref-counting \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-linalg-to-loops \ // RUN: -convert-scf-to-std \