diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp @@ -13,6 +13,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Support/Debug.h" @@ -109,6 +110,58 @@ dropRef->isBeforeInBlock(addRef.getOperation())) continue; + // When reference counted value passed to a function as an argument, + // function takes ownership of +1 reference and it will drop it before + // returning. + // + // Example: + // + // %token = ... : !async.token + // + // async.runtime.add_ref %token {count = 1 : i32} : !async.token + // call @pass_token(%token: !async.token, ...) + // + // async.await %token : !async.token + // async.runtime.drop_ref %token {count = 1 : i32} : !async.token + // + // In this example if we'll cancel a pair of reference counting + // operations we might end up with a deallocated token when we'll + // reach `async.await` operation. + Operation *firstFunctionCallUser = nullptr; + Operation *lastNonFunctionCallUser = nullptr; + + for (Operation *user : info.users) { + // `user` operation lies after `addRef` ... + if (user == addRef || user->isBeforeInBlock(addRef)) + continue; + // ... and before `dropRef`. + if (user == dropRef || dropRef->isBeforeInBlock(user)) + break; + + // Find the first function call user of the reference counted value. + Operation *functionCall = dyn_cast(user); + if (functionCall && + (!firstFunctionCallUser || + functionCall->isBeforeInBlock(firstFunctionCallUser))) { + firstFunctionCallUser = functionCall; + continue; + } + + // Find the last regular user of the reference counted value. + if (!functionCall && + (!lastNonFunctionCallUser || + lastNonFunctionCallUser->isBeforeInBlock(user))) { + lastNonFunctionCallUser = user; + continue; + } + } + + // Non function call user after the function call user of the reference + // counted value. + if (firstFunctionCallUser && lastNonFunctionCallUser && + firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser)) + continue; + // Try to cancel the pair of `add_ref` and `drop_ref` operations. auto emplaced = cancellable.try_emplace(dropRef.getOperation(), addRef.getOperation()); diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -13,8 +13,9 @@ LINK_LIBS PUBLIC MLIRIR MLIRAsync - MLIRSCF MLIRPass + MLIRSCF + MLIRStandard MLIRTransforms MLIRTransformUtils ) diff --git a/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir b/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir --- a/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir +++ b/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir @@ -53,3 +53,17 @@ // CHECK: return return } + +// CHECK-LABEL: @not_cancellable_operations_0 +func @not_cancellable_operations_0(%arg0: !async.token) { + // CHECK: add_ref + async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: call @consume_toke + call @consume_token(%arg0): (!async.token) -> () + // CHECK: async.runtime.await + async.runtime.await %arg0 : !async.token + // CHECK: async.runtime.drop_ref + async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: return + return +} diff --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir @@ -3,7 +3,7 @@ // RUN: -async-parallel-for \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -std-expand \ diff --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir @@ -2,13 +2,13 @@ // RUN: -async-parallel-for \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-linalg-to-loops \ // RUN: -convert-scf-to-std \ // RUN: -std-expand \ // RUN: -convert-vector-to-llvm \ -// RUN: -convert-std-to-llvm \ +// RUN: -convert-std-to-llvm -print-ir-after-all \ // RUN: | mlir-cpu-runner \ // RUN: -e entry -entry-point-result=void -O3 \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ @@ -20,7 +20,7 @@ // RUN: -async-parallel-for=async-dispatch=false \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-linalg-to-loops \ // RUN: -convert-scf-to-std \ diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -async-parallel-for \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ @@ -16,7 +16,7 @@ // RUN: target-block-size=1" \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -async-parallel-for \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ @@ -16,7 +16,7 @@ // RUN: target-block-size=1" \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ -// FIXME: -async-runtime-ref-counting-opt \ +// RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \