diff --git a/clang/lib/AST/Interp/ByteCodeEmitter.cpp b/clang/lib/AST/Interp/ByteCodeEmitter.cpp --- a/clang/lib/AST/Interp/ByteCodeEmitter.cpp +++ b/clang/lib/AST/Interp/ByteCodeEmitter.cpp @@ -96,8 +96,15 @@ if (!FuncDecl->isDefined()) return Func; + // Lambda static invokers are a special case that we emit custom code for. + bool IsEligibleForCompilation = false; + if (const auto *MD = dyn_cast(FuncDecl)) + IsEligibleForCompilation = MD->isLambdaStaticInvoker(); + if (!IsEligibleForCompilation) + IsEligibleForCompilation = FuncDecl->isConstexpr(); + // Compile the function body. - if (!FuncDecl->isConstexpr() || !visitFunc(FuncDecl)) { + if (!IsEligibleForCompilation || !visitFunc(FuncDecl)) { // Return a dummy function if compilation failed. if (BailLocation) return llvm::make_error(*BailLocation); diff --git a/clang/lib/AST/Interp/ByteCodeStmtGen.h b/clang/lib/AST/Interp/ByteCodeStmtGen.h --- a/clang/lib/AST/Interp/ByteCodeStmtGen.h +++ b/clang/lib/AST/Interp/ByteCodeStmtGen.h @@ -68,6 +68,8 @@ bool visitCaseStmt(const CaseStmt *S); bool visitDefaultStmt(const DefaultStmt *S); + bool emitLambdaStaticInvokerBody(const CXXMethodDecl *MD); + /// Type of the expression returned by the function. std::optional ReturnType; diff --git a/clang/lib/AST/Interp/ByteCodeStmtGen.cpp b/clang/lib/AST/Interp/ByteCodeStmtGen.cpp --- a/clang/lib/AST/Interp/ByteCodeStmtGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeStmtGen.cpp @@ -89,11 +89,67 @@ } // namespace interp } // namespace clang +template +bool ByteCodeStmtGen::emitLambdaStaticInvokerBody( + const CXXMethodDecl *MD) { + assert(MD->isLambdaStaticInvoker()); + assert(MD->hasBody()); + assert(cast(MD->getBody())->body_empty()); + + const CXXRecordDecl *ClosureClass = MD->getParent(); + const CXXMethodDecl *LambdaCallOp = ClosureClass->getLambdaCallOperator(); + assert(ClosureClass->captures_begin() == ClosureClass->captures_end()); + const Function *Func = this->getFunction(LambdaCallOp); + if (!Func) + return false; + assert(Func->hasThisPointer()); + assert(Func->getNumParams() == (MD->getNumParams() + 1 + Func->hasRVO())); + + if (Func->hasRVO()) { + if (!this->emitRVOPtr(MD)) + return false; + } + + // The lambda call operator needs an instance pointer, but we don't have + // one here, and we don't need one either because the lambda cannot have + // any captures, as verified above. Emit a null pointer. This is then + // special-cased when interpreting to not emit any misleading diagnostics. + if (!this->emitNullPtr(MD)) + return false; + + // Forward all arguments from the static invoker to the lambda call operator. + for (const ParmVarDecl *PVD : MD->parameters()) { + auto It = this->Params.find(PVD); + assert(It != this->Params.end()); + + // We do the lvalue-to-rvalue conversion manually here, so no need + // to care about references. + PrimType ParamType = this->classify(PVD->getType()).value_or(PT_Ptr); + if (!this->emitGetParam(ParamType, It->second, MD)) + return false; + } + + if (!this->emitCall(Func, LambdaCallOp)) + return false; + + this->emitCleanup(); + if (ReturnType) + return this->emitRet(*ReturnType, MD); + + // Nothing to do, since we emitted the RVO pointer above. + return this->emitRetVoid(MD); +} + template bool ByteCodeStmtGen::visitFunc(const FunctionDecl *F) { // Classify the return type. ReturnType = this->classify(F->getReturnType()); + // Emit custom code if this is a lambda static invoker. + if (const auto *MD = dyn_cast(F); + MD && MD->isLambdaStaticInvoker()) + return this->emitLambdaStaticInvokerBody(MD); + // Constructor. Set up field initializers. if (const auto *Ctor = dyn_cast(F)) { const RecordDecl *RD = Ctor->getParent(); diff --git a/clang/lib/AST/Interp/Function.h b/clang/lib/AST/Interp/Function.h --- a/clang/lib/AST/Interp/Function.h +++ b/clang/lib/AST/Interp/Function.h @@ -17,6 +17,7 @@ #include "Pointer.h" #include "Source.h" +#include "clang/AST/ASTLambda.h" #include "clang/AST/Decl.h" #include "llvm/Support/raw_ostream.h" @@ -65,7 +66,7 @@ /// the argument values need to be preceeded by a Pointer for the This object. /// /// If the function uses Return Value Optimization, the arguments (and -/// potentially the This pointer) need to be proceeded by a Pointer pointing +/// potentially the This pointer) need to be preceeded by a Pointer pointing /// to the location to construct the returned value. /// /// After the function has been called, it will remove all arguments, @@ -127,7 +128,7 @@ SourceInfo getSource(CodePtr PC) const; /// Checks if the function is valid to call in constexpr. - bool isConstexpr() const { return IsValid; } + bool isConstexpr() const { return IsValid || isLambdaStaticInvoker(); } /// Checks if the function is virtual. bool isVirtual() const; @@ -144,6 +145,22 @@ return nullptr; } + /// Returns whether this function is a lambda static invoker, + /// which we generate custom byte code for. + bool isLambdaStaticInvoker() const { + if (const auto *MD = dyn_cast(F)) + return MD->isLambdaStaticInvoker(); + return false; + } + + /// Returns whether this function is the call operator + /// of a lambda record decl. + bool isLambdaCallOperator() const { + if (const auto *MD = dyn_cast(F)) + return clang::isLambdaCallOperator(MD); + return false; + } + /// Checks if the function is fully done compiling. bool isFullyCompiled() const { return IsFullyCompiled; } diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h --- a/clang/lib/AST/Interp/Interp.h +++ b/clang/lib/AST/Interp/Interp.h @@ -1632,8 +1632,16 @@ const Pointer &ThisPtr = S.Stk.peek(ThisOffset); - if (!CheckInvoke(S, OpPC, ThisPtr)) - return false; + // If the current function is a lambda static invoker and + // the function we're about to call is a lambda call operator, + // skip the CheckInvoke, since the ThisPtr is a null pointer + // anyway. + if (!(S.Current->getFunction() && + S.Current->getFunction()->isLambdaStaticInvoker() && + Func->isLambdaCallOperator())) { + if (!CheckInvoke(S, OpPC, ThisPtr)) + return false; + } if (S.checkingPotentialConstantExpression()) return false; diff --git a/clang/test/AST/Interp/lambda.cpp b/clang/test/AST/Interp/lambda.cpp --- a/clang/test/AST/Interp/lambda.cpp +++ b/clang/test/AST/Interp/lambda.cpp @@ -107,3 +107,58 @@ static_assert(foo() == 1); // expected-error {{not an integral constant expression}} } +namespace StaticInvoker { + constexpr int sv1(int i) { + auto l = []() { return 12; }; + int (*fp)() = l; + return fp(); + } + static_assert(sv1(12) == 12); + + constexpr int sv2(int i) { + auto l = [](int m, float f, void *A) { return m; }; + int (*fp)(int, float, void*) = l; + return fp(i, 4.0f, nullptr); + } + static_assert(sv2(12) == 12); + + constexpr int sv3(int i) { + auto l = [](int m, const int &n) { return m; }; + int (*fp)(int, const int &) = l; + return fp(i, 3); + } + static_assert(sv3(12) == 12); + + constexpr int sv4(int i) { + auto l = [](int &m) { return m; }; + int (*fp)(int&) = l; + return fp(i); + } + static_assert(sv4(12) == 12); + + + + /// FIXME: This is broken for lambda-unrelated reasons. +#if 0 + constexpr int sv5(int i) { + struct F { int a; float f; }; + auto l = [](int m, F f) { return m; }; + int (*fp)(int, F) = l; + return fp(i, F{12, 14.0}); + } + static_assert(sv5(12) == 12); +#endif + + constexpr int sv6(int i) { + struct F { int a; + constexpr F(int a) : a(a) {} + }; + + auto l = [](int m) { return F(12); }; + F (*fp)(int) = l; + F f = fp(i); + + return fp(i).a; + } + static_assert(sv6(12) == 12); +}