Index: clang/docs/ReleaseNotes.rst =================================================================== --- clang/docs/ReleaseNotes.rst +++ clang/docs/ReleaseNotes.rst @@ -138,6 +138,10 @@ class, which can result in miscompiles in some cases. - Fix crash on use of a variadic overloaded operator. (`#42535 _`) +- Fixed an issue that the conditional access to local variables of the awaiter + after leaking the coroutine handle in the await_suspend may be converted to + unconditional access incorrectly. + (`#56301 `_) Bug Fixes to Compiler Builtins ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Index: clang/lib/CodeGen/CGCall.cpp =================================================================== --- clang/lib/CodeGen/CGCall.cpp +++ clang/lib/CodeGen/CGCall.cpp @@ -5484,6 +5484,26 @@ Attrs.addFnAttribute(getLLVMContext(), llvm::Attribute::AlwaysInline); } + // When we're emitting suspend block for C++20 coroutines, we need to be sure + // that the call to the `await_suspend()` may not get inlined until the + // coroutine got splitted in case the `await_suspend` may leak the coroutine + // handle. + // + // This is necessary since the standards specifies that the coroutine is + // considered to be suspended after we enter the await_suspend block. So that + // we need to make sure we don't update the coroutine handle during the + // execution of the await_suspend. To achieve this, we need to prevent the + // await_suspend get inlined before CoroSplit pass. + // + // We can omit the `NoInline` attribute in case we are sure the await_suspend + // call won't leak the coroutine handle so that the middle end can get more + // optimization opportunities. + // + // TODO: We should try to remove the `NoInline` attribute after CoroSplit + // pass. + if (inSuspendBlock() && maySuspendLeakCoroutineHandle()) + Attrs = Attrs.addFnAttribute(getLLVMContext(), llvm::Attribute::NoInline); + // Disable inlining inside SEH __try blocks. if (isSEHTryScope()) { Attrs = Attrs.addFnAttribute(getLLVMContext(), llvm::Attribute::NoInline); Index: clang/lib/CodeGen/CGCoroutine.cpp =================================================================== --- clang/lib/CodeGen/CGCoroutine.cpp +++ clang/lib/CodeGen/CGCoroutine.cpp @@ -12,9 +12,10 @@ #include "CGCleanup.h" #include "CodeGenFunction.h" -#include "llvm/ADT/ScopeExit.h" #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtVisitor.h" +#include "clang/AST/TypeVisitor.h" +#include "llvm/ADT/ScopeExit.h" using namespace clang; using namespace CodeGen; @@ -139,6 +140,164 @@ return true; } +namespace { +// We need a TypeVisitor to find the actual awaiter declaration. +// We can't use (CoroutineSuspendExpr).getCommonExpr()->getType() directly +// since its type may be AutoType, ElaboratedType, ... +class AwaiterTypeFinder : public TypeVisitor { + CXXRecordDecl *Result = nullptr; + +public: + typedef TypeVisitor Inherited; + + void Visit(const CoroutineSuspendExpr &S) { + Visit(S.getCommonExpr()->getType()); + } + + bool IsRecordEmpty() { + assert(Result && "Why can't we find the record type from the common " + "expression of a coroutine suspend expression? " + "Maybe we missed some types or the Sema get something " + "incorrect"); + + // In a release build without assertions enabled, return false directly + // to give users better user experience. It doesn't matter with the + // correctness but 1 byte memory overhead. +#ifdef NDEBUG + if (!Result) + return false; +#endif + + return Result->field_empty(); + } + + // Following off should only be called by Inherited. +public: + void Visit(QualType Type) { Visit(Type.getTypePtr()); } + + void Visit(const Type *T) { Inherited::Visit(T); } + + void VisitDeducedType(const DeducedType *T) { Visit(T->getDeducedType()); } + + void VisitTypedefType(const TypedefType *T) { + Visit(T->getDecl()->getUnderlyingType()); + } + + void VisitElaboratedType(const ElaboratedType *T) { + Visit(T->getNamedType()); + } + + void VisitReferenceType(const ReferenceType *T) { + Visit(T->getPointeeType()); + } + + void VisitTemplateSpecializationType(const TemplateSpecializationType *T) { + // In the case the type is sugared, we can only see InjectedClassNameType, + // which doesn't contain the definition information we need. + if (T->desugar().getTypePtr() != T) { + Visit(T->desugar().getTypePtr()); + return; + } + + TemplateName Name = T->getTemplateName(); + TemplateDecl *TD = Name.getAsTemplateDecl(); + + if (!TD) + return; + + if (auto *TypedD = dyn_cast(TD->getTemplatedDecl())) + Visit(TypedD->getTypeForDecl()); + } + + void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *T) { + Visit(T->getReplacementType()); + } + + void VisitInjectedClassNameType(const InjectedClassNameType *T) { + VisitCXXRecordDecl(T->getDecl()); + } + + void VisitCXXRecordDecl(CXXRecordDecl *Candidate) { + assert(Candidate); + +#ifdef NDEBUG + Result = Candidate; +#else + // Double check that the type we found is an awaiter class type. + // We only do this in debug mode since: + // The Sema should diagnose earlier in such cases. So this may + // be a waste of time in most cases. + // We just want to make sure our assumption is correct. + + auto HasMember = [](CXXRecordDecl *Candidate, llvm::StringRef Name, + auto HasMember) { + Candidate = Candidate->getDefinition(); + if (!Candidate) + return false; + + ASTContext &Context = Candidate->getASTContext(); + + auto IdenIter = Context.Idents.find(Name); + if (IdenIter == Context.Idents.end()) + return false; + + if (!Candidate->lookup(DeclarationName(IdenIter->second)).empty()) + return true; + + return llvm::any_of( + Candidate->bases(), [Name, &HasMember](CXXBaseSpecifier &Specifier) { + auto *RD = cast( + Specifier.getType()->getAs()->getDecl()); + return HasMember(RD, Name, HasMember); + }); + }; + + bool FoundAwaitReady = HasMember(Candidate, "await_ready", HasMember); + bool FoundAwaitSuspend = HasMember(Candidate, "await_suspend", HasMember); + bool FoundAwaitResume = HasMember(Candidate, "await_resume", HasMember); + + assert(FoundAwaitReady && FoundAwaitSuspend && FoundAwaitResume); + Result = Candidate; +#endif + } + + void VisitRecordType(const RecordType *RT) { + assert(isa(RT->getDecl())); + VisitCXXRecordDecl(cast(RT->getDecl())); + } + + void VisitType(const Type *T) {} +}; +} // namespace + +/// Return true when the await-suspend +/// (`awaiter.await_suspend(std::coroutine_handle)` expression) may leak the +/// coroutine handle. Return false only when the await-suspend won't leak the +/// coroutine handle for sure. +/// +/// While it is always safe to return true, return falses can bring better +/// performances. +/// +/// The middle end can't understand that the relationship between local +/// variables between local variables with the coroutine handle until CoroSplit +/// pass. However, there are a lot optimizations before CoroSplit. Luckily, it +/// is not so bothering since the C++ languages doesn't allow the programmers to +/// access the coroutine handle except in await_suspend. +/// +/// See https://github.com/llvm/llvm-project/issues/56301 and +/// https://reviews.llvm.org/D157070 for the example and the full discussion. +static bool MaySuspendLeak(CoroutineSuspendExpr const &S) { + AwaiterTypeFinder Finder; + Finder.Visit(S); + // In case the awaiter type is empty, the suspend wouldn't leak the coroutine + // handle. + // + // TODO: We can improve this by looking into the implementation of + // await-suspend and see if the coroutine handle is passed to foreign + // functions. + return !Finder.IsRecordEmpty(); +} + // Emit suspend expression which roughly looks like: // // auto && x = CommonExpr(); @@ -199,8 +358,11 @@ auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr}); CGF.CurCoro.InSuspendBlock = true; + CGF.CurCoro.MaySuspendLeak = MaySuspendLeak(S); auto *SuspendRet = CGF.EmitScalarExpr(S.getSuspendExpr()); CGF.CurCoro.InSuspendBlock = false; + CGF.CurCoro.MaySuspendLeak = false; + if (SuspendRet != nullptr && SuspendRet->getType()->isIntegerTy(1)) { // Veto suspension if requested by bool returning await_suspend. BasicBlock *RealSuspendBlock = Index: clang/lib/CodeGen/CodeGenFunction.h =================================================================== --- clang/lib/CodeGen/CodeGenFunction.h +++ clang/lib/CodeGen/CodeGenFunction.h @@ -334,6 +334,7 @@ struct CGCoroInfo { std::unique_ptr Data; bool InSuspendBlock = false; + bool MaySuspendLeak = false; CGCoroInfo(); ~CGCoroInfo(); }; @@ -347,6 +348,10 @@ return isCoroutine() && CurCoro.InSuspendBlock; } + bool maySuspendLeakCoroutineHandle() const { + return isCoroutine() && CurCoro.MaySuspendLeak; + } + /// CurGD - The GlobalDecl for the current function being compiled. GlobalDecl CurGD; Index: clang/test/CodeGenCoroutines/coro-awaiter-noinline-suspend.cpp =================================================================== --- /dev/null +++ clang/test/CodeGenCoroutines/coro-awaiter-noinline-suspend.cpp @@ -0,0 +1,207 @@ +// Tests that we can mark await-suspend as noinline correctly. +// +// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -emit-llvm -o - %s \ +// RUN: -disable-llvm-passes | FileCheck %s + +#include "Inputs/coroutine.h" + +struct Task { + struct promise_type { + struct FinalAwaiter { + bool await_ready() const noexcept { return false; } + template + std::coroutine_handle<> await_suspend(std::coroutine_handle h) noexcept { + return h.promise().continuation; + } + void await_resume() noexcept {} + }; + + Task get_return_object() noexcept { + return std::coroutine_handle::from_promise(*this); + } + + std::suspend_always initial_suspend() noexcept { return {}; } + FinalAwaiter final_suspend() noexcept { return {}; } + void unhandled_exception() noexcept {} + void return_void() noexcept {} + + std::coroutine_handle<> continuation; + }; + + Task(std::coroutine_handle handle); + ~Task(); + +private: + std::coroutine_handle handle; +}; + +struct StatefulAwaiter { + int value; + bool await_ready() const noexcept { return false; } + template + void await_suspend(std::coroutine_handle h) noexcept {} + void await_resume() noexcept {} +}; + +typedef std::suspend_always NoStateAwaiter; +using AnotherStatefulAwaiter = StatefulAwaiter; + +template +struct TemplatedAwaiter { + T value; + bool await_ready() const noexcept { return false; } + template + void await_suspend(std::coroutine_handle h) noexcept {} + void await_resume() noexcept {} +}; + + +class Awaitable {}; +StatefulAwaiter operator co_await(Awaitable) { + return StatefulAwaiter{}; +} + +StatefulAwaiter GlobalAwaiter; +class Awaitable2 {}; +StatefulAwaiter& operator co_await(Awaitable2) { + return GlobalAwaiter; +} + +Task testing() { + co_await std::suspend_always{}; + co_await StatefulAwaiter{}; + co_await AnotherStatefulAwaiter{}; + + // Test lvalue case. + StatefulAwaiter awaiter; + co_await awaiter; + + // The explicit call to await_suspend is not considered suspended. + awaiter.await_suspend(std::coroutine_handle::from_address(nullptr)); + + co_await TemplatedAwaiter{}; + TemplatedAwaiter TemplatedAwaiterInstace; + co_await TemplatedAwaiterInstace; + + co_await Awaitable{}; + co_await Awaitable2{}; +} + +// CHECK-LABEL: @_Z7testingv + +// Check `co_await __promise__.initial_suspend();` Since it returns std::suspend_always, +// which is an empty class, we shouldn't generate optimization blocker for it. +// CHECK: call token @llvm.coro.save +// CHECK: call void @_ZNSt14suspend_always13await_suspendESt16coroutine_handleIvE{{.*}}#[[NORMAL_ATTR:[0-9]+]] + +// Check the `co_await std::suspend_always{};` expression. We shouldn't emit the optimization +// blocker for it since it is an empty class. +// CHECK: call token @llvm.coro.save +// CHECK: call void @_ZNSt14suspend_always13await_suspendESt16coroutine_handleIvE{{.*}}#[[NORMAL_ATTR]] + +// Check `co_await StatefulAwaiter{};`. We need to emit the optimization blocker since +// the awaiter is not empty. +// CHECK: call token @llvm.coro.save +// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR:[0-9]+]] + +// Check `co_await AnotherStatefulAwaiter{};` to make sure that we can handle TypedefTypes. +// CHECK: call token @llvm.coro.save +// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]] + +// Check `co_await awaiter;` to make sure we can handle lvalue cases. +// CHECK: call token @llvm.coro.save +// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]] + +// Check `awaiter.await_suspend(...)` to make sure the explicit call the await_suspend won't be marked as noinline +// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIvEEvSt16coroutine_handleIT_E{{.*}}#[[NORMAL_ATTR]] + +// Check `co_await TemplatedAwaiter{};` to make sure we can handle specialized template +// type. +// CHECK: call token @llvm.coro.save +// CHECK: call void @_ZN16TemplatedAwaiterIiE13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]] + +// Check `co_await TemplatedAwaiterInstace;` to make sure we can handle the lvalue from +// specialized template type. +// CHECK: call token @llvm.coro.save +// CHECK: call void @_ZN16TemplatedAwaiterIiE13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]] + +// Check `co_await Awaitable{};` to make sure we can handle awaiter returned by +// `operator co_await`; +// CHECK: call token @llvm.coro.save +// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]] + +// Check `co_await Awaitable2{};` to make sure we can handle awaiter returned by +// `operator co_await` which returns a reference; +// CHECK: call token @llvm.coro.save +// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]] + +// Check `co_await __promise__.final_suspend();`. We don't emit an blocker here since it is +// empty. +// CHECK: call token @llvm.coro.save +// CHECK: call ptr @_ZN4Task12promise_type12FinalAwaiter13await_suspendIS0_EESt16coroutine_handleIvES3_IT_E{{.*}}#[[NORMAL_ATTR]] + +struct AwaitTransformTask { + struct promise_type { + struct FinalAwaiter { + bool await_ready() const noexcept { return false; } + template + std::coroutine_handle<> await_suspend(std::coroutine_handle h) noexcept { + return h.promise().continuation; + } + void await_resume() noexcept {} + }; + + AwaitTransformTask get_return_object() noexcept { + return std::coroutine_handle::from_promise(*this); + } + + std::suspend_always initial_suspend() noexcept { return {}; } + FinalAwaiter final_suspend() noexcept { return {}; } + void unhandled_exception() noexcept {} + void return_void() noexcept {} + + template + auto await_transform(Awaitable &&awaitable) { + return awaitable; + } + + std::coroutine_handle<> continuation; + }; + + AwaitTransformTask(std::coroutine_handle handle); + ~AwaitTransformTask(); + +private: + std::coroutine_handle handle; +}; + +struct awaitableWithGetAwaiter { + bool await_ready() const noexcept { return false; } + template + void await_suspend(std::coroutine_handle h) noexcept {} + void await_resume() noexcept {} +}; + +AwaitTransformTask testingWithAwaitTransform() { + co_await awaitableWithGetAwaiter{}; +} + +// CHECK-LABEL: @_Z25testingWithAwaitTransformv + +// Init suspend +// CHECK: call token @llvm.coro.save +// CHECK-NOT: call void @llvm.coro.opt.blocker( +// CHECK: call void @_ZNSt14suspend_always13await_suspendESt16coroutine_handleIvE{{.*}}#[[NORMAL_ATTR]] + +// Check `co_await awaitableWithGetAwaiter{};`. +// CHECK: call token @llvm.coro.save +// CHECK-NOT: call void @llvm.coro.opt.blocker( +// Check call void @_ZN23awaitableWithGetAwaiter13await_suspendIN18AwaitTransformTask12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NORMAL_ATTR]] + +// Final suspend +// CHECK: call token @llvm.coro.save +// CHECK-NOT: call void @llvm.coro.opt.blocker( +// CHECK: call ptr @_ZN18AwaitTransformTask12promise_type12FinalAwaiter13await_suspendIS0_EESt16coroutine_handleIvES3_IT_E{{.*}}#[[NORMAL_ATTR]] + +// CHECK-NOT: attributes #[[NORMAL_ATTR]] = noinline +// CHECK: attributes #[[NOINLINE_ATTR]] = {{.*}}noinline Index: clang/test/CodeGenCoroutines/pr56301.cpp =================================================================== --- /dev/null +++ clang/test/CodeGenCoroutines/pr56301.cpp @@ -0,0 +1,85 @@ +// An end-to-end test to make sure things get processed correctly. +// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -emit-llvm -o - %s -O3 | \ +// RUN: FileCheck %s + +#include "Inputs/coroutine.h" + +struct SomeAwaitable { + // Resume the supplied handle once the awaitable becomes ready, + // returning a handle that should be resumed now for the sake of symmetric transfer. + // If the awaitable is already ready, return an empty handle without doing anything. + // + // Defined in another translation unit. Note that this may contain + // code that synchronizees with another thread. + std::coroutine_handle<> Register(std::coroutine_handle<>); +}; + +// Defined in another translation unit. +void DidntSuspend(); + +struct Awaiter { + SomeAwaitable&& awaitable; + bool suspended; + + bool await_ready() { return false; } + + std::coroutine_handle<> await_suspend(const std::coroutine_handle<> h) { + // Assume we will suspend unless proven otherwise below. We must do + // this *before* calling Register, since we may be destroyed by another + // thread asynchronously as soon as we have registered. + suspended = true; + + // Attempt to hand off responsibility for resuming/destroying the coroutine. + const auto to_resume = awaitable.Register(h); + + if (!to_resume) { + // The awaitable is already ready. In this case we know that Register didn't + // hand off responsibility for the coroutine. So record the fact that we didn't + // actually suspend, and tell the compiler to resume us inline. + suspended = false; + return h; + } + + // Resume whatever Register wants us to resume. + return to_resume; + } + + void await_resume() { + // If we didn't suspend, make note of that fact. + if (!suspended) { + DidntSuspend(); + } + } +}; + +struct MyTask{ + struct promise_type { + MyTask get_return_object() { return {}; } + std::suspend_never initial_suspend() { return {}; } + std::suspend_always final_suspend() noexcept { return {}; } + void unhandled_exception(); + + Awaiter await_transform(SomeAwaitable&& awaitable) { + return Awaiter{static_cast(awaitable)}; + } + }; +}; + +MyTask FooBar() { + co_await SomeAwaitable(); +} + +// CHECK-LABEL: @_Z6FooBarv +// CHECK: %[[to_resume:.*]] = {{.*}}call ptr @_ZN13SomeAwaitable8RegisterESt16coroutine_handleIvE +// CHECK-NEXT: %[[to_bool:.*]] = icmp eq ptr %[[to_resume]], null +// CHECK-NEXT: br i1 %[[to_bool]], label %[[then:.*]], label %[[else:.*]] + +// CHECK: [[then]]: +// We only access the coroutine frame conditionally as the sources did. +// CHECK: store i8 0, +// CHECK-NEXT: br label %[[else]] + +// CHECK: [[else]]: +// No more access to the coroutine frame until suspended. +// CHECK-NOT: store +// CHECK: }