diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -149,6 +149,9 @@ because there is no way to fully qualify the enumerator name, so this "extension" was unintentional and useless. This fixes `Issue 42372 `_. +- Clang shouldn't lookup allocation function in global scope for coroutines + in case it found the allocation function name in the promise_type body. + This fixes Issue `Issue 54881 `_. Improvements to Clang's diagnostics ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11245,6 +11245,9 @@ "this coroutine may be split into pieces; not every piece is guaranteed to be inlined" >, InGroup; +def err_coroutine_unusable_new : Error< + "'operator new' provided by %0 is not usable with the function signature of %1" +>; } // end of coroutines issue category let CategoryName = "Documentation Issue" in { diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -1308,10 +1308,33 @@ PlacementArgs.push_back(PDRefExpr.get()); } - S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class, - /*DeleteScope*/ Sema::AFS_Both, PromiseType, - /*isArray*/ false, PassAlignment, PlacementArgs, - OperatorNew, UnusedResult, /*Diagnose*/ false); + + bool PromiseContainNew = [this, &PromiseType]() -> bool { + DeclarationName NewName = + S.getASTContext().DeclarationNames.getCXXOperatorName(OO_New); + LookupResult R(S, NewName, Loc, Sema::LookupOrdinaryName); + + if (PromiseType->isRecordType()) + S.LookupQualifiedName(R, PromiseType->getAsCXXRecordDecl()); + + return !R.empty() && !R.isAmbiguous(); + }(); + + auto LookupAllocationFunction = [&]() { + // [dcl.fct.def.coroutine]p9 + // The allocation function's name is looked up by searching for it in the + // scope of the promise type. + // - If any declarations are found, ... + // - Otherwise, a search is performed in the global scope. + Sema::AllocationFunctionScope NewScope = PromiseContainNew ? Sema::AFS_Class : Sema::AFS_Global; + S.FindAllocationFunctions(Loc, SourceRange(), + NewScope, + /*DeleteScope*/ Sema::AFS_Both, PromiseType, + /*isArray*/ false, PassAlignment, PlacementArgs, + OperatorNew, UnusedResult, /*Diagnose*/ false); + }; + + LookupAllocationFunction(); // [dcl.fct.def.coroutine]p9 // If no viable function is found ([over.match.viable]), overload resolution @@ -1319,22 +1342,7 @@ // space required as an argument of type std::size_t. if (!OperatorNew && !PlacementArgs.empty()) { PlacementArgs.clear(); - S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class, - /*DeleteScope*/ Sema::AFS_Both, PromiseType, - /*isArray*/ false, PassAlignment, PlacementArgs, - OperatorNew, UnusedResult, /*Diagnose*/ false); - } - - // [dcl.fct.def.coroutine]p9 - // The allocation function's name is looked up by searching for it in the - // scope of the promise type. - // - If any declarations are found, ... - // - Otherwise, a search is performed in the global scope. - if (!OperatorNew) { - S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Global, - /*DeleteScope*/ Sema::AFS_Both, PromiseType, - /*isArray*/ false, PassAlignment, PlacementArgs, - OperatorNew, UnusedResult); + LookupAllocationFunction(); } bool IsGlobalOverload = @@ -1354,8 +1362,12 @@ OperatorNew, UnusedResult); } - if (!OperatorNew) + if (!OperatorNew) { + if (PromiseContainNew) + S.Diag(Loc, diag::err_coroutine_unusable_new) << PromiseType << &FD; + return false; + } if (RequiresNoThrowAlloc) { const auto *FT = OperatorNew->getType()->castAs(); diff --git a/clang/test/SemaCXX/coroutine-allocs.cpp b/clang/test/SemaCXX/coroutine-allocs.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/coroutine-allocs.cpp @@ -0,0 +1,61 @@ +// RUN: %clang_cc1 %s -std=c++20 -fsyntax-only -verify +#include "Inputs/std-coroutine.h" + +namespace std { +typedef decltype(sizeof(int)) size_t; +} + +struct Allocator {}; + +struct resumable { + struct promise_type { + void *operator new(std::size_t sz, Allocator &); + + resumable get_return_object() { return {}; } + auto initial_suspend() { return std::suspend_always(); } + auto final_suspend() noexcept { return std::suspend_always(); } + void unhandled_exception() {} + void return_void(){}; + }; +}; + +resumable f1() { // expected-error {{'operator new' provided by 'std::coroutine_traits::promise_type' (aka 'resumable::promise_type') is not usable}} + co_return; +} + +// NOTE: Although the argument here is a rvalue reference and the corresponding +// allocation function in resumable::promise_type have lvalue references, it looks +// the signature of f2 is invalid. But according to [dcl.fct.def.coroutine]p4: +// +// In the following, pi is an lvalue of type Pi, where p1 denotes the object +// parameter and pi+1 denotes the ith non-object function parameter for a +// non-static member function. +// +// And [dcl.fct.def.coroutine]p9.1 +// +// overload resolution is performed on a function call created by assembling an argument list. +// The first argument is the amount of space requested, and has type std::size_­t. +// The lvalues p1…pn are the succeeding arguments. +// +// So the acctual type passed to resumable::promise_type::operator new is lvalue +// Allocator. It is allowed to convert a lvalue to a lvalue reference. So the +// following one is valid. +resumable f2(Allocator &&) { + co_return; +} + +resumable f3(Allocator &) { + co_return; +} + +resumable f4(Allocator) { + co_return; +} + +resumable f5(const Allocator) { // expected-error {{operator new' provided by 'std::coroutine_traits::promise_type' (aka 'resumable::promise_type') is not usable}} + co_return; +} + +resumable f6(const Allocator &) { // expected-error {{operator new' provided by 'std::coroutine_traits::promise_type' (aka 'resumable::promise_type') is not usable}} + co_return; +}