diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -146,6 +146,9 @@ - Implemented DR692, DR1395 and DR1432. Use the ``-fclang-abi-compat=14`` option to get the old partial ordering behavior regarding packs. +- Improved ``-O0`` code generation for calls to ``std::invoke``, and + ``std::invoke_r``. These are now treated as compiler builtins and implemented + directly, rather than instantiating a definition from the standard library. C++20 Feature Support ^^^^^^^^^^^^^^^^^^^^^ diff --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def --- a/clang/include/clang/Basic/Builtins.def +++ b/clang/include/clang/Basic/Builtins.def @@ -1560,6 +1560,8 @@ LANGBUILTIN(__addressof, "v*v&", "zfncT", CXX_LANG) LIBBUILTIN(as_const, "v&v&", "zfncTh", "utility", CXX_LANG) LIBBUILTIN(forward, "v&v&", "zfncTh", "utility", CXX_LANG) +LIBBUILTIN(invoke, "v.", "zfTh", "functional", CXX_LANG) +LIBBUILTIN(invoke_r, "v.", "zfTh", "functional", CXX_LANG) LIBBUILTIN(move, "v&v&", "zfncTh", "utility", CXX_LANG) LIBBUILTIN(move_if_noexcept, "v&v&", "zfncTh", "utility", CXX_LANG) diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -8402,6 +8402,33 @@ "%select{function|block|method|kernel function}0 call, " "expected at most %1, have %2; did you mean %3?">; +def err_invoke_pointer_to_member_too_few_args : Error< + "can't invoke pointer-to-%select{data member|member function}0: " + "'std::invoke%select{|_r}1' must have %select{exactly|at least}0 2 arguments " + "for a pointer-to-%select{data member|member function}0, got %2">; +def err_invoke_pointer_to_member_incompatible_second_arg : Error< + "can't invoke pointer-to-%select{data member|member function}0: expected " + "second argument to be a %select{reference|wrapee|pointer}1 to a class " + "compatible with %2, got %3">; +def err_invoke_pointer_to_member_drops_qualifiers : Error< + "can't invoke pointer-to-member function: '%0' drops '%1' qualifier%s2">; +def err_invoke_pointer_to_member_ref_qualifiers : Error< + "can't invoke pointer-to-member function: '%0' can only be called on an " + "%select{lvalue|rvalue}1">; +def err_invoke_wrong_number_of_args : Error< + "can't invoke %select{function|block|pointer-to-member function}0: expected " + "%1 %select{argument|arguments}2, got %3">; +def err_invoke_function_object : Error< + "can't invoke %0 function object: %select{no|%2}1 suitable " + "overload%s2 found%select{|, which makes choosing ambiguous}1">; +def err_invoke_function_object_deleted : Error< + "can't invoke %select{function|pointer-to-member function|" + "%1 function object}0: chosen overload candidate is deleted">; +def err_invoke_bad_conversion : Error< + "can't invoke %select{function|block|pointer-to-data member|" + "pointer-to-member function|%3 function object}0: return type " + "%1 isn't convertible to %2">; + def err_arc_typecheck_convert_incompatible_pointer : Error< "incompatible pointer types passing retainable parameter of type %0" "to a CF function expecting %1 type">; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -4096,20 +4096,20 @@ ExprResult CreateOverloadedUnaryOp(SourceLocation OpLoc, UnaryOperatorKind Opc, - const UnresolvedSetImpl &Fns, - Expr *input, bool RequiresADL = true); + const UnresolvedSetImpl &Fns, Expr *input, + bool RequiresADL = true, + bool IsStdInvoke = false); void LookupOverloadedBinOp(OverloadCandidateSet &CandidateSet, OverloadedOperatorKind Op, const UnresolvedSetImpl &Fns, ArrayRef Args, bool RequiresADL = true); - ExprResult CreateOverloadedBinOp(SourceLocation OpLoc, - BinaryOperatorKind Opc, - const UnresolvedSetImpl &Fns, - Expr *LHS, Expr *RHS, - bool RequiresADL = true, + ExprResult CreateOverloadedBinOp(SourceLocation OpLoc, BinaryOperatorKind Opc, + const UnresolvedSetImpl &Fns, Expr *LHS, + Expr *RHS, bool RequiresADL = true, bool AllowRewrittenCandidates = true, - FunctionDecl *DefaultedFn = nullptr); + FunctionDecl *DefaultedFn = nullptr, + bool IsStdInvoke = false); ExprResult BuildSynthesizedThreeWayComparison(SourceLocation OpLoc, const UnresolvedSetImpl &Fns, Expr *LHS, Expr *RHS, @@ -4119,17 +4119,16 @@ SourceLocation RLoc, Expr *Base, MultiExprArg Args); - ExprResult BuildCallToMemberFunction(Scope *S, Expr *MemExpr, - SourceLocation LParenLoc, - MultiExprArg Args, - SourceLocation RParenLoc, - Expr *ExecConfig = nullptr, - bool IsExecConfig = false, - bool AllowRecovery = false); - ExprResult - BuildCallToObjectOfClassType(Scope *S, Expr *Object, SourceLocation LParenLoc, - MultiExprArg Args, - SourceLocation RParenLoc); + ExprResult BuildCallToMemberFunction( + Scope *S, Expr *MemExpr, SourceLocation LParenLoc, MultiExprArg Args, + SourceLocation RParenLoc, Expr *ExecConfig = nullptr, + bool IsExecConfig = false, bool AllowRecovery = false, + bool IsStdInvoke = false); + ExprResult BuildCallToObjectOfClassType(Scope *S, Expr *Object, + SourceLocation LParenLoc, + MultiExprArg Args, + SourceLocation RParenLoc, + bool IsStdInvoke = false); ExprResult BuildOverloadedArrowExpr(Scope *S, Expr *Base, SourceLocation OpLoc, @@ -5253,7 +5252,8 @@ const ObjCInterfaceDecl *UnknownObjCClass = nullptr, bool ObjCPropertyAccess = false, bool AvoidPartialAvailabilityChecks = false, - ObjCInterfaceDecl *ClassReciever = nullptr); + ObjCInterfaceDecl *ClassReciever = nullptr, + bool IsStdInvoke = false); void NoteDeletedFunction(FunctionDecl *FD); void NoteDeletedInheritingConstructor(CXXConstructorDecl *CD); bool DiagnosePropertyAccessorMismatch(ObjCPropertyDecl *PD, @@ -5563,9 +5563,9 @@ // Binary/Unary Operators. 'Tok' is the token for the operator. ExprResult CreateBuiltinUnaryOp(SourceLocation OpLoc, UnaryOperatorKind Opc, - Expr *InputExpr); - ExprResult BuildUnaryOp(Scope *S, SourceLocation OpLoc, - UnaryOperatorKind Opc, Expr *Input); + Expr *InputExpr, bool IsStdInvoke = false); + ExprResult BuildUnaryOp(Scope *S, SourceLocation OpLoc, UnaryOperatorKind Opc, + Expr *Input, bool IsStdInvoke = false); ExprResult ActOnUnaryOp(Scope *S, SourceLocation OpLoc, tok::TokenKind Op, Expr *Input); @@ -5712,16 +5712,18 @@ const TemplateArgumentListInfo *TemplateArgs = nullptr); void ActOnDefaultCtorInitializers(Decl *CDtorDecl); - bool ConvertArgumentsForCall(CallExpr *Call, Expr *Fn, - FunctionDecl *FDecl, + bool ConvertArgumentsForCall(CallExpr *Call, Expr *Fn, FunctionDecl *FDecl, const FunctionProtoType *Proto, - ArrayRef Args, - SourceLocation RParenLoc, - bool ExecConfig = false); + ArrayRef Args, SourceLocation RParenLoc, + bool ExecConfig = false, + bool IsStdInvoke = false); void CheckStaticArrayArgument(SourceLocation CallLoc, ParmVarDecl *Param, const Expr *ArgExpr); + ExprResult BuildStdInvokeCall(CallExpr *TheCall, FunctionDecl *FDecl, + unsigned int BuiltinID); + /// ActOnCallExpr - Handle a call to Fn with the specified array of arguments. /// This provides the location of the left/right parens and a list of comma /// locations. @@ -5732,7 +5734,8 @@ MultiExprArg ArgExprs, SourceLocation RParenLoc, Expr *ExecConfig = nullptr, bool IsExecConfig = false, - bool AllowRecovery = false); + bool AllowRecovery = false, + bool IsStdInvoke = false); Expr *BuildBuiltinCallExpr(SourceLocation Loc, Builtin::ID Id, MultiExprArg CallArgs); enum class AtomicArgumentOrder { API, AST }; @@ -5745,7 +5748,8 @@ BuildResolvedCallExpr(Expr *Fn, NamedDecl *NDecl, SourceLocation LParenLoc, ArrayRef Arg, SourceLocation RParenLoc, Expr *Config = nullptr, bool IsExecConfig = false, - ADLCallKind UsesADL = ADLCallKind::NotADL); + ADLCallKind UsesADL = ADLCallKind::NotADL, + bool IsStdInvoke = false); ExprResult ActOnCUDAExecConfigExpr(Scope *S, SourceLocation LLLLoc, MultiExprArg ExecConfig, @@ -5796,10 +5800,11 @@ public: ExprResult ActOnBinOp(Scope *S, SourceLocation TokLoc, tok::TokenKind Kind, Expr *LHSExpr, Expr *RHSExpr); - ExprResult BuildBinOp(Scope *S, SourceLocation OpLoc, - BinaryOperatorKind Opc, Expr *LHSExpr, Expr *RHSExpr); + ExprResult BuildBinOp(Scope *S, SourceLocation OpLoc, BinaryOperatorKind Opc, + Expr *LHSExpr, Expr *RHSExpr, bool IsStdInvoke = false); ExprResult CreateBuiltinBinOp(SourceLocation OpLoc, BinaryOperatorKind Opc, - Expr *LHSExpr, Expr *RHSExpr); + Expr *LHSExpr, Expr *RHSExpr, + bool IsStdInvoke = false); void LookupBinOp(Scope *S, SourceLocation OpLoc, BinaryOperatorKind Opc, UnresolvedSetImpl &Functions); @@ -12166,8 +12171,8 @@ QualType InvalidLogicalVectorOperands(SourceLocation Loc, ExprResult &LHS, ExprResult &RHS); QualType CheckPointerToMemberOperands( // C++ 5.5 - ExprResult &LHS, ExprResult &RHS, ExprValueKind &VK, - SourceLocation OpLoc, bool isIndirect); + ExprResult &LHS, ExprResult &RHS, ExprValueKind &VK, SourceLocation OpLoc, + bool isIndirect, bool IsStdInvoke = false); QualType CheckMultiplyDivideOperands( // C99 6.5.5 ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, bool IsCompAssign, bool IsDivide); diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -8319,6 +8319,8 @@ switch (E->getBuiltinCallee()) { case Builtin::BIas_const: case Builtin::BIforward: + case Builtin::BIinvoke: + case Builtin::BIinvoke_r: case Builtin::BImove: case Builtin::BImove_if_noexcept: if (cast(E->getCalleeDecl())->isConstexpr()) diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -4641,6 +4641,8 @@ case Builtin::BImove_if_noexcept: case Builtin::BIforward: case Builtin::BIas_const: + case Builtin::BIinvoke: + case Builtin::BIinvoke_r: return RValue::get(EmitLValue(E->getArg(0)).getPointer(*this)); case Builtin::BI__GetExceptionInfo: { if (llvm::GlobalVariable *GV = diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -28,6 +28,7 @@ #include "clang/AST/ExprOpenMP.h" #include "clang/AST/FormatString.h" #include "clang/AST/NSAPI.h" +#include "clang/AST/NestedNameSpecifier.h" #include "clang/AST/NonTrivialTypeVisitor.h" #include "clang/AST/OperationKinds.h" #include "clang/AST/RecordLayout.h" @@ -2423,6 +2424,11 @@ } break; } + case Builtin::BIinvoke: + case Builtin::BIinvoke_r: { + TheCallResult = BuildStdInvokeCall(TheCall, FDecl, BuiltinID); + break; + } // OpenCL v2.0, s6.13.16 - Pipe functions case Builtin::BIread_pipe: case Builtin::BIwrite_pipe: diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -9425,6 +9425,9 @@ const auto *FPT = FD->getType()->castAs(); return FPT->getNumParams() == 1 && !FPT->isVariadic(); } + case Builtin::BIinvoke: + case Builtin::BIinvoke_r: + return true; default: return false; diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -222,7 +222,8 @@ const ObjCInterfaceDecl *UnknownObjCClass, bool ObjCPropertyAccess, bool AvoidPartialAvailabilityChecks, - ObjCInterfaceDecl *ClassReceiver) { + ObjCInterfaceDecl *ClassReceiver, + bool IsStdInvoke) { SourceLocation Loc = Locs.front(); if (getLangOpts().CPlusPlus && isa(D)) { // If there were any diagnostics suppressed by template argument deduction, @@ -263,13 +264,15 @@ // See if this is a deleted function. if (FD->isDeleted()) { auto *Ctor = dyn_cast(FD); - if (Ctor && Ctor->isInheritingConstructor()) - Diag(Loc, diag::err_deleted_inherited_ctor_use) - << Ctor->getParent() - << Ctor->getInheritedConstructor().getConstructor()->getParent(); - else - Diag(Loc, diag::err_deleted_function_use); - NoteDeletedFunction(FD); + if (!IsStdInvoke) { + if (Ctor && Ctor->isInheritingConstructor()) + Diag(Loc, diag::err_deleted_inherited_ctor_use) + << Ctor->getParent() + << Ctor->getInheritedConstructor().getConstructor()->getParent(); + else + Diag(Loc, diag::err_deleted_function_use); + NoteDeletedFunction(FD); + } return true; } @@ -5999,13 +6002,12 @@ /// Fn is the function expression. For a C++ member function, this /// routine does not attempt to convert the object argument. Returns /// true if the call is ill-formed. -bool -Sema::ConvertArgumentsForCall(CallExpr *Call, Expr *Fn, - FunctionDecl *FDecl, - const FunctionProtoType *Proto, - ArrayRef Args, - SourceLocation RParenLoc, - bool IsExecConfig) { +bool Sema::ConvertArgumentsForCall(CallExpr *Call, Expr *Fn, + FunctionDecl *FDecl, + const FunctionProtoType *Proto, + ArrayRef Args, + SourceLocation RParenLoc, bool IsExecConfig, + bool IsStdInvoke) { // Bail out early if calling a builtin with custom typechecking. if (FDecl) if (unsigned ID = FDecl->getBuiltinID()) @@ -6027,7 +6029,17 @@ if (Args.size() < NumParams) { if (Args.size() < MinArgs) { TypoCorrection TC; - if (FDecl && (TC = TryTypoCorrectionForCall(*this, Fn, FDecl, Args))) { + if (IsStdInvoke) { + QualType FnType = Fn->getType(); + unsigned Kind = FnType->isRecordType() ? 3 // function object + : isa(Fn) + ? 2 // pointer-to-member-function + : FnType->isBlockPointerType() ? 1 // block + : 0; // function + Diag(Call->getBeginLoc(), diag::err_invoke_wrong_number_of_args) + << Kind << NumParams << (NumParams != 1) << Args.size(); + } else if (FDecl && + (TC = TryTypoCorrectionForCall(*this, Fn, FDecl, Args))) { unsigned diag_id = MinArgs == NumParams && !Proto->isVariadic() ? diag::err_typecheck_call_too_few_args_suggest @@ -6065,7 +6077,17 @@ if (Args.size() > NumParams) { if (!Proto->isVariadic()) { TypoCorrection TC; - if (FDecl && (TC = TryTypoCorrectionForCall(*this, Fn, FDecl, Args))) { + if (IsStdInvoke) { + QualType FnType = Fn->getType(); + unsigned Kind = FnType->isRecordType() ? 3 // function object + : isa(Fn) + ? 2 // pointer-to-member-function + : FnType->isBlockPointerType() ? 1 // block + : 0; // function + Diag(Call->getBeginLoc(), diag::err_invoke_wrong_number_of_args) + << Kind << NumParams << (NumParams != 1) << Args.size(); + } else if (FDecl && + (TC = TryTypoCorrectionForCall(*this, Fn, FDecl, Args))) { unsigned diag_id = MinArgs == NumParams && !Proto->isVariadic() ? diag::err_typecheck_call_too_many_args_suggest @@ -6634,7 +6656,7 @@ ExprResult Sema::BuildCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc, MultiExprArg ArgExprs, SourceLocation RParenLoc, Expr *ExecConfig, bool IsExecConfig, - bool AllowRecovery) { + bool AllowRecovery, bool IsStdInvoke) { // Since this might be a postfix expression, get rid of ParenListExprs. ExprResult Result = MaybeConvertParenListExprToParenExpr(Scope, Fn); if (Result.isInvalid()) return ExprError(); @@ -6685,7 +6707,7 @@ // Determine whether this is a call to an object (C++ [over.call.object]). if (Fn->getType()->isRecordType()) return BuildCallToObjectOfClassType(Scope, Fn, LParenLoc, ArgExprs, - RParenLoc); + RParenLoc, IsStdInvoke); if (Fn->getType() == Context.UnknownAnyTy) { ExprResult result = rebuildUnknownAnyFunction(*this, Fn); @@ -6833,7 +6855,8 @@ CurFPFeatureOverrides()); } return BuildResolvedCallExpr(Fn, NDecl, LParenLoc, ArgExprs, RParenLoc, - ExecConfig, IsExecConfig); + ExecConfig, IsExecConfig, ADLCallKind::NotADL, + IsStdInvoke); } /// BuildBuiltinCallExpr - Create a call to a builtin function specified by Id @@ -6908,7 +6931,8 @@ SourceLocation LParenLoc, ArrayRef Args, SourceLocation RParenLoc, Expr *Config, - bool IsExecConfig, ADLCallKind UsesADL) { + bool IsExecConfig, ADLCallKind UsesADL, + bool IsStdInvoke) { FunctionDecl *FDecl = dyn_cast_or_null(NDecl); unsigned BuiltinID = (FDecl ? FDecl->getBuiltinID() : 0); @@ -7084,7 +7108,7 @@ if (Proto) { if (ConvertArgumentsForCall(TheCall, Fn, FDecl, Proto, Args, RParenLoc, - IsExecConfig)) + IsExecConfig, IsStdInvoke)) return ExprError(); } else { assert(isa(FuncT) && "Unknown FunctionType!"); @@ -14483,7 +14507,8 @@ /// CheckIndirectionOperand - Type check unary indirection (prefix '*'). static QualType CheckIndirectionOperand(Sema &S, Expr *Op, ExprValueKind &VK, - SourceLocation OpLoc) { + SourceLocation OpLoc, + bool IsStdInvoke = false) { if (Op->isTypeDependent()) return S.Context.DependentTy; @@ -14515,8 +14540,9 @@ } if (Result.isNull()) { - S.Diag(OpLoc, diag::err_typecheck_indirection_requires_pointer) - << OpTy << Op->getSourceRange(); + if (!IsStdInvoke) + S.Diag(OpLoc, diag::err_typecheck_indirection_requires_pointer) + << OpTy << Op->getSourceRange(); return QualType(); } @@ -14822,8 +14848,8 @@ /// operator @p Opc at location @c TokLoc. This routine only supports /// built-in operations; ActOnBinOp handles overloaded operators. ExprResult Sema::CreateBuiltinBinOp(SourceLocation OpLoc, - BinaryOperatorKind Opc, - Expr *LHSExpr, Expr *RHSExpr) { + BinaryOperatorKind Opc, Expr *LHSExpr, + Expr *RHSExpr, bool IsStdInvoke) { if (getLangOpts().CPlusPlus11 && isa(RHSExpr)) { // The syntax only allows initializer lists on the RHS of assignment, // so we don't need to worry about accepting invalid code for @@ -14923,7 +14949,7 @@ case BO_PtrMemD: case BO_PtrMemI: ResultTy = CheckPointerToMemberOperands(LHS, RHS, VK, OpLoc, - Opc == BO_PtrMemI); + Opc == BO_PtrMemI, IsStdInvoke); break; case BO_Mul: case BO_Div: @@ -15336,8 +15362,8 @@ /// Build an overloaded binary operator expression in the given scope. static ExprResult BuildOverloadedBinOp(Sema &S, Scope *Sc, SourceLocation OpLoc, - BinaryOperatorKind Opc, - Expr *LHS, Expr *RHS) { + BinaryOperatorKind Opc, Expr *LHS, + Expr *RHS, bool IsStdInvoke = false) { switch (Opc) { case BO_Assign: case BO_DivAssign: @@ -15359,12 +15385,13 @@ // Build the (potentially-overloaded, potentially-dependent) // binary operation. - return S.CreateOverloadedBinOp(OpLoc, Opc, Functions, LHS, RHS); + return S.CreateOverloadedBinOp(OpLoc, Opc, Functions, LHS, RHS, true, true, + nullptr, IsStdInvoke); } ExprResult Sema::BuildBinOp(Scope *S, SourceLocation OpLoc, - BinaryOperatorKind Opc, - Expr *LHSExpr, Expr *RHSExpr) { + BinaryOperatorKind Opc, Expr *LHSExpr, + Expr *RHSExpr, bool IsStdInvoke) { ExprResult LHS, RHS; std::tie(LHS, RHS) = CorrectDelayedTyposInBinOp(*this, Opc, LHSExpr, RHSExpr); if (!LHS.isUsable() || !RHS.isUsable()) @@ -15461,7 +15488,8 @@ // overloadable type. if (LHSExpr->getType()->isOverloadableType() || RHSExpr->getType()->isOverloadableType()) - return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr); + return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr, + IsStdInvoke); } if (getLangOpts().RecoveryAST && @@ -15520,8 +15548,8 @@ } ExprResult Sema::CreateBuiltinUnaryOp(SourceLocation OpLoc, - UnaryOperatorKind Opc, - Expr *InputExpr) { + UnaryOperatorKind Opc, Expr *InputExpr, + bool IsStdInvoke) { ExprResult Input = InputExpr; ExprValueKind VK = VK_PRValue; ExprObjectKind OK = OK_Ordinary; @@ -15571,7 +15599,8 @@ case UO_Deref: { Input = DefaultFunctionArrayLvalueConversion(Input.get()); if (Input.isInvalid()) return ExprError(); - resultType = CheckIndirectionOperand(*this, Input.get(), VK, OpLoc); + resultType = + CheckIndirectionOperand(*this, Input.get(), VK, OpLoc, IsStdInvoke); break; } case UO_Plus: @@ -15791,7 +15820,8 @@ } ExprResult Sema::BuildUnaryOp(Scope *S, SourceLocation OpLoc, - UnaryOperatorKind Opc, Expr *Input) { + UnaryOperatorKind Opc, Expr *Input, + bool IsStdInvoke) { // First things first: handle placeholders so that the // overloaded-operator check considers the right type. if (const BuiltinType *pty = Input->getType()->getAsPlaceholderType()) { @@ -15827,7 +15857,8 @@ if (S && OverOp != OO_None) LookupOverloadedOperatorName(OverOp, S, Functions); - return CreateOverloadedUnaryOp(OpLoc, Opc, Functions, Input); + return CreateOverloadedUnaryOp(OpLoc, Opc, Functions, Input, + /*RequiresADL*/ true, IsStdInvoke); } return CreateBuiltinUnaryOp(OpLoc, Opc, Input); diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp --- a/clang/lib/Sema/SemaExprCXX.cpp +++ b/clang/lib/Sema/SemaExprCXX.cpp @@ -17,7 +17,9 @@ #include "clang/AST/ASTLambda.h" #include "clang/AST/CXXInheritance.h" #include "clang/AST/CharUnits.h" +#include "clang/AST/DeclBase.h" #include "clang/AST/DeclObjC.h" +#include "clang/AST/Expr.h" #include "clang/AST/ExprCXX.h" #include "clang/AST/ExprObjC.h" #include "clang/AST/RecursiveASTVisitor.h" @@ -25,7 +27,9 @@ #include "clang/Basic/AlignedAllocation.h" #include "clang/Basic/DiagnosticSema.h" #include "clang/Basic/PartialDiagnostic.h" +#include "clang/Basic/SourceLocation.h" #include "clang/Basic/TargetInfo.h" +#include "clang/Basic/TokenKinds.h" #include "clang/Basic/TypeTraits.h" #include "clang/Lex/Preprocessor.h" #include "clang/Sema/DeclSpec.h" @@ -5699,6 +5703,182 @@ llvm_unreachable("Unknown type trait or not implemented"); } +static ExprResult HandleInvokePointerToMemberFunction( + Sema &S, const MemberPointerType *CalleeType, bool, + SourceLocation LParenLoc, Expr *F, MultiExprArg Args, + SourceLocation RParenLoc) { + assert(CalleeType->isMemberFunctionPointer()); + ExprResult B = S.BuildBinOp(S.getCurScope(), LParenLoc, + BinaryOperatorKind::BO_PtrMemD, Args[0], F, true); + if (B.isInvalid()) { + return ExprError(); + } + + return S.BuildCallToMemberFunction( + S.getCurScope(), B.get(), LParenLoc, Args.drop_front(), RParenLoc, + /*ExecConfig*/ nullptr, /*IsExecConfig*/ false, /*AllowRecovery*/ false, + /*IsStdInvoke*/ true); +} + +static ExprResult +HandleInvokePointerToDataMember(Sema &S, const MemberPointerType *CalleeType, + bool IsInvokeR, SourceLocation LParenLoc, + Expr *F, MultiExprArg Args, + SourceLocation RParenLoc) { + return S.BuildBinOp(S.getCurScope(), LParenLoc, + BinaryOperatorKind::BO_PtrMemD, Args[0], F); +} + +static ExprResult +HandleInvokePointerToMember(Sema &S, const MemberPointerType *CalleeType, + bool IsInvokeR, SourceLocation LParenLoc, Expr *F, + MultiExprArg Args, SourceLocation RParenLoc) { + auto *Fn = CalleeType->isMemberFunctionPointer() + ? HandleInvokePointerToMemberFunction + : HandleInvokePointerToDataMember; + return Fn(S, CalleeType, IsInvokeR, LParenLoc, F, Args, RParenLoc); +} + +static ExprResult UnwrapReferenceWrapper(Sema &S, QualType &FirstArgType, + Expr *Arg) { + auto *D = dyn_cast( + FirstArgType.getDesugaredType(S.Context)->getAsRecordDecl()); + FirstArgType = S.BuiltinAddReference(D->getTemplateArgs().get(0).getAsType(), + Sema::UTTKind::AddLvalueReference, + Arg->getExprLoc()); + return S.BuildCXXNamedCast(Arg->getBeginLoc(), tok::kw_static_cast, + S.Context.getTrivialTypeSourceInfo(FirstArgType), + Arg, {}, Arg->getEndLoc()); +} + +static ExprResult HandleInvokePointerToMember(Sema &S, bool IsInvokeR, + QualType CalleeType, + SourceLocation LParenLoc, Expr *F, + MultiExprArg Args, + SourceLocation RParenLoc) { + auto *PtrToMember = CalleeType->getAs(); + if (Args.size() == 0 || + (Args.size() > 1 && CalleeType->isMemberDataPointerType())) { + S.Diag(LParenLoc, diag::err_invoke_pointer_to_member_too_few_args) + << PtrToMember->isMemberFunctionPointer() << IsInvokeR << 1; + return ExprError(); + } + + QualType FirstArgType = Args[0]->getType(); + QualType ClassType = QualType(PtrToMember->getClass(), 0); + bool IsBase = EvaluateBinaryTypeTrait(S, TypeTrait::BTT_IsBaseOf, ClassType, + FirstArgType.getNonReferenceType(), {}); + if (IsBase) + return HandleInvokePointerToMember(S, PtrToMember, IsInvokeR, LParenLoc, F, + Args, RParenLoc); + + if (RecordDecl *D = FirstArgType->getAsCXXRecordDecl()) { + bool IsReferenceWrapper = + D->isInStdNamespace() && D->getName() == "reference_wrapper"; + if (IsReferenceWrapper) { + Args[0] = UnwrapReferenceWrapper(S, FirstArgType, Args[0]).get(); + FirstArgType = Args[0]->getType(); + } + + IsBase = EvaluateBinaryTypeTrait(S, TypeTrait::BTT_IsBaseOf, ClassType, + FirstArgType, {}); + if (IsBase) + return HandleInvokePointerToMember(S, PtrToMember, IsInvokeR, LParenLoc, + F, Args, RParenLoc); + + if (IsReferenceWrapper) { + S.Diag(F->getBeginLoc(), + diag::err_invoke_pointer_to_member_incompatible_second_arg) + << CalleeType->isMemberFunctionPointerType() << 1 + << PtrToMember->getClass()->getAsRecordDecl() << FirstArgType; + return ExprError(); + } + } + + ExprResult Deref = S.BuildUnaryOp(S.getCurScope(), LParenLoc, + UnaryOperatorKind::UO_Deref, Args[0], true); + if (Deref.isInvalid()) { + S.Diag(F->getBeginLoc(), + diag::err_invoke_pointer_to_member_incompatible_second_arg) + << CalleeType->isMemberFunctionPointerType() << 0 + << PtrToMember->getClass()->getAsCXXRecordDecl() << FirstArgType; + return ExprError(); + } + + Args[0] = Deref.get(); + IsBase = EvaluateBinaryTypeTrait(S, TypeTrait::BTT_IsBaseOf, ClassType, + Args[0]->getType(), {}); + if (!IsBase) { + S.Diag(LParenLoc, + diag::err_invoke_pointer_to_member_incompatible_second_arg) + << PtrToMember->isMemberFunctionPointer() << 2 + << PtrToMember->getClass()->getAsCXXRecordDecl() << FirstArgType; + return ExprError(); + } + + return HandleInvokePointerToMember(S, PtrToMember, IsInvokeR, LParenLoc, F, + Args, RParenLoc); +} + +static ExprResult HandleInvoke(Sema &S, CallExpr *TheCall, bool IsInvokeR) { + Expr *F = TheCall->getArgs()[0]; + QualType CalleeType = F->getType(); + MultiExprArg Args(TheCall->getArgs() + 1, TheCall->getNumArgs() - 1); + + // FIXME: remove this comment block once notes are addressed. + // Reviewer note 1: It really feels like there should be some way to use + // Context.getSubstTemplateTypeParmType, but it's not clear what the second + // argument should be. + // Reviewer note 2: Do we need to consider SubstTemplateTypeParmPackType too? + if (auto *T = dyn_cast(CalleeType)) + CalleeType = T->getReplacementType(); + + if (!CalleeType->isMemberPointerType()) { + return S.BuildCallExpr(S.getCurScope(), F, TheCall->getBeginLoc(), Args, + TheCall->getEndLoc(), /*ExecConfig*/ nullptr, + /*IsExecConfig*/ false, /*AllowRecovery*/ false, + /*IsStdInvoke*/ true); + } + + return HandleInvokePointerToMember(S, IsInvokeR, CalleeType, + TheCall->getBeginLoc(), F, Args, + TheCall->getEndLoc()); +} + +ExprResult Sema::BuildStdInvokeCall(CallExpr *TheCall, FunctionDecl *FDecl, + unsigned int BuiltinID) { + assert(TheCall->getNumArgs() > 0); + + ExprResult Result = + HandleInvoke(*this, TheCall, BuiltinID == Builtin::BIinvoke_r); + if (BuiltinID == Builtin::BIinvoke || Result.isInvalid()) + return Result; + + QualType ResultType = Result.get()->getType(); + QualType InvokeRType = FDecl->getReturnType(); + if (!EvaluateBinaryTypeTrait(*this, TypeTrait::BTT_IsConvertibleTo, + ResultType, InvokeRType, {})) { + QualType T = TheCall->getArgs()[0]->getType(); + unsigned Kind = T->isRecordType() ? 4 // function object + : T->isMemberFunctionPointerType() + ? 3 // pointer-to-member function + : T->isMemberDataPointerType() ? 2 // pointer-to-data member + : T->isBlockPointerType() ? 1 // block + : 0; // function + SemaDiagnosticBuilder B = + Diag(TheCall->getBeginLoc(), diag::err_invoke_bad_conversion) + << Kind << ResultType << InvokeRType; + if (T->isRecordType()) + B << T; + return ExprError(); + } + + return BuildCXXNamedCast(TheCall->getBeginLoc(), tok::kw_static_cast, + Context.getTrivialTypeSourceInfo(InvokeRType), + Result.get(), TheCall->getBeginLoc(), + TheCall->getBeginLoc()); +} + ExprResult Sema::ActOnArrayTypeTrait(ArrayTypeTrait ATT, SourceLocation KWLoc, ParsedType Ty, @@ -5830,8 +6010,8 @@ QualType Sema::CheckPointerToMemberOperands(ExprResult &LHS, ExprResult &RHS, ExprValueKind &VK, - SourceLocation Loc, - bool isIndirect) { + SourceLocation Loc, bool isIndirect, + bool IsStdInvoke) { assert(!LHS.get()->hasPlaceholderType() && !RHS.get()->hasPlaceholderType() && "placeholders should have been weeded out by now"); @@ -5949,16 +6129,28 @@ Diag(Loc, getLangOpts().CPlusPlus20 ? diag::warn_cxx17_compat_pointer_to_const_ref_member_on_rvalue : diag::ext_pointer_to_const_ref_member_on_rvalue); - else + else if (!IsStdInvoke) Diag(Loc, diag::err_pointer_to_member_oper_value_classify) << RHSType << 1 << LHS.get()->getSourceRange(); + else { + Diag(Loc, diag::err_invoke_pointer_to_member_ref_qualifiers) + << RHS.get() << 0; + return QualType(); + } } break; case RQ_RValue: - if (isIndirect || !LHS.get()->Classify(Context).isRValue()) - Diag(Loc, diag::err_pointer_to_member_oper_value_classify) - << RHSType << 0 << LHS.get()->getSourceRange(); + if (isIndirect || !LHS.get()->Classify(Context).isRValue()) { + if (!IsStdInvoke) + Diag(Loc, diag::err_pointer_to_member_oper_value_classify) + << RHSType << 0 << LHS.get()->getSourceRange(); + else { + Diag(Loc, diag::err_invoke_pointer_to_member_ref_qualifiers) + << RHS.get() << 1; + return QualType(); + } + } break; } } diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp --- a/clang/lib/Sema/SemaOverload.cpp +++ b/clang/lib/Sema/SemaOverload.cpp @@ -49,11 +49,13 @@ } /// A convenience routine for creating a decayed reference to a function. -static ExprResult CreateFunctionRefExpr( - Sema &S, FunctionDecl *Fn, NamedDecl *FoundDecl, const Expr *Base, - bool HadMultipleCandidates, SourceLocation Loc = SourceLocation(), - const DeclarationNameLoc &LocInfo = DeclarationNameLoc()) { - if (S.DiagnoseUseOfDecl(FoundDecl, Loc)) +static ExprResult +CreateFunctionRefExpr(Sema &S, FunctionDecl *Fn, NamedDecl *FoundDecl, + const Expr *Base, bool HadMultipleCandidates, + SourceLocation Loc = SourceLocation(), + const DeclarationNameLoc &LocInfo = DeclarationNameLoc(), + bool IsStdInvoke = false) { + if (S.DiagnoseUseOfDecl(FoundDecl, Loc, {}, {}, {}, {}, IsStdInvoke)) return ExprError(); // If FoundDecl is different from Fn (such as if one is a template // and the other a specialization), make sure DiagnoseUseOfDecl is @@ -6412,7 +6414,7 @@ NamedDecl *ND = Function; if (auto *SpecInfo = Function->getTemplateSpecializationInfo()) ND = SpecInfo->getTemplate(); - + if (ND->getFormalLinkage() == Linkage::InternalLinkage) { Candidate.Viable = false; Candidate.FailureKind = ovl_fail_module_mismatched; @@ -9671,7 +9673,7 @@ const OverloadCandidate &Cand1, const OverloadCandidate &Cand2) { // FIXME: Per P2113R0 we also need to compare the template parameter lists - // when comparing template functions. + // when comparing template functions. if (Cand1.Function && Cand2.Function && Cand1.Function->hasPrototype() && Cand2.Function->hasPrototype()) { auto *PT1 = cast(Cand1.Function->getFunctionType()); @@ -13398,10 +13400,11 @@ /// by CreateOverloadedUnaryOp(). /// /// \param Input The input argument. -ExprResult -Sema::CreateOverloadedUnaryOp(SourceLocation OpLoc, UnaryOperatorKind Opc, - const UnresolvedSetImpl &Fns, - Expr *Input, bool PerformADL) { +ExprResult Sema::CreateOverloadedUnaryOp(SourceLocation OpLoc, + UnaryOperatorKind Opc, + const UnresolvedSetImpl &Fns, + Expr *Input, bool PerformADL, + bool IsStdInvoke) { OverloadedOperatorKind Op = UnaryOperator::getOverloadedOperator(Opc); assert(Op != OO_None && "Invalid opcode for overloaded unary operator"); DeclarationName OpName = Context.DeclarationNames.getCXXOperatorName(Op); @@ -13571,7 +13574,7 @@ // Either we found no viable overloaded operator or we matched a // built-in operator. In either case, fall through to trying to // build a built-in operation. - return CreateBuiltinUnaryOp(OpLoc, Opc, Input); + return CreateBuiltinUnaryOp(OpLoc, Opc, Input, IsStdInvoke); } /// Perform lookup for an overloaded binary operator. @@ -13661,12 +13664,11 @@ /// the function in question. Such a function is never a candidate in /// our overload resolution. This also enables synthesizing a three-way /// comparison from < and == as described in C++20 [class.spaceship]p1. -ExprResult Sema::CreateOverloadedBinOp(SourceLocation OpLoc, - BinaryOperatorKind Opc, - const UnresolvedSetImpl &Fns, Expr *LHS, - Expr *RHS, bool PerformADL, - bool AllowRewrittenCandidates, - FunctionDecl *DefaultedFn) { +ExprResult +Sema::CreateOverloadedBinOp(SourceLocation OpLoc, BinaryOperatorKind Opc, + const UnresolvedSetImpl &Fns, Expr *LHS, Expr *RHS, + bool PerformADL, bool AllowRewrittenCandidates, + FunctionDecl *DefaultedFn, bool IsStdInvoke) { Expr *Args[2] = { LHS, RHS }; LHS=RHS=nullptr; // Please use only Args instead of LHS/RHS couple @@ -13727,7 +13729,7 @@ // If this is the .* operator, which is not overloadable, just // create a built-in binary operator. if (Opc == BO_PtrMemD) - return CreateBuiltinBinOp(OpLoc, Opc, Args[0], Args[1]); + return CreateBuiltinBinOp(OpLoc, Opc, Args[0], Args[1], IsStdInvoke); // Build the overload set. OverloadCandidateSet CandidateSet( @@ -14379,12 +14381,10 @@ /// parameter). The caller needs to validate that the member /// expression refers to a non-static member function or an overloaded /// member function. -ExprResult Sema::BuildCallToMemberFunction(Scope *S, Expr *MemExprE, - SourceLocation LParenLoc, - MultiExprArg Args, - SourceLocation RParenLoc, - Expr *ExecConfig, bool IsExecConfig, - bool AllowRecovery) { +ExprResult Sema::BuildCallToMemberFunction( + Scope *S, Expr *MemExprE, SourceLocation LParenLoc, MultiExprArg Args, + SourceLocation RParenLoc, Expr *ExecConfig, bool IsExecConfig, + bool AllowRecovery, bool IsStdInvoke) { assert(MemExprE->getType() == Context.BoundMemberTy || MemExprE->getType() == Context.OverloadTy); @@ -14418,10 +14418,16 @@ difference.removeAddressSpace(); if (difference) { std::string qualsString = difference.getAsString(); - Diag(LParenLoc, diag::err_pointer_to_member_call_drops_quals) - << fnType.getUnqualifiedType() - << qualsString - << (qualsString.find(' ') == std::string::npos ? 1 : 2); + if (!IsStdInvoke) + Diag(LParenLoc, diag::err_pointer_to_member_call_drops_quals) + << fnType.getUnqualifiedType() << qualsString + << (qualsString.find(' ') == std::string::npos ? 1 : 2); + else { + Diag(LParenLoc, diag::err_invoke_pointer_to_member_drops_qualifiers) + << op->getRHS() << qualsString + << (qualsString.find(' ') == std::string::npos ? 1 : 2); + return ExprError(); + } } CXXMemberCallExpr *call = CXXMemberCallExpr::Create( @@ -14432,7 +14438,8 @@ call, nullptr)) return ExprError(); - if (ConvertArgumentsForCall(call, op, nullptr, proto, Args, RParenLoc)) + if (ConvertArgumentsForCall(call, op, nullptr, proto, Args, RParenLoc, + false, IsStdInvoke)) return ExprError(); if (CheckOtherCall(call, proto)) @@ -14675,11 +14682,11 @@ /// type (C++ [over.call.object]), which can end up invoking an /// overloaded function call operator (@c operator()) or performing a /// user-defined conversion on the object argument. -ExprResult -Sema::BuildCallToObjectOfClassType(Scope *S, Expr *Obj, - SourceLocation LParenLoc, - MultiExprArg Args, - SourceLocation RParenLoc) { +ExprResult Sema::BuildCallToObjectOfClassType(Scope *S, Expr *Obj, + SourceLocation LParenLoc, + MultiExprArg Args, + SourceLocation RParenLoc, + bool IsStdInvoke) { if (checkPlaceholderForOverload(*this, Obj)) return ExprError(); ExprResult Object = Obj; @@ -14768,6 +14775,8 @@ // Perform overload resolution. OverloadCandidateSet::iterator Best; + QualType T = Object.get()->getType(); + SourceRange SR = Object.get()->getSourceRange(); switch (CandidateSet.BestViableFunction(*this, Object.get()->getBeginLoc(), Best)) { case OR_Success: @@ -14778,32 +14787,33 @@ case OR_No_Viable_Function: { PartialDiagnostic PD = CandidateSet.empty() - ? (PDiag(diag::err_ovl_no_oper) - << Object.get()->getType() << /*call*/ 1 - << Object.get()->getSourceRange()) - : (PDiag(diag::err_ovl_no_viable_object_call) - << Object.get()->getType() << Object.get()->getSourceRange()); + ? (PDiag(diag::err_ovl_no_oper) << T << /*call*/ 1 << SR) + : !IsStdInvoke ? (PDiag(diag::err_ovl_no_viable_object_call) << T << SR) + : (PDiag(diag::err_invoke_function_object) + << T << /*no*/ 0 << 1 << SR); CandidateSet.NoteCandidates( PartialDiagnosticAt(Object.get()->getBeginLoc(), PD), *this, OCD_AllCandidates, Args); break; } - case OR_Ambiguous: + case OR_Ambiguous: { + PartialDiagnostic PD = + !IsStdInvoke ? (PDiag(diag::err_ovl_ambiguous_object_call) << T << SR) + : (PDiag(diag::err_invoke_function_object) + << T << /*plural*/ 1 << CandidateSet.size() << 2 << SR); CandidateSet.NoteCandidates( - PartialDiagnosticAt(Object.get()->getBeginLoc(), - PDiag(diag::err_ovl_ambiguous_object_call) - << Object.get()->getType() - << Object.get()->getSourceRange()), - *this, OCD_AmbiguousCandidates, Args); + PartialDiagnosticAt(Object.get()->getBeginLoc(), PD), *this, + OCD_AmbiguousCandidates, Args); break; - + } case OR_Deleted: + PartialDiagnostic PD = + !IsStdInvoke + ? PDiag(diag::err_ovl_deleted_object_call) << T << SR + : PDiag(diag::err_invoke_function_object_deleted) << 2 << T << SR; CandidateSet.NoteCandidates( - PartialDiagnosticAt(Object.get()->getBeginLoc(), - PDiag(diag::err_ovl_deleted_object_call) - << Object.get()->getType() - << Object.get()->getSourceRange()), - *this, OCD_AllCandidates, Args); + PartialDiagnosticAt(Object.get()->getBeginLoc(), PD), *this, + OCD_AllCandidates, Args); break; } @@ -14860,10 +14870,9 @@ DeclarationNameInfo OpLocInfo( Context.DeclarationNames.getCXXOperatorName(OO_Call), LParenLoc); OpLocInfo.setCXXOperatorNameRange(SourceRange(LParenLoc, RParenLoc)); - ExprResult NewFn = CreateFunctionRefExpr(*this, Method, Best->FoundDecl, - Obj, HadMultipleCandidates, - OpLocInfo.getLoc(), - OpLocInfo.getInfo()); + ExprResult NewFn = CreateFunctionRefExpr( + *this, Method, Best->FoundDecl, Obj, HadMultipleCandidates, + OpLocInfo.getLoc(), OpLocInfo.getInfo(), IsStdInvoke); if (NewFn.isInvalid()) return true; diff --git a/clang/test/SemaCXX/builtin-std-invoke.cpp b/clang/test/SemaCXX/builtin-std-invoke.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/builtin-std-invoke.cpp @@ -0,0 +1,496 @@ +// RUN: %clang_cc1 -fsyntax-only -std=c++17 -verify %s + +namespace std { +template +void invoke(F &&, Args &&...); // expected-note{{requires at least 1 argument, but 0 were provided}} + +template +R invoke_r(F &&, Args &&...); // expected-note{{requires at least 1 argument, but 0 were provided}} + +// Slightly different to the real deal to simplify test. +template +class reference_wrapper { +public: + constexpr reference_wrapper(T &t) : data(&t) {} + + constexpr operator T &() const noexcept { return *data; } + +private: + T *data; +}; +} // namespace std + +#define assert(...) \ + if (!(__VA_ARGS__)) \ + __builtin_unreachable(); + +struct ThrowingInt { + constexpr ThrowingInt(int x) : value(x) {} + + int value; +}; + +template +constexpr void bullet_1(F f, T &&t, Args... args) { + assert(std::invoke(f, static_cast(t), args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, static_cast(t), args...)), Returns)); + static_assert(noexcept(std::invoke(f, static_cast(t), args...)) == IsNoexcept); + + assert(std::invoke_r(f, static_cast(t), args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, static_cast(t), args...)), double)); + static_assert(noexcept(std::invoke_r(f, static_cast(t), args...)) == IsNoexcept); + + assert(std::invoke_r(f, static_cast(t), args...).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, static_cast(t), args...)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, static_cast(t), args...))); +} + +template +constexpr void bullet_2(F f, T &t, Args... args) { + std::reference_wrapper rw(t); + assert(std::invoke(f, rw, args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, rw, args...)), Returns)); + static_assert(noexcept(std::invoke(f, rw, args...)) == IsNoexcept); + + assert(std::invoke_r(f, rw, args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, rw, args...)), double)); + static_assert(noexcept(std::invoke_r(f, rw, args...)) == IsNoexcept); + + assert(std::invoke_r(f, rw, args...).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, rw, args...)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, rw, args...))); +} + +template +class PointerWrapper { +public: + constexpr explicit PointerWrapper(T &t) noexcept : p(&t) {} + + constexpr T &operator*() const noexcept { return *p; } + +private: + T *p; +}; + +template +constexpr void bullet_3(F f, T &t, Args... args) { + assert(std::invoke(f, &t, args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, &t, args...)), Returns)); + static_assert(noexcept(std::invoke(f, &t, args...)) == IsNoexcept); + + assert(std::invoke(f, PointerWrapper(t), args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, PointerWrapper(t), args...)), Returns)); + static_assert(noexcept(std::invoke(f, PointerWrapper(t), args...)) == IsNoexcept); + + assert(std::invoke_r(f, &t, args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, &t, args...)), double)); + static_assert(noexcept(std::invoke_r(f, &t, args...)) == IsNoexcept); + + assert(std::invoke_r(f, PointerWrapper(t), args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, PointerWrapper(t), args...)), double)); + static_assert(noexcept(std::invoke_r(f, PointerWrapper(t), args...)) == IsNoexcept); + + assert(std::invoke_r(f, &t, args...).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, &t, args...)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, &t, args...))); + + assert(std::invoke_r(f, PointerWrapper(t), args...).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, PointerWrapper(t), args...)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, PointerWrapper(t), args...))); +} + +template +constexpr bool bullets_1_through_3(T t) { + bullet_1(&T::plus, t, 3, -3); + bullet_1(&T::minus, static_cast(t), 1, 2, 3); + bullet_1(&T::square, static_cast<__remove_reference_t(T) &&>(t), 7); + bullet_1(&T::sum, static_cast(t), -1, -2, -4, -8, -16, -32); + + bullet_2(&T::plus, t, 3, -3); + bullet_2(&T::minus, static_cast(t), 1, 2, 3); + + bullet_3(&T::plus, t, 3, -3); + bullet_3(&T::minus, static_cast(t), 1, 2, 3); + + return true; +} + +template +constexpr void bullet_4(F f, T &&t, U ExpectedResult) { + assert(std::invoke(f, static_cast(t)) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, static_cast(t))), Returns)); + static_assert(noexcept(std::invoke(f, static_cast(t)))); + + assert(std::invoke_r(f, static_cast(t)) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, static_cast(t))), double)); + static_assert(noexcept(std::invoke_r(f, static_cast(t)))); + + assert(std::invoke_r(f, static_cast(t)).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, static_cast(t))), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, static_cast(t)))); +} + +template +constexpr void bullet_5(F f, T &t, U ExpectedResult) { + std::reference_wrapper rw(t); + assert(std::invoke(f, rw) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, rw)), Returns)); + static_assert(noexcept(std::invoke(f, rw))); + + assert(std::invoke_r(f, rw) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, rw)), double)); + static_assert(noexcept(std::invoke_r(f, rw))); + + assert(std::invoke_r(f, rw).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, rw)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, rw))); +} + +template +constexpr void bullet_6(F f, T &t, U ExpectedResult) { + assert(std::invoke(f, &t) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, &t)), Returns)); + static_assert(noexcept(std::invoke(f, &t))); + + assert(std::invoke(f, PointerWrapper(t)) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, PointerWrapper(t))), Returns)); + static_assert(noexcept(std::invoke(f, PointerWrapper(t)))); + + assert(std::invoke_r(f, &t) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, &t)), double)); + static_assert(noexcept(std::invoke_r(f, &t))); + + assert(std::invoke_r(f, PointerWrapper(t)) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, PointerWrapper(t))), double)); + static_assert(noexcept(std::invoke_r(f, PointerWrapper(t)))); + + assert(std::invoke_r(f, &t).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, &t)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, &t))); + + assert(std::invoke_r(f, PointerWrapper(t)).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, PointerWrapper(t))), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, PointerWrapper(t)))); +} + +template +constexpr bool bullets_4_through_6(F f, T t, U ExpectedResult) { + bullet_4(f, t, ExpectedResult); + bullet_4(f, static_cast(t), ExpectedResult); + bullet_4(f, static_cast(t), ExpectedResult); + bullet_4(f, static_cast(t), ExpectedResult); + + bullet_5(f, t, ExpectedResult); + bullet_5(f, static_cast(t), ExpectedResult); + + bullet_6(f, t, ExpectedResult); + bullet_6(f, static_cast(t), ExpectedResult); + + return true; +} + +constexpr int zero() { return 0; } +constexpr int square(int x) { return x * x; } // expected-note 4 {{'square' declared here}} +constexpr double product(double x, double y) { return x * y; } + +struct summation { + template + constexpr auto operator()(Ts... ts) { + return (ts + ...); + } +}; + +struct callable { + int x; + + constexpr int &operator()() &noexcept { return x; } // expected-note 2 {{candidate function not viable: requires 0 arguments, but 2 were provided}} + constexpr const int &operator()() const &noexcept { return x; } // expected-note 2 {{candidate function not viable: requires 0 arguments, but 2 were provided}} + constexpr int &&operator()() &&noexcept { return static_cast(x); } // expected-note 2 {{candidate function not viable: requires 0 arguments, but 2 were provided}} + constexpr const int &&operator()() const &&noexcept { return static_cast(x); } // expected-note 2 {{candidate function not viable: requires 0 arguments, but 2 were provided}} +}; + +template +constexpr bool bullet_7(F &&f, Args &&...args) { + assert(std::invoke(static_cast(f), static_cast(args)...) == expected); + static_assert(__is_same(decltype(std::invoke(static_cast(f), static_cast(args)...)), T)); + static_assert(noexcept(std::invoke(static_cast(f), static_cast(args)...)) == IsNoexcept); + + assert(std::invoke_r(f, args...) == expected); + static_assert(__is_same(decltype(std::invoke_r(f, args...)), double)); + static_assert(noexcept(std::invoke_r(f, args...)) == IsNoexcept); + + assert(std::invoke_r(f, args...).value == expected); + static_assert(__is_same(decltype(std::invoke_r(f, args...)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, args...))); + + return true; +} + +template +constexpr bool bullet_7_1() { + callable c{21}; + return std::invoke(static_cast(c)) == 21 && + __is_same(decltype(std::invoke(static_cast(c))), Expected) &&noexcept(std::invoke(static_cast(c))); +} + +struct Base { + mutable int data; + + constexpr int &plus(int x, int y) &noexcept { + data = x + y; + return data; + } + + constexpr const int &minus(int x, int y, int z) const &noexcept { + data = x - y - z; + return data; + } + + constexpr int &&square(int x) && { + data = x * x; + return static_cast(data); + } + + constexpr const int &&sum(int a, int b, int c, int d, int e, int f) const && { + data = a + b + c + d + e + f; + return static_cast(data); + } +}; + +struct Derived : Base { + double data2; +}; + +void test_invoke() { + static_assert(bullets_1_through_3(Base{})); + static_assert(bullets_1_through_3(Derived{})); + + static_assert(bullets_4_through_6(&Base::data, Base{21}, 21)); + static_assert(bullets_4_through_6(&Base::data, Derived{-96, 18}, -96)); + static_assert(bullets_4_through_6(&Derived::data2, Derived{21, 34}, 34.0)); + + static_assert(bullet_7(zero)); + static_assert(bullet_7(square, 5.0)); + static_assert(bullet_7(product, 9, 1)); + static_assert(bullet_7(&zero)); + static_assert(bullet_7(summation{}, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10L)); + static_assert(bullet_7([] { return 18; })); + + static_assert(bullet_7_1()); + static_assert(bullet_7_1()); + static_assert(bullet_7_1()); + static_assert(bullet_7_1()); +} + +struct Ambiguous { + int operator()(int) const; // expected-note 2 {{candidate function}} + long operator()(unsigned int) const; // expected-note 2 {{candidate function}} +}; + +struct Deleted { + int operator()() const = delete; // expected-note 2 {{candidate function has been explicitly deleted}} +}; + +int deleted_function() = delete; // expected-note 2 {{'deleted_function' has been explicitly marked deleted here}} + +struct Incompatible {}; + +void test_errors() { + // TODO: add cases for bullets 1--6 where 2nd param is an int + std::invoke(); // expected-error{{no matching function for call to 'invoke'}} + std::invoke_r(); // expected-error{{no matching function for call to 'invoke_r'}} + + { // Concerning bullet 1 + std::invoke(&Base::plus); + // expected-error@-1{{can't invoke pointer-to-member function: 'std::invoke' must have at least 2 arguments for a pointer-to-member function, got 1}} + std::invoke_r(&Base::plus); + // expected-error@-1{{can't invoke pointer-to-member function: 'std::invoke_r' must have at least 2 arguments for a pointer-to-member function, got 1}} + std::invoke(&Base::plus, Incompatible{}, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: expected second argument to be a reference to a class compatible with 'Base', got 'Incompatible'}} + std::invoke_r(&Base::plus, Incompatible{}, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: expected second argument to be a reference to a class compatible with 'Base', got 'Incompatible'}} + + std::invoke(&Base::sum, Base{}, 0); + // expected-error@-1{{can't invoke pointer-to-member function: expected 6 arguments, got 1}} + std::invoke_r(&Base::sum, Base{}, 0); + // expected-error@-1{{can't invoke pointer-to-member function: expected 6 arguments, got 1}} + + std::invoke(&Base::sum, Base{}, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + // expected-error@-1{{can't invoke pointer-to-member function: expected 6 arguments, got 10}} + std::invoke_r(&Base::sum, Base{}, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + // expected-error@-1{{can't invoke pointer-to-member function: expected 6 arguments, got 10}} + + std::invoke_r(&Base::sum, Base{}, 0, 1, 2, 3, 4, 5); + // expected-error@-1{{can't invoke pointer-to-member function: return type 'const int' isn't convertible to 'void *'}} + + const Base cb; + std::invoke(&Base::plus, cb, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: '&Base::plus' drops 'const' qualifier}} + std::invoke_r(&Base::plus, cb, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: '&Base::plus' drops 'const' qualifier}} + + volatile Base vb; + std::invoke(&Base::plus, vb, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: '&Base::plus' drops 'volatile' qualifier}} + std::invoke_r(&Base::plus, vb, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: '&Base::plus' drops 'volatile' qualifier}} + + const volatile Base cvb; + std::invoke(&Base::plus, cvb, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: '&Base::plus' drops 'const volatile' qualifiers}} + std::invoke_r(&Base::plus, cvb, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: '&Base::plus' drops 'const volatile' qualifiers}} + + std::invoke(&Base::plus, Base{}, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: '&Base::plus' can only be called on an lvalue}} + std::invoke_r(&Base::plus, Base{}, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: '&Base::plus' can only be called on an lvalue}} + + std::invoke(&Base::sum, cb, 1, 2, 3, 4, 5, 6); + // expected-error@-1{{can't invoke pointer-to-member function: '&Base::sum' can only be called on an rvalue}} + std::invoke_r(&Base::sum, cb, 1, 2, 3, 4, 5, 6); + // expected-error@-1{{can't invoke pointer-to-member function: '&Base::sum' can only be called on an rvalue}} + } + { // Concerning bullet 2 + Base b; + std::reference_wrapper rw(b); + + Incompatible p; + std::reference_wrapper pw(p); + + std::invoke(&Base::plus, pw, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: expected second argument to be a wrapee to a class compatible with 'Base', got 'Incompatible'}} + std::invoke_r(&Base::plus, pw, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: expected second argument to be a wrapee to a class compatible with 'Base', got 'Incompatible'}} + + std::invoke(&Base::plus, rw, 0); + // expected-error@-1{{can't invoke pointer-to-member function: expected 2 arguments, got 1}} + std::invoke_r(&Base::plus, rw, 0); + // expected-error@-1{{can't invoke pointer-to-member function: expected 2 arguments, got 1}} + + std::invoke(&Base::plus, rw, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + // expected-error@-1{{can't invoke pointer-to-member function: expected 2 arguments, got 10}} + std::invoke_r(&Base::plus, rw, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + // expected-error@-1{{can't invoke pointer-to-member function: expected 2 arguments, got 10}} + + std::invoke_r(&Base::plus, rw, 4, 1); + // expected-error@-1{{can't invoke pointer-to-member function: return type 'int' isn't convertible to 'void *'}} + } + { // Concerning bullet 3 + Base b; + Incompatible p; + + std::invoke(&Base::plus, &p, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: expected second argument to be a pointer to a class compatible with 'Base', got 'Incompatible *'}} + std::invoke_r(&Base::plus, &p, 1, 2); + // expected-error@-1{{can't invoke pointer-to-member function: expected second argument to be a pointer to a class compatible with 'Base', got 'Incompatible *'}} + + std::invoke(&Base::plus, &b, 0); + // expected-error@-1{{can't invoke pointer-to-member function: expected 2 arguments, got 1}} + std::invoke_r(&Base::plus, &b, 0); + // expected-error@-1{{can't invoke pointer-to-member function: expected 2 arguments, got 1}} + + std::invoke(&Base::plus, &b, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + // expected-error@-1{{can't invoke pointer-to-member function: expected 2 arguments, got 10}} + std::invoke_r(&Base::plus, &b, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + // expected-error@-1{{can't invoke pointer-to-member function: expected 2 arguments, got 10}} + + std::invoke_r(&Base::plus, &b, 4, 1); + // expected-error@-1{{can't invoke pointer-to-member function: return type 'int' isn't convertible to 'void *'}} + } + { // Concerning bullet 4 + std::invoke(&Base::data); + // expected-error@-1{{can't invoke pointer-to-data member: 'std::invoke' must have exactly 2 arguments for a pointer-to-data member, got 1}} + std::invoke_r(&Base::data); + // expected-error@-1{{can't invoke pointer-to-data member: 'std::invoke_r' must have exactly 2 arguments for a pointer-to-data member, got 1}} + + std::invoke(&Base::data, Incompatible{}); + // expected-error@-1{{can't invoke pointer-to-data member: expected second argument to be a reference to a class compatible with 'Base', got 'Incompatible'}} + std::invoke_r(&Base::data, Incompatible{}); + // expected-error@-1{{can't invoke pointer-to-data member: expected second argument to be a reference to a class compatible with 'Base', got 'Incompatible'}} + + std::invoke(&Base::data, Base{}, Base{}); + // expected-error@-1{{can't invoke pointer-to-data member: 'std::invoke' must have exactly 2 arguments for a pointer-to-data member, got 1}} + std::invoke_r(&Base::data, Base{}, Base{}); + // expected-error@-1{{can't invoke pointer-to-data member: 'std::invoke_r' must have exactly 2 arguments for a pointer-to-data member, got 1}} + + std::invoke_r(&Base::data, Base{}); + // expected-error@-1{{can't invoke pointer-to-data member: return type 'int' isn't convertible to 'void *'}} + } + { // Concerning bullet 5 + Base b; + std::reference_wrapper rw(b); + + Incompatible p; + std::reference_wrapper pw(p); + + std::invoke(&Base::data, pw); + // expected-error@-1{{can't invoke pointer-to-data member: expected second argument to be a wrapee to a class compatible with 'Base', got 'Incompatible'}} + std::invoke_r(&Base::data, pw); + // expected-error@-1{{can't invoke pointer-to-data member: expected second argument to be a wrapee to a class compatible with 'Base', got 'Incompatible'}} + + std::invoke(&Base::data, rw, Base{}); + // expected-error@-1{{can't invoke pointer-to-data member: 'std::invoke' must have exactly 2 arguments for a pointer-to-data member, got 1}} + std::invoke_r(&Base::data, rw, Base{}); + // expected-error@-1{{can't invoke pointer-to-data member: 'std::invoke_r' must have exactly 2 arguments for a pointer-to-data member, got 1}} + + std::invoke_r(&Base::data, rw); + // expected-error@-1{{can't invoke pointer-to-data member: return type 'int' isn't convertible to 'void *'}} + } + { // Concerning bullet 6 + Base b; + Incompatible p; + + std::invoke(&Base::data, &p); + // expected-error@-1{{can't invoke pointer-to-data member: expected second argument to be a pointer to a class compatible with 'Base', got 'Incompatible *'}} + std::invoke_r(&Base::data, &p); + // expected-error@-1{{can't invoke pointer-to-data member: expected second argument to be a pointer to a class compatible with 'Base', got 'Incompatible *'}} + + std::invoke(&Base::data, &b, Base{}); + // expected-error@-1{{can't invoke pointer-to-data member: 'std::invoke' must have exactly 2 arguments for a pointer-to-data member, got 1}} + std::invoke_r(&Base::data, &b, Base{}); + // expected-error@-1{{can't invoke pointer-to-data member: 'std::invoke_r' must have exactly 2 arguments for a pointer-to-data member, got 1}} + + std::invoke_r(&Base::data, &b); + // expected-error@-1{{can't invoke pointer-to-data member: return type 'int' isn't convertible to 'void *'}} + } + { // Concerning bullet 7 + std::invoke(square); + // expected-error@-1{{can't invoke function: expected 1 argument, got 0}} + std::invoke_r(square); + // expected-error@-1{{can't invoke function: expected 1 argument, got 0}} + + std::invoke(square, 1, 2); + // expected-error@-1{{can't invoke function: expected 1 argument, got 2}} + std::invoke_r(square, 1, 2); + // expected-error@-1{{can't invoke function: expected 1 argument, got 2}} + + std::invoke(&product, 1); + // expected-error@-1{{can't invoke function: expected 2 arguments, got 1}} + std::invoke_r(&product, 1); + // expected-error@-1{{can't invoke function: expected 2 arguments, got 1}} + + std::invoke(deleted_function); + // expected-error@-1{{attempt to use a deleted function}} + std::invoke_r(deleted_function); + // expected-error@-1{{attempt to use a deleted function}} + + std::invoke(callable{1}, 1, 2); + // expected-error@-1{{can't invoke 'callable' function object: no suitable overload found}} + std::invoke_r(callable{1}, 1, 2); + // expected-error@-1{{can't invoke 'callable' function object: no suitable overload found}} + + std::invoke_r(callable{1}); + // expected-error@-1{{can't invoke 'callable' function object: return type 'int' isn't convertible to 'callable'}} + + std::invoke(Ambiguous{}, 0.0); + // expected-error@-1{{can't invoke 'Ambiguous' function object: 2 suitable overloads found, which makes choosing ambiguous}} + std::invoke_r(Ambiguous{}, 0.0); + // expected-error@-1{{can't invoke 'Ambiguous' function object: 2 suitable overloads found, which makes choosing ambiguous}} + + std::invoke(Deleted{}); + // expected-error@-1{{can't invoke 'Deleted' function object: chosen overload candidate is deleted}} + std::invoke_r(Deleted{}); + // expected-error@-1{{can't invoke 'Deleted' function object: chosen overload candidate is deleted}} + } +}