Index: clang/lib/AST/Interp/ByteCodeExprGen.cpp =================================================================== --- clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -134,10 +134,11 @@ } case CK_NullToPointer: - case CK_IntegralToPointer: { - if (isa(SubExpr)) - return this->visit(SubExpr); + if (DiscardResult) + return true; + return this->emitNull(classifyPrim(CE->getType()), CE); + case CK_IntegralToPointer: { if (!this->visit(SubExpr)) return false; @@ -968,6 +969,7 @@ return this->emitZeroUint64(E); case PT_Ptr: return this->emitNullPtr(E); + case PT_FnPtr: case PT_Float: assert(false); } @@ -1134,6 +1136,7 @@ case PT_Bool: return this->emitConstBool(Value, E); case PT_Ptr: + case PT_FnPtr: case PT_Float: llvm_unreachable("Invalid integral type"); break; @@ -1667,8 +1670,27 @@ if (E->getBuiltinCallee()) return VisitBuiltinCallExpr(E); - const Decl *Callee = E->getCalleeDecl(); - if (const auto *FuncDecl = dyn_cast_if_present(Callee)) { + QualType ReturnType = E->getCallReturnType(Ctx.getASTContext()); + std::optional T = classify(ReturnType); + bool HasRVO = !ReturnType->isVoidType() && !T; + + if (HasRVO && DiscardResult) { + // If we need to discard the return value but the function returns its + // value via an RVO pointer, we need to create one such pointer just + // for this call. + if (std::optional LocalIndex = allocateLocal(E)) { + if (!this->emitGetPtrLocal(*LocalIndex, E)) + return false; + } + } + + // Put arguments on the stack. + for (const auto *Arg : E->arguments()) { + if (!this->visit(Arg)) + return false; + } + + if (const FunctionDecl *FuncDecl = E->getDirectCallee()) { const Function *Func = getFunction(FuncDecl); if (!Func) return false; @@ -1680,24 +1702,7 @@ if (Func->isFullyCompiled() && !Func->isConstexpr()) return false; - QualType ReturnType = E->getCallReturnType(Ctx.getASTContext()); - std::optional T = classify(ReturnType); - - if (Func->hasRVO() && DiscardResult) { - // If we need to discard the return value but the function returns its - // value via an RVO pointer, we need to create one such pointer just - // for this call. - if (std::optional LocalIndex = allocateLocal(E)) { - if (!this->emitGetPtrLocal(*LocalIndex, E)) - return false; - } - } - - // Put arguments on the stack. - for (const auto *Arg : E->arguments()) { - if (!this->visit(Arg)) - return false; - } + assert(HasRVO == Func->hasRVO()); // In any case call the function. The return value will end up on the stack // and if the function has RVO, we already have the pointer on the stack to @@ -1705,15 +1710,21 @@ if (!this->emitCall(Func, E)) return false; - if (DiscardResult && !ReturnType->isVoidType() && T) - return this->emitPop(*T, E); - - return true; } else { - assert(false && "We don't support non-FunctionDecl callees right now."); + // Indirect call. Visit the callee, which will leave a FunctionPointer on + // the stack. Cleanup of the returned value if necessary will be done after + // the function call completed. + if (!this->visit(E->getCallee())) + return false; + + this->emitCallPtr(E); } - return false; + // Cleanup for discarded return values. + if (DiscardResult && !ReturnType->isVoidType() && T) + return this->emitPop(*T, E); + + return true; } template @@ -1912,6 +1923,9 @@ } } else if (const auto *ECD = dyn_cast(Decl)) { return this->emitConst(ECD->getInitVal(), E); + } else if (const auto *FuncDecl = dyn_cast(Decl)) { + const Function *F = getFunction(FuncDecl); + return F && this->emitGetFnPtr(F, E); } return false; Index: clang/lib/AST/Interp/Context.cpp =================================================================== --- clang/lib/AST/Interp/Context.cpp +++ clang/lib/AST/Interp/Context.cpp @@ -78,9 +78,11 @@ const LangOptions &Context::getLangOpts() const { return Ctx.getLangOpts(); } std::optional Context::classify(QualType T) const { - if (T->isReferenceType() || T->isPointerType()) { + if (T->isFunctionPointerType() || T->isFunctionReferenceType()) + return PT_FnPtr; + + if (T->isReferenceType() || T->isPointerType()) return PT_Ptr; - } if (T->isBooleanType()) return PT_Bool; Index: clang/lib/AST/Interp/Descriptor.cpp =================================================================== --- clang/lib/AST/Interp/Descriptor.cpp +++ clang/lib/AST/Interp/Descriptor.cpp @@ -9,6 +9,7 @@ #include "Descriptor.h" #include "Boolean.h" #include "Floating.h" +#include "FunctionPointer.h" #include "Pointer.h" #include "PrimType.h" #include "Record.h" Index: clang/lib/AST/Interp/FunctionPointer.h =================================================================== --- /dev/null +++ clang/lib/AST/Interp/FunctionPointer.h @@ -0,0 +1,57 @@ + + +#ifndef LLVM_CLANG_AST_INTERP_FUNCTION_POINTER_H +#define LLVM_CLANG_AST_INTERP_FUNCTION_POINTER_H + +#include "Function.h" +#include "Primitives.h" +#include "clang/AST/APValue.h" + +namespace clang { +namespace interp { + +class FunctionPointer final { +private: + const Function *Func; + +public: + FunctionPointer() : Func(nullptr) {} + FunctionPointer(const Function *Func) : Func(Func) { assert(Func); } + + const Function *getFunction() const { return Func; } + + APValue toAPValue() const { + if (!Func) + return APValue(static_cast(nullptr), CharUnits::Zero(), {}, + /*OnePastTheEnd=*/false, /*IsNull=*/true); + + return APValue(Func->getDecl(), CharUnits::Zero(), {}, + /*OnePastTheEnd=*/false, /*IsNull=*/false); + } + + void print(llvm::raw_ostream &OS) const { + OS << "FnPtr("; + if (Func) + OS << Func->getName(); + else + OS << "nullptr"; + OS << ")"; + } + + ComparisonCategoryResult compare(const FunctionPointer &RHS) const { + if (Func == RHS.Func) + return ComparisonCategoryResult::Equal; + return ComparisonCategoryResult::Unordered; + } +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, + FunctionPointer FP) { + FP.print(OS); + return OS; +} + +} // namespace interp +} // namespace clang + +#endif Index: clang/lib/AST/Interp/Interp.h =================================================================== --- clang/lib/AST/Interp/Interp.h +++ clang/lib/AST/Interp/Interp.h @@ -16,6 +16,7 @@ #include "Boolean.h" #include "Floating.h" #include "Function.h" +#include "FunctionPointer.h" #include "InterpFrame.h" #include "InterpStack.h" #include "InterpState.h" @@ -1545,6 +1546,22 @@ return false; } +inline bool CallPtr(InterpState &S, CodePtr &PC) { + const FunctionPointer &FuncPtr = S.Stk.pop(); + + const Function *F = FuncPtr.getFunction(); + if (!F || !F->isConstexpr()) + return false; + + return Call(S, PC, F); +} + +inline bool GetFnPtr(InterpState &S, CodePtr &PC, const Function *Func) { + assert(Func); + S.Stk.push(Func); + return true; +} + //===----------------------------------------------------------------------===// // Read opcode arguments //===----------------------------------------------------------------------===// Index: clang/lib/AST/Interp/InterpStack.h =================================================================== --- clang/lib/AST/Interp/InterpStack.h +++ clang/lib/AST/Interp/InterpStack.h @@ -13,6 +13,7 @@ #ifndef LLVM_CLANG_AST_INTERP_INTERPSTACK_H #define LLVM_CLANG_AST_INTERP_INTERPSTACK_H +#include "FunctionPointer.h" #include "PrimType.h" #include #include @@ -162,6 +163,8 @@ return PT_Uint64; else if constexpr (std::is_same_v) return PT_Float; + else if constexpr (std::is_same_v) + return PT_FnPtr; llvm_unreachable("unknown type push()'ed into InterpStack"); } Index: clang/lib/AST/Interp/Opcodes.td =================================================================== --- clang/lib/AST/Interp/Opcodes.td +++ clang/lib/AST/Interp/Opcodes.td @@ -27,6 +27,7 @@ def Uint64 : Type; def Float : Type; def Ptr : Type; +def FnPtr : Type; //===----------------------------------------------------------------------===// // Types transferred to the interpreter. @@ -77,7 +78,7 @@ } def PtrTypeClass : TypeClass { - let Types = [Ptr]; + let Types = [Ptr, FnPtr]; } def BoolTypeClass : TypeClass { @@ -187,6 +188,12 @@ let ChangesPC = 1; } +def CallPtr : Opcode { + let Args = []; + let Types = []; + let ChangesPC = 1; +} + //===----------------------------------------------------------------------===// // Frame management //===----------------------------------------------------------------------===// @@ -228,6 +235,7 @@ // [] -> [Pointer] def Null : Opcode { let Types = [PtrTypeClass]; + let HasGroup = 1; } //===----------------------------------------------------------------------===// @@ -447,6 +455,14 @@ let HasGroup = 0; } +//===----------------------------------------------------------------------===// +// Function pointers. +//===----------------------------------------------------------------------===// +def GetFnPtr : Opcode { + let Args = [ArgFunction]; +} + + //===----------------------------------------------------------------------===// // Binary operators. //===----------------------------------------------------------------------===// Index: clang/lib/AST/Interp/PrimType.h =================================================================== --- clang/lib/AST/Interp/PrimType.h +++ clang/lib/AST/Interp/PrimType.h @@ -24,6 +24,7 @@ class Pointer; class Boolean; class Floating; +class FunctionPointer; /// Enumeration of the primitive types of the VM. enum PrimType : unsigned { @@ -38,6 +39,7 @@ PT_Bool, PT_Float, PT_Ptr, + PT_FnPtr, }; /// Mapping from primitive types to their representation. @@ -53,6 +55,7 @@ template <> struct PrimConv { using T = Floating; }; template <> struct PrimConv { using T = Boolean; }; template <> struct PrimConv { using T = Pointer; }; +template <> struct PrimConv { using T = FunctionPointer; }; /// Returns the size of a primitive type in bytes. size_t primSize(PrimType Type); @@ -90,6 +93,7 @@ TYPE_SWITCH_CASE(PT_Float, B) \ TYPE_SWITCH_CASE(PT_Bool, B) \ TYPE_SWITCH_CASE(PT_Ptr, B) \ + TYPE_SWITCH_CASE(PT_FnPtr, B) \ } \ } while (0) #define COMPOSITE_TYPE_SWITCH(Expr, B, D) \ Index: clang/lib/AST/Interp/PrimType.cpp =================================================================== --- clang/lib/AST/Interp/PrimType.cpp +++ clang/lib/AST/Interp/PrimType.cpp @@ -9,6 +9,7 @@ #include "PrimType.h" #include "Boolean.h" #include "Floating.h" +#include "FunctionPointer.h" #include "Pointer.h" using namespace clang; Index: clang/test/AST/Interp/functions.cpp =================================================================== --- clang/test/AST/Interp/functions.cpp +++ clang/test/AST/Interp/functions.cpp @@ -99,3 +99,58 @@ huh(); // expected-error {{use of undeclared identifier}} \ // ref-error {{use of undeclared identifier}} } + +namespace FunctionPointers { + constexpr int add(int a, int b) { + return a + b; + } + + struct S { int a; }; + constexpr S getS() { + return S{12}; + } + + constexpr int applyBinOp(int a, int b, int (*op)(int, int)) { + return op(a, b); + } + static_assert(applyBinOp(1, 2, add) == 3, ""); + + constexpr int ignoreReturnValue() { + int (*foo)(int, int) = add; + + foo(1, 2); + return 1; + } + static_assert(ignoreReturnValue() == 1, ""); + + constexpr int createS(S (*gimme)()) { + gimme(); // Ignored return value + return gimme().a; + } + static_assert(createS(getS) == 12, ""); + +namespace FunctionReturnType { + typedef int (*ptr)(int*); + typedef ptr (*pm)(); + + constexpr int fun1(int* y) { + return *y + 10; + } + constexpr ptr fun() { + return &fun1; + } + static_assert(fun() == nullptr, ""); // expected-error {{static assertion failed}} \ + // ref-error {{static assertion failed}} + + constexpr int foo() { + int (*f)(int *) = fun(); + int m = 0; + + m = f(&m); + + return m; + } + static_assert(foo() == 10); +} + +}