diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -398,51 +398,55 @@ diag::warn_coroutine_handle_address_invalid_return_type) << JustAddress->getType(); - // After the await_suspend call on the awaiter, the coroutine may have - // been destroyed. In that case, we can not store anything to the frame - // from this point on. Hence here we wrap it immediately with a cleanup. This - // could have applied to all await_suspend calls. However doing so causes - // alive objects being destructed for reasons that need further - // investigations. Here we walk-around it temporarily by only doing it after - // the suspend call on the final awaiter (indicated by IsImplicit) where it's - // most common to happen. - // TODO: Properly clean up the temps generated by await_suspend calls. - if (IsImplicit) - JustAddress = S.MaybeCreateExprWithCleanups(JustAddress); + // Clean up temporary objects so that they don't live across suspension points + // unnecessarily. We choose to clean up before the call to + // __builtin_coro_resume so that the cleanup code are not inserted in-between + // the resume call and return instruction, which would interfere with the + // musttail call contract. + JustAddress = S.MaybeCreateExprWithCleanups(JustAddress); return buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_resume, JustAddress); } /// Build calls to await_ready, await_suspend, and await_resume for a co_await /// expression. +/// The generated AST tries to clean up temporary objects as early as +/// possible so that they don't live across suspension points if possible. +/// Having temporary objects living across suspension points unnecessarily can +/// lead to large frame size, and also lead to memory corruptions if the +/// coroutine frame is destroyed after coming back from suspension. This is done +/// by wrapping both the await_ready call and the await_suspend call with +/// ExprWithCleanups. In the end of this function, we also need to explicitly +/// set cleanup state so that the CoawaitExpr is also wrapped with an +/// ExprWithCleanups to clean up the awaiter associated with the co_await +/// expression. static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise, SourceLocation Loc, Expr *E, bool IsImplicit) { OpaqueValueExpr *Operand = new (S.Context) OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E); - // Assume invalid until we see otherwise. - ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true}; - - ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc); - if (CoroHandleRes.isInvalid()) - return Calls; - Expr *CoroHandle = CoroHandleRes.get(); + // Assume valid until we see otherwise. + // Further operations are responsible for setting IsInalid to true. + ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/false}; - const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"}; - MultiExprArg Args[] = {None, CoroHandle, None}; - for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) { - ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]); - if (Result.isInvalid()) - return Calls; - Calls.Results[I] = Result.get(); - } + using ACT = ReadySuspendResumeResult::AwaitCallType; - // Assume the calls are valid; all further checking should make them invalid. - Calls.IsInvalid = false; + auto BuildSubExpr = [&](ACT CallType, StringRef Func, + MultiExprArg Arg) -> Expr * { + ExprResult Result = buildMemberCall(S, Operand, Loc, Func, Arg); + if (Result.isInvalid()) { + Calls.IsInvalid = true; + return nullptr; + } + Calls.Results[CallType] = Result.get(); + return Result.get(); + }; - using ACT = ReadySuspendResumeResult::AwaitCallType; - CallExpr *AwaitReady = cast(Calls.Results[ACT::ACT_Ready]); + CallExpr *AwaitReady = + cast_or_null(BuildSubExpr(ACT::ACT_Ready, "await_ready", None)); + if (!AwaitReady) + return Calls; if (!AwaitReady->getType()->isDependentType()) { // [expr.await]p3 [...] // — await-ready is the expression e.await_ready(), contextually converted @@ -454,10 +458,21 @@ S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) << AwaitReady->getDirectCallee() << E->getSourceRange(); Calls.IsInvalid = true; - } - Calls.Results[ACT::ACT_Ready] = Conv.get(); + } else + Calls.Results[ACT::ACT_Ready] = S.MaybeCreateExprWithCleanups(Conv.get()); + } + + ExprResult CoroHandleRes = + buildCoroutineHandle(S, CoroPromise->getType(), Loc); + if (CoroHandleRes.isInvalid()) { + Calls.IsInvalid = true; + return Calls; } - CallExpr *AwaitSuspend = cast(Calls.Results[ACT::ACT_Suspend]); + Expr *CoroHandle = CoroHandleRes.get(); + CallExpr *AwaitSuspend = cast_or_null( + BuildSubExpr(ACT::ACT_Suspend, "await_suspend", CoroHandle)); + if (!AwaitSuspend) + return Calls; if (!AwaitSuspend->getType()->isDependentType()) { // [expr.await]p3 [...] // - await-suspend is the expression e.await_suspend(h), which shall be @@ -468,6 +483,11 @@ // Experimental support for coroutine_handle returning await_suspend. if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc, IsImplicit)) + // Note that we don't wrap the expression with ExprWithCleanups here + // because that might interfere with tailcall contract (e.g. inserting + // clean up instructions in-between tailcall and return). Instead + // ExprWithCleanups is wrapped within maybeTailCall() prior to the resume + // call. Calls.Results[ACT::ACT_Suspend] = TailCallSuspend; else { // non-class prvalues always have cv-unqualified types @@ -479,10 +499,17 @@ S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) << AwaitSuspend->getDirectCallee(); Calls.IsInvalid = true; - } + } else + Calls.Results[ACT::ACT_Suspend] = + S.MaybeCreateExprWithCleanups(AwaitSuspend); } } + BuildSubExpr(ACT::ACT_Resume, "await_resume", None); + + // Make sure the awaiter object gets a chance to be cleaned up. + S.Cleanup.setExprNeedsCleanups(true); + return Calls; } diff --git a/clang/test/AST/Inputs/std-coroutine.h b/clang/test/AST/Inputs/std-coroutine.h --- a/clang/test/AST/Inputs/std-coroutine.h +++ b/clang/test/AST/Inputs/std-coroutine.h @@ -5,18 +5,54 @@ namespace std { namespace experimental { -template -struct coroutine_traits { using promise_type = typename Ret::promise_type; }; +template struct coroutine_traits { + using promise_type = typename R::promise_type; +}; + +template struct coroutine_handle; + +template <> struct coroutine_handle { + static coroutine_handle from_address(void *addr) noexcept { + coroutine_handle me; + me.ptr = addr; + return me; + } + void operator()() { resume(); } + void *address() const noexcept { return ptr; } + void resume() const { __builtin_coro_resume(ptr); } + void destroy() const { __builtin_coro_destroy(ptr); } + bool done() const { return __builtin_coro_done(ptr); } + coroutine_handle &operator=(decltype(nullptr)) { + ptr = nullptr; + return *this; + } + coroutine_handle(decltype(nullptr)) : ptr(nullptr) {} + coroutine_handle() : ptr(nullptr) {} + // void reset() { ptr = nullptr; } // add to P0057? + explicit operator bool() const { return ptr; } -template -struct coroutine_handle { - static coroutine_handle from_address(void *) noexcept; +protected: + void *ptr; }; -template <> -struct coroutine_handle { - template - coroutine_handle(coroutine_handle) noexcept; - static coroutine_handle from_address(void *); + +template struct coroutine_handle : coroutine_handle<> { + using coroutine_handle<>::operator=; + + static coroutine_handle from_address(void *addr) noexcept { + coroutine_handle me; + me.ptr = addr; + return me; + } + + Promise &promise() const { + return *reinterpret_cast( + __builtin_coro_promise(ptr, alignof(Promise), false)); + } + static coroutine_handle from_promise(Promise &promise) { + coroutine_handle p; + p.ptr = __builtin_coro_promise(&promise, alignof(Promise), true); + return p; + } }; struct suspend_always { diff --git a/clang/test/AST/coroutine-locals-cleanup.cpp b/clang/test/AST/coroutine-locals-cleanup.cpp new file mode 100644 --- /dev/null +++ b/clang/test/AST/coroutine-locals-cleanup.cpp @@ -0,0 +1,107 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fcoroutines-ts -std=c++14 -fsyntax-only -ast-dump %s | FileCheck %s + +#include "Inputs/std-coroutine.h" + +using namespace std::experimental; + +struct Task { + struct promise_type { + Task get_return_object() noexcept { + return Task{coroutine_handle::from_promise(*this)}; + } + + void return_void() noexcept {} + + struct final_awaiter { + bool await_ready() noexcept { return false; } + coroutine_handle<> await_suspend(coroutine_handle h) noexcept { + h.destroy(); + return {}; + } + void await_resume() noexcept {} + }; + + void unhandled_exception() noexcept {} + + final_awaiter final_suspend() noexcept { return {}; } + + suspend_always initial_suspend() noexcept { return {}; } + + template + auto await_transform(Awaitable &&awaitable) { + return awaitable.co_viaIfAsync(); + } + }; + + using handle_t = coroutine_handle; + + class Awaiter { + public: + explicit Awaiter(handle_t coro) noexcept; + Awaiter(Awaiter &&other) noexcept; + Awaiter(const Awaiter &) = delete; + ~Awaiter(); + + bool await_ready() noexcept { return false; } + handle_t await_suspend(coroutine_handle<> continuation) noexcept; + void await_resume(); + + private: + handle_t coro_; + }; + + Task(handle_t coro) noexcept : coro_(coro) {} + + handle_t coro_; + + Task(const Task &t) = delete; + Task(Task &&t) noexcept; + ~Task(); + Task &operator=(Task t) noexcept; + + Awaiter co_viaIfAsync(); +}; + +static Task foo() { + co_return; +} + +Task bar() { + auto mode = 2; + switch (mode) { + case 1: + co_await foo(); + break; + case 2: + co_await foo(); + break; + default: + break; + } +} + +// CHECK-LABEL: FunctionDecl {{.*}} bar 'Task ()' +// CHECK: SwitchStmt +// CHECK: CaseStmt +// CHECK: ExprWithCleanups {{.*}} 'void' +// CHECK-NEXT: CoawaitExpr +// CHECK-NEXT: MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter' +// CHECK: ExprWithCleanups {{.*}} 'bool' +// CHECK-NEXT: CXXMemberCallExpr {{.*}} 'bool' +// CHECK-NEXT: MemberExpr {{.*}} .await_ready +// CHECK: CallExpr {{.*}} 'void' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(void *)' +// CHECK-NEXT: DeclRefExpr {{.*}} '__builtin_coro_resume' 'void (void *)' +// CHECK-NEXT: ExprWithCleanups {{.*}} 'void *' + +// CHECK: CaseStmt +// CHECK: ExprWithCleanups {{.*}} 'void' +// CHECK-NEXT: CoawaitExpr +// CHECK-NEXT: MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter' +// CHECK: ExprWithCleanups {{.*}} 'bool' +// CHECK-NEXT: CXXMemberCallExpr {{.*}} 'bool' +// CHECK-NEXT: MemberExpr {{.*}} .await_ready +// CHECK: CallExpr {{.*}} 'void' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(void *)' +// CHECK-NEXT: DeclRefExpr {{.*}} '__builtin_coro_resume' 'void (void *)' +// CHECK-NEXT: ExprWithCleanups {{.*}} 'void *' diff --git a/clang/test/CodeGenCoroutines/coro-symmetric-transfer.cpp b/clang/test/CodeGenCoroutines/coro-symmetric-transfer-01.cpp rename from clang/test/CodeGenCoroutines/coro-symmetric-transfer.cpp rename to clang/test/CodeGenCoroutines/coro-symmetric-transfer-01.cpp --- a/clang/test/CodeGenCoroutines/coro-symmetric-transfer.cpp +++ b/clang/test/CodeGenCoroutines/coro-symmetric-transfer-01.cpp @@ -48,7 +48,7 @@ co_return; } -// check that the lifetime of the coroutine handle used to obtain the address is contained within single basic block. +// check that the lifetime of the coroutine handle used to obtain the address is contained within single basic block, and hence does not live across suspension points. // CHECK-LABEL: final.suspend: // CHECK: %[[PTR1:.+]] = bitcast %"struct.std::experimental::coroutines_v1::coroutine_handle.0"* %[[ADDR_TMP:.+]] to i8* // CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 8, i8* %[[PTR1]]) diff --git a/clang/test/CodeGenCoroutines/coro-symmetric-transfer-02.cpp b/clang/test/CodeGenCoroutines/coro-symmetric-transfer-02.cpp new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenCoroutines/coro-symmetric-transfer-02.cpp @@ -0,0 +1,126 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fcoroutines-ts -std=c++14 -O1 -emit-llvm %s -o - -disable-llvm-passes | FileCheck %s + +#include "Inputs/coroutine.h" + +namespace coro = std::experimental::coroutines_v1; + +struct Task { + struct promise_type { + Task get_return_object() noexcept { + return Task{coro::coroutine_handle::from_promise(*this)}; + } + + void return_void() noexcept {} + + struct final_awaiter { + bool await_ready() noexcept { return false; } + coro::coroutine_handle<> await_suspend(coro::coroutine_handle h) noexcept { + h.destroy(); + return {}; + } + void await_resume() noexcept {} + }; + + void unhandled_exception() noexcept {} + + final_awaiter final_suspend() noexcept { return {}; } + + coro::suspend_always initial_suspend() noexcept { return {}; } + + template + auto await_transform(Awaitable &&awaitable) { + return awaitable.co_viaIfAsync(); + } + }; + + using handle_t = coro::coroutine_handle; + + class Awaiter { + public: + explicit Awaiter(handle_t coro) noexcept; + Awaiter(Awaiter &&other) noexcept; + Awaiter(const Awaiter &) = delete; + ~Awaiter(); + + bool await_ready() noexcept { return false; } + handle_t await_suspend(coro::coroutine_handle<> continuation) noexcept; + void await_resume(); + + private: + handle_t coro_; + }; + + Task(handle_t coro) noexcept : coro_(coro) {} + + handle_t coro_; + + Task(const Task &t) = delete; + Task(Task &&t) noexcept; + ~Task(); + Task &operator=(Task t) noexcept; + + Awaiter co_viaIfAsync(); +}; + +static Task foo() { + co_return; +} + +Task bar() { + auto mode = 2; + switch (mode) { + case 1: + co_await foo(); + break; + case 2: + co_await foo(); + break; + default: + break; + } +} + +// CHECK-LABEL: define void @_Z3barv +// CHECK: %[[MODE:.+]] = load i32, i32* %mode +// CHECK-NEXT: switch i32 %[[MODE]], label %{{.+}} [ +// CHECK-NEXT: i32 1, label %[[CASE1:.+]] +// CHECK-NEXT: i32 2, label %[[CASE2:.+]] +// CHECK-NEXT: ] + +// CHECK: [[CASE1]]: +// CHECK: br i1 %{{.+}}, label %[[CASE1_AWAIT_READY:.+]], label %[[CASE1_AWAIT_SUSPEND:.+]] +// CHECK: [[CASE1_AWAIT_SUSPEND]]: +// CHECK-NEXT: %{{.+}} = call token @llvm.coro.save(i8* null) +// CHECK-NEXT: %[[HANDLE11:.+]] = bitcast %"struct.std::experimental::coroutines_v1::coroutine_handle"* %[[TMP1:.+]] to i8* +// CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 8, i8* %[[HANDLE11]]) + +// CHECK: %[[HANDLE12:.+]] = bitcast %"struct.std::experimental::coroutines_v1::coroutine_handle"* %[[TMP1]] to i8* +// CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 8, i8* %[[HANDLE12]]) +// CHECK-NEXT: call void @llvm.coro.resume +// CHECK-NEXT: %{{.+}} = call i8 @llvm.coro.suspend +// CHECK-NEXT: switch i8 %{{.+}}, label %coro.ret [ +// CHECK-NEXT: i8 0, label %[[CASE1_AWAIT_READY]] +// CHECK-NEXT: i8 1, label %[[CASE1_AWAIT_CLEANUP:.+]] +// CHECK-NEXT: ] +// CHECK: [[CASE1_AWAIT_CLEANUP]]: +// make sure that the awaiter eventually gets cleaned up. +// CHECK: call void @{{.+Awaiter.+}} + +// CHECK: [[CASE2]]: +// CHECK: br i1 %{{.+}}, label %[[CASE2_AWAIT_READY:.+]], label %[[CASE2_AWAIT_SUSPEND:.+]] +// CHECK: [[CASE2_AWAIT_SUSPEND]]: +// CHECK-NEXT: %{{.+}} = call token @llvm.coro.save(i8* null) +// CHECK-NEXT: %[[HANDLE21:.+]] = bitcast %"struct.std::experimental::coroutines_v1::coroutine_handle"* %[[TMP2:.+]] to i8* +// CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 8, i8* %[[HANDLE21]]) + +// CHECK: %[[HANDLE22:.+]] = bitcast %"struct.std::experimental::coroutines_v1::coroutine_handle"* %[[TMP2]] to i8* +// CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 8, i8* %[[HANDLE22]]) +// CHECK-NEXT: call void @llvm.coro.resume +// CHECK-NEXT: %{{.+}} = call i8 @llvm.coro.suspend +// CHECK-NEXT: switch i8 %{{.+}}, label %coro.ret [ +// CHECK-NEXT: i8 0, label %[[CASE2_AWAIT_READY]] +// CHECK-NEXT: i8 1, label %[[CASE2_AWAIT_CLEANUP:.+]] +// CHECK-NEXT: ] +// CHECK: [[CASE2_AWAIT_CLEANUP]]: +// make sure that the awaiter eventually gets cleaned up. +// CHECK: call void @{{.+Awaiter.+}}