diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -29,6 +29,17 @@ std::unique_ptr createAsyncRuntimeRefCountingOptPass(); +std::unique_ptr createAsyncRuntimeBespokeRefCountingPass(); + +// Each policy function must return a number of references that should be added +// or dropped for the given operand: +// +// 0 - no new reference counting operations are required +// +N - `add_ref` operation will be created before the operand owner +// -N - `drop_ref` operation will be created after the operand owner +std::unique_ptr createAsyncRuntimeBespokeRefCountingPass( + llvm::SmallVector(OpOperand &)>> policy); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -66,4 +66,35 @@ let dependentDialects = ["async::AsyncDialect"]; } +def AsyncRuntimeBespokeRefCounting + : Pass<"async-runtime-bespoke-ref-counting"> { + let summary = "Policy based reference counting for Async runtime operations"; + let description = [{ + This pass works at the async runtime abtraction level, after all + `async.execute` and `async.await` operations are lowered to the async + runtime API calls, and async coroutine operations. + + This pass takes user defined policy that specifies where to put reference + counting operations. Currently there is no way to specify the policy using + command line flags. + + The default reference counting policy makes few assumptions: + 1. Async token can be awaited or added to the group only once. + 2. Async value or group can be awaited only once. + + Under these assumptions reference counting only needs to drop reference: + 1. After `async.runtime.await` operation for async tokens and groups + (until error handling is not implemented for the sync await). + 2. After `async.runtime.is_error` operation for async tokens and groups + (this is the last operation in the coroutine resume function). + 3. After `async.runtime.load` operation for async values. + + This pass introduces significanly less runtime overhead compared to the + automatic reference counting. + }]; + + let constructor = "mlir::createAsyncRuntimeBespokeRefCountingPass()"; + let dependentDialects = ["async::AsyncDialect"]; +} + #endif // MLIR_DIALECT_ASYNC_PASSES diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeBespokeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeBespokeRefCounting.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeBespokeRefCounting.cpp @@ -0,0 +1,168 @@ +//===- AsyncRuntimeBespokeRefCounting.cpp - Async Runtime Ref Counting ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements policy based reference counting for Async runtime +// operations and types. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; +using namespace mlir::async; + +#define DEBUG_TYPE "async-runtime-ref-counting" + +namespace { + +class AsyncRuntimeBespokeRefCountingPass + : public AsyncRuntimeBespokeRefCountingBase< + AsyncRuntimeBespokeRefCountingPass> { +public: + AsyncRuntimeBespokeRefCountingPass() { initializeDefaultPolicy(); } + + AsyncRuntimeBespokeRefCountingPass( + llvm::SmallVector(OpOperand &)>> policy) + : policy(policy) {} + + void runOnOperation() override; + +private: + // Adds a reference counting operations for all uses of the `value` according + // to the reference counting policy. + LogicalResult addRefCounting(Value value); + + void initializeDefaultPolicy(); + + llvm::SmallVector(OpOperand &)>> policy; +}; + +} // namespace + +LogicalResult AsyncRuntimeBespokeRefCountingPass::addRefCounting(Value value) { + OpBuilder b(value.getContext()); + + for (OpOperand &operand : value.getUses()) { + Location loc = operand.getOwner()->getLoc(); + + for (auto &func : policy) { + FailureOr refCount = func(operand); + if (failed(refCount)) + return failure(); + + int cnt = refCount.getValue(); + + // Create `add_ref` operation before the operand owner. + if (cnt > 0) { + b.setInsertionPoint(operand.getOwner()); + b.create(loc, value, b.getI32IntegerAttr(cnt)); + } + + // Create `drop_ref` operation after the operand owner. + if (cnt < 0) { + b.setInsertionPointAfter(operand.getOwner()); + b.create(loc, value, b.getI32IntegerAttr(-cnt)); + } + } + } + + return success(); +} + +void AsyncRuntimeBespokeRefCountingPass::initializeDefaultPolicy() { + policy.push_back([](OpOperand &operand) -> FailureOr { + Operation *op = operand.getOwner(); + Type type = operand.get().getType(); + + bool isToken = type.isa(); + bool isGroup = type.isa(); + bool isValue = type.isa(); + + // Drop reference after async token or group await (sync await) + if (auto await = dyn_cast(op)) + return (isToken || isGroup) ? -1 : 0; + + // Drop reference after async token or group error check (coro await). + if (auto await = dyn_cast(op)) + return (isToken || isGroup) ? -1 : 0; + + // Drop reference after async value load. + if (auto load = dyn_cast(op)) + return isValue ? -1 : 0; + + // Drop reference after async token added to the group. + if (auto add = dyn_cast(op)) + return isToken ? -1 : 0; + + return 0; + }); +} + +void AsyncRuntimeBespokeRefCountingPass::runOnOperation() { + Operation *op = getOperation(); + + // Check that we do not have high level async operations in the IR because + // otherwise automatic reference counting will produce incorrect results after + // execute operations will be lowered to `async.runtime` + WalkResult executeOpWalk = op->walk([&](Operation *op) -> WalkResult { + if (!isa(op)) + return WalkResult::advance(); + + return op->emitError() + << "async operations must be lowered to async runtime operations"; + }); + + if (executeOpWalk.wasInterrupted()) + return signalPassFailure(); + + // Add reference counting to block arguments. + WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { + for (BlockArgument arg : block->getArguments()) + if (isRefCounted(arg.getType())) + if (failed(addRefCounting(arg))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (blockWalk.wasInterrupted()) { + signalPassFailure(); + return; + } + + // Add reference counting to operation results. + WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { + for (unsigned i = 0; i < op->getNumResults(); ++i) + if (isRefCounted(op->getResultTypes()[i])) + if (failed(addRefCounting(op->getResult(i)))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (opWalk.wasInterrupted()) + signalPassFailure(); +} + +std::unique_ptr mlir::createAsyncRuntimeBespokeRefCountingPass() { + return std::make_unique(); +} + +std::unique_ptr createAsyncRuntimeBespokeRefCountingPass( + llvm::SmallVector(OpOperand &)>> policy) { + return std::make_unique( + std::move(policy)); +} diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRAsyncTransforms AsyncParallelFor.cpp + AsyncRuntimeBespokeRefCounting.cpp AsyncRuntimeRefCounting.cpp AsyncRuntimeRefCountingOpt.cpp AsyncToAsyncRuntime.cpp diff --git a/mlir/test/Dialect/Async/async-runtime-bespoke-ref-counting.mlir b/mlir/test/Dialect/Async/async-runtime-bespoke-ref-counting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-runtime-bespoke-ref-counting.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt %s -async-runtime-bespoke-ref-counting | FileCheck %s + +// CHECK-LABEL: @token_await +// CHECK: %[[TOKEN:.*]]: !async.token +func @token_await(%arg0: !async.token) { + // CHECK: async.runtime.await %[[TOKEN]] + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + async.runtime.await %arg0 : !async.token + return +} + +// CHECK-LABEL: @group_await +// CHECK: %[[GROUP:.*]]: !async.group +func @group_await(%arg0: !async.group) { + // CHECK: async.runtime.await %[[GROUP]] + // CHECK: async.runtime.drop_ref %[[GROUP]] {count = 1 : i32} + async.runtime.await %arg0 : !async.group + return +} + +// CHECK-LABEL: @add_token_to_group +// CHECK: %[[GROUP:.*]]: !async.group +// CHECK: %[[TOKEN:.*]]: !async.token +func @add_token_to_group(%arg0: !async.group, %arg1: !async.token) { + // CHECK: async.runtime.add_to_group %[[TOKEN]], %[[GROUP]] + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + async.runtime.add_to_group %arg1, %arg0 : !async.token + return +} + +// CHECK-LABEL: @value_load +// CHECK: %[[VALUE:.*]]: !async.value +func @value_load(%arg0: !async.value) { + // CHECK: async.runtime.load %[[VALUE]] + // CHECK: async.runtime.drop_ref %[[VALUE]] {count = 1 : i32} + %0 = async.runtime.load %arg0 : !async.value + return +} + +// CHECK-LABEL: @error_check +// CHECK: %[[TOKEN:.*]]: !async.token +func @error_check(%arg0: !async.token) { + // CHECK: async.runtime.is_error %[[TOKEN]] + // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} + %0 = async.runtime.is_error %arg0 : !async.token + return +} diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -11,6 +11,18 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ // RUN: | FileCheck %s --dump-input=always +// RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-to-async-runtime \ +// RUN: -async-runtime-bespoke-ref-counting \ +// RUN: -convert-async-to-llvm \ +// RUN: -convert-scf-to-std \ +// RUN: -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void -O0 \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ +// RUN: | FileCheck %s --dump-input=always + // RUN: mlir-opt %s -async-parallel-for="async-dispatch=false \ // RUN: num-workers=20 \ // RUN: target-block-size=1" \ diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -11,6 +11,18 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ // RUN: | FileCheck %s --dump-input=always +// RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-to-async-runtime \ +// RUN: -async-runtime-bespoke-ref-counting \ +// RUN: -convert-async-to-llvm \ +// RUN: -convert-scf-to-std \ +// RUN: -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void -O0 \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ +// RUN: | FileCheck %s --dump-input=always + // RUN: mlir-opt %s -async-parallel-for="async-dispatch=false \ // RUN: num-workers=20 \ // RUN: target-block-size=1" \