diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -29,6 +29,8 @@ std::unique_ptr createAsyncRuntimeRefCountingOptPass(); +std::unique_ptr createAsyncRuntimePolicyBasedRefCountingPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -66,4 +66,36 @@ let dependentDialects = ["async::AsyncDialect"]; } +def AsyncRuntimePolicyBasedRefCounting + : Pass<"async-runtime-policy-based-ref-counting"> { + let summary = "Policy based reference counting for Async runtime operations"; + let description = [{ + This pass works at the async runtime abtraction level, after all + `async.execute` and `async.await` operations are lowered to the async + runtime API calls, and async coroutine operations. + + This pass doesn't rely on reference counted values liveness analysis, and + instead uses simple policy to create reference counting operations. If the + program violates any of the assumptions, then this pass might lead to + memory leaks or runtime errors. + + The default reference counting policy assumptions: + 1. Async token can be awaited or added to the group only once. + 2. Async value or group can be awaited only once. + + Under these assumptions reference counting only needs to drop reference: + 1. After `async.runtime.await` operation for async tokens and groups + (until error handling is not implemented for the sync await). + 2. After `async.runtime.is_error` operation for async tokens and groups + (this is the last operation in the coroutine resume function). + 3. After `async.runtime.load` operation for async values. + + This pass introduces significanly less runtime overhead compared to the + automatic reference counting. + }]; + + let constructor = "mlir::createAsyncRuntimePolicyBasedRefCountingPass()"; + let dependentDialects = ["async::AsyncDialect"]; +} + #endif // MLIR_DIALECT_ASYNC_PASSES diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp @@ -26,6 +26,79 @@ #define DEBUG_TYPE "async-runtime-ref-counting" +//===----------------------------------------------------------------------===// +// Utility functions shared by reference counting passes. +//===----------------------------------------------------------------------===// + +// Drop the reference count immediately if the value has no uses. +static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) { + if (!value.getUses().empty()) + return failure(); + + OpBuilder b(value.getContext()); + + // Set insertion point after the operation producing a value, or at the + // beginning of the block if the value defined by the block argument. + if (Operation *op = value.getDefiningOp()) + b.setInsertionPointAfter(op); + else + b.setInsertionPointToStart(value.getParentBlock()); + + b.create(value.getLoc(), value, b.getI32IntegerAttr(1)); + return success(); +} + +// Calls `addRefCounting` for every reference counted value defined by the +// operation `op` (block arguments and values defined in nested regions). +static LogicalResult walkReferenceCountedValues( + Operation *op, llvm::function_ref addRefCounting) { + // Check that we do not have high level async operations in the IR because + // otherwise reference counting will produce incorrect results after high + // level async operations will be lowered to `async.runtime` + WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult { + if (!isa(op)) + return WalkResult::advance(); + + return op->emitError() + << "async operations must be lowered to async runtime operations"; + }); + + if (checkNoAsyncWalk.wasInterrupted()) + return failure(); + + // Add reference counting to block arguments. + WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { + for (BlockArgument arg : block->getArguments()) + if (isRefCounted(arg.getType())) + if (failed(addRefCounting(arg))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (blockWalk.wasInterrupted()) + return failure(); + + // Add reference counting to operation results. + WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { + for (unsigned i = 0; i < op->getNumResults(); ++i) + if (isRefCounted(op->getResultTypes()[i])) + if (failed(addRefCounting(op->getResult(i)))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (opWalk.wasInterrupted()) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Automatic reference counting based on the liveness analysis. +//===----------------------------------------------------------------------===// + namespace { class AsyncRuntimeRefCountingPass @@ -356,21 +429,9 @@ LogicalResult AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) { - OpBuilder builder(value.getContext()); - Location loc = value.getLoc(); - - // Set inserton point after the operation producing a value, or at the - // beginning of the block if the value defined by the block argument. - if (Operation *op = value.getDefiningOp()) - builder.setInsertionPointAfter(op); - else - builder.setInsertionPointToStart(value.getParentBlock()); - - // Drop the reference count immediately if the value has no uses. - if (value.getUses().empty()) { - builder.create(loc, value, builder.getI32IntegerAttr(1)); + // Short-circuit reference counting for values without uses. + if (succeeded(dropRefIfNoUses(value))) return success(); - } // Add `drop_ref` operations based on the liveness analysis. if (failed(addDropRefAfterLastUse(value))) @@ -388,53 +449,114 @@ } void AsyncRuntimeRefCountingPass::runOnOperation() { - Operation *op = getOperation(); + auto functor = [&](Value value) { return addAutomaticRefCounting(value); }; + if (failed(walkReferenceCountedValues(getOperation(), functor))) + signalPassFailure(); +} - // Check that we do not have high level async operations in the IR because - // otherwise automatic reference counting will produce incorrect results after - // execute operations will be lowered to `async.runtime` - WalkResult executeOpWalk = op->walk([&](Operation *op) -> WalkResult { - if (!isa(op)) - return WalkResult::advance(); +//===----------------------------------------------------------------------===// +// Reference counting based on the user defined policy. +//===----------------------------------------------------------------------===// - return op->emitError() - << "async operations must be lowered to async runtime operations"; - }); +namespace { - if (executeOpWalk.wasInterrupted()) { - signalPassFailure(); - return; - } +class AsyncRuntimePolicyBasedRefCountingPass + : public AsyncRuntimePolicyBasedRefCountingBase< + AsyncRuntimePolicyBasedRefCountingPass> { +public: + AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); } - // Add reference counting to block arguments. - WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { - for (BlockArgument arg : block->getArguments()) - if (isRefCounted(arg.getType())) - if (failed(addAutomaticRefCounting(arg))) - return WalkResult::interrupt(); + void runOnOperation() override; - return WalkResult::advance(); - }); +private: + // Adds a reference counting operations for all uses of the `value` according + // to the reference counting policy. + LogicalResult addRefCounting(Value value); - if (blockWalk.wasInterrupted()) { - signalPassFailure(); - return; + void initializeDefaultPolicy(); + + llvm::SmallVector(OpOperand &)>> policy; +}; + +} // namespace + +LogicalResult +AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) { + // Short-circuit reference counting for values without uses. + if (succeeded(dropRefIfNoUses(value))) + return success(); + + OpBuilder b(value.getContext()); + + // Consult the user defined policy for every value use. + for (OpOperand &operand : value.getUses()) { + Location loc = operand.getOwner()->getLoc(); + + for (auto &func : policy) { + FailureOr refCount = func(operand); + if (failed(refCount)) + return failure(); + + int cnt = refCount.getValue(); + + // Create `add_ref` operation before the operand owner. + if (cnt > 0) { + b.setInsertionPoint(operand.getOwner()); + b.create(loc, value, b.getI32IntegerAttr(cnt)); + } + + // Create `drop_ref` operation after the operand owner. + if (cnt < 0) { + b.setInsertionPointAfter(operand.getOwner()); + b.create(loc, value, b.getI32IntegerAttr(-cnt)); + } + } } - // Add reference counting to operation results. - WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { - for (unsigned i = 0; i < op->getNumResults(); ++i) - if (isRefCounted(op->getResultTypes()[i])) - if (failed(addAutomaticRefCounting(op->getResult(i)))) - return WalkResult::interrupt(); + return success(); +} - return WalkResult::advance(); +void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() { + policy.push_back([](OpOperand &operand) -> FailureOr { + Operation *op = operand.getOwner(); + Type type = operand.get().getType(); + + bool isToken = type.isa(); + bool isGroup = type.isa(); + bool isValue = type.isa(); + + // Drop reference after async token or group await (sync await) + if (auto await = dyn_cast(op)) + return (isToken || isGroup) ? -1 : 0; + + // Drop reference after async token or group error check (coro await). + if (auto await = dyn_cast(op)) + return (isToken || isGroup) ? -1 : 0; + + // Drop reference after async value load. + if (auto load = dyn_cast(op)) + return isValue ? -1 : 0; + + // Drop reference after async token added to the group. + if (auto add = dyn_cast(op)) + return isToken ? -1 : 0; + + return 0; }); +} - if (opWalk.wasInterrupted()) +void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() { + auto functor = [&](Value value) { return addRefCounting(value); }; + if (failed(walkReferenceCountedValues(getOperation(), functor))) signalPassFailure(); } +//----------------------------------------------------------------------------// + std::unique_ptr mlir::createAsyncRuntimeRefCountingPass() { return std::make_unique(); } + +std::unique_ptr mlir::createAsyncRuntimePolicyBasedRefCountingPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir b/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt %s -async-runtime-policy-based-ref-counting | FileCheck %s + +// CHECK-LABEL: @token_await +// CHECK: %[[TOKEN:.*]]: !async.token +func @token_await(%arg0: !async.token) { + // CHECK: async.runtime.await %[[TOKEN]] + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + async.runtime.await %arg0 : !async.token + return +} + +// CHECK-LABEL: @group_await +// CHECK: %[[GROUP:.*]]: !async.group +func @group_await(%arg0: !async.group) { + // CHECK: async.runtime.await %[[GROUP]] + // CHECK: async.runtime.drop_ref %[[GROUP]] {count = 1 : i32} + async.runtime.await %arg0 : !async.group + return +} + +// CHECK-LABEL: @add_token_to_group +// CHECK: %[[GROUP:.*]]: !async.group +// CHECK: %[[TOKEN:.*]]: !async.token +func @add_token_to_group(%arg0: !async.group, %arg1: !async.token) { + // CHECK: async.runtime.add_to_group %[[TOKEN]], %[[GROUP]] + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + async.runtime.add_to_group %arg1, %arg0 : !async.token + return +} + +// CHECK-LABEL: @value_load +// CHECK: %[[VALUE:.*]]: !async.value +func @value_load(%arg0: !async.value) { + // CHECK: async.runtime.load %[[VALUE]] + // CHECK: async.runtime.drop_ref %[[VALUE]] {count = 1 : i32} + %0 = async.runtime.load %arg0 : !async.value + return +} + +// CHECK-LABEL: @error_check +// CHECK: %[[TOKEN:.*]]: !async.token +func @error_check(%arg0: !async.token) { + // CHECK: async.runtime.is_error %[[TOKEN]] + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + %0 = async.runtime.is_error %arg0 : !async.token + return +} diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -11,6 +11,18 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ // RUN: | FileCheck %s --dump-input=always +// RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-to-async-runtime \ +// RUN: -async-runtime-policy-based-ref-counting \ +// RUN: -convert-async-to-llvm \ +// RUN: -convert-scf-to-std \ +// RUN: -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void -O0 \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ +// RUN: | FileCheck %s --dump-input=always + // RUN: mlir-opt %s -async-parallel-for="async-dispatch=false \ // RUN: num-workers=20 \ // RUN: target-block-size=1" \ diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -11,6 +11,18 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ // RUN: | FileCheck %s --dump-input=always +// RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-to-async-runtime \ +// RUN: -async-runtime-policy-based-ref-counting \ +// RUN: -convert-async-to-llvm \ +// RUN: -convert-scf-to-std \ +// RUN: -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void -O0 \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ +// RUN: | FileCheck %s --dump-input=always + // RUN: mlir-opt %s -async-parallel-for="async-dispatch=false \ // RUN: num-workers=20 \ // RUN: target-block-size=1" \