diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp @@ -88,8 +88,9 @@ /// has a `+1` reference count. LogicalResult addAddRefBeforeFunctionCall(Value value); - /// (#3) Verifies that if a block has a value in the `liveOut` set, then the - /// value is in `liveIn` set in all successors. + /// (#3) Adds the `drop_ref` operation to account for successor blocks with + /// divergent `liveIn` property: `value` is not in the `liveIn` set of all + /// successor blocks. /// /// Example: /// @@ -98,12 +99,29 @@ /// cond_br %cond, ^bb1, ^bb2 /// ^bb1: /// async.runtime.await %token - /// return + /// async.runtime.drop_ref %token + /// br ^bb2 /// ^bb2: /// return /// - /// This CFG will be rejected because ^bb2 does not have `value` in the - /// `liveIn` set, and it will leak a reference counted object. + /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have + /// to branch into a special "reference counting block" from the ^entry that + /// will have a `drop_ref` operation, and then branch into the ^bb2. + /// + /// After transformation: + /// + /// ^entry: + /// %token = async.runtime.create : !async.token + /// cond_br %cond, ^bb1, ^reference_counting + /// ^bb1: + /// async.runtime.await %token + /// async.runtime.drop_ref %token + /// br ^bb2 + /// ^reference_counting: + /// async.runtime.drop_ref %token + /// br ^bb2 + /// ^bb2: + /// return /// /// An exception to this rule are blocks with `async.coro.suspend` terminator, /// because in Async to LLVM lowering it is guaranteed that the control flow @@ -126,7 +144,7 @@ /// Although cleanup and suspend blocks do not have the `value` in the /// `liveIn` set, it is guaranteed that execution will eventually continue in /// the resume block (we never explicitly destroy coroutines). - LogicalResult verifySuccessors(Value value); + LogicalResult addDropRefInDivergentLivenessSuccessor(Value value); }; } // namespace @@ -237,11 +255,16 @@ return success(); } -LogicalResult AsyncRuntimeRefCountingPass::verifySuccessors(Value value) { +LogicalResult +AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( + Value value) { + using BlockSet = llvm::SmallPtrSet; + OpBuilder builder(value.getContext()); - // Blocks with successfors with different `liveIn` properties of the `value`. - llvm::SmallSet divergentLivenessBlocks; + // If a block has successors with different `liveIn` property of the `value`, + // record block successors that do not thave the `value` in the `liveIn` set. + llvm::SmallDenseMap divergentLivenessBlocks; // Use liveness analysis to find the placement of `drop_ref`operation. auto &liveness = getAnalysis(); @@ -258,9 +281,8 @@ if (!blockLiveness->isLiveOut(value)) continue; - // Sucessors with value in `liveIn` set and not value in `liveIn` set. - llvm::SmallSet liveInSuccessors; - llvm::SmallSet noLiveInSuccessors; + BlockSet liveInSuccessors; // `value` is in `liveIn` set + BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set // Collect successors that do not have `value` in the `liveIn` set. for (Block *successor : block.getSuccessors()) { @@ -273,18 +295,60 @@ // Block has successors with different `liveIn` property of the `value`. if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty()) - divergentLivenessBlocks.insert(&block); + divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors); } - // Verify that divergent `liveIn` property only present in blocks with - // async.coro.suspend terminator. - for (Block *block : divergentLivenessBlocks) { + // Try to insert `dropRef` operations to handle blocks with divergent liveness + // in successors blocks. + for (auto kv : divergentLivenessBlocks) { + Block *block = kv.getFirst(); + BlockSet &successors = kv.getSecond(); + + // Coroutine suspension is a special case terminator for wich we do not + // need to create additional reference counting (see details above). Operation *terminator = block->getTerminator(); if (isa(terminator)) continue; - return terminator->emitOpError("successor have different `liveIn` property " - "of the reference counted value: "); + // We only support successor blocks with empty block argument list. + auto hasArgs = [](Block *block) { return !block->getArguments().empty(); }; + if (llvm::any_of(successors, hasArgs)) + return terminator->emitOpError() + << "successor have different `liveIn` property of the reference " + "counted value"; + + // Make sure that `dropRef` operation is called when branched into the + // successor block without `value` in the `liveIn` set. + for (Block *successor : successors) { + // If successor has a unique predecessor, it is safe to create `dropRef` + // operations directly in the successor block. + // + // Otherwise we need to create a special block for reference counting + // operations, and branch from it to the original successor block. + Block *refCountingBlock = nullptr; + + if (successor->getUniquePredecessor() == block) { + refCountingBlock = successor; + } else { + refCountingBlock = &successor->getParent()->emplaceBlock(); + refCountingBlock->moveBefore(successor); + OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock); + builder.create(value.getLoc(), successor); + } + + OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock); + builder.create(value.getLoc(), value, + builder.getI32IntegerAttr(1)); + + // No need to update the terminator operation. + if (successor == refCountingBlock) + continue; + + // Update terminator `successor` block to `refCountingBlock`. + for (auto pair : llvm::enumerate(terminator->getSuccessors())) + if (pair.value() == successor) + terminator->setSuccessor(refCountingBlock, pair.index()); + } } return success(); @@ -316,8 +380,8 @@ if (failed(addAddRefBeforeFunctionCall(value))) return failure(); - // Verify that the `value` is in `liveIn` set of all successors. - if (failed(verifySuccessors(value))) + // Add `drop_ref` operations to successors with divergent `value` liveness. + if (failed(addDropRefInDivergentLivenessSuccessor(value))) return failure(); return success(); diff --git a/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir b/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir --- a/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir +++ b/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir @@ -213,3 +213,79 @@ // CHECK: return return } + +// CHECK-LABEL: @divergent_liveness_one_token +func @divergent_liveness_one_token(%arg0 : i1) { + // CHECK: %[[TOKEN:.*]] = call @token() + %token = call @token() : () -> !async.token + // CHECK: cond_br %arg0, ^[[LIVE_IN:.*]], ^[[REF_COUNTING:.*]] + cond_br %arg0, ^bb1, ^bb2 +^bb1: + // CHECK: ^[[LIVE_IN]]: + // CHECK: async.runtime.await %[[TOKEN]] + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: br ^[[RETURN:.*]] + async.runtime.await %token : !async.token + br ^bb2 + // CHECK: ^[[REF_COUNTING:.*]]: + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: br ^[[RETURN:.*]] +^bb2: + // CHECK: ^[[RETURN]]: + // CHECK: return + return +} + +// CHECK-LABEL: @divergent_liveness_unique_predecessor +func @divergent_liveness_unique_predecessor(%arg0 : i1) { + // CHECK: %[[TOKEN:.*]] = call @token() + %token = call @token() : () -> !async.token + // CHECK: cond_br %arg0, ^[[LIVE_IN:.*]], ^[[NO_LIVE_IN:.*]] + cond_br %arg0, ^bb2, ^bb1 +^bb1: + // CHECK: ^[[NO_LIVE_IN]]: + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: br ^[[RETURN:.*]] + br ^bb3 +^bb2: + // CHECK: ^[[LIVE_IN]]: + // CHECK: async.runtime.await %[[TOKEN]] + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: br ^[[RETURN]] + async.runtime.await %token : !async.token + br ^bb3 +^bb3: + // CHECK: ^[[RETURN]]: + // CHECK: return + return +} + +// CHECK-LABEL: @divergent_liveness_two_tokens +func @divergent_liveness_two_tokens(%arg0 : i1) { + // CHECK: %[[TOKEN0:.*]] = call @token() + // CHECK: %[[TOKEN1:.*]] = call @token() + %token0 = call @token() : () -> !async.token + %token1 = call @token() : () -> !async.token + // CHECK: cond_br %arg0, ^[[AWAIT0:.*]], ^[[AWAIT1:.*]] + cond_br %arg0, ^await0, ^await1 +^await0: + // CHECK: ^[[AWAIT0]]: + // CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i32} + // CHECK: async.runtime.await %[[TOKEN0]] + // CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i32} + // CHECK: br ^[[RETURN:.*]] + async.runtime.await %token0 : !async.token + br ^ret +^await1: + // CHECK: ^[[AWAIT1]]: + // CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i32} + // CHECK: async.runtime.await %[[TOKEN1]] + // CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i32} + // CHECK: br ^[[RETURN]] + async.runtime.await %token1 : !async.token + br ^ret +^ret: + // CHECK: ^[[RETURN]]: + // CHECK: return + return +}