diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -12,8 +12,10 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -571,10 +573,22 @@ << " functions built from async.execute operations\n"; }); + // Returns true if operation is inside the coroutine. + auto isInCoroutine = [&](Operation *op) -> bool { + auto parentFunc = op->getParentOfType(); + return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); + }; + // Lower async operations to async.runtime operations. MLIRContext *ctx = module->getContext(); RewritePatternSet asyncPatterns(ctx); + // Conversion to async runtime augments original CFG with the coroutine CFG, + // and we have to make sure that structured control flow operations with async + // operations in nested regions will be converted to branch-based control flow + // before we add the coroutine basic blocks. + populateLoopToStdConversionPatterns(asyncPatterns); + // Async lowering does not use type converter because it must preserve all // types for async.runtime operations. asyncPatterns.add(ctx); @@ -591,12 +605,22 @@ runtimeTarget.addIllegalOp(); runtimeTarget.addIllegalOp(); + // Decide if structured control flow has to be lowered to branch-based CFG. + runtimeTarget.addDynamicallyLegalDialect([&](Operation *op) { + auto walkResult = op->walk([&](Operation *nested) { + bool isAsync = isa(nested->getDialect()); + return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() + : WalkResult::advance(); + }); + return !walkResult.wasInterrupted(); + }); + runtimeTarget.addLegalOp(); + // Assertions must be converted to runtime errors inside async functions. runtimeTarget.addDynamicallyLegalOp([&](AssertOp op) -> bool { auto func = op->getParentOfType(); return outlinedFunctions.find(func) == outlinedFunctions.end(); }); - runtimeTarget.addLegalOp(); if (failed(applyPartialConversion(module, runtimeTarget, std::move(asyncPatterns)))) { diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -15,6 +15,7 @@ MLIRAsync MLIRPass MLIRSCF + MLIRSCFToStandard MLIRStandard MLIRTransforms MLIRTransformUtils diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -374,3 +374,35 @@ // CHECK: ^[[SUSPEND]]: // CHECK: async.coro.end %[[HDL]] // CHECK: return %[[TOKEN]] + +// ----- +// Structured control flow operations with async operations in the body must be +// lowered to branch-based control flow to enable coroutine CFG rewrite. + +// CHECK-LABEL: @lower_scf_to_cfg +func @lower_scf_to_cfg(%arg0: f32, %arg1: memref<1xf32>, %arg2: i1) { + %token0 = async.execute { async.yield } + %token1 = async.execute { + scf.if %arg2 { + async.await %token0 : !async.token + } else { + async.await %token0 : !async.token + } + async.yield + } + return +} + +// Function outlined from the first async.execute operation. +// CHECK-LABEL: func private @async_execute_fn( +// CHECK-SAME: -> !async.token + +// Function outlined from the second async.execute operation. +// CHECK-LABEL: func private @async_execute_fn_0( +// CHECK: %[[TOKEN:.*]]: !async.token +// CHECK: %[[FLAG:.*]]: i1 +// CHECK-SAME: -> !async.token + +// Check that structured control flow lowered to CFG. +// CHECK-NOT: scf.if +// CHECK: cond_br %[[FLAG]] diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1791,6 +1791,7 @@ ":IR", ":Pass", ":SCFDialect", + ":SCFToStandard", ":StandardOps", ":Support", ":TransformUtils",