diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -432,9 +432,9 @@ ----------------------------- - Improved ``-O0`` code generation for calls to ``std::move``, ``std::forward``, - ``std::move_if_noexcept``, ``std::addressof``, and ``std::as_const``. These - are now treated as compiler builtins and implemented directly, rather than - instantiating the definition from the standard library. + ``std::move_if_noexcept``, ``std::addressof``, ``std::as_const``, ``std::invoke``, and + ``std::invoke_r``. These are now treated as compiler builtins and implemented directly, rather + than instantiating the definition from the standard library. - Fixed mangling of nested dependent names such as ``T::a::b``, where ``T`` is a template parameter, to conform to the Itanium C++ ABI and be compatible with GCC. This breaks binary compatibility with code compiled with earlier versions @@ -565,7 +565,7 @@ - Added ``forEachTemplateArgument`` matcher which creates a match every time a ``templateArgument`` matches the matcher supplied to it. - + - Added ``objcStringLiteral`` matcher which matches ObjectiveC String literal expressions. diff --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def --- a/clang/include/clang/Basic/Builtins.def +++ b/clang/include/clang/Basic/Builtins.def @@ -1559,6 +1559,8 @@ LANGBUILTIN(__addressof, "v*v&", "zfncT", CXX_LANG) LIBBUILTIN(as_const, "v&v&", "zfncTh", "utility", CXX_LANG) LIBBUILTIN(forward, "v&v&", "zfncTh", "utility", CXX_LANG) +LIBBUILTIN(invoke, "v.", "zfTh", "functional", CXX_LANG) +LIBBUILTIN(invoke_r, "v.", "zfTh", "functional", CXX_LANG) LIBBUILTIN(move, "v&v&", "zfncTh", "utility", CXX_LANG) LIBBUILTIN(move_if_noexcept, "v&v&", "zfncTh", "utility", CXX_LANG) diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -8431,6 +8431,19 @@ "%select{function|block|method|kernel function}0 call, " "expected at most %1, have %2; did you mean %3?">; +def err_invoke_pointer_to_member_too_few_args : Error< + "can't invoke '%0': 'std::invoke%select{|_r}1' must have " + "%select{exactly|at least}2 2 arguments for a pointer-to-" + "%select{data member|member function}2, got %3">; +def err_invoke_pointer_to_member_incompatible_second_arg : Error< + "can't invoke '%0': expected second argument to be a " + "%select{reference|wrapee|pointer}1 to a class compatible with %2, got " + "%select{nothing|%4}3">; +def err_invoke_wrong_number_of_args : Error< + "can't invoke '%0': expected %1 %select{argument|arguments}2, got %3">; +def err_invoke_bad_conversion : Error< + "can't invoke '%0': return type %1 isn't convertible to %2">; + def err_arc_typecheck_convert_incompatible_pointer : Error< "incompatible pointer types passing retainable parameter of type %0" "to a CF function expecting %1 type">; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -5686,16 +5686,18 @@ const TemplateArgumentListInfo *TemplateArgs = nullptr); void ActOnDefaultCtorInitializers(Decl *CDtorDecl); - bool ConvertArgumentsForCall(CallExpr *Call, Expr *Fn, - FunctionDecl *FDecl, + bool ConvertArgumentsForCall(CallExpr *Call, Expr *Fn, FunctionDecl *FDecl, const FunctionProtoType *Proto, - ArrayRef Args, - SourceLocation RParenLoc, - bool ExecConfig = false); + ArrayRef Args, SourceLocation RParenLoc, + bool ExecConfig = false, + bool IsStdInvoke = false); void CheckStaticArrayArgument(SourceLocation CallLoc, ParmVarDecl *Param, const Expr *ArgExpr); + ExprResult BuildStdInvokeCall(CallExpr *TheCall, FunctionDecl *FDecl, + unsigned int BuiltinID); + /// ActOnCallExpr - Handle a call to Fn with the specified array of arguments. /// This provides the location of the left/right parens and a list of comma /// locations. @@ -5706,7 +5708,8 @@ MultiExprArg ArgExprs, SourceLocation RParenLoc, Expr *ExecConfig = nullptr, bool IsExecConfig = false, - bool AllowRecovery = false); + bool AllowRecovery = false, + bool IsStdInvoke = false); Expr *BuildBuiltinCallExpr(SourceLocation Loc, Builtin::ID Id, MultiExprArg CallArgs); enum class AtomicArgumentOrder { API, AST }; @@ -5719,7 +5722,8 @@ BuildResolvedCallExpr(Expr *Fn, NamedDecl *NDecl, SourceLocation LParenLoc, ArrayRef Arg, SourceLocation RParenLoc, Expr *Config = nullptr, bool IsExecConfig = false, - ADLCallKind UsesADL = ADLCallKind::NotADL); + ADLCallKind UsesADL = ADLCallKind::NotADL, + bool IsStdInvoke = false); ExprResult ActOnCUDAExecConfigExpr(Scope *S, SourceLocation LLLLoc, MultiExprArg ExecConfig, diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -8316,6 +8316,8 @@ switch (E->getBuiltinCallee()) { case Builtin::BIas_const: case Builtin::BIforward: + case Builtin::BIinvoke: + case Builtin::BIinvoke_r: case Builtin::BImove: case Builtin::BImove_if_noexcept: if (cast(E->getCalleeDecl())->isConstexpr()) diff --git a/clang/lib/Analysis/BodyFarm.cpp b/clang/lib/Analysis/BodyFarm.cpp --- a/clang/lib/Analysis/BodyFarm.cpp +++ b/clang/lib/Analysis/BodyFarm.cpp @@ -22,6 +22,8 @@ #include "clang/Analysis/CodeInjector.h" #include "clang/Basic/Builtins.h" #include "clang/Basic/OperatorKinds.h" +#include "clang/Sema/Sema.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Debug.h" diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -4632,6 +4632,8 @@ case Builtin::BImove_if_noexcept: case Builtin::BIforward: case Builtin::BIas_const: + case Builtin::BIinvoke: + case Builtin::BIinvoke_r: return RValue::get(EmitLValue(E->getArg(0)).getPointer(*this)); case Builtin::BI__GetExceptionInfo: { if (llvm::GlobalVariable *GV = diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -28,6 +28,7 @@ #include "clang/AST/ExprOpenMP.h" #include "clang/AST/FormatString.h" #include "clang/AST/NSAPI.h" +#include "clang/AST/NestedNameSpecifier.h" #include "clang/AST/NonTrivialTypeVisitor.h" #include "clang/AST/OperationKinds.h" #include "clang/AST/RecordLayout.h" @@ -2418,6 +2419,11 @@ } break; } + case Builtin::BIinvoke: + case Builtin::BIinvoke_r: { + TheCallResult = BuildStdInvokeCall(TheCall, FDecl, BuiltinID); + break; + } // OpenCL v2.0, s6.13.16 - Pipe functions case Builtin::BIread_pipe: case Builtin::BIwrite_pipe: diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -9292,6 +9292,9 @@ const auto *FPT = FD->getType()->castAs(); return FPT->getNumParams() == 1 && !FPT->isVariadic(); } + case Builtin::BIinvoke: + case Builtin::BIinvoke_r: + return true; default: return false; diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -6006,13 +6006,12 @@ /// Fn is the function expression. For a C++ member function, this /// routine does not attempt to convert the object argument. Returns /// true if the call is ill-formed. -bool -Sema::ConvertArgumentsForCall(CallExpr *Call, Expr *Fn, - FunctionDecl *FDecl, - const FunctionProtoType *Proto, - ArrayRef Args, - SourceLocation RParenLoc, - bool IsExecConfig) { +bool Sema::ConvertArgumentsForCall(CallExpr *Call, Expr *Fn, + FunctionDecl *FDecl, + const FunctionProtoType *Proto, + ArrayRef Args, + SourceLocation RParenLoc, bool IsExecConfig, + bool IsStdInvoke) { // Bail out early if calling a builtin with custom typechecking. if (FDecl) if (unsigned ID = FDecl->getBuiltinID()) @@ -6034,7 +6033,11 @@ if (Args.size() < NumParams) { if (Args.size() < MinArgs) { TypoCorrection TC; - if (FDecl && (TC = TryTypoCorrectionForCall(*this, Fn, FDecl, Args))) { + if (IsStdInvoke) + Diag(Call->getBeginLoc(), diag::err_invoke_wrong_number_of_args) + << Fn << NumParams << (NumParams != 1) << Args.size(); + else if (FDecl && + (TC = TryTypoCorrectionForCall(*this, Fn, FDecl, Args))) { unsigned diag_id = MinArgs == NumParams && !Proto->isVariadic() ? diag::err_typecheck_call_too_few_args_suggest @@ -6072,7 +6075,11 @@ if (Args.size() > NumParams) { if (!Proto->isVariadic()) { TypoCorrection TC; - if (FDecl && (TC = TryTypoCorrectionForCall(*this, Fn, FDecl, Args))) { + if (IsStdInvoke) + Diag(Call->getBeginLoc(), diag::err_invoke_wrong_number_of_args) + << Fn << NumParams << (NumParams != 1) << Args.size(); + else if (FDecl && + (TC = TryTypoCorrectionForCall(*this, Fn, FDecl, Args))) { unsigned diag_id = MinArgs == NumParams && !Proto->isVariadic() ? diag::err_typecheck_call_too_many_args_suggest @@ -6641,7 +6648,7 @@ ExprResult Sema::BuildCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc, MultiExprArg ArgExprs, SourceLocation RParenLoc, Expr *ExecConfig, bool IsExecConfig, - bool AllowRecovery) { + bool AllowRecovery, bool IsStdInvoke) { // Since this might be a postfix expression, get rid of ParenListExprs. ExprResult Result = MaybeConvertParenListExprToParenExpr(Scope, Fn); if (Result.isInvalid()) return ExprError(); @@ -6840,7 +6847,8 @@ CurFPFeatureOverrides()); } return BuildResolvedCallExpr(Fn, NDecl, LParenLoc, ArgExprs, RParenLoc, - ExecConfig, IsExecConfig); + ExecConfig, IsExecConfig, ADLCallKind::NotADL, + IsStdInvoke); } /// BuildBuiltinCallExpr - Create a call to a builtin function specified by Id @@ -6915,7 +6923,8 @@ SourceLocation LParenLoc, ArrayRef Args, SourceLocation RParenLoc, Expr *Config, - bool IsExecConfig, ADLCallKind UsesADL) { + bool IsExecConfig, ADLCallKind UsesADL, + bool IsStdInvoke) { FunctionDecl *FDecl = dyn_cast_or_null(NDecl); unsigned BuiltinID = (FDecl ? FDecl->getBuiltinID() : 0); @@ -7091,7 +7100,7 @@ if (Proto) { if (ConvertArgumentsForCall(TheCall, Fn, FDecl, Proto, Args, RParenLoc, - IsExecConfig)) + IsExecConfig, IsStdInvoke)) return ExprError(); } else { assert(isa(FuncT) && "Unknown FunctionType!"); diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp --- a/clang/lib/Sema/SemaExprCXX.cpp +++ b/clang/lib/Sema/SemaExprCXX.cpp @@ -17,15 +17,20 @@ #include "clang/AST/ASTLambda.h" #include "clang/AST/CXXInheritance.h" #include "clang/AST/CharUnits.h" +#include "clang/AST/DeclBase.h" #include "clang/AST/DeclObjC.h" +#include "clang/AST/Expr.h" #include "clang/AST/ExprCXX.h" #include "clang/AST/ExprObjC.h" #include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/Type.h" #include "clang/AST/TypeLoc.h" #include "clang/Basic/AlignedAllocation.h" #include "clang/Basic/DiagnosticSema.h" #include "clang/Basic/PartialDiagnostic.h" +#include "clang/Basic/SourceLocation.h" #include "clang/Basic/TargetInfo.h" +#include "clang/Basic/TokenKinds.h" #include "clang/Basic/TypeTraits.h" #include "clang/Lex/Preprocessor.h" #include "clang/Sema/DeclSpec.h" @@ -5645,6 +5650,174 @@ llvm_unreachable("Unknown type trait or not implemented"); } +static ExprResult HandleInvokePointerToMemberFunction( + Sema &S, const MemberPointerType *CalleeType, bool, + SourceLocation LParenLoc, Expr *F, MultiExprArg Args, + SourceLocation RParenLoc) { + assert(CalleeType->isMemberFunctionPointer()); + auto *Prototype = CalleeType->getPointeeType()->getAs(); + std::size_t TrueSize = Args.size() - 1; + std::size_t TotalParams = Prototype->getNumParams(); + if (TotalParams != TrueSize) { + S.Diag(LParenLoc, diag::err_invoke_wrong_number_of_args) + << F << Prototype->getNumParams() << (TotalParams != 1) << TrueSize; + return ExprError(); + } + + ExprResult B = S.BuildBinOp(S.getCurScope(), LParenLoc, + BinaryOperatorKind::BO_PtrMemD, Args[0], F); + if (B.isInvalid()) { + return ExprError(); + } + + return S.BuildCallToMemberFunction(S.getCurScope(), B.get(), LParenLoc, + Args.drop_front(), RParenLoc); +} + +static ExprResult +HandleInvokePointerToDataMember(Sema &S, const MemberPointerType *CalleeType, + bool IsInvokeR, SourceLocation LParenLoc, + Expr *F, MultiExprArg Args, + SourceLocation RParenLoc) { + assert(CalleeType->isMemberDataPointer()); + if (Args.size() != 1) { + S.Diag(LParenLoc, diag::err_invoke_pointer_to_member_too_few_args) + << F << IsInvokeR << 0 << (Args.size() + 1); + return ExprError(); + } + + return S.BuildBinOp(S.getCurScope(), LParenLoc, + BinaryOperatorKind::BO_PtrMemD, Args[0], F); +} + +static ExprResult +HandleInvokePointerToMember(Sema &S, const MemberPointerType *CalleeType, + bool IsInvokeR, SourceLocation LParenLoc, Expr *F, + MultiExprArg Args, SourceLocation RParenLoc) { + auto *Fn = CalleeType->isMemberFunctionPointer() + ? HandleInvokePointerToMemberFunction + : HandleInvokePointerToDataMember; + return Fn(S, CalleeType, IsInvokeR, LParenLoc, F, Args, RParenLoc); +} + +static ExprResult UnwrapReferenceWrapper(Sema &S, QualType &FirstArgType, + Expr *Arg) { + FirstArgType = S.BuiltinAddReference( + dyn_cast( + dyn_cast(FirstArgType)->getNamedType()) + ->getArg(0) + .getAsType(), + Sema::UTTKind::AddLvalueReference, {}); + return S.BuildCXXNamedCast({}, tok::kw_static_cast, + S.Context.getTrivialTypeSourceInfo(FirstArgType), + Arg, {}, {}); +} + +static ExprResult HandleInvokePointerToMember(Sema &S, bool IsInvokeR, + QualType CalleeType, + SourceLocation LParenLoc, Expr *F, + MultiExprArg Args, + SourceLocation RParenLoc) { + auto *PtrToMember = CalleeType->getAs(); + if (Args.size() == 0) { + S.Diag(LParenLoc, diag::err_invoke_pointer_to_member_too_few_args) + << F << IsInvokeR << PtrToMember->isMemberFunctionPointer() << 1; + return ExprError(); + } + + QualType FirstArgType = Args[0]->getType(); + QualType ClassType = QualType(PtrToMember->getClass(), 0); + bool IsBase = EvaluateBinaryTypeTrait(S, TypeTrait::BTT_IsBaseOf, ClassType, + FirstArgType.getNonReferenceType(), {}); + if (IsBase) + return HandleInvokePointerToMember(S, PtrToMember, IsInvokeR, LParenLoc, F, + Args, RParenLoc); + + if (RecordDecl *D = FirstArgType->getAsCXXRecordDecl()) { + bool IsReferenceWrapper = + D->isInStdNamespace() && D->getName() == "reference_wrapper"; + if (IsReferenceWrapper) { + Args[0] = UnwrapReferenceWrapper(S, FirstArgType, Args[0]).get(); + FirstArgType = Args[0]->getType(); + } + + IsBase = EvaluateBinaryTypeTrait(S, TypeTrait::BTT_IsBaseOf, ClassType, + FirstArgType, {}); + if (IsBase) + return HandleInvokePointerToMember(S, PtrToMember, IsInvokeR, LParenLoc, + F, Args, RParenLoc); + } + + // TODO: detect if dereferencing is possible and exit early if it isn't, or + // signal to BuildUnaryOp that we're using std::invoke and to issue a more + // targeted diagnostic + ExprResult Deref = S.BuildUnaryOp(S.getCurScope(), LParenLoc, + UnaryOperatorKind::UO_Deref, Args[0]); + if (Deref.isInvalid()) + return ExprError(); + + Args[0] = Deref.get(); + IsBase = EvaluateBinaryTypeTrait(S, TypeTrait::BTT_IsBaseOf, ClassType, + Args[0]->getType(), {}); + if (!IsBase) { + S.Diag(LParenLoc, + diag::err_invoke_pointer_to_member_incompatible_second_arg) + << F << 2 << PtrToMember->getClass()->getAsCXXRecordDecl() << 1 + << FirstArgType; + return ExprError(); + } + + return HandleInvokePointerToMember(S, PtrToMember, IsInvokeR, LParenLoc, F, + Args, RParenLoc); +} + +static ExprResult HandleInvoke(Sema &S, CallExpr *TheCall, bool IsInvokeR) { + Expr *F = TheCall->getArgs()[0]; + QualType CalleeType = F->getType(); + MultiExprArg Args(TheCall->getArgs() + 1, TheCall->getNumArgs() - 1); + + // FIXME: remove this comment block once notes are addressed. + // Reviewer note 1: It really feels like there should be some way to use + // Context.getSubstTemplateTypeParmType, but it's not clear what the second + // argument should be. + // Reviewer note 2: Do we need to consider SubstTemplateTypeParmPackType too? + if (auto *T = dyn_cast(CalleeType)) + CalleeType = T->getReplacementType(); + + if (!CalleeType->isMemberPointerType()) + return S.BuildCallExpr(S.getCurScope(), F, TheCall->getBeginLoc(), Args, + TheCall->getEndLoc(), /*ExecConfig*/ nullptr, + /*IsExecConfig*/ false, /*AllowRecovery*/ false, + /*IsStdInvoke*/ true); + + return HandleInvokePointerToMember(S, IsInvokeR, CalleeType, + TheCall->getBeginLoc(), F, Args, + TheCall->getEndLoc()); +} + +ExprResult Sema::BuildStdInvokeCall(CallExpr *TheCall, FunctionDecl *FDecl, + unsigned int BuiltinID) { + assert(TheCall->getNumArgs() > 0); + + ExprResult Result = + HandleInvoke(*this, TheCall, BuiltinID == Builtin::BIinvoke_r); + if (BuiltinID == Builtin::BIinvoke || Result.isInvalid()) + return Result; + + QualType ResultType = Result.get()->getType(); + QualType InvokeRType = FDecl->getReturnType(); + if (!EvaluateBinaryTypeTrait(*this, TypeTrait::BTT_IsConvertibleTo, + ResultType, InvokeRType, {})) { + Diag(TheCall->getBeginLoc(), diag::err_invoke_bad_conversion) + << TheCall->getArgs()[0] << ResultType << InvokeRType; + return ExprError(); + } + + return BuildCXXNamedCast(TheCall->getBeginLoc(), tok::kw_static_cast, + Context.getTrivialTypeSourceInfo(InvokeRType), + Result.get(), {}, TheCall->getBeginLoc()); +} + ExprResult Sema::ActOnArrayTypeTrait(ArrayTypeTrait ATT, SourceLocation KWLoc, ParsedType Ty, diff --git a/clang/test/SemaCXX/builtin-std-invoke.cpp b/clang/test/SemaCXX/builtin-std-invoke.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/builtin-std-invoke.cpp @@ -0,0 +1,439 @@ +// RUN: %clang_cc1 -fsyntax-only -std=c++17 -verify %s + +namespace std { +template +void invoke(F &&, Args &&...); // expected-note{{requires at least 1 argument, but 0 were provided}} + +template +R invoke_r(F &&, Args &&...); // expected-note{{requires at least 1 argument, but 0 were provided}} + +// Slightly different to the real deal to simplify test. +template +class reference_wrapper { +public: + constexpr reference_wrapper(T &t) : data(&t) {} + + constexpr operator T &() const noexcept { return *data; } + +private: + T *data; +}; +} // namespace std + +#define assert(...) \ + if (!(__VA_ARGS__)) \ + __builtin_unreachable(); + +struct ThrowingInt { + constexpr ThrowingInt(int x) : value(x) {} + + int value; +}; + +template +constexpr void bullet_1(F f, T &&t, Args... args) { + assert(std::invoke(f, static_cast(t), args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, static_cast(t), args...)), Returns)); + static_assert(noexcept(std::invoke(f, static_cast(t), args...)) == IsNoexcept); + + assert(std::invoke_r(f, static_cast(t), args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, static_cast(t), args...)), double)); + static_assert(noexcept(std::invoke_r(f, static_cast(t), args...)) == IsNoexcept); + + assert(std::invoke_r(f, static_cast(t), args...).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, static_cast(t), args...)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, static_cast(t), args...))); +} + +template +constexpr void bullet_2(F f, T &t, Args... args) { + std::reference_wrapper rw(t); + assert(std::invoke(f, rw, args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, rw, args...)), Returns)); + static_assert(noexcept(std::invoke(f, rw, args...)) == IsNoexcept); + + assert(std::invoke_r(f, rw, args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, rw, args...)), double)); + static_assert(noexcept(std::invoke_r(f, rw, args...)) == IsNoexcept); + + assert(std::invoke_r(f, rw, args...).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, rw, args...)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, rw, args...))); +} + +template +class PointerWrapper { +public: + constexpr explicit PointerWrapper(T &t) noexcept : p(&t) {} + + constexpr T &operator*() const noexcept { return *p; } + +private: + T *p; +}; + +template +constexpr void bullet_3(F f, T &t, Args... args) { + assert(std::invoke(f, &t, args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, &t, args...)), Returns)); + static_assert(noexcept(std::invoke(f, &t, args...)) == IsNoexcept); + + assert(std::invoke(f, PointerWrapper(t), args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, PointerWrapper(t), args...)), Returns)); + static_assert(noexcept(std::invoke(f, PointerWrapper(t), args...)) == IsNoexcept); + + assert(std::invoke_r(f, &t, args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, &t, args...)), double)); + static_assert(noexcept(std::invoke_r(f, &t, args...)) == IsNoexcept); + + assert(std::invoke_r(f, PointerWrapper(t), args...) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, PointerWrapper(t), args...)), double)); + static_assert(noexcept(std::invoke_r(f, PointerWrapper(t), args...)) == IsNoexcept); + + assert(std::invoke_r(f, &t, args...).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, &t, args...)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, &t, args...))); + + assert(std::invoke_r(f, PointerWrapper(t), args...).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, PointerWrapper(t), args...)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, PointerWrapper(t), args...))); +} + +template +constexpr bool bullets_1_through_3(T t) { + bullet_1(&T::plus, t, 3, -3); + bullet_1(&T::minus, static_cast(t), 1, 2, 3); + bullet_1(&T::square, static_cast<__remove_reference(T) &&>(t), 7); + bullet_1(&T::sum, static_cast(t), -1, -2, -4, -8, -16, -32); + + bullet_2(&T::plus, t, 3, -3); + bullet_2(&T::minus, static_cast(t), 1, 2, 3); + + bullet_3(&T::plus, t, 3, -3); + bullet_3(&T::minus, static_cast(t), 1, 2, 3); + + return true; +} + +template +constexpr void bullet_4(F f, T &&t, U ExpectedResult) { + assert(std::invoke(f, static_cast(t)) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, static_cast(t))), Returns)); + static_assert(noexcept(std::invoke(f, static_cast(t)))); + + assert(std::invoke_r(f, static_cast(t)) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, static_cast(t))), double)); + static_assert(noexcept(std::invoke_r(f, static_cast(t)))); + + assert(std::invoke_r(f, static_cast(t)).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, static_cast(t))), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, static_cast(t)))); +} + +template +constexpr void bullet_5(F f, T &t, U ExpectedResult) { + std::reference_wrapper rw(t); + assert(std::invoke(f, rw) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, rw)), Returns)); + static_assert(noexcept(std::invoke(f, rw))); + + assert(std::invoke_r(f, rw) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, rw)), double)); + static_assert(noexcept(std::invoke_r(f, rw))); + + assert(std::invoke_r(f, rw).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, rw)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, rw))); +} + +template +constexpr void bullet_6(F f, T &t, U ExpectedResult) { + assert(std::invoke(f, &t) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, &t)), Returns)); + static_assert(noexcept(std::invoke(f, &t))); + + assert(std::invoke(f, PointerWrapper(t)) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke(f, PointerWrapper(t))), Returns)); + static_assert(noexcept(std::invoke(f, PointerWrapper(t)))); + + assert(std::invoke_r(f, &t) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, &t)), double)); + static_assert(noexcept(std::invoke_r(f, &t))); + + assert(std::invoke_r(f, PointerWrapper(t)) == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, PointerWrapper(t))), double)); + static_assert(noexcept(std::invoke_r(f, PointerWrapper(t)))); + + assert(std::invoke_r(f, &t).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, &t)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, &t))); + + assert(std::invoke_r(f, PointerWrapper(t)).value == ExpectedResult); + static_assert(__is_same(decltype(std::invoke_r(f, PointerWrapper(t))), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, PointerWrapper(t)))); +} + +template +constexpr bool bullets_4_through_6(F f, T t, U ExpectedResult) { + bullet_4(f, t, ExpectedResult); + bullet_4(f, static_cast(t), ExpectedResult); + bullet_4(f, static_cast(t), ExpectedResult); + bullet_4(f, static_cast(t), ExpectedResult); + + bullet_5(f, t, ExpectedResult); + bullet_5(f, static_cast(t), ExpectedResult); + + bullet_6(f, t, ExpectedResult); + bullet_6(f, static_cast(t), ExpectedResult); + + return true; +} + +constexpr int zero() { return 0; } +constexpr int square(int x) { return x * x; } // expected-note 4 {{'square' declared here}} +constexpr double product(double x, double y) { return x * y; } + +struct summation { + template + constexpr auto operator()(Ts... ts) { + return (ts + ...); + } +}; + +struct callable { + int x; + + constexpr int &operator()() &noexcept { return x; } // expected-note 2 {{candidate function not viable: requires 0 arguments, but 2 were provided}} + constexpr const int &operator()() const &noexcept { return x; } // expected-note 2 {{candidate function not viable: requires 0 arguments, but 2 were provided}} + constexpr int &&operator()() &&noexcept { return static_cast(x); } // expected-note 2 {{candidate function not viable: requires 0 arguments, but 2 were provided}} + constexpr const int &&operator()() const &&noexcept { return static_cast(x); } // expected-note 2 {{candidate function not viable: requires 0 arguments, but 2 were provided}} +}; + +template +constexpr bool bullet_7(F &&f, Args &&...args) { + assert(std::invoke(static_cast(f), static_cast(args)...) == expected); + static_assert(__is_same(decltype(std::invoke(static_cast(f), static_cast(args)...)), T)); + static_assert(noexcept(std::invoke(static_cast(f), static_cast(args)...)) == IsNoexcept); + + assert(std::invoke_r(f, args...) == expected); + static_assert(__is_same(decltype(std::invoke_r(f, args...)), double)); + static_assert(noexcept(std::invoke_r(f, args...)) == IsNoexcept); + + assert(std::invoke_r(f, args...).value == expected); + static_assert(__is_same(decltype(std::invoke_r(f, args...)), ThrowingInt)); + static_assert(!noexcept(std::invoke_r(f, args...))); + + return true; +} + +template +constexpr bool bullet_7_1() { + callable c{21}; + return std::invoke(static_cast(c)) == 21 && + __is_same(decltype(std::invoke(static_cast(c))), Expected) &&noexcept(std::invoke(static_cast(c))); +} + +struct Base { + mutable int data; + + constexpr int &plus(int x, int y = 0) &noexcept { + data = x + y; + return data; + } + + constexpr const int &minus(int x, int y, int z) const &noexcept { + data = x - y - z; + return data; + } + + constexpr int &&square(int x) && { + data = x * x; + return static_cast(data); + } + + constexpr const int &&sum(int a, int b, int c, int d, int e, int f) const && { + data = a + b + c + d + e + f; + return static_cast(data); + } +}; + +struct Derived : Base { + double data2; +}; + +void test_invoke() { + static_assert(bullets_1_through_3(Base{})); + static_assert(bullets_1_through_3(Derived{})); + + static_assert(bullets_4_through_6(&Base::data, Base{21}, 21)); + static_assert(bullets_4_through_6(&Base::data, Derived{-96, 18}, -96)); + static_assert(bullets_4_through_6(&Derived::data2, Derived{21, 34}, 34.0)); + + static_assert(bullet_7(zero)); + static_assert(bullet_7(square, 5.0)); + static_assert(bullet_7(product, 9, 1)); + static_assert(bullet_7(&zero)); + static_assert(bullet_7(summation{}, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10L)); + static_assert(bullet_7([] { return 18; })); + + static_assert(bullet_7_1()); + static_assert(bullet_7_1()); + static_assert(bullet_7_1()); + static_assert(bullet_7_1()); +} + +struct Incompatible {}; + +void test_errors() { + // TODO: add cases for bullets 1--6 where 2nd param is an int + std::invoke(); // expected-error{{no matching function for call to 'invoke'}} + std::invoke_r(); // expected-error{{no matching function for call to 'invoke_r'}} + + { // Concerning bullet 1 + std::invoke(&Base::plus); + // expected-error@-1{{can't invoke '&Base::plus': 'std::invoke' must have at least 2 arguments for a pointer-to-member function, got 1}} + std::invoke_r(&Base::plus); + // expected-error@-1{{can't invoke '&Base::plus': 'std::invoke_r' must have at least 2 arguments for a pointer-to-member function, got 1}} + std::invoke(&Base::plus, Incompatible{}, 1, 2); + // expected-error@-1{{can't invoke '&Base::plus': expected second argument to be a reference to a class compatible with 'Base', got 'Incompatible'}} + std::invoke_r(&Base::plus, Incompatible{}, 1, 2); + // expected-error@-1{{can't invoke '&Base::plus': expected second argument to be a reference to a class compatible with 'Base', got 'Incompatible'}} + + std::invoke(&Base::sum, Base{}, 0); + // expected-error@-1{{can't invoke '&Base::sum': expected 6 arguments, got 1}} + std::invoke_r(&Base::sum, Base{}, 0); + // expected-error@-1{{can't invoke '&Base::sum': expected 6 arguments, got 1}} + + std::invoke(&Base::sum, Base{}, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + // expected-error@-1{{can't invoke '&Base::sum': expected 6 arguments, got 10}} + std::invoke_r(&Base::sum, Base{}, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + // expected-error@-1{{can't invoke '&Base::sum': expected 6 arguments, got 10}} + + std::invoke_r(&Base::sum, Base{}, 0, 1, 2, 3, 4, 5); + // expected-error@-1{{can't invoke '&Base::sum': return type 'const int' isn't convertible to 'void *'}} + } + { // Concerning bullet 2 + Base b; + std::reference_wrapper rw(b); + + Incompatible p; + std::reference_wrapper pw(p); + + std::invoke(&Base::plus, pw, 1, 2); + // expected-error@-1{{can't invoke '&Base::plus': expected second argument to be a wrapee to a class compatible with 'Base', got 'Incompatible'}} + std::invoke_r(&Base::plus, pw, 1, 2); + // expected-error@-1{{can't invoke '&Base::plus': expected second argument to be a wrapee to a class compatible with 'Base', got 'Incompatible'}} + + std::invoke(&Base::sum, rw, 0); + // expected-error@-1{{can't invoke '&Base::sum': expected 6 arguments, got 1}} + std::invoke_r(&Base::sum, rw, 0); + // expected-error@-1{{can't invoke '&Base::sum': expected 6 arguments, got 1}} + + std::invoke(&Base::sum, rw, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + // expected-error@-1{{can't invoke '&Base::sum': expected 6 arguments, got 10}} + std::invoke_r(&Base::sum, rw, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + // expected-error@-1{{can't invoke '&Base::sum': expected 6 arguments, got 10}} + + std::invoke_r(&Base::plus, rw, 4, 1); + // expected-error@-1{{can't invoke '&Base::plus': return type 'int' isn't convertible to 'void *'}} + } + { // Concerning bullet 3 + Base b; + Incompatible p; + + std::invoke(&Base::plus, &p, 1, 2); + // expected-error@-1{{can't invoke '&Base::plus': expected second argument to be a pointer to a class compatible with 'Base', got 'Incompatible *'}} + std::invoke_r(&Base::plus, &p, 1, 2); + // expected-error@-1{{can't invoke '&Base::plus': expected second argument to be a pointer to a class compatible with 'Base', got 'Incompatible *'}} + + std::invoke(&Base::sum, &b, 0); + // expected-error@-1{{can't invoke '&Base::sum': expected 6 arguments, got 1}} + std::invoke_r(&Base::sum, &b, 0); + // expected-error@-1{{can't invoke '&Base::sum': expected 6 arguments, got 1}} + + std::invoke(&Base::sum, &b, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + // expected-error@-1{{can't invoke '&Base::sum': expected 6 arguments, got 10}} + std::invoke_r(&Base::sum, &b, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + // expected-error@-1{{can't invoke '&Base::sum': expected 6 arguments, got 10}} + + std::invoke_r(&Base::plus, &b, 4, 1); + // expected-error@-1{{can't invoke '&Base::plus': return type 'int' isn't convertible to 'void *'}} + } + { // Concerning bullet 4 + std::invoke(&Base::data); + // expected-error@-1{{can't invoke '&Base::data': 'std::invoke' must have exactly 2 arguments for a pointer-to-data member, got 1}} + std::invoke_r(&Base::data); + // expected-error@-1{{can't invoke '&Base::data': 'std::invoke_r' must have exactly 2 arguments for a pointer-to-data member, got 1}} + + std::invoke(&Base::data, Incompatible{}); + // expected-error@-1{{can't invoke '&Base::data': expected second argument to be a reference to a class compatible with 'Base', got 'Incompatible'}} + std::invoke_r(&Base::data, Incompatible{}); + // expected-error@-1{{can't invoke '&Base::data': expected second argument to be a reference to a class compatible with 'Base', got 'Incompatible'}} + + std::invoke(&Base::data, Base{}, Base{}); + // expected-error@-1{{can't invoke '&Base::data': 'std::invoke' must have exactly 2 arguments for a pointer-to-data member, got 3}} + std::invoke_r(&Base::data, Base{}, Base{}); + // expected-error@-1{{can't invoke '&Base::data': 'std::invoke_r' must have exactly 2 arguments for a pointer-to-data member, got 3}} + + std::invoke_r(&Base::data, Base{}); + // expected-error@-1{{can't invoke '&Base::data': return type 'int' isn't convertible to 'void *'}} + } + { // Concerning bullet 5 + Base b; + std::reference_wrapper rw(b); + + Incompatible p; + std::reference_wrapper pw(p); + + std::invoke(&Base::data, pw); + // expected-error@-1{{can't invoke '&Base::data': expected second argument to be a wrapee to a class compatible with 'Base', got 'Incompatible'}} + std::invoke_r(&Base::data, pw); + // expected-error@-1{{can't invoke '&Base::data': expected second argument to be a wrapee to a class compatible with 'Base', got 'Incompatible'}} + + std::invoke(&Base::data, rw, Base{}); + // expected-error@-1{{can't invoke '&Base::data': 'std::invoke' must have exactly 2 arguments for a pointer-to-data member, got 3}} + std::invoke_r(&Base::data, rw, Base{}); + // expected-error@-1{{can't invoke '&Base::data': 'std::invoke_r' must have exactly 2 arguments for a pointer-to-data member, got 3}} + + std::invoke_r(&Base::data, rw); + // expected-error@-1{{can't invoke '&Base::data': return type 'int' isn't convertible to 'void *'}} + } + { // Concerning bullet 6 + Base b; + Incompatible p; + + std::invoke(&Base::data, &p, 1, 2); + // expected-error@-1{{can't invoke '&Base::data': expected second argument to be a pointer to a class compatible with 'Base', got 'Incompatible *'}} + std::invoke_r(&Base::data, &p, 1, 2); + // expected-error@-1{{can't invoke '&Base::data': expected second argument to be a pointer to a class compatible with 'Base', got 'Incompatible *'}} + + std::invoke(&Base::data, &b, Base{}); + // expected-error@-1{{can't invoke '&Base::data': 'std::invoke' must have exactly 2 arguments for a pointer-to-data member, got 3}} + std::invoke_r(&Base::data, &b, Base{}); + // expected-error@-1{{can't invoke '&Base::data': 'std::invoke_r' must have exactly 2 arguments for a pointer-to-data member, got 3}} + + std::invoke_r(&Base::data, &b); + // expected-error@-1{{can't invoke '&Base::data': return type 'int' isn't convertible to 'void *'}} + } + { // Concerning bullet 7 + std::invoke(square); + // expected-error@-1{{can't invoke 'square': expected 1 argument, got 0}} + (void)std::invoke_r(square); + // expected-error@-1{{can't invoke 'square': expected 1 argument, got 0}} + + std::invoke(square, 1, 2); + // expected-error@-1{{can't invoke 'square': expected 1 argument, got 2}} + (void)std::invoke_r(square, 1, 2); + // expected-error@-1{{can't invoke 'square': expected 1 argument, got 2}} + + std::invoke(&product, 1); + // expected-error@-1{{can't invoke '&product': expected 2 arguments, got 1}} + (void)std::invoke_r(&product, 1); + // expected-error@-1{{can't invoke '&product': expected 2 arguments, got 1}} + + std::invoke(callable{1}, 1, 2); + // expected-error@-1{{can't invoke 'callable{1}': expected 0 arguments, got 2}} + (void)std::invoke_r(callable{1}, 1, 2); + // expected-error@-1{{can't invoke 'callable{1}': expected 0 arguments, got 2}} + } +}