diff --git a/clang/docs/ConstantInterpreter.rst b/clang/docs/ConstantInterpreter.rst --- a/clang/docs/ConstantInterpreter.rst +++ b/clang/docs/ConstantInterpreter.rst @@ -177,7 +177,6 @@ * Definition of externs must override previous declaration * Changing the active field of unions -* Union copy constructors * ``typeid`` * ``volatile`` * ``__builtin_constant_p`` @@ -190,5 +189,4 @@ Known Bugs ---------- -* Pointer comparison for equality needs to narrow/expand pointers * If execution fails, memory storing APInts and APFloats is leaked when the stack is cleared diff --git a/clang/lib/AST/CMakeLists.txt b/clang/lib/AST/CMakeLists.txt --- a/clang/lib/AST/CMakeLists.txt +++ b/clang/lib/AST/CMakeLists.txt @@ -52,6 +52,8 @@ FormatString.cpp InheritViz.cpp Interp/Block.cpp + Interp/Boolean.cpp + Interp/Builtin.cpp Interp/ByteCodeEmitter.cpp Interp/ByteCodeExprGen.cpp Interp/ByteCodeGenError.cpp @@ -62,8 +64,9 @@ Interp/EvalEmitter.cpp Interp/Frame.cpp Interp/Function.cpp - Interp/Interp.cpp + Interp/InterpLoop.cpp Interp/InterpFrame.cpp + Interp/InterpHelper.cpp Interp/InterpStack.cpp Interp/InterpState.cpp Interp/Pointer.cpp diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -802,7 +802,7 @@ : Ctx(const_cast(C)), EvalStatus(S), CurrentCall(nullptr), CallStackDepth(0), NextCallIndex(1), StepsLeft(getLangOpts().ConstexprStepLimit), - ForceNewConstInterp(getLangOpts().ForceNewConstInterp), + ForceNewConstInterp(true || getLangOpts().ForceNewConstInterp), EnableNewConstInterp(ForceNewConstInterp || getLangOpts().EnableNewConstInterp), BottomFrame(*this, SourceLocation(), nullptr, nullptr, nullptr), @@ -12194,7 +12194,7 @@ /// EvaluateAsRValue - Try to evaluate this expression, performing an implicit /// lvalue-to-rvalue cast if it is an lvalue. static bool EvaluateAsRValue(EvalInfo &Info, const Expr *E, APValue &Result) { - if (Info.EnableNewConstInterp) { + if (Info.EnableNewConstInterp) { auto &InterpCtx = Info.Ctx.getInterpContext(); switch (InterpCtx.evaluateAsRValue(Info, E, Result)) { case interp::InterpResult::Success: @@ -13146,6 +13146,18 @@ } } + // The constexpr VM attempts to compile all methods to bytecode here. + if (Info.EnableNewConstInterp) { + auto &InterpCtx = Info.Ctx.getInterpContext(); + switch (InterpCtx.isPotentialConstantExpr(Info, FD)) { + case interp::InterpResult::Success: + case interp::InterpResult::Fail: + return Diags.empty(); + case interp::InterpResult::Bail: + break; + } + } + const CXXMethodDecl *MD = dyn_cast(FD); const CXXRecordDecl *RD = MD ? MD->getParent()->getCanonicalDecl() : nullptr; diff --git a/clang/lib/AST/Interp/Block.h b/clang/lib/AST/Interp/Block.h --- a/clang/lib/AST/Interp/Block.h +++ b/clang/lib/AST/Interp/Block.h @@ -55,6 +55,8 @@ bool isStatic() const { return IsStatic; } /// Checks if the block is temporary. bool isTemporary() const { return Desc->IsTemporary; } + /// Checks if the block is constant. + bool isConst() const { return Desc->IsConst; } /// Returns the size of the block. InterpSize getSize() const { return Desc->getAllocSize(); } /// Returns the declaration ID. diff --git a/clang/lib/AST/Interp/Boolean.h b/clang/lib/AST/Interp/Boolean.h --- a/clang/lib/AST/Interp/Boolean.h +++ b/clang/lib/AST/Interp/Boolean.h @@ -11,7 +11,6 @@ #include #include -#include "Integral.h" #include "clang/AST/APValue.h" #include "clang/AST/ComparisonCategories.h" #include "llvm/ADT/APSInt.h" @@ -21,6 +20,9 @@ namespace clang { namespace interp { +template class Integral; +template class FixedIntegral; + /// Wrapper around boolean types. class Boolean { private: @@ -34,6 +36,9 @@ /// Zero-initializes a boolean. Boolean() : V(false) {} + /// Initializes a boolean from an APSInt. + explicit Boolean(const llvm::APSInt &V) : V(V != 0) {} + bool operator<(Boolean RHS) const { return V < RHS.V; } bool operator>(Boolean RHS) const { return V > RHS.V; } bool operator<=(Boolean RHS) const { return V <= RHS.V; } @@ -43,18 +48,30 @@ bool operator>(unsigned RHS) const { return static_cast(V) > RHS; } - Boolean operator-() const { return Boolean(V); } + Boolean operator+(Boolean RHS) const { return Boolean(V | RHS.V); } + Boolean operator-(Boolean RHS) const { return Boolean(V ^ RHS.V); } + Boolean operator*(Boolean RHS) const { return Boolean(V & RHS.V); } + Boolean operator/(Boolean RHS) const { return Boolean(V); } + Boolean operator%(Boolean RHS) const { return Boolean(V % RHS.V); } + Boolean operator&(Boolean RHS) const { return Boolean(V & RHS.V); } + Boolean operator|(Boolean RHS) const { return Boolean(V | RHS.V); } + Boolean operator^(Boolean RHS) const { return Boolean(V ^ RHS.V); } + + Boolean operator-() const { return *this; } Boolean operator~() const { return Boolean(true); } + Boolean operator>>(unsigned RHS) const { return Boolean(V && RHS == 0); } + Boolean operator<<(unsigned RHS) const { return *this; } + explicit operator unsigned() const { return V; } explicit operator int64_t() const { return V; } explicit operator uint64_t() const { return V; } - APSInt toAPSInt() const { - return APSInt(APInt(1, static_cast(V), false), true); + llvm::APSInt toAPSInt() const { + return llvm::APSInt(llvm::APInt(1, static_cast(V), false), true); } - APSInt toAPSInt(unsigned NumBits) const { - return APSInt(toAPSInt().zextOrTrunc(NumBits), true); + llvm::APSInt toAPSInt(unsigned NumBits) const { + return llvm::APSInt(toAPSInt().zextOrTrunc(NumBits), true); } APValue toAPValue() const { return APValue(toAPSInt()); } @@ -62,6 +79,8 @@ constexpr static unsigned bitWidth() { return true; } bool isZero() const { return !V; } + bool isTrue() const { return !isZero(); } + bool isFalse() const { return isZero(); } bool isMin() const { return isZero(); } constexpr static bool isMinusOne() { return false; } @@ -72,7 +91,11 @@ constexpr static bool isPositive() { return !isNegative(); } ComparisonCategoryResult compare(const Boolean &RHS) const { - return Compare(V, RHS.V); + if (!V && RHS.V) + return ComparisonCategoryResult::Less; + if (V && !RHS.V) + return ComparisonCategoryResult::Greater; + return ComparisonCategoryResult::Equal; } unsigned countLeadingZeros() const { return V ? 0 : 1; } @@ -91,17 +114,14 @@ } template - static typename std::enable_if::type from( - Integral Value) { - return Boolean(!Value.isZero()); - } + static Boolean from(const Integral &Value); template - static Boolean from(Integral<0, SrcSign> Value) { - return Boolean(!Value.isZero()); - } + static Boolean from(const FixedIntegral &I); + + static Boolean from(Boolean Value) { return Value; } - static Boolean zero() { return from(false); } + static Boolean zero() { return Boolean(false); } template static Boolean from(T Value, unsigned NumBits) { @@ -122,7 +142,7 @@ } static bool add(Boolean A, Boolean B, unsigned OpBits, Boolean *R) { - *R = Boolean(A.V || B.V); + *R = Boolean(A.V | B.V); return false; } @@ -132,7 +152,7 @@ } static bool mul(Boolean A, Boolean B, unsigned OpBits, Boolean *R) { - *R = Boolean(A.V && B.V); + *R = Boolean(A.V & B.V); return false; } }; diff --git a/clang/lib/AST/Interp/Boolean.cpp b/clang/lib/AST/Interp/Boolean.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/AST/Interp/Boolean.cpp @@ -0,0 +1,7 @@ +//===--- Boolean.cpp - Wrapper for boolean types ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// diff --git a/clang/lib/AST/Interp/Builtin.h b/clang/lib/AST/Interp/Builtin.h new file mode 100644 --- /dev/null +++ b/clang/lib/AST/Interp/Builtin.h @@ -0,0 +1,27 @@ +//===--- Builtin.h - Builtins for the constexpr VM --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Builtin dispatch method definition. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_AST_INTERP_BUILTIN_H +#define LLVM_CLANG_AST_INTERP_BUILTIN_H + +#include "Function.h" + +namespace clang { +namespace interp { +class InterpState; + +bool InterpBuiltin(InterpState &S, CodePtr OpPC, unsigned BuiltinOp); + +} // namespace interp +} // namespace clang + +#endif diff --git a/clang/lib/AST/Interp/Builtin.cpp b/clang/lib/AST/Interp/Builtin.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/AST/Interp/Builtin.cpp @@ -0,0 +1,95 @@ +//===--- Builtin.cpp - Builtins for the constexpr VM ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Builtin.h" +#include "InterpState.h" +#include "Pointer.h" +#include "Program.h" +#include "PrimType.h" +#include "clang/AST/ASTContext.h" +#include "clang/Basic/Builtins.h" + +using namespace clang; +using namespace clang::interp; + +static bool BuiltinLength(InterpState &S, CodePtr OpPC) { + using ResultT = PrimConv::T; + // Get the pointer to the string and ensure it's an array of primitives. + const auto &Ptr = S.Stk.pop(); + if (!Ptr.inPrimitiveArray()) + return false; + if (!S.CheckLoad(OpPC, Ptr)) + return false; + + // Find the null terminator, starting at the pointed element. + const unsigned Size = Ptr.getSize() - Ptr.getOffset(); + char *Data = &Ptr.narrow().deref(); + switch (Ptr.elemSize()) { + case 1: + for (unsigned Off = 0, I = 0; Off < Size; Off += 1, ++I) { + if (*reinterpret_cast(Data + Off) == 0) { + S.Stk.push(ResultT::from(I)); + return true; + } + } + break; + case 2: + for (unsigned Off = 0, I = 0; Off < Size; Off += 2, ++I) { + if (*reinterpret_cast(Data + Off) == 0) { + S.Stk.push(ResultT::from(I)); + return true; + } + } + break; + case 4: + for (unsigned Off = 0, I = 0; Off < Size; Off += 4, ++I) { + if (*reinterpret_cast(Data + Off) == 0) { + S.Stk.push(ResultT::from(I)); + return true; + } + } + break; + default: + llvm_unreachable("Unsupported character width!"); + } + + S.FFDiag(S.getSource(OpPC), diag::note_constexpr_access_past_end) << AK_Read; + return false; +} + +static void ReportInvalidExpr(InterpState &S, CodePtr OpPC, unsigned Op) { + const char *Name = S.getCtx().BuiltinInfo.getName(Op); + + if (S.getLangOpts().CPlusPlus11) + S.CCEDiag(S.getSource(OpPC), diag::note_constexpr_invalid_function) + << /*isConstexpr*/ 0 << /*isConstructor*/ 0 + << (std::string("'") + Name + "'"); + else + S.CCEDiag(S.getSource(OpPC), diag::note_invalid_subexpr_in_const_expr); +} + +namespace clang { +namespace interp { + +bool InterpBuiltin(InterpState &S, CodePtr OpPC, unsigned Op) { + switch (Op) { + case Builtin::BIstrlen: + case Builtin::BIwcslen: + ReportInvalidExpr(S, OpPC, Op); + return BuiltinLength(S, OpPC); + case Builtin::BI__builtin_strlen: + case Builtin::BI__builtin_wcslen: + return BuiltinLength(S, OpPC); + default: + // Builtin is not constexpr. + return false; + } +} + +} // namespace interp +} // namespace clang diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.h b/clang/lib/AST/Interp/ByteCodeExprGen.h --- a/clang/lib/AST/Interp/ByteCodeExprGen.h +++ b/clang/lib/AST/Interp/ByteCodeExprGen.h @@ -66,10 +66,61 @@ : Emitter(Ctx, P, Args...), Ctx(Ctx), P(P) {} // Expression visitors - result returned on stack. + bool VisitDeclRefExpr(const DeclRefExpr *E); bool VisitCastExpr(const CastExpr *E); + bool VisitConstantExpr(const ConstantExpr *E); bool VisitIntegerLiteral(const IntegerLiteral *E); + bool VisitFloatingLiteral(const FloatingLiteral *E); + bool VisitStringLiteral(const StringLiteral *E); + bool VisitCharacterLiteral(const CharacterLiteral *E); + bool VisitImaginaryLiteral(const ImaginaryLiteral *E); bool VisitParenExpr(const ParenExpr *E); bool VisitBinaryOperator(const BinaryOperator *E); + bool VisitUnaryPostInc(const UnaryOperator *E); + bool VisitUnaryPostDec(const UnaryOperator *E); + bool VisitUnaryPreInc(const UnaryOperator *E); + bool VisitUnaryPreDec(const UnaryOperator *E); + bool VisitUnaryAddrOf(const UnaryOperator *E); + bool VisitUnaryDeref(const UnaryOperator *E); + bool VisitUnaryPlus(const UnaryOperator *E); + bool VisitUnaryMinus(const UnaryOperator *E); + bool VisitUnaryNot(const UnaryOperator *E); + bool VisitUnaryLNot(const UnaryOperator *E); + bool VisitUnaryReal(const UnaryOperator *E); + bool VisitUnaryImag(const UnaryOperator *E); + bool VisitConditionalOperator(const ConditionalOperator *E); + bool VisitBinaryConditionalOperator(const BinaryConditionalOperator *E); + bool VisitMemberExpr(const MemberExpr *E); + bool VisitCallExpr(const CallExpr *E); + bool VisitArraySubscriptExpr(const ArraySubscriptExpr *E); + bool VisitExprWithCleanups(const ExprWithCleanups *E); + bool VisitCompoundLiteralExpr(const CompoundLiteralExpr *E); + bool VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *E); + bool VisitImplicitValueInitExpr(const ImplicitValueInitExpr *E); + bool VisitSubstNonTypeTemplateParmExpr(const SubstNonTypeTemplateParmExpr *E); + bool VisitUnaryExprOrTypeTraitExpr(const UnaryExprOrTypeTraitExpr *E); + bool VisitSizeOfPackExpr(const SizeOfPackExpr *E); + bool VisitOpaqueValueExpr(const OpaqueValueExpr *E); + bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E); + bool VisitArrayInitIndexExpr(const ArrayInitIndexExpr *E); + bool VisitInitListExpr(const InitListExpr *E); + bool VisitTypeTraitExpr(const TypeTraitExpr *E); + bool VisitBlockExpr(const BlockExpr *E); + bool VisitCXXConstructExpr(const CXXConstructExpr *E); + bool VisitCXXMemberCallExpr(const CXXMemberCallExpr *E); + bool VisitCXXNullPtrLiteralExpr(const CXXNullPtrLiteralExpr *E); + bool VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr *E); + bool VisitCXXThisExpr(const CXXThisExpr *E); + bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E); + bool VisitCXXDefaultInitExpr(const CXXDefaultInitExpr *E); + bool VisitCXXScalarValueInitExpr(const CXXScalarValueInitExpr *E); + bool VisitCXXThrowExpr(const CXXThrowExpr *E); + bool VisitCXXTemporaryObjectExpr(const CXXTemporaryObjectExpr *E); + bool VisitCXXTypeidExpr(const CXXTypeidExpr *E); + bool VisitCXXReinterpretCastExpr(const CXXReinterpretCastExpr *E); + bool VisitCXXDynamicCastExpr(const CXXDynamicCastExpr *E); + bool VisitCXXInheritedCtorInitExpr(const CXXInheritedCtorInitExpr *E); + bool VisitCXXStdInitializerListExpr(const CXXStdInitializerListExpr *E); protected: bool visitExpr(const Expr *E) override; @@ -86,16 +137,26 @@ Record *getRecord(QualType Ty); Record *getRecord(const RecordDecl *RD); + /// Perform an action if a record field is found. + bool withField(const FieldDecl *F, + llvm::function_ref GenField); + /// Returns the size int bits of an integer. unsigned getIntWidth(QualType Ty) { - auto &ASTContext = Ctx.getASTContext(); - return ASTContext.getIntWidth(Ty); + return Ctx.getASTContext().getIntWidth(Ty); + } + /// Returns the applicable fp semantics. + const fltSemantics *getFltSemantics(QualType Ty) { + return &Ctx.getASTContext().getFloatTypeSemantics(Ty); } - /// Returns the value of CHAR_BIT. unsigned getCharBit() const { - auto &ASTContext = Ctx.getASTContext(); - return ASTContext.getTargetInfo().getCharWidth(); + return Ctx.getASTContext().getTargetInfo().getCharWidth(); + } + + /// Canonicalizes an array type. + const ConstantArrayType *getAsConstantArrayType(QualType AT) { + return Ctx.getASTContext().getAsConstantArrayType(AT); } /// Classifies a type. @@ -108,6 +169,8 @@ /// Checks if a pointer needs adjustment. bool needsAdjust(QualType Ty) const { + if (llvm::Optional T = classify(Ty)) + return T != PT_MemPtr && T != PT_FnPtr; return true; } @@ -129,6 +192,9 @@ /// Visits an expression and converts it to a boolean. bool visitBool(const Expr *E); + /// Visits a base class initializer. + bool visitBaseInitializer(const Expr *Init, InitFnRef GenBase); + /// Visits an initializer for a local. bool visitLocalInitializer(const Expr *Init, unsigned I) { return visitInitializer(Init, [this, I, Init] { @@ -166,6 +232,32 @@ /// Emits a zero initializer. bool visitZeroInitializer(PrimType T, const Expr *E); + /// Emits a direct function call. + bool emitFunctionCall(const FunctionDecl *Callee, llvm::Optional T, + const Expr *Call); + /// Emits a direct method call. + bool emitMethodCall(const CXXMethodDecl *Callee, llvm::Optional T, + const Expr *Call); + + /// Compiles an offset calculator. + bool visitOffset(const Expr *Ptr, const Expr *Offset, const Expr *E, + UnaryFn OffsetFn); + + /// Fetches a member of a structure given by a pointer. + bool visitIndirectMember(const BinaryOperator *E); + + /// Compiles simple or compound assignments. + bool visitAssign(PrimType T, const BinaryOperator *BO); + bool visitShiftAssign(PrimType RHS, BinaryFn F, const BinaryOperator *BO); + bool visitCompoundAssign(PrimType RHS, UnaryFn F, const BinaryOperator *BO); + bool visitPtrAssign(PrimType RHS, UnaryFn F, const BinaryOperator *BO); + + /// Emits a cast between two types. + bool emitConv(PrimType From, QualType FromTy, PrimType To, QualType ToTy, + const Expr *Cast); + + bool visitShortCircuit(const BinaryOperator *E); + enum class DerefKind { /// Value is read and pushed to stack. Read, @@ -188,6 +280,12 @@ bool dereferenceVar(const Expr *LV, PrimType T, const VarDecl *PD, DerefKind AK, llvm::function_ref Direct, llvm::function_ref Indirect); + bool dereferenceMember(const MemberExpr *PD, PrimType T, DerefKind AK, + llvm::function_ref Direct, + llvm::function_ref Indirect); + + /// Converts an lvalue to an rvalue. + bool lvalueToRvalue(const Expr *LV, const Expr *E); /// Emits an APInt constant. bool emitConst(PrimType T, unsigned NumBits, const llvm::APInt &Value, @@ -201,18 +299,56 @@ return emitConst(*Ctx.classify(Ty), NumBits, WrappedValue, E); } + /// Compiles a list of arguments. + bool visitArguments(QualType CalleeTy, ArrayRef Args); + /// Compiles an argument. + bool visitArgument(const Expr *E, bool Discard); + + /// Visits a constant function invocation. + bool getPtrConstFn(const FunctionDecl *FD, const Expr *E); /// Returns a pointer to a variable declaration. bool getPtrVarDecl(const VarDecl *VD, const Expr *E); /// Returns the index of a global. llvm::Optional getGlobalIdx(const VarDecl *VD); + /// Compiles a string initializer. + bool visitStringInitializer(const StringLiteral *S); + /// Compiles an array initializer. + bool visitArrayInitializer(const ConstantArrayType *AT, + const Expr *E, + llvm::function_ref Elem); + /// Visits a record initializer. + bool visitRecordInitializer(const RecordType *RT, const InitListExpr *List); + /// Visits a complex initializer. + bool visitComplexInitializer(const ComplexType *CT, const InitListExpr *List); + + /// Visits a cast to complex. + bool visitCastToComplex(const CastExpr *CE); + + /// Zero initializer for a record. + bool visitZeroInitializer(const Expr *E, const RecordDecl *RD); + + /// Registers an opaque expression. + bool visitOpaqueExpr(const OpaqueValueExpr *Expr); + /// Visits a conditional operator. + bool visitConditionalOperator(const AbstractConditionalOperator *CO); + + /// Materializes a composite. + bool materialize(const Expr *Alloc, const Expr *Init, bool IsConst, + bool IsGlobal, bool IsExtended); + /// Emits the initialized pointer. bool emitInitFn() { - assert(InitFn && "missing initializer"); - return (*InitFn)(); + return InitFn ? (*InitFn)() : this->emitTrap(SourceInfo{}); } + /// Returns the alignment of a type. + CharUnits getAlignOfType(QualType T, bool IsPreferred); + + /// Returns the alignment of an expresion. + CharUnits getAlignOfExpr(const Expr *E, bool IsPreferred); + protected: /// Variable to storage mapping. llvm::DenseMap Locals; @@ -231,6 +367,19 @@ /// Expression being initialized. llvm::Optional InitFn = {}; + + /// Enumeration of initializer kinds. + enum class InitKind { + /// Regular invocation. + ROOT, + /// Base class initializer. + BASE, + /// Activates a union field. + UNION, + }; + + /// Initialiser kinds for the current object. + InitKind Initialiser = InitKind::ROOT; }; extern template class ByteCodeExprGen; diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp --- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -43,35 +43,45 @@ template class OptionScope { public: using InitFnRef = typename ByteCodeExprGen::InitFnRef; + using InitKind = typename ByteCodeExprGen::InitKind; using ChainedInitFnRef = std::function; /// Root constructor, compiling or discarding primitives. OptionScope(ByteCodeExprGen *Ctx, bool NewDiscardResult) : Ctx(Ctx), OldDiscardResult(Ctx->DiscardResult), - OldInitFn(std::move(Ctx->InitFn)) { + OldInitFn(std::move(Ctx->InitFn)), + OldInitialiser(Ctx->Initialiser) { Ctx->DiscardResult = NewDiscardResult; Ctx->InitFn = llvm::Optional{}; + Ctx->Initialiser = InitKind::ROOT; } /// Root constructor, setting up compilation state. - OptionScope(ByteCodeExprGen *Ctx, InitFnRef NewInitFn) + OptionScope(ByteCodeExprGen *Ctx, InitFnRef NewInitFn, + InitKind NewInitialiser = InitKind::ROOT) : Ctx(Ctx), OldDiscardResult(Ctx->DiscardResult), - OldInitFn(std::move(Ctx->InitFn)) { + OldInitFn(std::move(Ctx->InitFn)), + OldInitialiser(Ctx->Initialiser) { Ctx->DiscardResult = true; Ctx->InitFn = NewInitFn; + Ctx->Initialiser = NewInitialiser; } + ~OptionScope() { + Ctx->DiscardResult = OldDiscardResult; + Ctx->InitFn = std::move(OldInitFn); + Ctx->Initialiser = OldInitialiser; + } + +protected: /// Extends the chain of initialisation pointers. - OptionScope(ByteCodeExprGen *Ctx, ChainedInitFnRef NewInitFn) + OptionScope(ByteCodeExprGen *Ctx, ChainedInitFnRef NewInitFn, + InitKind NewInitialiser) : Ctx(Ctx), OldDiscardResult(Ctx->DiscardResult), - OldInitFn(std::move(Ctx->InitFn)) { + OldInitFn(std::move(Ctx->InitFn)), OldInitialiser(Ctx->Initialiser) { assert(OldInitFn && "missing initializer"); Ctx->InitFn = [this, NewInitFn] { return NewInitFn(*OldInitFn); }; - } - - ~OptionScope() { - Ctx->DiscardResult = OldDiscardResult; - Ctx->InitFn = std::move(OldInitFn); + Ctx->Initialiser = NewInitialiser; } private: @@ -81,29 +91,173 @@ bool OldDiscardResult; /// Old pointer emitter to restore. llvm::Optional OldInitFn; + /// Base flag to restore. + InitKind OldInitialiser; +}; + +// Scope which initialises a base class. +template class BaseScope : public OptionScope { +public: + using ChainedInitFnRef = typename OptionScope::ChainedInitFnRef; + using InitKind = typename OptionScope::InitKind; + BaseScope(ByteCodeExprGen *Ctx, ChainedInitFnRef FieldFn) + : OptionScope(Ctx, FieldFn, InitKind::BASE) {} +}; + +// Scope which initialises a union field. +template class UnionScope : public OptionScope { +public: + using ChainedInitFnRef = typename OptionScope::ChainedInitFnRef; + using InitKind = typename OptionScope::InitKind; + UnionScope(ByteCodeExprGen *Ctx, ChainedInitFnRef FieldFn) + : OptionScope(Ctx, FieldFn, InitKind::UNION) {} +}; + +// Scope which initialises a record field or array element. +template class FieldScope : public OptionScope { +public: + using ChainedInitFnRef = typename OptionScope::ChainedInitFnRef; + using InitKind = typename OptionScope::InitKind; + FieldScope(ByteCodeExprGen *Ctx, ChainedInitFnRef FieldFn) + : OptionScope(Ctx, FieldFn, InitKind::ROOT) {} }; } // namespace interp } // namespace clang +template +bool ByteCodeExprGen::VisitDeclRefExpr(const DeclRefExpr *DE) { + if (DiscardResult) + return true; + + if (auto *PD = dyn_cast(DE->getDecl())) { + QualType Ty = PD->getType(); + auto It = this->Params.find(PD); + if (It == this->Params.end()) { + // Pointers to parameters are not constant expressions. + return this->emitTrap(DE); + } else { + // Generate a pointer to a parameter. + if (Ty->isReferenceType() || !classify(Ty)) + return this->emitGetParamPtr(It->second, DE); + else + return this->emitGetPtrParam(It->second, DE); + } + } + if (auto *VD = dyn_cast(DE->getDecl())) { + auto It = Locals.find(VD); + if (It == Locals.end()) { + return getPtrVarDecl(VD, DE); + } else { + // Generate a pointer to a local. + if (VD->getType()->isReferenceType()) + return this->emitGetLocal(PT_Ptr, It->second.Offset, DE); + else + return this->emitGetPtrLocal(It->second.Offset, DE); + } + } + if (auto *ED = dyn_cast(DE->getDecl())) { + QualType Ty = ED->getType(); + if (Optional T = classify(Ty)) + return this->emitConst(*T, getIntWidth(Ty), ED->getInitVal(), DE); + return this->bail(DE); + } + if (auto *FD = dyn_cast(DE->getDecl())) { + if (auto *CD = dyn_cast(DE->getDecl())) + if (!CD->isStatic()) + return this->emitConstMem(CD, DE); + return getPtrConstFn(FD, DE); + } + if (auto *FD = dyn_cast(DE->getDecl())) + return this->emitConstMem(FD, DE); + + // TODO: compile other decls. + return this->bail(DE); +} + template bool ByteCodeExprGen::VisitCastExpr(const CastExpr *CE) { auto *SubExpr = CE->getSubExpr(); switch (CE->getCastKind()) { - case CK_LValueToRValue: { - return dereference( - CE->getSubExpr(), DerefKind::Read, - [](PrimType) { - // Value loaded - nothing to do here. - return true; - }, - [this, CE](PrimType T) { - // Pointer on stack - dereference it. - if (!this->emitLoadPop(T, CE)) - return false; - return DiscardResult ? this->emitPop(T, CE) : true; - }); + case CK_LValueToRValue: + return lvalueToRvalue(CE->getSubExpr(), CE); + + case CK_IntegralToBoolean: { + if (DiscardResult) + return discard(SubExpr); + + if (!visit(SubExpr)) + return false; + + return this->emitTest(*classify(SubExpr->getType()), CE); + } + + case CK_IntegralCast: { + if (DiscardResult) + return discard(SubExpr); + + if (!visit(SubExpr)) + return false; + + QualType ArgTy = SubExpr->getType(); + QualType RetTy = CE->getType(); + auto ArgT = *classify(ArgTy); + auto RetT = *classify(RetTy); + if (isFixedIntegral(RetT)) + return this->emitCastFP(ArgT, RetT, getIntWidth(RetTy), CE); + else + return this->emitCast(ArgT, RetT, CE); + } + + case CK_FloatingCast: + case CK_IntegralToFloating: { + if (DiscardResult) + return discard(SubExpr); + + if (!visit(SubExpr)) + return false; + + QualType ArgTy = SubExpr->getType(); + QualType RetTy = CE->getType(); + auto ArgT = *classify(ArgTy); + return this->emitCastRealFP(ArgT, getFltSemantics(RetTy), CE); + } + + case CK_FloatingToIntegral: { + if (!visit(SubExpr)) + return false; + + QualType RetTy = CE->getType(); + PrimType ArgT = *classify(SubExpr->getType()); + PrimType RetT = *classify(RetTy); + if (isFixedIntegral(RetT)) { + if (!this->emitCastRealFPToAluFP(ArgT, RetT, getIntWidth(RetTy), CE)) + return false; + } else { + if (!this->emitCastRealFPToAlu(ArgT, RetT, CE)) + return false; + } + + return DiscardResult ? this->emitPop(RetT, CE) : true; + } + + case CK_NullToPointer: + case CK_NullToMemberPointer: + // Emit a null pointer, avoiding redundancy when casting from nullptr. + if (!isa(SubExpr)) { + if (!discard(SubExpr)) + return false; + } + if (DiscardResult) + return true; + return visitZeroInitializer(classifyPrim(CE->getType()), CE); + + case CK_PointerToBoolean: { + if (DiscardResult) + return discard(SubExpr); + else + return visit(SubExpr) && this->emitTestPtr(CE); } case CK_ArrayToPointerDecay: @@ -118,6 +272,108 @@ case CK_ToVoid: return discard(SubExpr); + case CK_IntegralRealToComplex: + case CK_FloatingRealToComplex: + return visitCastToComplex(CE); + + case CK_DerivedToBase: + case CK_UncheckedDerivedToBase: { + if (!visit(SubExpr)) + return false; + + const Record *R = getRecord(SubExpr->getType()); + if (!R) + return false; + + for (auto *Step : CE->path()) { + auto *Decl = Step->getType()->getAs()->getDecl(); + + const Record::Base *Base; + if (Step->isVirtual()) { + Base = R->getVirtualBase(Decl); + if (!this->emitGetPtrVirtBase(Decl, CE)) + return false; + } else { + Base = R->getBase(Decl); + if (!this->emitGetPtrBase(Base->Offset, CE)) + return false; + } + + R = Base->R; + } + + return DiscardResult ? this->emitPopPtr(CE) : true; + } + + case CK_BaseToDerived: { + if (!visit(SubExpr)) + return false; + if (!this->emitCastToDerived(CE, CE)) + return false; + return DiscardResult ? this->emitPopPtr(CE) : true; + } + + case CK_DerivedToBaseMemberPointer: { + if (!visit(SubExpr)) + return false; + for (auto *Step : CE->path()) { + auto *R = Step->getType()->getAs()->getDecl(); + if (!this->emitCastMemberToBase(cast(R), CE)) + return false; + } + return DiscardResult ? this->emitPopMemPtr(CE) : true; + } + + case CK_BaseToDerivedMemberPointer: { + if (CE->path_empty()) + return discard(SubExpr); + + if (!visit(SubExpr)) + return false; + + // Single cast step from base to derived. + auto Step = [this, CE](const Type *T) { + auto *R = cast(T->getAs()->getDecl()); + return this->emitCastMemberToDerived(R, CE); + }; + + // Base-to-derived member pointer casts store the path in derived-to-base + // order, so iterate backwards. The CXXBaseSpecifier also provides us with + // the wrong end of the derived->base arc, so stagger the path by one class. + typedef std::reverse_iterator Rev; + for (auto It = Rev(CE->path_end() - 1); It != Rev(CE->path_begin()); ++It) { + if (!Step((*It)->getType().getTypePtr())) + return false; + } + // The final type is encoded in the type of the cast. + if (!Step(CE->getType()->getAs()->getClass())) + return false; + return DiscardResult ? this->emitPopMemPtr(CE) : true; + } + + case CK_BitCast: { + if (!this->Visit(SubExpr)) + return false; + QualType T = CE->getType(); + if (!T->isVoidPointerType()) { + if (T->isMemberFunctionPointerType()) + return false; + if (T->isMemberPointerType()) + return false; + if (T->isFunctionPointerType()) + return false; + return this->emitPointerBitCast(SubExpr); + } + return true; + } + + case CK_IntegralToPointer: + case CK_PointerToIntegral: { + if (!this->Visit(SubExpr)) + return false; + return this->emitPointerBitCast(SubExpr); + } + default: { // TODO: implement other casts. return this->bail(CE); @@ -125,6 +381,27 @@ } } +template +bool ByteCodeExprGen::VisitConstantExpr(const ConstantExpr *CE) { + if (DiscardResult) + return true; + + switch (CE->getResultStorageKind()) { + case ConstantExpr::RSK_Int64: { + QualType Ty = CE->getType(); + PrimType T = *classify(Ty); + return this->emitConst(T, getIntWidth(Ty), CE->getResultAsAPSInt(), CE); + } + + case ConstantExpr::RSK_APValue: + return this->bail(CE); + + case ConstantExpr::RSK_None: + return this->Visit(CE->getSubExpr()); + } + return false; +} + template bool ByteCodeExprGen::VisitIntegerLiteral(const IntegerLiteral *LE) { if (DiscardResult) @@ -137,6 +414,57 @@ return this->bail(LE); } +template +bool ByteCodeExprGen::VisitFloatingLiteral(const FloatingLiteral *E) { + if (DiscardResult) + return true; + return this->emitConstRealFP(E, E); +} + +template +bool ByteCodeExprGen::VisitStringLiteral(const StringLiteral *E) { + if (InitFn) + return visitStringInitializer(E); + if (DiscardResult) + return true; + return this->emitGetPtrGlobal(P.createGlobalString(E), E); +} + +template +bool ByteCodeExprGen::VisitCharacterLiteral( + const CharacterLiteral *CE) { + if (DiscardResult) + return true; + + QualType CharTy = CE->getType(); + if (Optional T = classify(CharTy)) { + const unsigned NumBits = sizeof(unsigned) * CHAR_BIT; + APInt Char(NumBits, static_cast(CE->getValue()), false); + return this->emitConst(*T, getIntWidth(CharTy), Char, CE); + } + return this->bail(CE); +} + +template +bool ByteCodeExprGen::VisitImaginaryLiteral( + const ImaginaryLiteral *E) { + const Expr *SubExpr = E->getSubExpr(); + QualType ElemType = SubExpr->getType(); + + PrimType T = classifyPrim(ElemType); + ImplicitValueInitExpr ZE(ElemType); + + if (!emitInitFn()) + return false; + if (!visit(&ZE)) + return false; + if (!this->emitInitElem(T, 0, E)) + return false; + if (!visit(SubExpr)) + return false; + return this->emitInitElemPop(T, 1, E); +} + template bool ByteCodeExprGen::VisitParenExpr(const ParenExpr *PE) { return this->Visit(PE->getSubExpr()); @@ -149,6 +477,9 @@ // Deal with operations which have composite or void types. switch (BO->getOpcode()) { + case BO_PtrMemD: + case BO_PtrMemI: + return visitIndirectMember(BO); case BO_Comma: if (!discard(LHS)) return false; @@ -167,6 +498,79 @@ } if (Optional T = classify(BO->getType())) { + switch (BO->getOpcode()) { + case BO_Assign: + return visitAssign(*T, BO); + + case BO_AddAssign: + if (LT == PT_Ptr) + return visitPtrAssign(*RT, &ByteCodeExprGen::emitAddOffset, BO); + else + return visitCompoundAssign(*RT, &ByteCodeExprGen::emitAdd, BO); + case BO_SubAssign: + if (LT == PT_Ptr) + return visitPtrAssign(*RT, &ByteCodeExprGen::emitSubOffset, BO); + else + return visitCompoundAssign(*RT, &ByteCodeExprGen::emitSub, BO); + + case BO_MulAssign: + return visitCompoundAssign(*RT, &ByteCodeExprGen::emitMul, BO); + case BO_DivAssign: + return visitCompoundAssign(*RT, &ByteCodeExprGen::emitDiv, BO); + case BO_RemAssign: + return visitCompoundAssign(*RT, &ByteCodeExprGen::emitRem, BO); + case BO_AndAssign: + return visitCompoundAssign(*RT, &ByteCodeExprGen::emitAnd, BO); + case BO_XorAssign: + return visitCompoundAssign(*RT, &ByteCodeExprGen::emitXor, BO); + case BO_OrAssign: + return visitCompoundAssign(*RT, &ByteCodeExprGen::emitOr, BO); + + case BO_ShlAssign: + if (Ctx.getLangOpts().OpenCL) + return this->bail(BO); + return visitShiftAssign(*RT, &ByteCodeExprGen::emitShl, BO); + + case BO_ShrAssign: + if (Ctx.getLangOpts().OpenCL) + return this->bail(BO); + return visitShiftAssign(*RT, &ByteCodeExprGen::emitShr, BO); + + case BO_Add: { + if (*LT == PT_Ptr && *RT != PT_Ptr) + return visitOffset(LHS, RHS, BO, &ByteCodeExprGen::emitAddOffset); + if (*LT != PT_Ptr && *RT == PT_Ptr) + return visitOffset(RHS, LHS, BO, &ByteCodeExprGen::emitAddOffset); + break; + } + case BO_Sub: { + if (*LT == PT_Ptr && *RT == PT_Ptr) { + if (!visit(LHS)) + return false; + if (!visit(RHS)) + return false; + if (isFixedIntegral(*T)) { + if (!this->emitPtrDiffFP(*T, getIntWidth(BO->getType()), BO)) + return false; + } else { + if (!this->emitPtrDiff(*T, BO)) + return false; + } + return DiscardResult ? this->emitPop(*T, BO) : true; + } + if (*LT == PT_Ptr && *RT != PT_Ptr) + return visitOffset(LHS, RHS, BO, &ByteCodeExprGen::emitSubOffset); + break; + } + + case BO_LOr: + case BO_LAnd: + return visitShortCircuit(BO); + + default: + break; + } + if (!visit(LHS)) return false; if (!visit(RHS)) @@ -191,12 +595,30 @@ return Discard(this->emitGT(*LT, BO)); case BO_GE: return Discard(this->emitGE(*LT, BO)); + case BO_Or: + return Discard(this->emitOr(*T, BO)); + case BO_And: + return Discard(this->emitAnd(*T, BO)); + case BO_Xor: + return Discard(this->emitXor(*T, BO)); + case BO_Rem: + return Discard(this->emitRem(*T, BO)); + case BO_Div: + return Discard(this->emitDiv(*T, BO)); case BO_Sub: return Discard(this->emitSub(*T, BO)); case BO_Add: return Discard(this->emitAdd(*T, BO)); case BO_Mul: return Discard(this->emitMul(*T, BO)); + case BO_Shr: + if (Ctx.getLangOpts().OpenCL) + return this->bail(BO); + return Discard(this->emitShr(*LT, *RT, BO)); + case BO_Shl: + if (Ctx.getLangOpts().OpenCL) + return this->bail(BO); + return Discard(this->emitShl(*LT, *RT, BO)); default: return this->bail(BO); } @@ -206,51 +628,1187 @@ } template -bool ByteCodeExprGen::discard(const Expr *E) { - OptionScope Scope(this, /*discardResult=*/true); - return this->Visit(E); +bool ByteCodeExprGen::VisitUnaryPlus(const UnaryOperator *UO) { + return this->Visit(UO->getSubExpr()); +} + +template +bool ByteCodeExprGen::VisitUnaryMinus(const UnaryOperator *UM) { + if (!visit(UM->getSubExpr())) + return false; + + if (Optional T = classify(UM->getType())) { + if (!this->emitMinus(*T, UM)) + return false; + return DiscardResult ? this->emitPop(*T, UM) : true; + } + + return this->bail(UM); +} + +template +bool ByteCodeExprGen::VisitUnaryDeref(const UnaryOperator *E) { + if (!this->Visit(E->getSubExpr())) + return false; + if (DiscardResult) + return true; + if (needsAdjust(E->getType())) + return this->emitNarrowPtr(E); + return true; +} + +template +bool ByteCodeExprGen::VisitUnaryAddrOf(const UnaryOperator *E) { + if (!this->Visit(E->getSubExpr())) + return false; + if (DiscardResult) + return true; + if (needsAdjust(E->getType())) + return this->emitExpandPtr(E); + return true; +} + +template +bool ByteCodeExprGen::VisitUnaryPostInc(const UnaryOperator *E) { + if (!visit(E->getSubExpr())) + return false; + PrimType T = *classify(E->getType()); + if (!this->emitPostInc(T, E)) + return false; + return DiscardResult ? this->emitPop(T, E) : true; +} + +template +bool ByteCodeExprGen::VisitUnaryPostDec(const UnaryOperator *E) { + if (!visit(E->getSubExpr())) + return false; + PrimType T = *classify(E->getType()); + if (!this->emitPostDec(T, E)) + return false; + return DiscardResult ? this->emitPop(T, E) : true; +} + +template +bool ByteCodeExprGen::VisitUnaryPreInc(const UnaryOperator *E) { + if (!visit(E->getSubExpr())) + return false; + PrimType T = *classify(E->getType()); + if (!this->emitPreInc(T, E)) + return false; + return DiscardResult ? this->emitPopPtr(E) : true; +} + +template +bool ByteCodeExprGen::VisitUnaryPreDec(const UnaryOperator *E) { + if (!visit(E->getSubExpr())) + return false; + PrimType T = *classify(E->getType()); + if (!this->emitPreDec(T, E)) + return false; + return DiscardResult ? this->emitPopPtr(E) : true; +} + +template +bool ByteCodeExprGen::VisitUnaryNot(const UnaryOperator *UO) { + if (!this->Visit(UO->getSubExpr())) + return false; + PrimType T = *classify(UO->getType()); + return DiscardResult ? true : this->emitNot(T, UO); +} + +template +bool ByteCodeExprGen::VisitUnaryLNot(const UnaryOperator *UO) { + if (!this->Visit(UO->getSubExpr())) + return false; + PrimType T = *classify(UO->getType()); + return DiscardResult ? true : this->emitLogicalNot(T, UO); +} + +template +bool ByteCodeExprGen::VisitUnaryReal(const UnaryOperator *UO) { + const Expr *SubExpr = UO->getSubExpr(); + if (!isa(SubExpr->getType())) + return this->Visit(SubExpr); + if (!visit(SubExpr)) + return false; + if (!this->emitRealElem(UO)) + return false; + return DiscardResult ? this->emitPop(PT_Ptr, UO) : true; +} + +template +bool ByteCodeExprGen::VisitUnaryImag(const UnaryOperator *UO) { + const Expr *SubExpr = UO->getSubExpr(); + if (!isa(SubExpr->getType())) + return this->Visit(SubExpr); + if (!visit(SubExpr)) + return false; + if (!this->emitImagElem(UO)) + return false; + return DiscardResult ? this->emitPop(PT_Ptr, UO) : true; +} + +template +bool ByteCodeExprGen::VisitConditionalOperator( + const ConditionalOperator *CO) { + return visitConditionalOperator(CO); +} + +template +bool ByteCodeExprGen::VisitBinaryConditionalOperator( + const BinaryConditionalOperator *BO) { + if (!visitOpaqueExpr(BO->getOpaqueValue())) + return false; + return visitConditionalOperator(BO); +} + +template +bool ByteCodeExprGen::visitConditionalOperator( + const AbstractConditionalOperator *CO) { + if (!visitBool(CO->getCond())) + return false; + LabelTy LabelFalse = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + if (!this->jumpFalse(LabelFalse)) + return false; + if (!this->Visit(CO->getTrueExpr())) + return false; + if (!this->jump(LabelEnd)) + return false; + this->emitLabel(LabelFalse); + if (!this->Visit(CO->getFalseExpr())) + return false; + if (!this->fallthrough(LabelEnd)) + return false; + return true; +} + +template +bool ByteCodeExprGen::VisitMemberExpr(const MemberExpr *ME) { + // Fetch a pointer to the required field. + if (auto *FD = dyn_cast(ME->getMemberDecl())) { + return withField(FD, [this, FD, ME](const Record::Field *F) { + const bool IsReference = FD->getType()->isReferenceType(); + if (isa(ME->getBase())) { + if (IsReference) { + if (!this->emitGetThisFieldPtr(F->Offset, ME)) + return false; + } else { + if (!this->emitGetPtrThisField(F->Offset, ME)) + return false; + } + } else { + if (!visit(ME->getBase())) + return false; + if (IsReference) { + if (!this->emitGetFieldPopPtr(F->Offset, ME)) + return false; + } else { + if (!this->emitGetPtrField(F->Offset, ME)) + return false; + } + } + return DiscardResult ? this->emitPopPtr(ME) : true; + }); + } + + // Emit the enum constant value of the enum field. + if (auto *ED = dyn_cast(ME->getMemberDecl())) + return this->Visit(ED->getInitExpr()); + + // Pointer to static field. + if (auto *VD = dyn_cast(ME->getMemberDecl())) + return DiscardResult ? true : getPtrVarDecl(VD, ME); + + // Pointer to static method or method. + if (auto *MD = dyn_cast(ME->getMemberDecl())) { + assert(MD->isStatic() && "Method is not static"); + return DiscardResult ? true : getPtrConstFn(MD, ME); + } + + llvm_unreachable("Invalid member field"); +} + +template +bool ByteCodeExprGen::VisitCallExpr(const CallExpr *CE) { + // Emit the pointer to build the return value into. + if (InitFn && !emitInitFn()) + return false; + + auto Args = llvm::makeArrayRef(CE->getArgs(), CE->getNumArgs()); + auto T = classify(CE->getCallReturnType(Ctx.getASTContext())); + + // Emit the call. + if (unsigned BuiltinOp = CE->getBuiltinCallee()) { + // If the callee is a builtin, lower args and invoke. + if (!visitArguments(CE->getCallee()->getType(), Args)) + return false; + if (!this->emitBuiltin(BuiltinOp, CE)) + return false; + return DiscardResult && T ? this->emitPop(*T, CE) : true; + } else if (const FunctionDecl *Callee = CE->getDirectCallee()) { + // Emit a direct call if the callee is known. + if (isa(Callee) && !Callee->isStatic()) { + if (!visitArguments(CE->getCallee()->getType(), Args.slice(1))) + return false; + if (Args.size() < 1 || !visit(Args[0])) + return false; + return emitMethodCall(dyn_cast(Callee), T, CE); + } else { + // Lower arguments. + if (!visitArguments(CE->getCallee()->getType(), Args)) + return false; + return emitFunctionCall(Callee, T, CE); + } + } else { + // Function pointer call. + if (!visitArguments(CE->getCallee()->getType(), Args)) + return false; + if (!visit(CE->getCallee())) + return false; + if (!this->emitIndirectCall(CE)) + return false; + return DiscardResult && T ? this->emitPop(*T, CE) : true; + } +} + +template +bool ByteCodeExprGen::VisitCXXConstructExpr( + const CXXConstructExpr *E) { + CXXConstructorDecl *CD = E->getConstructor(); + + // Helper to invoke the constructor to initialize the pointer in InitFn. + auto InvokeConstructor = [this, E, CD] { + auto Args = llvm::makeArrayRef(E->getArgs(), E->getNumArgs()); + if (E->requiresZeroInitialization() && CD->isTrivial()) { + // Do not invoke constructor, fill in with 0. + auto *RT = E->getType()->getAs(); + return visitZeroInitializer(E, RT->getDecl()); + } else { + // Avoid materialization if elidable. + if (E->isElidable()) { + if (auto *ME = dyn_cast(E->getArg(0))) + return this->Visit(ME->GetTemporaryExpr()); + } + + // Invoke constructor with 'this' as the pointer. + if (!visitArguments(E->getConstructor()->getType(), Args)) + return false; + if (!emitInitFn()) + return false; + return emitMethodCall(E->getConstructor(), {}, E); + } + }; + + // Invoke the constructor on each array element or on the single record. + if (auto *CAT = dyn_cast(E->getType())) { + uint64_t NumElems = CAT->getSize().getZExtValue(); + for (unsigned I = 0; I < NumElems; ++I) { + FieldScope Scope(this, [this, I, E](InitFnRef Base) { + if (!Base()) + return false; + if (!this->emitConstUint32(I, E)) + return false; + if (!this->emitAddOffsetUint32(E)) + return false; + return this->emitNarrowPtr(E); + }); + if (!InvokeConstructor()) + return false; + } + return true; + } + + return InvokeConstructor(); +} + +template +bool ByteCodeExprGen::VisitCXXMemberCallExpr( + const CXXMemberCallExpr *CE) { + // Emit the pointer to build the return value into. + if (InitFn && !emitInitFn()) + return false; + + // Emit the arguments, this pointer and the call. + auto T = classify(CE->getType()); + auto Args = llvm::makeArrayRef(CE->getArgs(), CE->getNumArgs()); + + // Identify member calls with direct targets, invoking them. + if (auto *MD = dyn_cast_or_null(CE->getDirectCallee())) { + if (!visitArguments(MD->getType(), Args)) + return false; + + // Lower the 'this' pointer. + if (!visit(CE->getImplicitObjectArgument())) + return false; + + // Direct member call. + if (!MD->isVirtual()) + return emitMethodCall(MD, T, CE); + + // Indirect virtual call. + if (!this->emitVirtualInvoke(MD, CE)) + return false; + return DiscardResult && T ? this->emitPop(*T, CE) : true; + } + + // Pattern match the callee for indirect calls to avoid materializing + // the bound member function pointer on the stack. + const Expr *Callee = CE->getCallee(); + if (Callee->getType()->isSpecificBuiltinType(BuiltinType::BoundMember)) { + auto *BO = dyn_cast(Callee->IgnoreParens()); + switch (BO->getOpcode()) { + case BO_PtrMemD: + case BO_PtrMemI: + if (!visitArguments(BO->getRHS()->getType(), Args)) + return false; + // Emit 'this' pointer. + if (!visit(BO->getLHS())) + return false; + // Emit member function. + if (!visit(BO->getRHS())) + return false; + // Emit invocation. + if (!this->emitIndirectInvoke(CE)) + return false; + return DiscardResult && T ? this->emitPop(*T, CE) : true; + default: + llvm_unreachable("invalid indirect callee"); + } + } + + llvm_unreachable("invalid member call expression callee"); +} + +template +bool ByteCodeExprGen::emitFunctionCall(const FunctionDecl *Callee, + Optional T, + const Expr *E) { + if (Expected Func = P.getOrCreateFunction(Callee)) { + if (*Func) { + if (!this->emitCall(*Func, E)) + return false; + } else { + if (!this->emitNoCall(Callee, E)) + return false; + } + } else { + consumeError(Func.takeError()); + return this->bail(E); + } + return DiscardResult && T ? this->emitPop(*T, E) : true; +} + +template +bool ByteCodeExprGen::emitMethodCall(const CXXMethodDecl *Callee, + Optional T, + const Expr *E) { + if (Expected Func = P.getOrCreateFunction(Callee)) { + if (*Func) { + if (!this->emitInvoke(*Func, E)) + return false; + } else { + if (!this->emitNoInvoke(Callee, E)) + return false; + } + } else { + consumeError(Func.takeError()); + return this->bail(E); + } + return DiscardResult && T ? this->emitPop(*T, E) : true; +} + +template +bool ByteCodeExprGen::VisitArraySubscriptExpr( + const ArraySubscriptExpr *E) { + if (!visit(E->getBase())) + return false; + auto *Idx = E->getIdx(); + if (!visit(Idx)) + return false; + if (!this->emitAddOffset(*classify(Idx->getType()), E)) + return false; + return DiscardResult ? this->emitPopPtr(E) : this->emitNarrowPtr(E); +} + +template +bool ByteCodeExprGen::VisitExprWithCleanups( + const ExprWithCleanups *CE) { + return this->Visit(CE->getSubExpr()); +} + +template +bool ByteCodeExprGen::VisitCompoundLiteralExpr( + const CompoundLiteralExpr *E) { + return materialize(E, E->getInitializer(), + /*isConst=*/E->getType().isConstQualified(), + /*isGlobal=*/E->isFileScope(), + /*isExtended=*/false); +} + +template +bool ByteCodeExprGen::VisitMaterializeTemporaryExpr( + const MaterializeTemporaryExpr *ME) { + const bool IsConst = ME->getType().isConstQualified(); + const bool IsGlobal = ME->getStorageDuration() == SD_Static; + const bool IsExtended = ME->getStorageDuration() == SD_Automatic; + return materialize(ME, ME->GetTemporaryExpr(), IsConst, IsGlobal, IsExtended); +} + +template +bool ByteCodeExprGen::materialize(const Expr *Alloc, const Expr *Init, + bool IsConst, bool IsGlobal, + bool IsExtended) { + if (IsGlobal) { + if (auto I = P.createGlobal(Alloc)) { + if (Optional T = classify(Init->getType())) { + // Primitive global - compute and set. + if (!visit(Init)) + return false; + if (!this->emitInitGlobal(*T, *I, Alloc)) + return false; + return DiscardResult ? true : this->emitGetPtrGlobal(*I, Alloc); + } else { + // Composite global - initialize in place. + if (!visitGlobalInitializer(Init, *I)) + return false; + return DiscardResult ? true : this->emitGetPtrGlobal(*I, Alloc); + } + } + } else { + if (Optional T = classify(Init->getType())) { + auto I = allocateLocalPrimitive(Init, *T, IsConst, IsExtended); + if (!visit(Init)) + return false; + if (!this->emitSetLocal(*T, I, Init)) + return false; + return DiscardResult ? true : this->emitGetPtrLocal(I, Init); + } else { + // Composite types - allocate storage and initialize it. + // This operation leaves a pointer to the temporary on the stack. + if (auto I = allocateLocal(Alloc, IsExtended)) { + if (!visitLocalInitializer(Init, *I)) + return false; + return DiscardResult ? true : this->emitGetPtrLocal(*I, Alloc); + } + } + } + return this->bail(Alloc); +} + +template +CharUnits ByteCodeExprGen::getAlignOfType(QualType T, + bool IsPreferred) { + // C++ [expr.alignof]p3: + // When alignof is applied to a reference type, the result is the + // alignment of the referenced type. + if (const ReferenceType *Ref = T->getAs()) + T = Ref->getPointeeType(); + + if (T.getQualifiers().hasUnaligned()) + return CharUnits::One(); + + const auto &AST = Ctx.getASTContext(); + const bool AlignOfReturnsPreferred = + AST.getLangOpts().getClangABICompat() <= LangOptions::ClangABI::Ver7; + + // __alignof is defined to return the preferred alignment. + // Before 8, clang returned the preferred alignment for alignof and _Alignof + // as well. + if (IsPreferred || AlignOfReturnsPreferred) + return AST.toCharUnitsFromBits(AST.getPreferredTypeAlign(T.getTypePtr())); + // alignof and _Alignof are defined to return the ABI alignment. + return AST.getTypeAlignInChars(T.getTypePtr()); +} + +template +CharUnits ByteCodeExprGen::getAlignOfExpr(const Expr *E, + bool IsPreferred) { + E = E->IgnoreParens(); + + // The kinds of expressions that we have special-case logic here for + // should be kept up to date with the special checks for those + // expressions in Sema. + + // alignof decl is always accepted, even if it doesn't make sense: we default + // to 1 in those cases. + if (const DeclRefExpr *DRE = dyn_cast(E)) + return Ctx.getASTContext().getDeclAlign(DRE->getDecl(), + /*RefAsPointee*/ true); + + if (const MemberExpr *ME = dyn_cast(E)) + return Ctx.getASTContext().getDeclAlign(ME->getMemberDecl(), + /*RefAsPointee*/ true); + + return getAlignOfType(E->getType(), IsPreferred); +} + +template +bool ByteCodeExprGen::VisitImplicitValueInitExpr( + const ImplicitValueInitExpr *E) { + + QualType Ty = E->getType(); + if (auto *AT = Ty->getAs()) + Ty = AT->getValueType(); + + if (auto T = Ctx.classify(Ty)) + return DiscardResult ? true : visitZeroInitializer(*T, E); + + if (auto *RT = Ty->getAs()) + return visitZeroInitializer(E, RT->getDecl()); + + if (auto *CAT = dyn_cast(Ty)) { + ImplicitValueInitExpr E(CAT->getElementType()); + return visitArrayInitializer(CAT, &E, [&E](unsigned) { return &E; }); + } + + return false; +} + +template +bool ByteCodeExprGen::VisitSubstNonTypeTemplateParmExpr( + const SubstNonTypeTemplateParmExpr *E) { + return this->Visit(E->getReplacement()); +} + +template +bool ByteCodeExprGen::VisitUnaryExprOrTypeTraitExpr( + const UnaryExprOrTypeTraitExpr *TE) { + switch (TE->getKind()) { + case UETT_PreferredAlignOf: + case UETT_AlignOf: { + const bool IsPreferred = TE->getKind() == UETT_PreferredAlignOf; + CharUnits Align; + if (TE->isArgumentType()) + Align = getAlignOfType(TE->getArgumentType(), IsPreferred); + else + Align = getAlignOfExpr(TE->getArgumentExpr(), IsPreferred); + return DiscardResult ? true : emitConst(TE, Align.getQuantity()); + } + + case UETT_VecStep: { + return this->bail(TE); + } + + case UETT_SizeOf: { + QualType ArgTy = TE->getTypeOfArgument(); + + // C++ [expr.sizeof]p2: "When applied to a reference or a reference type, + // the result is the size of the referenced type." + if (auto *Ref = ArgTy->getAs()) + ArgTy = Ref->getPointeeType(); + + CharUnits Size; + if (ArgTy->isVoidType() || ArgTy->isFunctionType()) { + Size = CharUnits::One(); + } else if (ArgTy->isDependentType()) { + return this->emitTrap(TE); + } else if (!ArgTy->isConstantSizeType()) { + // sizeof(vla) is not a constantexpr: C99 6.5.3.4p2. + // FIXME: Better diagnostic. + return this->emitTrap(TE); + } else { + Size = Ctx.getASTContext().getTypeSizeInChars(ArgTy); + } + + // Emit the size as an integer. + return DiscardResult ? true : emitConst(TE, Size.getQuantity()); + } + case UETT_OpenMPRequiredSimdAlign: { + return this->bail(TE); + } + } + + llvm_unreachable("unknown expr/type trait"); +} + +template +bool ByteCodeExprGen::VisitSizeOfPackExpr(const SizeOfPackExpr *E) { + if (DiscardResult) + return true; + return emitConst(E, E->getPackLength()); +} + +template +bool ByteCodeExprGen::VisitOpaqueValueExpr(const OpaqueValueExpr *E) { + // The opaque expression was evaluated earlier and the pointer was cached in + // a local variable. Load the pointer at this point. + PrimType Ty; + if (Optional T = classify(E->getType())) { + Ty = *T; + } else { + Ty = PT_Ptr; + } + auto It = OpaqueExprs.find(E); + if (It != OpaqueExprs.end()) { + return DiscardResult ? true : this->emitGetLocal(Ty, It->second, E); + } else { + return this->emitTrap(E); + } +} + +template +bool ByteCodeExprGen::VisitArrayInitLoopExpr( + const ArrayInitLoopExpr *E) { + // Evaluate the common expression, a pointer to the array copied from. + if (auto *C = E->getCommonExpr()) { + if (!visitOpaqueExpr(C)) + return false; + } + + // Initialise each element of the array. In this scope, ArrayIndex is set to + // refer to the index of the element being initialised in the callback which + // returns the initialiser of that element. + auto OldArrayIndex = ArrayIndex; + + auto ElemInit = [this, E](uint64_t I) { + ArrayIndex = I; + return E->getSubExpr(); + }; + + auto *AT = E->getType()->getAsArrayTypeUnsafe(); + if (!visitArrayInitializer(cast(AT), E, ElemInit)) + return false; + + ArrayIndex = OldArrayIndex; + return true; +} + +template +bool ByteCodeExprGen::VisitArrayInitIndexExpr( + const ArrayInitIndexExpr *E) { + assert(!DiscardResult && "ArrayInitIndexExpr should not be discarded"); + return ArrayIndex ? this->emitConst(E, *ArrayIndex) : this->emitTrap(E); +} + +template +bool ByteCodeExprGen::VisitInitListExpr(const InitListExpr *E) { + // Initialise a scalar with a given value or zero. + if (auto T = Ctx.classify(E->getType())) { + if (E->getNumInits() == 0) + return DiscardResult ? true : visitZeroInitializer(*T, E); + if (E->getNumInits() == 1) + return this->Visit(E->getInit(0)); + return false; + } + + // Desugar the type and decide which initializer to call based on it. + QualType Ty = E->getType(); + if (auto *AT = dyn_cast(Ty)) + Ty = AT->getValueType(); + + if (auto *RT = Ty->getAs()) + return visitRecordInitializer(RT, E); + + if (auto *CT = Ty->getAs()) + return visitComplexInitializer(CT, E); + + if (E->isStringLiteralInit()) { + auto *S = cast(E->getInit(0)->IgnoreParens()); + return visitStringInitializer(S); + } + + if (auto CT = Ty->getAsArrayTypeUnsafe()) { + if (auto CAT = dyn_cast(CT)) { + return visitArrayInitializer(CAT, E, [E](uint64_t I) { + if (I < E->getNumInits()) + return E->getInit(I); + else + return E->getArrayFiller(); + }); + } + } + + return this->bail(E); +} + +template +bool ByteCodeExprGen::VisitTypeTraitExpr(const TypeTraitExpr *E) { + return DiscardResult ? true : this->emitConstBool(E->getValue(), E); +} + +template +bool ByteCodeExprGen::VisitBlockExpr(const BlockExpr *E) { + if (E->getBlockDecl()->hasCaptures()) + return this->emitTrap(E); + // FIXME: blocks with no captures. + return this->bail(E); +} + +template +bool ByteCodeExprGen::VisitCXXNullPtrLiteralExpr( + const CXXNullPtrLiteralExpr *E) { + return DiscardResult ? true : this->emitNullPtr(E); +} + +template +bool ByteCodeExprGen::VisitCXXBoolLiteralExpr( + const CXXBoolLiteralExpr *B) { + return DiscardResult ? true : this->emitConstBool(B->getValue(), B); +} + +template +bool ByteCodeExprGen::VisitCXXThisExpr(const CXXThisExpr *TE) { + if (!this->emitThis(TE)) + return false; + return DiscardResult ? this->emitPopPtr(TE) : true; +} + +template +bool ByteCodeExprGen::VisitCXXDefaultArgExpr( + const CXXDefaultArgExpr *DE) { + return this->Visit(DE->getExpr()); +} + +template +bool ByteCodeExprGen::VisitCXXDefaultInitExpr( + const CXXDefaultInitExpr *DE) { + if (auto *E = DE->getExpr()) + return this->Visit(E); + return this->emitTrap(DE); +} + +template +bool ByteCodeExprGen::VisitCXXScalarValueInitExpr( + const CXXScalarValueInitExpr *E) { + QualType Ty = E->getType(); + if (auto T = Ctx.classify(Ty)) + return DiscardResult ? true : visitZeroInitializer(*T, E); + + // Complex numbers are not really scalar. + if (auto *CT = Ty->getAs()) { + QualType ElemType = CT->getElementType(); + PrimType T = classifyPrim(ElemType); + ImplicitValueInitExpr ZE(ElemType); + if (!emitInitFn()) + return false; + if (!visit(&ZE)) + return false; + if (!this->emitInitElem(T, 0, E)) + return false; + if (!visit(&ZE)) + return false; + return this->emitInitElemPop(T, 1, E); + } + return false; +} + +template +bool ByteCodeExprGen::VisitCXXThrowExpr(const CXXThrowExpr *E) { + return this->emitTrap(E); +} + +template +bool ByteCodeExprGen::VisitCXXTemporaryObjectExpr( + const CXXTemporaryObjectExpr *E) { + // In-place construction, simply invoke constructor. + if (InitFn) + return VisitCXXConstructExpr(cast(E)); + + // Allocate a temporary and build into it. + if (auto I = allocateLocal(E, /*isExtended=*/false)) { + OptionScope Scope(this, InitFnRef{[this, &I, E] { + return this->emitGetPtrLocal(*I, E); + }}); + if (!VisitCXXConstructExpr(cast(E))) + return false; + return DiscardResult ? true : this->emitGetPtrLocal(*I, E); + } + return this->bail(E); +} + +template +bool ByteCodeExprGen::VisitCXXTypeidExpr(const CXXTypeidExpr *E) { + // TODO: implement typeid. + return this->bail(E); +} + +template +bool ByteCodeExprGen::VisitCXXReinterpretCastExpr( + const CXXReinterpretCastExpr *E) { + if (!this->emitReinterpretCastWarning(E)) + return false; + return VisitCastExpr(E); +} + +template +bool ByteCodeExprGen::VisitCXXDynamicCastExpr( + const CXXDynamicCastExpr *E) { + if (!Ctx.getLangOpts().CPlusPlus2a) { + if (!this->emitDynamicCastWarning(E)) + return false; + } + return VisitCastExpr(E); +} + +template +bool ByteCodeExprGen::VisitCXXInheritedCtorInitExpr( + const CXXInheritedCtorInitExpr *E) { + auto *Ctor = E->getConstructor(); + + // Forward all parameters. + unsigned Offset = 0; + for (auto *PD : Ctor->parameters()) { + PrimType Ty; + if (Optional T = classify(PD->getType())) { + Ty = *T; + } else { + Ty = PT_Ptr; + } + if (!this->emitGetParam(Ty, Offset, E)) + return false; + Offset += align(primSize(Ty)); + } + + // Invoke the constructor on the same 'this'. + return emitInitFn() && emitMethodCall(Ctor, {}, E); +} + +template +bool ByteCodeExprGen::VisitCXXStdInitializerListExpr( + const CXXStdInitializerListExpr *E) { + auto *AT = getAsConstantArrayType(E->getSubExpr()->getType()); + + /// Push the pointer to the std::initializer_list. + if (!this->emitInitFn()) + return false; + + + Record *R = getRecord(E->getType()); + + // std::initializer_list should not have base classes. + if (!R || R->getNumBases() != 0 || R->getNumVirtualBases() != 0) + return this->emitTrap(E); + + // pointer + pointer or pointer + length + if (R->getNumFields() != 2) + return this->emitTrap(E); + + // The first field must be a pointer. + Record::Field *Fst = R->getField(0); + if (Fst->Desc->ElemTy != PT_Ptr) + return this->emitTrap(E); + + // Set the first field - either a pointer or an integer. + Record::Field *Snd = R->getField(1); + if (Optional T = Snd->Desc->ElemTy) { + QualType Ty = Snd->Decl->getType(); + switch (*T) { + case PT_Sint8: + case PT_Uint8: + case PT_Sint16: + case PT_Uint16: + case PT_Sint32: + case PT_Uint32: + case PT_Sint64: + case PT_Uint64: + case PT_SintFP: + case PT_UintFP: + case PT_Bool: + /// Lower the array. + if (!visit(E->getSubExpr())) + return false; + // Set the first field. + if (!this->emitInitFieldPtr(Fst->Offset, E)) + return false; + // Set the length. + if (!this->emitConst(*T, getIntWidth(Ty), AT->getSize(), E)) + return false; + if (DiscardResult) + return this->emitInitFieldPop(*T, Snd->Offset, E); + else + return this->emitInitField(*T, Snd->Offset, E); + case PT_Ptr: + /// Lower the array. + if (!visit(E->getSubExpr())) + return false; + if (!this->emitInitFieldPeekPtr(Fst->Offset, E)) + return false; + // Set the end pointer. + if (!this->emitConstUint64(AT->getSize().getZExtValue(), E)) + return false; + if (!this->emitAddOffsetUint64(E)) + return false; + if (DiscardResult) + return this->emitInitFieldPopPtr(Snd->Offset, E); + else + return this->emitInitFieldPtr(Snd->Offset, E); + default: + return this->emitTrap(E); + } + } + return this->emitTrap(E); +} + +template +bool ByteCodeExprGen::discard(const Expr *E) { + OptionScope Scope(this, /*discardResult=*/true); + return this->Visit(E); +} + +template +bool ByteCodeExprGen::visit(const Expr *E) { + OptionScope Scope(this, /*discardResult=*/false); + return this->Visit(E); +} + +template +bool ByteCodeExprGen::visitBool(const Expr *E) { + if (Optional T = classify(E->getType())) { + if (!visit(E)) + return false; + return (*T != PT_Bool) ? this->emitTest(*T, E) : true; + } else { + return this->bail(E); + } +} + +template +bool ByteCodeExprGen::visitZeroInitializer(PrimType T, const Expr *E) { + switch (T) { + case PT_Bool: + return this->emitZeroBool(E); + case PT_Sint8: + return this->emitZeroSint8(E); + case PT_Uint8: + return this->emitZeroUint8(E); + case PT_Sint16: + return this->emitZeroSint16(E); + case PT_Uint16: + return this->emitZeroUint16(E); + case PT_Sint32: + return this->emitZeroSint32(E); + case PT_Uint32: + return this->emitZeroUint32(E); + case PT_Sint64: + return this->emitZeroSint64(E); + case PT_Uint64: + return this->emitZeroUint64(E); + case PT_UintFP: + return this->emitZeroFPUintFP(getIntWidth(E->getType()), E); + case PT_SintFP: + return this->emitZeroFPSintFP(getIntWidth(E->getType()), E); + case PT_RealFP: + return this->emitZeroRealFP(getFltSemantics(E->getType()), E); + case PT_Ptr: + return this->emitNullPtr(E); + case PT_FnPtr: + return this->emitNullFnPtr(E); + case PT_MemPtr: + return this->emitNullMemPtr(E); + } + return false; +} + +template +bool ByteCodeExprGen::visitOffset(const Expr *Ptr, const Expr *Offset, + const Expr *E, UnaryFn OffsetFn) { + if (!visit(Ptr)) + return false; + if (!visit(Offset)) + return false; + if (!(this->*OffsetFn)(*classify(Offset->getType()), E)) + return false; + return DiscardResult ? this->emitPopPtr(E) : true; +} + +template +bool ByteCodeExprGen::visitIndirectMember(const BinaryOperator *E) { + if (!visit(E->getLHS())) + return false; + if (!visit(E->getRHS())) + return false; + + if (!this->emitGetPtrFieldIndirect(E)) + return false; + if (DiscardResult && !this->emitPopPtr(E)) + return false; + return true; +} + +template +bool ByteCodeExprGen::visitAssign(PrimType T, + const BinaryOperator *BO) { + return dereference( + BO->getLHS(), DerefKind::Write, + [this, BO](PrimType) { + // Generate a value to store - will be set. + return visit(BO->getRHS()); + }, + [this, BO](PrimType T) { + // Pointer on stack - compile RHS and assign to pointer. + if (!visit(BO->getRHS())) + return false; + + if (BO->getLHS()->refersToBitField()) { + if (DiscardResult) + return this->emitStoreBitFieldPop(T, BO); + else + return this->emitStoreBitField(T, BO); + } else { + if (DiscardResult) + return this->emitStorePop(T, BO); + else + return this->emitStore(T, BO); + } + }); +} + +template +bool ByteCodeExprGen::visitShiftAssign(PrimType RHS, BinaryFn F, + const BinaryOperator *BO) { + return dereference( + BO->getLHS(), DerefKind::ReadWrite, + [this, F, RHS, BO](PrimType LHS) { + if (!visit(BO->getRHS())) + return false; + (this->*F)(LHS, RHS, BO); + return true; + }, + [this, F, RHS, BO](PrimType LHS) { + if (!this->emitLoad(LHS, BO)) + return false; + if (!visit(BO->getRHS())) + return false; + if (!(this->*F)(LHS, RHS, BO)) + return false; + + if (BO->getLHS()->refersToBitField()) { + if (DiscardResult) + return this->emitStoreBitFieldPop(LHS, BO); + else + return this->emitStoreBitField(LHS, BO); + } else { + if (DiscardResult) + return this->emitStorePop(LHS, BO); + else + return this->emitStore(LHS, BO); + } + }); +} + +template +bool ByteCodeExprGen::visitCompoundAssign(PrimType RHS, UnaryFn F, + const BinaryOperator *BO) { + auto ApplyOperation = [this, F, BO, RHS](PrimType LHS) { + auto *ExprRHS = BO->getRHS(); + auto *ExprLHS = BO->getLHS(); + if (ExprLHS->getType() == ExprRHS->getType()) { + if (!visit(BO->getRHS())) + return false; + if (!(this->*F)(LHS, BO)) + return false; + } else { + if (!emitConv(LHS, ExprLHS->getType(), RHS, ExprRHS->getType(), ExprRHS)) + return false; + if (!visit(ExprRHS)) + return false; + if (!(this->*F)(RHS, BO)) + return false; + if (!emitConv(RHS, ExprRHS->getType(), LHS, ExprLHS->getType(), ExprLHS)) + return false; + } + return true; + }; + + return dereference( + BO->getLHS(), DerefKind::ReadWrite, ApplyOperation, + [this, BO, &ApplyOperation](PrimType LHS) -> bool { + if (!this->emitLoad(LHS, BO)) + return false; + if (!ApplyOperation(LHS)) + return false; + + if (BO->getLHS()->refersToBitField()) { + if (DiscardResult) + return this->emitStoreBitFieldPop(LHS, BO); + else + return this->emitStoreBitField(LHS, BO); + } else { + if (DiscardResult) + return this->emitStorePop(LHS, BO); + else + return this->emitStore(LHS, BO); + } + }); } template -bool ByteCodeExprGen::visit(const Expr *E) { - OptionScope Scope(this, /*discardResult=*/false); - return this->Visit(E); +bool ByteCodeExprGen::visitPtrAssign(PrimType RHS, UnaryFn F, + const BinaryOperator *BO) { + return dereference( + BO->getLHS(), DerefKind::ReadWrite, + [this, F, BO, RHS](PrimType LHS) { + if (!visit(BO->getRHS())) + return false; + if (!(this->*F)(RHS, BO)) + return false; + return true; + }, + [this, BO, F, RHS](PrimType LHS) -> bool { + if (!this->emitLoad(LHS, BO)) + return false; + if (!visit(BO->getRHS())) + return false; + if (!(this->*F)(RHS, BO)) + return false; + + if (BO->getLHS()->refersToBitField()) + return this->bail(BO); + + if (DiscardResult) + return this->emitStorePop(LHS, BO); + else + return this->emitStore(LHS, BO); + }); } template -bool ByteCodeExprGen::visitBool(const Expr *E) { - if (Optional T = classify(E->getType())) { - return visit(E); - } else { - return this->bail(E); +bool ByteCodeExprGen::emitConv(PrimType LHS, QualType TyLHS, + PrimType RHS, QualType TyRHS, + const Expr *Cast) { + if (isPrimitiveIntegral(RHS)) { + if (isFixedReal(LHS)) { + return this->emitCastRealFPToAlu(LHS, RHS, Cast); + } else { + return this->emitCast(LHS, RHS, Cast); + } + } else if (isFixedIntegral(RHS)) { + if (isFixedReal(LHS)) { + return this->emitCastRealFPToAluFP(LHS, RHS, getIntWidth(TyRHS), Cast); + } else { + return this->emitCastFP(LHS, RHS, getIntWidth(TyRHS), Cast); + } + } else if (isFixedReal(RHS)) { + return this->emitCastRealFP(LHS, getFltSemantics(TyRHS), Cast); } + llvm_unreachable("Invalid conversion"); } template -bool ByteCodeExprGen::visitZeroInitializer(PrimType T, const Expr *E) { - switch (T) { - case PT_Bool: - return this->emitZeroBool(E); - case PT_Sint8: - return this->emitZeroSint8(E); - case PT_Uint8: - return this->emitZeroUint8(E); - case PT_Sint16: - return this->emitZeroSint16(E); - case PT_Uint16: - return this->emitZeroUint16(E); - case PT_Sint32: - return this->emitZeroSint32(E); - case PT_Uint32: - return this->emitZeroUint32(E); - case PT_Sint64: - return this->emitZeroSint64(E); - case PT_Uint64: - return this->emitZeroUint64(E); - case PT_Ptr: - return this->emitNullPtr(E); +bool ByteCodeExprGen::visitShortCircuit(const BinaryOperator *E) { + if (!visitBool(E->getLHS())) + return false; + + LabelTy Short = this->getLabel(); + const bool IsTrue = E->getOpcode() == BO_LAnd; + if (DiscardResult) { + if (!(IsTrue ? this->jumpFalse(Short) : this->jumpTrue(Short))) + return false; + if (!this->discard(E->getRHS())) + return false; + } else { + if (!this->emitDupBool(E)) + return false; + if (!(IsTrue ? this->jumpFalse(Short) : this->jumpTrue(Short))) + return false; + if (!this->emitPopBool(E)) + return false; + if (!visitBool(E->getRHS())) + return false; } - llvm_unreachable("unknown primitive type"); + this->fallthrough(Short); + return true; } template @@ -268,6 +1826,10 @@ return dereferenceVar(LV, *T, VD, AK, Direct, Indirect); } } + if (auto *ME = dyn_cast(LV)) { + if (!ME->getMemberDecl()->getType()->isReferenceType()) + return dereferenceMember(ME, *T, AK, Direct, Indirect); + } } if (!visit(LV)) @@ -390,6 +1952,87 @@ return visit(LV) && Indirect(T); } +template +bool ByteCodeExprGen::dereferenceMember( + const MemberExpr *ME, PrimType T, DerefKind AK, + llvm::function_ref Direct, + llvm::function_ref Indirect) { + // Fetch or read the required field. + if (auto *FD = dyn_cast(ME->getMemberDecl())) { + return withField(FD, [this, ME, AK, T, Direct](const Record::Field *F) { + const unsigned Off = F->Offset; + if (isa(ME->getBase())) { + switch (AK) { + case DerefKind::Read: + if (!this->emitGetThisField(T, Off, ME)) + return false; + return DiscardResult ? this->emitPop(T, ME) : true; + + case DerefKind::Write: + if (!Direct(T)) + return false; + if (!this->emitSetThisField(T, Off, ME)) + return false; + return DiscardResult ? true : this->emitGetPtrThisField(Off, ME); + + case DerefKind::ReadWrite: + if (!this->emitGetThisField(T, Off, ME)) + return false; + if (!Direct(T)) + return false; + if (!this->emitSetThisField(T, Off, ME)) + return false; + return DiscardResult ? true : this->emitGetPtrThisField(Off, ME); + } + } else { + if (!visit(ME->getBase())) + return false; + + switch (AK) { + case DerefKind::Read: + if (!this->emitGetFieldPop(T, Off, ME)) + return false; + return DiscardResult ? this->emitPop(T, ME) : true; + + case DerefKind::Write: + if (!Direct(T)) + return false; + if (!this->emitSetField(T, Off, ME)) + return false; + return DiscardResult ? true : this->emitGetPtrField(Off, ME); + + case DerefKind::ReadWrite: + if (!this->emitGetField(T, Off, ME)) + return false; + if (!Direct(T)) + return false; + if (!this->emitSetField(T, Off, ME)) + return false; + return DiscardResult ? true : this->emitGetPtrField(Off, ME); + } + } + return false; + }); + } + + // Value cannot be produced - try to emit pointer. + return visit(ME) && Indirect(T); +} + +template +bool ByteCodeExprGen::lvalueToRvalue(const Expr *LV, const Expr *E) { + return dereference( + LV, DerefKind::Read, + [](PrimType) { + // Value loaded - nothing to do here. + return true; + }, + [this, E](PrimType T) { + // Pointer on stack - dereference it. + return this->emitLoadPop(T, E); + }); +} + template bool ByteCodeExprGen::emitConst(PrimType T, unsigned NumBits, const APInt &Value, const Expr *E) { @@ -410,9 +2053,16 @@ return this->emitConstSint64(Value.getSExtValue(), E); case PT_Uint64: return this->emitConstUint64(Value.getZExtValue(), E); + case PT_SintFP: + return this->emitConstFPSintFP(NumBits, Value.getSExtValue(), E); + case PT_UintFP: + return this->emitConstFPUintFP(NumBits, Value.getZExtValue(), E); case PT_Bool: return this->emitConstBool(Value.getBoolValue(), E); case PT_Ptr: + case PT_FnPtr: + case PT_RealFP: + case PT_MemPtr: llvm_unreachable("Invalid integral type"); break; } @@ -460,6 +2110,45 @@ return Local.Offset; } +template +bool ByteCodeExprGen::visitArguments(QualType CalleeTy, + ArrayRef Args) { + if (auto *MemberTy = CalleeTy->getAs()) + CalleeTy = MemberTy->getPointeeType(); + if (auto *PtrTy = CalleeTy->getAs()) + CalleeTy = PtrTy->getPointeeType(); + + + if (auto *Ty = CalleeTy->getAs()) { + unsigned NumParams = Ty->getNumParams(); + for (unsigned I = 0, N = Args.size(); I < N; ++I) + if (!visitArgument(Args[I], I >= NumParams)) + return false; + } else { + for (unsigned I = 0, N = Args.size(); I < N; ++I) + if (!visitArgument(Args[I], /*discard=*/true)) + return false; + } + + return true; +} + +template +bool ByteCodeExprGen::visitArgument(const Expr *E, bool Discard) { + // Primitive or pointer argument - leave it on the stack. + if (Optional T = classify(E)) + return Discard ? discard(E) : visit(E); + + // Composite argument - copy construct and push pointer. + if (auto I = allocateLocal(E, /*isExtended=*/false)) { + if (!visitLocalInitializer(E, *I)) + return false; + return Discard ? true : this->emitGetPtrLocal(*I, E); + } + + return this->bail(E); +} + template bool ByteCodeExprGen::visitInitializer( const Expr *Init, InitFnRef InitFn) { @@ -467,6 +2156,27 @@ return this->Visit(Init); } +template +bool ByteCodeExprGen::visitBaseInitializer(const Expr *Init, + InitFnRef GenBase) { + OptionScope Scope(this, GenBase, InitKind::BASE); + return this->Visit(Init); +} + +template +bool ByteCodeExprGen::getPtrConstFn(const FunctionDecl *FD, + const Expr *E) { + if (Expected Func = P.getOrCreateFunction(FD)) { + if (*Func) + return this->emitConstFn(*Func, E); + else + return this->emitConstNoFn(FD, E); + } else { + consumeError(Func.takeError()); + return this->emitConstNoFn(FD, E); + } +} + template bool ByteCodeExprGen::getPtrVarDecl(const VarDecl *VD, const Expr *E) { // Generate a pointer to the local, loading refs. @@ -476,7 +2186,7 @@ else return this->emitGetPtrGlobal(*Idx, E); } - return this->bail(VD); + return false; } template @@ -494,6 +2204,293 @@ return {}; } +template +bool ByteCodeExprGen::visitStringInitializer(const StringLiteral *S) { + // Find the type to represent characters. + const unsigned CharBytes = S->getCharByteWidth(); + PrimType CharType; + switch (CharBytes) { + case 1: + CharType = PT_Sint8; + break; + case 2: + CharType = PT_Uint16; + break; + case 4: + CharType = PT_Uint32; + break; + default: + llvm_unreachable("Unsupported character width!"); + } + + if (!emitInitFn()) + return false; + + // Initialise elements one by one, advancing the pointer. + const unsigned NumBits = CharBytes * getCharBit(); + for (unsigned I = 0, N = S->getLength(); I < N; ++I) { + // Set individual characters here. + APInt CodeUnit(NumBits, S->getCodeUnit(I)); + // Lower the code point. + if (!emitConst(CharType, NumBits, CodeUnit, S)) + return false; + // Set the character. + if (!this->emitInitElem(CharType, I, S)) + return false; + } + + // Set the null terminator. + if (!emitConst(CharType, NumBits, APInt(NumBits, 0), S)) + return false; + if (!this->emitInitElemPop(CharType, S->getLength(), S)) + return false; + + // String was initialised. + return true; +} + +template +bool ByteCodeExprGen::visitRecordInitializer( + const RecordType *RT, const InitListExpr *List) { + auto *R = P.getOrCreateRecord(RT->getDecl()); + if (!R) + return this->bail(List); + + auto InitElem = [this, R](const Record::Field *F, const Expr *E) { + if (Optional T = classify(E)) { + if (!emitInitFn()) + return false; + if (!visit(E)) + return false; + if (R->isUnion()) + return this->emitInitFieldActive(*T, F->Offset, E); + else if (F->Decl->isBitField()) + return this->emitInitBitField(*T, F, E); + else + return this->emitInitFieldPop(*T, F->Offset, E); + } else { + auto FieldFn = [this, F, E](InitFnRef Ptr) { + return Ptr() && this->emitGetPtrField(F->Offset, E); + }; + if (R->isUnion()) { + UnionScope Scope(this, FieldFn); + return this->Visit(E); + } else { + FieldScope Scope(this, FieldFn); + return this->Visit(E); + } + } + }; + + if (R->isUnion()) { + if (auto *InitField = List->getInitializedFieldInUnion()) { + auto *F = R->getField(InitField); + switch (List->getNumInits()) { + case 0: { + // Clang doesn't provide an initializer - need to default init. + ImplicitValueInitExpr Init(InitField->getType()); + if (!InitElem(F, &Init)) + return false; + break; + } + + case 1: { + // Initialise the element. + if (!InitElem(F, List->getInit(0))) + return false; + break; + } + + default: + // Malformed initializer. + return false; + } + } + } else { + for (unsigned I = 0, N = R->getNumBases(); I < N; ++I) { + unsigned Off = R->getBase(I)->Offset; + const Expr *Init = List->getInit(I); + + BaseScope Scope(this, [this, Off, Init] (InitFnRef Ptr) { + return Ptr() && this->emitGetPtrBase(Off, Init); + }); + if (!this->Visit(Init)) + return false; + } + for (unsigned I = 0, N = R->getNumFields(); I < N; ++I) { + const Expr * Init = List->getInit(I + R->getNumBases()); + if (!InitElem(R->getField(I), Init)) + return false; + } + } + + switch (Initialiser) { + case InitKind::ROOT: + // Nothing to initialise. + return true; + case InitKind::BASE: + // Initialise the base class. + return this->emitInitFn() && this->emitInitialise(List); + case InitKind::UNION: + // Activate the union field. + return this->emitInitFn() && this->emitActivate(List); + } + return false; +} + + +template +bool ByteCodeExprGen::visitComplexInitializer( + const ComplexType *CT, const InitListExpr *List) { + QualType ElemTy = CT->getElementType(); + + ImplicitValueInitExpr ZE(ElemTy); + PrimType T = classifyPrim(ElemTy); + + auto InitElem = [this, T](unsigned Index, const Expr *E) -> bool { + if (!emitInitFn()) + return false; + if (!visit(E)) + return false; + return this->emitInitElemPop(T, Index, E); + }; + + switch (List->getNumInits()) { + case 0: + return InitElem(0, &ZE) && InitElem(1, &ZE); + case 1: + return this->Visit(List->getInit(0)); + case 2: + return InitElem(0, List->getInit(0)) && InitElem(1, List->getInit(1)); + default: + llvm_unreachable("invalid complex initializer"); + } +} + +template +bool ByteCodeExprGen::visitCastToComplex(const CastExpr *CE) { + QualType SubType = CE->getSubExpr()->getType(); + PrimType T = classifyPrim(SubType); + + if (!emitInitFn()) + return false; + if (!this->visit(CE->getSubExpr())) + return false; + if (!this->emitInitElem(T, 0, CE)) + return false; + ImplicitValueInitExpr ZE(SubType); + if (!visit(&ZE)) + return false; + return this->emitInitElemPop(T, 1, CE); +} + +template +bool ByteCodeExprGen::visitArrayInitializer( + const ConstantArrayType *AT, + const Expr *List, + llvm::function_ref Elem) { + uint64_t NumElems = AT->getSize().getZExtValue(); + if (Optional T = classify(AT->getElementType())) { + if (NumElems > 0 && !emitInitFn()) + return false; + for (unsigned I = 0; I < NumElems; ++I) { + const Expr *ElemInit = Elem(I); + // Construct the value. + if (!visit(ElemInit)) + return false; + // Set the element. + if (I + 1 != NumElems) { + if (!this->emitInitElem(*T, I, ElemInit)) + return false; + } else { + if (!this->emitInitElemPop(*T, I, ElemInit)) + return false; + } + } + } else { + for (unsigned I = 0; I < NumElems; ++I) { + const Expr *ElemInit = Elem(I); + + // Generate a pointer to the field from the array pointer. + FieldScope Scope(this, [this, I, ElemInit](InitFnRef Ptr) { + if (!Ptr()) + return false; + if (!this->emitConstUint32(I, ElemInit)) + return false; + if (!this->emitAddOffsetUint32(ElemInit)) + return false; + return this->emitNarrowPtr(ElemInit); + }); + if (!this->Visit(ElemInit)) + return false; + } + } + + switch (Initialiser) { + case InitKind::ROOT: + return true; + case InitKind::BASE: + llvm_unreachable("arrays cannot be base classes"); + case InitKind::UNION: + return this->emitInitFn() && this->emitActivate(List); + } + return false; +} + +template +bool ByteCodeExprGen::visitZeroInitializer(const Expr *E, + const RecordDecl *RD) { + Record *R = P.getOrCreateRecord(RD); + if (!R) + return false; + + for (auto &B : R->bases()) { + BaseScope Scope(this, [this, &B, E](InitFnRef Ptr) { + if (!Ptr()) + return false; + return this->emitGetPtrBase(B.Offset, E); + }); + if (!visitZeroInitializer(E, B.Decl)) + return false; + } + + for (auto &F : R->fields()) { + ImplicitValueInitExpr Init(F.Decl->getType()); + if (Optional T = classify(F.Decl->getType())) { + if (!emitInitFn()) + return false; + if (!visitZeroInitializer(*T, &Init)) + return false; + if (F.Decl->isBitField()) { + if (!this->emitInitBitField(*T, &F, E)) + return false; + } else { + if (!this->emitInitFieldPop(*T, F.Offset, E)) + return false; + } + } else { + FieldScope Scope(this, [this, &F, E](InitFnRef Ptr) { + if (!Ptr()) + return false; + return this->emitGetPtrField(F.Offset, E); + }); + if (!this->Visit(&Init)) + return false; + } + } + return true; +} + +template +bool ByteCodeExprGen::withField( + const FieldDecl *F, + llvm::function_ref GenField) { + if (auto *R = getRecord(F->getParent())) + if (auto *Field = R->getField(F)) + return GenField(Field); + return false; +} + template const RecordType *ByteCodeExprGen::getRecordTy(QualType Ty) { if (auto *PT = dyn_cast(Ty)) @@ -515,16 +2512,59 @@ return P.getOrCreateRecord(RD); } +template +bool ByteCodeExprGen::visitOpaqueExpr(const OpaqueValueExpr *E) { + PrimType Ty; + if (Optional T = classify(E->getType())) { + Ty = *T; + } else { + Ty = PT_Ptr; + } + + auto Off = allocateLocalPrimitive(E, Ty, /*isConst=*/true); + if (!visit(E->getSourceExpr())) + return false; + this->emitSetLocal(Ty, Off, E); + OpaqueExprs.insert({E, Off}); + return true; +} + template bool ByteCodeExprGen::visitExpr(const Expr *Exp) { ExprScope RootScope(this); - if (!visit(Exp)) - return false; - if (Optional T = classify(Exp)) - return this->emitRet(*T, Exp); - else - return this->emitRetValue(Exp); + if (Optional ExpTy = classify(Exp)) { + if (Exp->isGLValue()) { + // If the expression is an lvalue, dereference it. + if (!lvalueToRvalue(Exp, Exp)) + return false; + + // Find the type of the expression. + PrimType Ty; + if (Optional T = classify(Exp->getType())) { + Ty = *T; + } else { + Ty = PT_Ptr; + } + + return this->emitRet(Ty, Exp); + } else { + // Otherwise, simply evaluate the rvalue. + if (!visit(Exp)) + return false; + return this->emitRet(*ExpTy, Exp); + } + } else { + // Composite declaration - allocate a local and initialize it. + if (auto I = allocateLocal(Exp, /*isExtended=*/false)) { + if (!visitLocalInitializer(Exp, *I)) + return false; + if (!this->emitGetPtrLocal(*I, Exp)) + return false; + return this->emitRetValue(Exp); + } + return this->bail(Exp); + } } template diff --git a/clang/lib/AST/Interp/ByteCodeStmtGen.h b/clang/lib/AST/Interp/ByteCodeStmtGen.h --- a/clang/lib/AST/Interp/ByteCodeStmtGen.h +++ b/clang/lib/AST/Interp/ByteCodeStmtGen.h @@ -60,12 +60,23 @@ bool visitStmt(const Stmt *S); bool visitCompoundStmt(const CompoundStmt *S); bool visitDeclStmt(const DeclStmt *DS); + bool visitForStmt(const ForStmt *FS); + bool visitWhileStmt(const WhileStmt *DS); + bool visitDoStmt(const DoStmt *DS); bool visitReturnStmt(const ReturnStmt *RS); bool visitIfStmt(const IfStmt *IS); + bool visitBreakStmt(const BreakStmt *BS); + bool visitContinueStmt(const ContinueStmt *CS); + bool visitSwitchStmt(const SwitchStmt *SS); + bool visitCaseStmt(const SwitchCase *CS); + bool visitCXXForRangeStmt(const CXXForRangeStmt *FS); /// Compiles a variable declaration. bool visitVarDecl(const VarDecl *VD); + /// Visits a field initializer. + bool visitCtorInit(Record *This, const CXXCtorInitializer *Init); + private: /// Type of the expression returned by the function. llvm::Optional ReturnType; diff --git a/clang/lib/AST/Interp/ByteCodeStmtGen.cpp b/clang/lib/AST/Interp/ByteCodeStmtGen.cpp --- a/clang/lib/AST/Interp/ByteCodeStmtGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeStmtGen.cpp @@ -27,10 +27,9 @@ /// Scope managing label targets. template class LabelScope { public: - virtual ~LabelScope() { } - protected: LabelScope(ByteCodeStmtGen *Ctx) : Ctx(Ctx) {} + /// ByteCodeStmtGen instance. ByteCodeStmtGen *Ctx; }; @@ -97,8 +96,34 @@ ReturnType = this->classify(F->getReturnType()); // Set up fields and context if a constructor. - if (auto *MD = dyn_cast(F)) - return this->bail(MD); + if (auto *MD = dyn_cast(F)) { + const CXXRecordDecl *RD = MD->getParent()->getCanonicalDecl(); + if (auto *R = this->getRecord(RD)) { + if (auto *CD = dyn_cast(MD)) { + // If this is a copy or move constructor for a union, emit a special + // opcode since the operation cannot be represented in the AST easily. + if (CD->isDefaulted() && CD->isCopyOrMoveConstructor() && RD->isUnion()) + return this->emitInitUnion(CD) && this->emitRetVoid(CD); + + if (CD->isDelegatingConstructor()) { + CXXConstructorDecl::init_const_iterator I = CD->init_begin(); + { + ExprScope InitScope(this); + if (!this->visitThisInitializer((*I)->getInit())) + return false; + } + } else { + // Compile all member initializers. + for (const CXXCtorInitializer *Init : CD->inits()) { + if (!visitCtorInit(R, Init)) + return false; + } + } + } + } else { + return this->bail(RD); + } + } if (auto *Body = F->getBody()) if (!visitStmt(Body)) @@ -118,10 +143,27 @@ return visitCompoundStmt(cast(S)); case Stmt::DeclStmtClass: return visitDeclStmt(cast(S)); + case Stmt::ForStmtClass: + return visitForStmt(cast(S)); + case Stmt::WhileStmtClass: + return visitWhileStmt(cast(S)); + case Stmt::DoStmtClass: + return visitDoStmt(cast(S)); case Stmt::ReturnStmtClass: return visitReturnStmt(cast(S)); case Stmt::IfStmtClass: return visitIfStmt(cast(S)); + case Stmt::BreakStmtClass: + return visitBreakStmt(cast(S)); + case Stmt::ContinueStmtClass: + return visitContinueStmt(cast(S)); + case Stmt::SwitchStmtClass: + return visitSwitchStmt(cast(S)); + case Stmt::CaseStmtClass: + case Stmt::DefaultStmtClass: + return visitCaseStmt(cast(S)); + case Stmt::CXXForRangeStmtClass: + return visitCXXForRangeStmt(cast(S)); case Stmt::NullStmtClass: return true; default: { @@ -161,6 +203,114 @@ return true; } +template +bool ByteCodeStmtGen::visitForStmt(const ForStmt *FS) { + // Compile the initialisation statement in an outer scope. + BlockScope OuterScope(this); + if (auto *Init = FS->getInit()) + if (!visitStmt(Init)) + return false; + + LabelTy LabelStart = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + + // Compile the condition, body and increment in the loop scope. + this->emitLabel(LabelStart); + { + BlockScope InnerScope(this); + + if (auto *Cond = FS->getCond()) { + if (auto *CondDecl = FS->getConditionVariableDeclStmt()) + if (!visitDeclStmt(CondDecl)) + return false; + + if (!this->visit(Cond)) + return false; + + if (!this->jumpFalse(LabelEnd)) + return false; + } + + if (auto *Body = FS->getBody()) { + LabelTy LabelSkip = this->getLabel(); + LoopScope FlowScope(this, LabelEnd, LabelSkip); + if (!visitStmt(Body)) + return false; + this->emitLabel(LabelSkip); + } + + if (auto *Inc = FS->getInc()) { + ExprScope IncScope(this); + if (!this->discard(Inc)) + return false; + } + if (!this->jump(LabelStart)) + return false; + } + this->emitLabel(LabelEnd); + return true; +} + +template +bool ByteCodeStmtGen::visitWhileStmt(const WhileStmt *WS) { + LabelTy LabelStart = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + + this->emitLabel(LabelStart); + { + BlockScope InnerScope(this); + if (auto *CondDecl = WS->getConditionVariableDeclStmt()) + if (!visitDeclStmt(CondDecl)) + return false; + + if (!this->visit(WS->getCond())) + return false; + + if (!this->jumpFalse(LabelEnd)) + return false; + + { + LoopScope FlowScope(this, LabelEnd, LabelStart); + if (!visitStmt(WS->getBody())) + return false; + } + if (!this->jump(LabelStart)) + return false; + } + this->emitLabel(LabelEnd); + + return true; +} + +template +bool ByteCodeStmtGen::visitDoStmt(const DoStmt *DS) { + LabelTy LabelStart = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + LabelTy LabelSkip = this->getLabel(); + + this->emitLabel(LabelStart); + { + { + LoopScope FlowScope(this, LabelEnd, LabelSkip); + if (!visitStmt(DS->getBody())) + return false; + this->emitLabel(LabelSkip); + } + + { + ExprScope CondScope(this); + if (!this->visitBool(DS->getCond())) + return false; + } + + if (!this->jumpTrue(LabelStart)) + return false; + } + this->emitLabel(LabelEnd); + + return true; +} + template bool ByteCodeStmtGen::visitReturnStmt(const ReturnStmt *RS) { if (const Expr *RE = RS->getRetValue()) { @@ -226,6 +376,167 @@ return true; } +template +bool ByteCodeStmtGen::visitBreakStmt(const BreakStmt *BS) { + if (!BreakLabel) + return this->bail(BS); + return this->jump(*BreakLabel); +} + +template +bool ByteCodeStmtGen::visitContinueStmt(const ContinueStmt *CS) { + if (!ContinueLabel) + return this->bail(CS); + return this->jump(*ContinueLabel); +} + +template +bool ByteCodeStmtGen::visitSwitchStmt(const SwitchStmt *SS) { + BlockScope InnerScope(this); + + if (Optional T = this->classify(SS->getCond()->getType())) { + // The condition is stored in a local and fetched for every test. + unsigned Off = this->allocateLocalPrimitive(SS->getCond(), *T, + /*isConst=*/true); + + // Compile the condition in its own scope. + { + ExprScope CondScope(this); + if (const Stmt *CondInit = SS->getInit()) + if (!visitStmt(SS->getInit())) + return false; + + if (const DeclStmt *CondDecl = SS->getConditionVariableDeclStmt()) + if (!visitDeclStmt(CondDecl)) + return false; + + if (!this->visit(SS->getCond())) + return false; + + if (!this->emitSetLocal(*T, Off, SS->getCond())) + return false; + } + + LabelTy LabelEnd = this->getLabel(); + + // Generate code to inspect all case labels, jumping to the matched one. + const DefaultStmt *Default = nullptr; + CaseMap Labels; + for (auto *SC = SS->getSwitchCaseList(); SC; SC = SC->getNextSwitchCase()) { + LabelTy Label = this->getLabel(); + Labels.insert({SC, Label}); + + if (auto *DS = dyn_cast(SC)) { + Default = DS; + continue; + } + + if (auto *CS = dyn_cast(SC)) { + if (!this->emitGetLocal(*T, Off, CS)) + return false; + if (!this->visit(CS->getLHS())) + return false; + + if (auto *RHS = CS->getRHS()) { + if (!this->visit(CS->getRHS())) + return false; + if (!this->emitInRange(*T, CS)) + return false; + } else { + if (!this->emitEQ(*T, CS)) + return false; + } + + if (!this->jumpTrue(Label)) + return false; + continue; + } + + return this->bail(SS); + } + + // If a case wasn't matched, jump to default or skip the body. + if (!this->jump(Default ? Labels[Default] : LabelEnd)) + return false; + OptLabelTy DefaultLabel = Default ? Labels[Default] : OptLabelTy{}; + + // Compile the body, using labels defined previously. + SwitchScope LabelScope(this, std::move(Labels), LabelEnd, + DefaultLabel); + if (!visitStmt(SS->getBody())) + return false; + this->emitLabel(LabelEnd); + return true; + } else { + return this->bail(SS); + } +} + +template +bool ByteCodeStmtGen::visitCaseStmt(const SwitchCase *CS) { + auto It = CaseLabels.find(CS); + if (It == CaseLabels.end()) + return this->bail(CS); + + this->emitLabel(It->second); + return visitStmt(CS->getSubStmt()); +} + +template +bool ByteCodeStmtGen::visitCXXForRangeStmt(const CXXForRangeStmt *FS) { + BlockScope Scope(this); + + // Emit the optional init-statement. + if (auto *Init = FS->getInit()) { + if (!visitStmt(Init)) + return false; + } + + // Initialise the __range variable. + if (!visitStmt(FS->getRangeStmt())) + return false; + + // Create the __begin and __end iterators. + if (!visitStmt(FS->getBeginStmt()) || !visitStmt(FS->getEndStmt())) + return false; + + LabelTy LabelStart = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + + this->emitLabel(LabelStart); + { + // Lower the condition. + if (!this->visitBool(FS->getCond())) + return false; + if (!this->jumpFalse(LabelEnd)) + return false; + + // Lower the loop var and body, marking labels for continue/break. + { + BlockScope InnerScope(this); + if (!visitStmt(FS->getLoopVarStmt())) + return false; + + LabelTy LabelSkip = this->getLabel(); + { + LoopScope FlowScope(this, LabelEnd, LabelSkip); + + if (!visitStmt(FS->getBody())) + return false; + } + this->emitLabel(LabelSkip); + } + + // Increment: ++__begin + if (!visitStmt(FS->getInc())) + return false; + if (!this->jump(LabelStart)) + return false; + } + this->emitLabel(LabelEnd); + return true; +} + template bool ByteCodeStmtGen::visitVarDecl(const VarDecl *VD) { auto DT = VD->getType(); @@ -256,6 +567,104 @@ } } +template +bool ByteCodeStmtGen::visitCtorInit(Record *This, + const CXXCtorInitializer *Ctor) { + auto *Init = Ctor->getInit(); + if (auto *Base = Ctor->getBaseClass()) { + auto *Decl = Base->getAs()->getDecl(); + return this->visitBaseInitializer(Init, [this, &This, Init, Ctor, Decl] { + if (Ctor->isBaseVirtual()) + return this->emitGetPtrThisVirtBase(Decl, Init); + else + return this->emitGetPtrThisBase(This->getBase(Decl)->Offset, Init); + }); + } + + if (const FieldDecl *FD = Ctor->getMember()) { + ExprScope Scope(this); + return this->withField(FD, [this, FD, Init, This](const Record::Field *F) { + if (Optional T = this->classify(FD->getType())) { + // Primitive type, can be computed and set. + if (!this->visit(Init)) + return false; + if (This->isUnion()) + return this->emitInitThisFieldActive(*T, F->Offset, Init); + if (F->Decl->isBitField()) + return this->emitInitThisBitField(*T, F, Init); + return this->emitInitThisField(*T, F->Offset, Init); + } else { + // Nested structures. + return this->visitInitializer(Init, [this, F, Init] { + return this->emitGetPtrThisField(F->Offset, Init); + }); + } + }); + } + + if (auto *FD = Ctor->getIndirectMember()) { + auto *Init = Ctor->getInit(); + if (Optional T = this->classify(Init->getType())) { + ArrayRef Chain = FD->chain(); + Record *R = This; + for (unsigned I = 0, N = Chain.size(); I < N; ++I) { + const bool IsUnion = R->isUnion(); + auto *Member = cast(Chain[I]); + auto MemberTy = Member->getType(); + + auto *FD = R->getField(Member); + if (!FD) + return this->bail(Member); + + const unsigned Off = FD->Offset; + if (I + 1 == N) { + // Last member - set the field. + if (!this->visit(Init)) + return false; + if (IsUnion) { + return this->emitInitFieldActive(*T, Off, Init); + } else { + return this->emitInitFieldPop(*T, Off, Init); + } + } else { + // Next field must be a record - fetch it. + R = this->getRecord(cast(MemberTy)->getDecl()); + if (!R) + return this->bail(Member); + + if (IsUnion) { + if (I == 0) { + // Base member - activate this pointer. + if (!this->emitGetPtrActiveThisField(Off, Member)) + return false; + } else { + // Intermediate - active subfield. + if (!this->emitGetPtrActiveField(Off, Member)) + return false; + } + } else { + if (I == 0) { + // Base member - get this pointer. + if (!this->emitGetPtrThisField(Off, Member)) + return false; + } else { + // Intermediate - get pointer. + if (!this->emitGetPtrField(Off, Member)) + return false; + } + } + } + } + } else { + return this->bail(FD); + } + + return true; + } + + llvm_unreachable("unknown base initializer kind"); +} + namespace clang { namespace interp { diff --git a/clang/lib/AST/Interp/Context.cpp b/clang/lib/AST/Interp/Context.cpp --- a/clang/lib/AST/Interp/Context.cpp +++ b/clang/lib/AST/Interp/Context.cpp @@ -11,7 +11,7 @@ #include "ByteCodeExprGen.h" #include "ByteCodeStmtGen.h" #include "EvalEmitter.h" -#include "Interp.h" +#include "InterpLoop.h" #include "InterpFrame.h" #include "InterpStack.h" #include "PrimType.h" @@ -67,6 +67,8 @@ llvm::Optional Context::classify(QualType T) { if (T->isReferenceType() || T->isPointerType()) { + if (T->isFunctionPointerType()) + return PT_FnPtr; return PT_Ptr; } @@ -84,7 +86,7 @@ case 8: return PT_Sint8; default: - return {}; + return PT_SintFP; } } @@ -99,13 +101,22 @@ case 8: return PT_Uint8; default: - return {}; + return PT_UintFP; } } + if (T->isRealFloatingType()) + return PT_RealFP; + if (T->isNullPtrType()) return PT_Ptr; + if (T->isMemberPointerType()) + return PT_MemPtr; + + if (T->isFunctionProtoType()) + return PT_FnPtr; + if (auto *AT = dyn_cast(T)) return classify(AT->getValueType()); diff --git a/clang/lib/AST/Interp/Descriptor.h b/clang/lib/AST/Interp/Descriptor.h --- a/clang/lib/AST/Interp/Descriptor.h +++ b/clang/lib/AST/Interp/Descriptor.h @@ -70,6 +70,8 @@ Record *const ElemRecord = nullptr; /// Descriptor of the array element. Descriptor *const ElemDesc = nullptr; + /// Type of primitive elements. + llvm::Optional ElemTy = {}; /// Flag indicating if the block is mutable. const bool IsConst = false; /// Flag indicating if a field is mutable. diff --git a/clang/lib/AST/Interp/Descriptor.cpp b/clang/lib/AST/Interp/Descriptor.cpp --- a/clang/lib/AST/Interp/Descriptor.cpp +++ b/clang/lib/AST/Interp/Descriptor.cpp @@ -188,26 +188,26 @@ Descriptor::Descriptor(const DeclTy &D, PrimType Type, bool IsConst, bool IsTemporary, bool IsMutable) : Source(D), ElemSize(primSize(Type)), Size(ElemSize), AllocSize(Size), - IsConst(IsConst), IsMutable(IsMutable), IsTemporary(IsTemporary), - CtorFn(getCtorPrim(Type)), DtorFn(getDtorPrim(Type)), - MoveFn(getMovePrim(Type)) { + ElemTy(Type), IsConst(IsConst), IsMutable(IsMutable), + IsTemporary(IsTemporary), CtorFn(getCtorPrim(Type)), + DtorFn(getDtorPrim(Type)), MoveFn(getMovePrim(Type)) { assert(Source && "Missing source"); } Descriptor::Descriptor(const DeclTy &D, PrimType Type, size_t NumElems, bool IsConst, bool IsTemporary, bool IsMutable) : Source(D), ElemSize(primSize(Type)), Size(ElemSize * NumElems), - AllocSize(align(Size) + sizeof(InitMap *)), IsConst(IsConst), - IsMutable(IsMutable), IsTemporary(IsTemporary), IsArray(true), - CtorFn(getCtorArrayPrim(Type)), DtorFn(getDtorArrayPrim(Type)), - MoveFn(getMoveArrayPrim(Type)) { + AllocSize(align(Size) + sizeof(InitMap *)), ElemTy(Type), + IsConst(IsConst), IsMutable(IsMutable), IsTemporary(IsTemporary), + IsArray(true), CtorFn(getCtorArrayPrim(Type)), + DtorFn(getDtorArrayPrim(Type)), MoveFn(getMoveArrayPrim(Type)) { assert(Source && "Missing source"); } Descriptor::Descriptor(const DeclTy &D, PrimType Type, bool IsTemporary, UnknownSize) : Source(D), ElemSize(primSize(Type)), Size(UnknownSizeMark), - AllocSize(alignof(void *)), IsConst(true), IsMutable(false), + AllocSize(alignof(void *)), ElemTy(Type), IsConst(true), IsMutable(false), IsTemporary(IsTemporary), IsArray(true), CtorFn(getCtorArrayPrim(Type)), DtorFn(getDtorArrayPrim(Type)), MoveFn(getMoveArrayPrim(Type)) { assert(Source && "Missing source"); diff --git a/clang/lib/AST/Interp/EvalEmitter.h b/clang/lib/AST/Interp/EvalEmitter.h --- a/clang/lib/AST/Interp/EvalEmitter.h +++ b/clang/lib/AST/Interp/EvalEmitter.h @@ -113,9 +113,9 @@ bool isActive() { return CurrentLabel == ActiveLabel; } /// Helper to invoke a method. - bool ExecuteCall(Function *F, Pointer &&This, const SourceInfo &Info); + bool executeCall(Function *F, Pointer &&This, const SourceInfo &Info); /// Helper to emit a diagnostic on a missing method. - bool ExecuteNoCall(const FunctionDecl *F, const SourceInfo &Info); + bool executeNoCall(const FunctionDecl *F, const SourceInfo &Info); protected: #define GET_EVAL_PROTO diff --git a/clang/lib/AST/Interp/EvalEmitter.cpp b/clang/lib/AST/Interp/EvalEmitter.cpp --- a/clang/lib/AST/Interp/EvalEmitter.cpp +++ b/clang/lib/AST/Interp/EvalEmitter.cpp @@ -9,6 +9,8 @@ #include "EvalEmitter.h" #include "Context.h" #include "Interp.h" +#include "InterpHelper.h" +#include "InterpLoop.h" #include "Opcode.h" #include "Program.h" #include "clang/AST/DeclCXX.h" @@ -42,9 +44,7 @@ return false; } -void EvalEmitter::emitLabel(LabelTy Label) { - CurrentLabel = Label; -} +void EvalEmitter::emitLabel(LabelTy Label) { CurrentLabel = Label; } EvalEmitter::LabelTy EvalEmitter::getLabel() { return NextLabel++; } @@ -95,102 +95,70 @@ return true; } +bool EvalEmitter::emitInvoke(Function *Fn, const SourceInfo &Info) { + if (!isActive()) + return true; + + Pointer This = S.Stk.pop(); + if (!S.CheckInvoke(OpPC, This)) + return false; + if (S.checkingPotentialConstantExpression()) + return false; + + // Interpret the method in its own frame. + return executeCall(Fn, std::move(This), Info); +} + +bool EvalEmitter::emitCall(Function *Fn, const SourceInfo &Info) { + if (!isActive()) + return true; + if (S.checkingPotentialConstantExpression()) + return false; + + // Interpret the function in its own frame. + return executeCall(Fn, Pointer(), Info); +} + +bool EvalEmitter::emitVirtualInvoke(const CXXMethodDecl *MD, + const SourceInfo &Info) { + if (!isActive()) + return true; + // Get the this pointer, ensure it's live. + Pointer This = S.Stk.pop(); + if (!S.CheckInvoke(OpPC, This)) + return false; + if (S.checkingPotentialConstantExpression()) + return false; + + // Dispatch to the virtual method. + CurrentSource = Info; + const CXXMethodDecl *Override = VirtualLookup(This, MD); + if (!S.CheckPure(OpPC, Override)) + return false; + if (Expected Func = P.getOrCreateFunction(Override)) { + if (*Func) { + return executeCall(*Func, std::move(This), Info); + } else { + return executeNoCall(Override, Info); + } + } else { + // Failed to compile method - bail out. + return false; + } +} + template bool EvalEmitter::emitRet(const SourceInfo &Info) { if (!isActive()) return true; using T = typename PrimConv::T; - return ReturnValue(S.Stk.pop(), Result); + return PrimitiveToValue(S.Stk.pop(), Result); } bool EvalEmitter::emitRetVoid(const SourceInfo &Info) { return true; } bool EvalEmitter::emitRetValue(const SourceInfo &Info) { - // Method to recursively traverse composites. - std::function Composite; - Composite = [this, &Composite](QualType Ty, const Pointer &Ptr, APValue &R) { - if (auto *AT = Ty->getAs()) - Ty = AT->getValueType(); - - if (auto *RT = Ty->getAs()) { - auto *Record = Ptr.getRecord(); - assert(Record && "Missing record descriptor"); - - bool Ok = true; - if (RT->getDecl()->isUnion()) { - const FieldDecl *ActiveField = nullptr; - APValue Value; - for (auto &F : Record->fields()) { - const Pointer &FP = Ptr.atField(F.Offset); - QualType FieldTy = F.Decl->getType(); - if (FP.isActive()) { - if (llvm::Optional T = Ctx.classify(FieldTy)) { - TYPE_SWITCH(*T, Ok &= ReturnValue(FP.deref(), Value)); - } else { - Ok &= Composite(FieldTy, FP, Value); - } - break; - } - } - R = APValue(ActiveField, Value); - } else { - unsigned NF = Record->getNumFields(); - unsigned NB = Record->getNumBases(); - unsigned NV = Ptr.isBaseClass() ? 0 : Record->getNumVirtualBases(); - - R = APValue(APValue::UninitStruct(), NB, NF); - - for (unsigned I = 0; I < NF; ++I) { - const Record::Field *FD = Record->getField(I); - QualType FieldTy = FD->Decl->getType(); - const Pointer &FP = Ptr.atField(FD->Offset); - APValue &Value = R.getStructField(I); - - if (llvm::Optional T = Ctx.classify(FieldTy)) { - TYPE_SWITCH(*T, Ok &= ReturnValue(FP.deref(), Value)); - } else { - Ok &= Composite(FieldTy, FP, Value); - } - } - - for (unsigned I = 0; I < NB; ++I) { - const Record::Base *BD = Record->getBase(I); - QualType BaseTy = Ctx.getASTContext().getRecordType(BD->Decl); - const Pointer &BP = Ptr.atField(BD->Offset); - Ok &= Composite(BaseTy, BP, R.getStructBase(I)); - } - - for (unsigned I = 0; I < NV; ++I) { - const Record::Base *VD = Record->getVirtualBase(I); - QualType VirtBaseTy = Ctx.getASTContext().getRecordType(VD->Decl); - const Pointer &VP = Ptr.atField(VD->Offset); - Ok &= Composite(VirtBaseTy, VP, R.getStructBase(NB + I)); - } - } - return Ok; - } - if (auto *AT = Ty->getAsArrayTypeUnsafe()) { - const size_t NumElems = Ptr.getNumElems(); - QualType ElemTy = AT->getElementType(); - R = APValue(APValue::UninitArray{}, NumElems, NumElems); - - bool Ok = true; - for (unsigned I = 0; I < NumElems; ++I) { - APValue &Slot = R.getArrayInitializedElt(I); - const Pointer &EP = Ptr.atIndex(I); - if (llvm::Optional T = Ctx.classify(ElemTy)) { - TYPE_SWITCH(*T, Ok &= ReturnValue(EP.deref(), Slot)); - } else { - Ok &= Composite(ElemTy, EP.narrow(), Slot); - } - } - return Ok; - } - llvm_unreachable("invalid value to return"); - }; - // Return the composite type. - const auto &Ptr = S.Stk.pop(); - return Composite(Ptr.getType(), Ptr, Result); + return PointerToValue(S.Stk.pop(), Result); } bool EvalEmitter::emitGetPtrLocal(uint32_t I, const SourceInfo &Info) { @@ -244,6 +212,104 @@ return true; } +bool EvalEmitter::emitNoCall(const FunctionDecl *F, const SourceInfo &Info) { + if (!isActive()) + return true; + if (S.checkingPotentialConstantExpression()) + return false; + return executeNoCall(F, Info); +} + +bool EvalEmitter::emitNoInvoke(const CXXMethodDecl *F, const SourceInfo &Info) { + if (!isActive()) + return true; + if (S.checkingPotentialConstantExpression()) + return false; + return executeNoCall(F, Info); +} + +bool EvalEmitter::emitIndirectCall(const SourceInfo &Info) { + if (!isActive()) + return true; + if (S.checkingPotentialConstantExpression()) + return false; + + const FnPointer &FnPtr = S.Stk.pop(); + if (FnPtr.isZero()) { + S.FFDiag(Info); + return false; + } + + if (Function *F = FnPtr.asFunction()) + return executeCall(F, Pointer(), Info); + + const FunctionDecl *Decl = FnPtr.asDecl(); + if (Expected Func = S.P.getOrCreateFunction(Decl)) { + if (*Func) + return executeCall(*Func, Pointer(), Info); + else + return executeNoCall(Decl, Info); + } else { + consumeError(Func.takeError()); + return false; + } +} + +bool EvalEmitter::emitIndirectInvoke(const SourceInfo &Info) { + if (!isActive()) + return true; + if (S.checkingPotentialConstantExpression()) + return false; + + // Fetch and validate the object and the pointer. + const MemberPointer &Field = S.Stk.pop(); + Pointer This = S.Stk.pop(); + if (!S.CheckInvoke(OpPC, This)) + return false; + if (Field.isZero()) { + S.FFDiag(Info); + return false; + } + + // Fetch a pointer to the member function. + const CXXMethodDecl *Method = IndirectLookup(This, Field); + if (!Method) + return false; + if (!S.CheckPure(OpPC, Method)) + return false; + + // Execute the function. + if (Expected Func = S.P.getOrCreateFunction(Method)) { + if (*Func) + return executeCall(*Func, std::move(This), Info); + else + return executeNoCall(Method, Info); + } else { + consumeError(Func.takeError()); + return false; + } +} + +bool EvalEmitter::executeCall(Function *F, Pointer &&This, + const SourceInfo &Info) { + CurrentSource = Info; + if (!S.CheckCallable(OpPC, F)) + return false; + S.Current = new InterpFrame(S, F, S.Current, OpPC, std::move(This)); + return Interpret(S, Result); +} + +bool EvalEmitter::executeNoCall(const FunctionDecl *F, const SourceInfo &Info) { + if (S.getLangOpts().CPlusPlus11) { + S.FFDiag(Info, diag::note_constexpr_invalid_function, 1) + << F->isConstexpr() << (bool)dyn_cast(F) << F; + S.Note(F->getLocation(), diag::note_declared_at); + } else { + S.FFDiag(Info, diag::note_invalid_subexpr_in_const_expr); + } + return false; +} + //===----------------------------------------------------------------------===// // Opcode evaluators //===----------------------------------------------------------------------===// diff --git a/clang/lib/AST/Interp/FixedIntegral.h b/clang/lib/AST/Interp/FixedIntegral.h new file mode 100644 --- /dev/null +++ b/clang/lib/AST/Interp/FixedIntegral.h @@ -0,0 +1,237 @@ +//===--- FixedIntegral.h - Wrapper for fixed precision ints -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_AST_INTERP_FIXEDINTEGRAL_H +#define LLVM_CLANG_AST_INTERP_FIXEDINTEGRAL_H + +#include +#include +#include "Boolean.h" +#include "clang/AST/APValue.h" +#include "clang/AST/ComparisonCategories.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/Support/raw_ostream.h" + +namespace clang { +namespace interp { + +using APInt = llvm::APInt; +using APSInt = llvm::APSInt; + +/// Specialisation for fixed precision integers. +template +class FixedIntegral { + private: + template friend class FixedIntegral; + + /// Value is stored in an APSInt. + APSInt V; + + public: + /// Integral initialized to an undefined value. + FixedIntegral() : V() {} + + /// Construct an integral from a value based on signedness. + explicit FixedIntegral(const APSInt &V) : V(V, !Signed) {} + + bool operator<(const FixedIntegral &RHS) const { return V < RHS.V; } + bool operator>(const FixedIntegral &RHS) const { return V > RHS.V; } + bool operator<=(const FixedIntegral &RHS) const { return V <= RHS.V; } + bool operator>=(const FixedIntegral &RHS) const { return V >= RHS.V; } + bool operator==(const FixedIntegral &RHS) const { return V == RHS.V; } + bool operator!=(const FixedIntegral &RHS) const { return V != RHS.V; } + + template + bool operator>(T RHS) const { + constexpr bool ArgSigned = std::is_signed::value; + return V > APSInt(APInt(V.getBitWidth(), RHS, ArgSigned), !ArgSigned); + } + + template + bool operator<(T RHS) const { + constexpr bool ArgSigned = std::is_signed::value; + return V < APSInt(APInt(V.getBitWidth(), RHS, ArgSigned), !ArgSigned); + } + + FixedIntegral operator+(FixedIntegral RHS) const { + return FixedIntegral(V + RHS.V); + } + FixedIntegral operator-(FixedIntegral RHS) const { + return FixedIntegral(V - RHS.V); + } + FixedIntegral operator*(FixedIntegral RHS) const { + return FixedIntegral(V * RHS.V); + } + FixedIntegral operator/(FixedIntegral RHS) const { + return FixedIntegral(V / RHS.V); + } + FixedIntegral operator%(FixedIntegral RHS) const { + return FixedIntegral(V % RHS.V); + } + FixedIntegral operator&(FixedIntegral RHS) const { + return FixedIntegral(V & RHS.V); + } + FixedIntegral operator|(FixedIntegral RHS) const { + return FixedIntegral(V | RHS.V); + } + FixedIntegral operator^(FixedIntegral RHS) const { + return FixedIntegral(V ^ RHS.V); + } + + FixedIntegral operator-() const { return FixedIntegral(-V); } + FixedIntegral operator~() const { return FixedIntegral(~V); } + + FixedIntegral operator>>(unsigned RHS) const { + return FixedIntegral(V >> RHS); + } + FixedIntegral operator<<(unsigned RHS) const { + return FixedIntegral(V << RHS); + } + + explicit operator off_t() const { + return Signed ? V.getSExtValue() : V.getZExtValue(); + } + explicit operator unsigned() const { + return Signed ? V.getZExtValue() : V.getZExtValue(); + } + + APSInt toAPSInt() const { return V; } + APSInt toAPSInt(unsigned NumBits) const { + if (Signed) + return APSInt(toAPSInt().sextOrTrunc(NumBits), !Signed); + else + return APSInt(toAPSInt().zextOrTrunc(NumBits), !Signed); + } + APValue toAPValue() const { return APValue(toAPSInt()); } + + FixedIntegral toUnsigned() const { + return FixedIntegral::from(*this, bitWidth()); + } + + unsigned bitWidth() const { return V.getBitWidth(); } + + bool isZero() const { return V.isNullValue(); } + + bool isMin() const { + if (Signed) + return V.isMinSignedValue(); + else + return V.isMinValue(); + } + + bool isMinusOne() const { return Signed && V.isAllOnesValue(); } + + constexpr static bool isSigned() { return Signed; } + + bool isNegative() const { return V.isNegative(); } + bool isPositive() const { return !isNegative(); } + + ComparisonCategoryResult compare(const FixedIntegral &RHS) const { + if (V < RHS.V) return ComparisonCategoryResult::Less; + if (V > RHS.V) return ComparisonCategoryResult::Greater; + return ComparisonCategoryResult::Equal; + } + + unsigned countLeadingZeros() const { return V.countLeadingZeros(); } + + FixedIntegral truncate(unsigned TruncBits) const { + if (TruncBits >= bitWidth()) return *this; + return FixedIntegral(V.trunc(TruncBits).extend(bitWidth())); + } + + void print(llvm::raw_ostream &OS) const { OS << V; } + + static FixedIntegral min(unsigned NumBits) { + if (Signed) + return FixedIntegral(APSInt(APInt::getSignedMinValue(NumBits), !Signed)); + else + return FixedIntegral(APSInt(APInt::getMinValue(NumBits), !Signed)); + } + + static FixedIntegral max(unsigned NumBits) { + if (Signed) + return FixedIntegral(APSInt(APInt::getSignedMaxValue(NumBits), !Signed)); + else + return FixedIntegral(APSInt(APInt::getMaxValue(NumBits), !Signed)); + } + + template + static + typename std::enable_if::value, FixedIntegral>::type + from(T Value, unsigned NumBits) { + const bool SrcSign = std::is_signed::value; + APInt IntVal(NumBits, static_cast(Value), SrcSign); + return FixedIntegral(APSInt(IntVal, !SrcSign)); + } + + template + static FixedIntegral from(const Integral &I, + unsigned NumBits); + + template + static FixedIntegral from(FixedIntegral Value, unsigned NumBits) { + if (SrcSign) + return FixedIntegral(APSInt(Value.V.sextOrTrunc(NumBits), !Signed)); + else + return FixedIntegral(APSInt(Value.V.zextOrTrunc(NumBits), !Signed)); + } + + static FixedIntegral from(Boolean Value, unsigned NumBits) { + return FixedIntegral( + APSInt(NumBits, static_cast(!Value.isZero()))); + } + + static FixedIntegral zero(unsigned NumBits) { return from(0, NumBits); } + + static bool inRange(int64_t Value, unsigned NumBits) { + APSInt V(APInt(NumBits, Value, true), false); + return min(NumBits).V <= V && V <= max(NumBits).V; + } + + static bool increment(FixedIntegral A, FixedIntegral *R) { + *R = A + FixedIntegral::from(1, A.bitWidth()); + return A.isSigned() && !A.isNegative() && R->isNegative(); + } + + static bool decrement(FixedIntegral A, FixedIntegral *R) { + *R = A - FixedIntegral::from(1, A.bitWidth()); + return !A.isSigned() && A.isNegative() && !R->isNegative(); + } + + static bool add(FixedIntegral A, FixedIntegral B, unsigned OpBits, + FixedIntegral *R) { + APSInt Value(A.V.extend(OpBits) + B.V.extend(OpBits)); + *R = FixedIntegral(Value.trunc(A.bitWidth())); + return R->V.extend(OpBits) != Value; + } + + static bool sub(FixedIntegral A, FixedIntegral B, unsigned OpBits, + FixedIntegral *R) { + APSInt Value(A.V.extend(OpBits) - B.V.extend(OpBits)); + *R = FixedIntegral(Value.trunc(A.bitWidth())); + return R->V.extend(OpBits) != Value; + } + + static bool mul(FixedIntegral A, FixedIntegral B, unsigned OpBits, + FixedIntegral *R) { + APSInt Value(A.V.extend(OpBits) * B.V.extend(OpBits)); + *R = FixedIntegral(Value.trunc(A.bitWidth())); + return R->V.extend(OpBits) != Value; + } +}; + +template +llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, FixedIntegral I) { + I.print(OS); + return OS; +} + +} // namespace interp +} // namespace clang + +#endif diff --git a/clang/lib/AST/Interp/Function.h b/clang/lib/AST/Interp/Function.h --- a/clang/lib/AST/Interp/Function.h +++ b/clang/lib/AST/Interp/Function.h @@ -96,6 +96,12 @@ /// Returns a specific scope. Scope &getScope(unsigned Idx) { return Scopes[Idx]; } + /// Solves a call relocation. + void relocateCall(CodePtr Loc, Function *Callee); + + /// Solves an invoke relocation. + void relocateInvoke(CodePtr Loc, Function *Callee); + /// Returns the source information at a given PC. SourceInfo getSource(CodePtr PC) const; @@ -108,6 +114,9 @@ /// Checks if the function is a constructor. bool isConstructor() const { return isa(F); } + /// Checks if the function was default. + bool isDefaulted() const { return F->isDefaulted(); } + private: /// Construct a function representing an actual function. Function(Program &P, const FunctionDecl *F, unsigned ArgSize, diff --git a/clang/lib/AST/Interp/Function.cpp b/clang/lib/AST/Interp/Function.cpp --- a/clang/lib/AST/Interp/Function.cpp +++ b/clang/lib/AST/Interp/Function.cpp @@ -31,13 +31,32 @@ return It->second; } +void Function::relocateCall(CodePtr Loc, Function *Callee) { + using namespace llvm::support; + + char *Ptr = Code.data() + (Loc - getCodeBegin()); + const auto Addr = reinterpret_cast(Callee); + *(reinterpret_cast(Ptr) - 1) = OP_Call; + + endian::write(Ptr, Addr); +} + +void Function::relocateInvoke(CodePtr Loc, Function *Callee) { + using namespace llvm::support; + + char *Ptr = Code.data() + (Loc - getCodeBegin()); + const auto Addr = reinterpret_cast(Callee); + *(reinterpret_cast(Ptr) - 1) = OP_Invoke; + + endian::write(Ptr, Addr); +} + SourceInfo Function::getSource(CodePtr PC) const { unsigned Offset = PC - getCodeBegin(); using Elem = std::pair; auto It = std::lower_bound(SrcMap.begin(), SrcMap.end(), Elem{Offset, {}}, [](Elem A, Elem B) { return A.first < B.first; }); - if (It == SrcMap.end() || It->first != Offset) - llvm::report_fatal_error("missing source location"); + assert(It != SrcMap.end() && It->first == Offset); return It->second; } diff --git a/clang/lib/AST/Interp/Integral.h b/clang/lib/AST/Interp/Integral.h --- a/clang/lib/AST/Interp/Integral.h +++ b/clang/lib/AST/Interp/Integral.h @@ -13,13 +13,15 @@ #ifndef LLVM_CLANG_AST_INTERP_INTEGRAL_H #define LLVM_CLANG_AST_INTERP_INTEGRAL_H -#include "clang/AST/ComparisonCategories.h" +#include +#include +#include "Boolean.h" +#include "FixedIntegral.h" #include "clang/AST/APValue.h" +#include "clang/AST/ComparisonCategories.h" #include "llvm/ADT/APSInt.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include -#include namespace clang { namespace interp { @@ -30,45 +32,71 @@ /// Helper to compare two comparable types. template ComparisonCategoryResult Compare(const T &X, const T &Y) { - if (X < Y) - return ComparisonCategoryResult::Less; - if (X > Y) - return ComparisonCategoryResult::Greater; + if (X < Y) return ComparisonCategoryResult::Less; + if (X > Y) return ComparisonCategoryResult::Greater; return ComparisonCategoryResult::Equal; } // Helper structure to select the representation. -template struct Repr; -template <> struct Repr<8, false> { using Type = uint8_t; }; -template <> struct Repr<16, false> { using Type = uint16_t; }; -template <> struct Repr<32, false> { using Type = uint32_t; }; -template <> struct Repr<64, false> { using Type = uint64_t; }; -template <> struct Repr<8, true> { using Type = int8_t; }; -template <> struct Repr<16, true> { using Type = int16_t; }; -template <> struct Repr<32, true> { using Type = int32_t; }; -template <> struct Repr<64, true> { using Type = int64_t; }; +template +struct Repr; +template <> +struct Repr<8, false> { + using Type = uint8_t; +}; +template <> +struct Repr<16, false> { + using Type = uint16_t; +}; +template <> +struct Repr<32, false> { + using Type = uint32_t; +}; +template <> +struct Repr<64, false> { + using Type = uint64_t; +}; +template <> +struct Repr<8, true> { + using Type = int8_t; +}; +template <> +struct Repr<16, true> { + using Type = int16_t; +}; +template <> +struct Repr<32, true> { + using Type = int32_t; +}; +template <> +struct Repr<64, true> { + using Type = int64_t; +}; /// Wrapper around numeric types. /// /// These wrappers are required to shared an interface between APSint and /// builtin primitive numeral types, while optimising for storage and /// allowing methods operating on primitive type to compile to fast code. -template class Integral { -private: - template friend class Integral; +template +class Integral { + private: + template + friend class Integral; // The primitive representing the integral. using T = typename Repr::Type; T V; /// Primitive representing limits. - static const auto Min = std::numeric_limits::min(); - static const auto Max = std::numeric_limits::max(); + static constexpr auto Min = std::numeric_limits::min(); + static constexpr auto Max = std::numeric_limits::max(); /// Construct an integral from anything that is convertible to storage. - template explicit Integral(T V) : V(V) {} + template + explicit Integral(T V) : V(V) {} -public: + public: /// Zero-initializes an integral. Integral() : V(0) {} @@ -91,9 +119,21 @@ return V >= 0 && static_cast(V) > RHS; } + Integral operator+(Integral RHS) const { return Integral(V + RHS.V); } + Integral operator-(Integral RHS) const { return Integral(V - RHS.V); } + Integral operator*(Integral RHS) const { return Integral(V * RHS.V); } + Integral operator/(Integral RHS) const { return Integral(V / RHS.V); } + Integral operator%(Integral RHS) const { return Integral(V % RHS.V); } + Integral operator&(Integral RHS) const { return Integral(V & RHS.V); } + Integral operator|(Integral RHS) const { return Integral(V | RHS.V); } + Integral operator^(Integral RHS) const { return Integral(V ^ RHS.V); } + Integral operator-() const { return Integral(-V); } Integral operator~() const { return Integral(~V); } + Integral operator>>(unsigned RHS) const { return Integral(V >> RHS); } + Integral operator<<(unsigned RHS) const { return Integral(V << RHS); } + template explicit operator Integral() const { return Integral(V); @@ -135,11 +175,12 @@ return Compare(V, RHS.V); } - unsigned countLeadingZeros() const { return llvm::countLeadingZeros(V); } + unsigned countLeadingZeros() const { + return llvm::countLeadingZeros(toUnsigned().V); + } Integral truncate(unsigned TruncBits) const { - if (TruncBits >= Bits) - return *this; + if (TruncBits >= Bits) return *this; const T BitMask = (T(1) << T(TruncBits)) - 1; const T SignBit = T(1) << (TruncBits - 1); const T ExtMask = ~BitMask; @@ -148,12 +189,8 @@ void print(llvm::raw_ostream &OS) const { OS << V; } - static Integral min(unsigned NumBits) { - return Integral(Min); - } - static Integral max(unsigned NumBits) { - return Integral(Max); - } + static Integral min(unsigned NumBits) { return Integral(Min); } + static Integral max(unsigned NumBits) { return Integral(Max); } template static typename std::enable_if::value, Integral>::type @@ -162,21 +199,29 @@ } template - static typename std::enable_if::type - from(Integral Value) { + static Integral from(const Integral &Value) { return Integral(Value.V); } - template static Integral from(Integral<0, SrcSign> Value) { + template + static Integral from(const FixedIntegral &Value) { + return Integral(Value.toAPSInt().getExtValue()); + } + + template + static Integral from(Integral<0, SrcSign> Value) { if (SrcSign) return Integral(Value.V.getSExtValue()); else return Integral(Value.V.getZExtValue()); } + static Integral from(Boolean Value) { return Integral(!Value.isZero()); } + static Integral zero() { return from(0); } - template static Integral from(T Value, unsigned NumBits) { + template + static Integral from(T Value, unsigned NumBits) { return Integral(Value); } @@ -204,7 +249,7 @@ return CheckMulUB(A.V, B.V, R->V); } -private: + private: template static typename std::enable_if::value, bool>::type CheckAddUB(T A, T B, T &R) { @@ -263,7 +308,24 @@ return OS; } -} // namespace interp -} // namespace clang +template +Boolean Boolean::from(const Integral &V) { + return Boolean(!V.isZero()); +} + +template +Boolean Boolean::from(const FixedIntegral &V) { + return Boolean(!V.isZero()); +} + +template +template +FixedIntegral FixedIntegral::from( + const Integral &I, unsigned NumBits) { + return FixedIntegral(I.toAPSInt().extOrTrunc(NumBits)); +} + +} // namespace interp +} // namespace clang #endif diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h --- a/clang/lib/AST/Interp/Interp.h +++ b/clang/lib/AST/Interp/Interp.h @@ -13,10 +13,10 @@ #ifndef LLVM_CLANG_AST_INTERP_INTERP_H #define LLVM_CLANG_AST_INTERP_INTERP_H -#include -#include +#include "Builtin.h" #include "Function.h" #include "InterpFrame.h" +#include "InterpHelper.h" #include "InterpStack.h" #include "InterpState.h" #include "Opcode.h" @@ -30,6 +30,8 @@ #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APSInt.h" #include "llvm/Support/Endian.h" +#include +#include namespace clang { namespace interp { @@ -37,62 +39,6 @@ using APInt = llvm::APInt; using APSInt = llvm::APSInt; -/// Convers a value to an APValue. -template bool ReturnValue(const T &V, APValue &R) { - R = V.toAPValue(); - return true; -} - -/// Checks if the variable has externally defined storage. -bool CheckExtern(InterpState &S, CodePtr OpPC, const Pointer &Ptr); - -/// Checks if the array is offsetable. -bool CheckArray(InterpState &S, CodePtr OpPC, const Pointer &Ptr); - -/// Checks if a pointer is live and accesible. -bool CheckLive(InterpState &S, CodePtr OpPC, const Pointer &Ptr, - AccessKinds AK); -/// Checks if a pointer is null. -bool CheckNull(InterpState &S, CodePtr OpPC, const Pointer &Ptr, - CheckSubobjectKind CSK); - -/// Checks if a pointer is in range. -bool CheckRange(InterpState &S, CodePtr OpPC, const Pointer &Ptr, - AccessKinds AK); - -/// Checks if a field from which a pointer is going to be derived is valid. -bool CheckRange(InterpState &S, CodePtr OpPC, const Pointer &Ptr, - CheckSubobjectKind CSK); - -/// Checks if a pointer points to const storage. -bool CheckConst(InterpState &S, CodePtr OpPC, const Pointer &Ptr); - -/// Checks if a pointer points to a mutable field. -bool CheckMutable(InterpState &S, CodePtr OpPC, const Pointer &Ptr); - -/// Checks if a value can be loaded from a block. -bool CheckLoad(InterpState &S, CodePtr OpPC, const Pointer &Ptr); - -/// Checks if a value can be stored in a block. -bool CheckStore(InterpState &S, CodePtr OpPC, const Pointer &Ptr); - -/// Checks if a method can be invoked on an object. -bool CheckInvoke(InterpState &S, CodePtr OpPC, const Pointer &Ptr); - -/// Checks if a value can be initialized. -bool CheckInit(InterpState &S, CodePtr OpPC, const Pointer &Ptr); - -/// Checks if a method can be called. -bool CheckCallable(InterpState &S, CodePtr OpPC, Function *F); - -/// Checks the 'this' pointer. -bool CheckThis(InterpState &S, CodePtr OpPC, const Pointer &This); - -/// Checks if a method is pure virtual. -bool CheckPure(InterpState &S, CodePtr OpPC, const CXXMethodDecl *MD); - -template inline bool IsTrue(const T &V) { return !V.isZero(); } - //===----------------------------------------------------------------------===// // Add, Sub, Mul //===----------------------------------------------------------------------===// @@ -115,7 +61,7 @@ APSInt Value = OpAP()(LHS.toAPSInt(Bits), RHS.toAPSInt(Bits)); // Report undefined behaviour, stopping if required. - const Expr *E = S.Current->getExpr(OpPC); + const Expr *E = S.getSource(OpPC).asExpr(); QualType Type = E->getType(); if (S.checkingForUndefinedBehavior()) { auto Trunc = Value.trunc(Result.bitWidth()).toString(10); @@ -152,6 +98,92 @@ return AddSubMulHelper(S, OpPC, Bits, LHS, RHS); } +template +bool RealArithHelper(InterpState &S, CodePtr OpPC) { + const Real &RHS = S.Stk.pop(); + const Real &LHS = S.Stk.pop(); + + Real Result = Op(LHS, RHS, M); + S.Stk.push(Result); + if (!Result.isNaN()) + return true; + + const SourceInfo &Src = S.getSource(OpPC); + S.CCEDiag(Src, diag::note_constexpr_float_arithmetic) << Result.isNaN(); + return S.noteUndefinedBehavior(); +} + +template <> inline bool Add(InterpState &S, CodePtr OpPC) { + return RealArithHelper(S, OpPC); +} + +template <> inline bool Sub(InterpState &S, CodePtr OpPC) { + return RealArithHelper(S, OpPC); +} + +template <> inline bool Mul(InterpState &S, CodePtr OpPC) { + return RealArithHelper(S, OpPC); +} + +//===----------------------------------------------------------------------===// +// Div, Rem +//===----------------------------------------------------------------------===// + +template +bool DivRemHelper(InterpState &S, CodePtr OpPC) { + const T RHS = S.Stk.pop(); + const T LHS = S.Stk.pop(); + + // Bail on division by zero. + if (RHS.isZero()) { + S.FFDiag(S.getSource(OpPC), diag::note_expr_divide_by_zero); + return false; + } + + if (RHS.isSigned()) { + // (-MAX - 1) / -1 = MAX + 1 overflows. + if (LHS.isMin() && RHS.isMinusOne()) { + // Push the truncated value in case the interpreter continues. + S.Stk.push(T::min(RHS.bitWidth())); + + // Compute the actual value for the diagnostic. + const size_t Bits = RHS.bitWidth() + 1; + APSInt Value = OpAP()(LHS.toAPSInt(Bits), RHS.toAPSInt(Bits)); + return S.reportOverflow(S.getSource(OpPC).asExpr(), Value); + } + } + + // Safe to execute division here. + S.Stk.push(OpFW(LHS, RHS)); + return true; +} + +template T DivFn(T a, T b) { return a / b; } + +template T RemFn(T a, T b) { return a % b; } + +template ::T> +bool Div(InterpState &S, CodePtr OpPC) { + return DivRemHelper, std::divides>(S, OpPC); +} + +template <> inline bool Div(InterpState &S, CodePtr OpPC) { + const Real RHS = S.Stk.pop(); + const Real LHS = S.Stk.pop(); + + if (RHS.isZero()) { + S.FFDiag(S.getSource(OpPC), diag::note_expr_divide_by_zero); + return false; + } + S.Stk.push(Real::div(LHS, RHS, APFloat::rmNearestTiesToEven)); + return true; +} + +template ::T> +bool Rem(InterpState &S, CodePtr OpPC) { + return DivRemHelper, std::modulus>(S, OpPC); +} + //===----------------------------------------------------------------------===// // EQ, NE, GT, GE, LT, LE //===----------------------------------------------------------------------===// @@ -179,7 +211,7 @@ const Pointer &LHS = S.Stk.pop(); if (!Pointer::hasSameBase(LHS, RHS)) { - const SourceInfo &Loc = S.Current->getSource(OpPC); + const SourceInfo &Loc = S.getSource(OpPC); S.FFDiag(Loc, diag::note_invalid_subexpr_in_const_expr); return false; } else { @@ -215,6 +247,38 @@ } } +template <> +inline bool CmpHelperEQ(InterpState &S, CodePtr OpPC, + CompareFn Fn) { + using BoolT = PrimConv::T; + const auto &RHS = S.Stk.pop(); + const auto &LHS = S.Stk.pop(); + + auto CheckVirtual = [OpPC, &S](const MemberPointer &Ptr) { + if (auto *MD = dyn_cast(Ptr.getDecl())) { + if (MD->isVirtual()) { + const SourceInfo &E = S.getSource(OpPC); + S.CCEDiag(E, diag::note_constexpr_compare_virtual_mem_ptr) << MD; + return false; + } + } + return true; + }; + + using CCR = ComparisonCategoryResult; + if (!LHS.getDecl() || !RHS.getDecl()) { + CCR Result = !LHS.getDecl() && !RHS.getDecl() ? CCR::Equal : CCR::Nonequal; + S.Stk.push(BoolT::from(Fn(Result))); + return true; + } + + if (!CheckVirtual(LHS) || !CheckVirtual(RHS)) + return false; + + S.Stk.push(BoolT::from(Fn(LHS == RHS ? CCR::Equal : CCR::Nonequal))); + return true; +} + template ::T> bool EQ(InterpState &S, CodePtr OpPC) { return CmpHelperEQ(S, OpPC, [](ComparisonCategoryResult R) { @@ -273,6 +337,45 @@ return true; } +//===----------------------------------------------------------------------===// +// Minus, Not, LogicalNot +//===----------------------------------------------------------------------===// + +template ::T> +bool Minus(InterpState &S, CodePtr OpPC) { + const T &Arg = S.Stk.pop(); + + if (Arg.isSigned() && Arg.isMin()) { + // Push the truncated value in case the interpreter continues. + S.Stk.push(T::min(Arg.bitWidth())); + + // Compute the actual value for the diagnostic. + const size_t Bits = Arg.bitWidth() + 1; + const APSInt Value = std::negate()(Arg.toAPSInt(Bits)); + return S.reportOverflow(S.getSource(OpPC).asExpr(), Value); + } + + S.Stk.push(std::negate()(Arg)); + return true; +} + +template <> inline bool Minus(InterpState &S, CodePtr OpPC) { + S.Stk.push(Real::negate(S.Stk.pop())); + return true; +} + +template ::T> +bool Not(InterpState &S, CodePtr OpPC) { + S.Stk.push(~S.Stk.pop()); + return true; +} + +template ::T> +bool LogicalNot(InterpState &S, CodePtr OpPC) { + S.Stk.push(S.Stk.pop().isZero()); + return true; +} + //===----------------------------------------------------------------------===// // Dup, Pop, Test //===----------------------------------------------------------------------===// @@ -289,6 +392,12 @@ return true; } +template ::T> +bool Test(InterpState &S, CodePtr OpPC) { + S.Stk.push(!S.Stk.pop().isZero()); + return true; +} + //===----------------------------------------------------------------------===// // Const //===----------------------------------------------------------------------===// @@ -299,6 +408,38 @@ return true; } +template ::T> +bool ConstFP(InterpState &S, CodePtr OpPC, uint32_t Bits, uint64_t Value) { + S.Stk.push(T::from(Value, Bits)); + return true; +} + +inline bool ConstRealFP(InterpState &S, CodePtr OpPC, + const FloatingLiteral *Value) { + S.Stk.push(Real::from(Value->getValue(), Value->getSemantics())); + return true; +} + +//===----------------------------------------------------------------------===// +// ConstFn, ConstNoFn, ConstMem +//===----------------------------------------------------------------------===// + +inline bool ConstFn(InterpState &S, CodePtr OpPC, Function *Fn) { + S.Stk.push(Fn); + return true; +} + +inline bool ConstNoFn(InterpState &S, CodePtr OpPC, const FunctionDecl *FD) { + // TODO: try to compile and rewrite opcode. + S.Stk.push(FD); + return true; +} + +inline bool ConstMem(InterpState &S, CodePtr OpPC, const ValueDecl *VD) { + S.Stk.push(VD); + return true; +} + //===----------------------------------------------------------------------===// // Get/Set Local/Param/Global/This //===----------------------------------------------------------------------===// @@ -333,12 +474,12 @@ template ::T> bool GetField(InterpState &S, CodePtr OpPC, uint32_t I) { const Pointer &Obj = S.Stk.peek(); - if (!CheckNull(S, OpPC, Obj, CSK_Field)) - return false; - if (!CheckRange(S, OpPC, Obj, CSK_Field)) + if (!S.CheckNull(OpPC, Obj, CSK_Field)) + return false; + if (!S.CheckRange(OpPC, Obj, CSK_Field)) return false; const Pointer &Field = Obj.atField(I); - if (!CheckLoad(S, OpPC, Field)) + if (!S.CheckLoad(OpPC, Field)) return false; S.Stk.push(Field.deref()); return true; @@ -348,12 +489,12 @@ bool SetField(InterpState &S, CodePtr OpPC, uint32_t I) { const T &Value = S.Stk.pop(); const Pointer &Obj = S.Stk.peek(); - if (!CheckNull(S, OpPC, Obj, CSK_Field)) + if (!S.CheckNull(OpPC, Obj, CSK_Field)) return false; - if (!CheckRange(S, OpPC, Obj, CSK_Field)) + if (!S.CheckRange(OpPC, Obj, CSK_Field)) return false; const Pointer &Field = Obj.atField(I); - if (!CheckStore(S, OpPC, Field)) + if (!S.CheckStore(OpPC, Field)) return false; Field.deref() = Value; return true; @@ -362,12 +503,12 @@ template ::T> bool GetFieldPop(InterpState &S, CodePtr OpPC, uint32_t I) { const Pointer &Obj = S.Stk.pop(); - if (!CheckNull(S, OpPC, Obj, CSK_Field)) + if (!S.CheckNull(OpPC, Obj, CSK_Field)) return false; - if (!CheckRange(S, OpPC, Obj, CSK_Field)) + if (!S.CheckRange(OpPC, Obj, CSK_Field)) return false; const Pointer &Field = Obj.atField(I); - if (!CheckLoad(S, OpPC, Field)) + if (!S.CheckLoad(OpPC, Field)) return false; S.Stk.push(Field.deref()); return true; @@ -378,10 +519,10 @@ if (S.checkingPotentialConstantExpression()) return false; const Pointer &This = S.Current->getThis(); - if (!CheckThis(S, OpPC, This)) + if (!S.CheckThis(OpPC, This)) return false; const Pointer &Field = This.atField(I); - if (!CheckLoad(S, OpPC, Field)) + if (!S.CheckLoad(OpPC, Field)) return false; S.Stk.push(Field.deref()); return true; @@ -393,10 +534,10 @@ return false; const T &Value = S.Stk.pop(); const Pointer &This = S.Current->getThis(); - if (!CheckThis(S, OpPC, This)) + if (!S.CheckThis(OpPC, This)) return false; const Pointer &Field = This.atField(I); - if (!CheckStore(S, OpPC, Field)) + if (!S.CheckStore(OpPC, Field)) return false; Field.deref() = Value; return true; @@ -405,6 +546,13 @@ template ::T> bool GetGlobal(InterpState &S, CodePtr OpPC, uint32_t I) { auto *B = S.P.getGlobal(I); + if (isIntegral(Name) && !B->isConst()) { + const SourceInfo &E = S.getSource(OpPC); + const ValueDecl *VD = B->getDescriptor()->asValueDecl(); + S.FFDiag(E, diag::note_constexpr_ltor_non_const_int, 1) << VD; + S.Note(VD->getLocation(), diag::note_declared_at); + return false; + } if (B->isExtern()) return false; S.Stk.push(B->deref()); @@ -428,7 +576,7 @@ if (S.checkingPotentialConstantExpression()) return false; const Pointer &This = S.Current->getThis(); - if (!CheckThis(S, OpPC, This)) + if (!S.CheckThis(OpPC, This)) return false; const Pointer &Field = This.atField(I); Field.deref() = S.Stk.pop(); @@ -441,7 +589,7 @@ if (S.checkingPotentialConstantExpression()) return false; const Pointer &This = S.Current->getThis(); - if (!CheckThis(S, OpPC, This)) + if (!S.CheckThis(OpPC, This)) return false; const Pointer &Field = This.atField(F->Offset); const auto &Value = S.Stk.pop(); @@ -455,7 +603,7 @@ if (S.checkingPotentialConstantExpression()) return false; const Pointer &This = S.Current->getThis(); - if (!CheckThis(S, OpPC, This)) + if (!S.CheckThis(OpPC, This)) return false; const Pointer &Field = This.atField(I); Field.deref() = S.Stk.pop(); @@ -466,6 +614,16 @@ template ::T> bool InitField(InterpState &S, CodePtr OpPC, uint32_t I) { + const T &Value = S.Stk.pop(); + const Pointer &Field = S.Stk.peek().atField(I); + Field.deref() = Value; + Field.activate(); + Field.initialize(); + return true; +} + +template ::T> +bool InitFieldPop(InterpState &S, CodePtr OpPC, uint32_t I) { const T &Value = S.Stk.pop(); const Pointer &Field = S.Stk.pop().atField(I); Field.deref() = Value; @@ -474,6 +632,17 @@ return true; } +template ::T> +bool InitFieldPeek(InterpState &S, CodePtr OpPC, uint32_t I) { + const T &Value = S.Stk.pop(); + const Pointer &Field = S.Stk.peek().atField(I); + Field.deref() = Value; + Field.activate(); + Field.initialize(); + S.Stk.push(Value); + return true; +} + template ::T> bool InitBitField(InterpState &S, CodePtr OpPC, const Record::Field *F) { const T &Value = S.Stk.pop(); @@ -519,11 +688,9 @@ inline bool GetPtrField(InterpState &S, CodePtr OpPC, uint32_t Off) { const Pointer &Ptr = S.Stk.pop(); - if (!CheckNull(S, OpPC, Ptr, CSK_Field)) + if (!S.CheckNull(OpPC, Ptr, CSK_Field)) return false; - if (!CheckExtern(S, OpPC, Ptr)) - return false; - if (!CheckRange(S, OpPC, Ptr, CSK_Field)) + if (!S.CheckRange(OpPC, Ptr, CSK_Field)) return false; S.Stk.push(Ptr.atField(Off)); return true; @@ -533,7 +700,7 @@ if (S.checkingPotentialConstantExpression()) return false; const Pointer &This = S.Current->getThis(); - if (!CheckThis(S, OpPC, This)) + if (!S.CheckThis(OpPC, This)) return false; S.Stk.push(This.atField(Off)); return true; @@ -541,9 +708,9 @@ inline bool GetPtrActiveField(InterpState &S, CodePtr OpPC, uint32_t Off) { const Pointer &Ptr = S.Stk.pop(); - if (!CheckNull(S, OpPC, Ptr, CSK_Field)) + if (!S.CheckNull(OpPC, Ptr, CSK_Field)) return false; - if (!CheckRange(S, OpPC, Ptr, CSK_Field)) + if (!S.CheckRange(OpPC, Ptr, CSK_Field)) return false; Pointer Field = Ptr.atField(Off); Ptr.deactivate(); @@ -553,10 +720,10 @@ } inline bool GetPtrActiveThisField(InterpState &S, CodePtr OpPC, uint32_t Off) { - if (S.checkingPotentialConstantExpression()) + if (S.checkingPotentialConstantExpression()) return false; const Pointer &This = S.Current->getThis(); - if (!CheckThis(S, OpPC, This)) + if (!S.CheckThis(OpPC, This)) return false; Pointer Field = This.atField(Off); This.deactivate(); @@ -567,7 +734,7 @@ inline bool GetPtrBase(InterpState &S, CodePtr OpPC, uint32_t Off) { const Pointer &Ptr = S.Stk.pop(); - if (!CheckNull(S, OpPC, Ptr, CSK_Base)) + if (!S.CheckNull(OpPC, Ptr, CSK_Base)) return false; S.Stk.push(Ptr.atField(Off)); return true; @@ -577,7 +744,7 @@ if (S.checkingPotentialConstantExpression()) return false; const Pointer &This = S.Current->getThis(); - if (!CheckThis(S, OpPC, This)) + if (!S.CheckThis(OpPC, This)) return false; S.Stk.push(This.atField(Off)); return true; @@ -596,7 +763,7 @@ inline bool GetPtrVirtBase(InterpState &S, CodePtr OpPC, const RecordDecl *D) { const Pointer &Ptr = S.Stk.pop(); - if (!CheckNull(S, OpPC, Ptr, CSK_Base)) + if (!S.CheckNull(OpPC, Ptr, CSK_Base)) return false; return VirtBaseHelper(S, OpPC, D, Ptr); } @@ -606,11 +773,40 @@ if (S.checkingPotentialConstantExpression()) return false; const Pointer &This = S.Current->getThis(); - if (!CheckThis(S, OpPC, This)) + if (!S.CheckThis(OpPC, This)) return false; return VirtBaseHelper(S, OpPC, D, S.Current->getThis()); } +inline bool GetPtrFieldIndirect(InterpState &S, CodePtr OpPC) { + // Fetch a pointer to the member. + const MemberPointer &Field = S.Stk.pop(); + + // Fetch and validate the object. + Pointer Ptr = S.Stk.pop(); + if (!S.CheckNull(OpPC, Ptr, CSK_Field)) + return false; + if (!S.CheckRange(OpPC, Ptr, CSK_Field)) + return false; + + // Validate the pointer. + if (Field.isZero()) { + S.FFDiag(S.getSource(OpPC)); + return false; + } + + // Find the correct class in the hierarchy and fetch the pointer. + if (auto *Decl = dyn_cast(Field.getDecl())) { + if (Record *R = PointerLookup(Ptr, Field)) { + S.Stk.push(Ptr.atField(R->getField(Decl)->Offset)); + return true; + } + } + + S.FFDiag(S.getSource(OpPC)); + return false; +} + //===----------------------------------------------------------------------===// // Load, Store, Init //===----------------------------------------------------------------------===// @@ -618,7 +814,7 @@ template ::T> bool Load(InterpState &S, CodePtr OpPC) { const Pointer &Ptr = S.Stk.peek(); - if (!CheckLoad(S, OpPC, Ptr)) + if (!S.CheckLoad(OpPC, Ptr)) return false; S.Stk.push(Ptr.deref()); return true; @@ -627,7 +823,7 @@ template ::T> bool LoadPop(InterpState &S, CodePtr OpPC) { const Pointer &Ptr = S.Stk.pop(); - if (!CheckLoad(S, OpPC, Ptr)) + if (!S.CheckLoad(OpPC, Ptr)) return false; S.Stk.push(Ptr.deref()); return true; @@ -637,7 +833,7 @@ bool Store(InterpState &S, CodePtr OpPC) { const T &Value = S.Stk.pop(); const Pointer &Ptr = S.Stk.peek(); - if (!CheckStore(S, OpPC, Ptr)) + if (!S.CheckStore(OpPC, Ptr)) return false; Ptr.deref() = Value; return true; @@ -647,7 +843,7 @@ bool StorePop(InterpState &S, CodePtr OpPC) { const T &Value = S.Stk.pop(); const Pointer &Ptr = S.Stk.pop(); - if (!CheckStore(S, OpPC, Ptr)) + if (!S.CheckStore(OpPC, Ptr)) return false; Ptr.deref() = Value; return true; @@ -657,7 +853,7 @@ bool StoreBitField(InterpState &S, CodePtr OpPC) { const T &Value = S.Stk.pop(); const Pointer &Ptr = S.Stk.peek(); - if (!CheckStore(S, OpPC, Ptr)) + if (!S.CheckStore(OpPC, Ptr)) return false; if (auto *FD = Ptr.getField()) { Ptr.deref() = Value.truncate(FD->getBitWidthValue(S.getCtx())); @@ -671,7 +867,7 @@ bool StoreBitFieldPop(InterpState &S, CodePtr OpPC) { const T &Value = S.Stk.pop(); const Pointer &Ptr = S.Stk.pop(); - if (!CheckStore(S, OpPC, Ptr)) + if (!S.CheckStore(OpPC, Ptr)) return false; if (auto *FD = Ptr.getField()) { Ptr.deref() = Value.truncate(FD->getBitWidthValue(S.getCtx())); @@ -685,7 +881,7 @@ bool InitPop(InterpState &S, CodePtr OpPC) { const T &Value = S.Stk.pop(); const Pointer &Ptr = S.Stk.pop(); - if (!CheckInit(S, OpPC, Ptr)) + if (!S.CheckInit(OpPC, Ptr)) return false; Ptr.initialize(); new (&Ptr.deref()) T(Value); @@ -696,7 +892,7 @@ bool InitElem(InterpState &S, CodePtr OpPC, uint32_t Idx) { const T &Value = S.Stk.pop(); const Pointer &Ptr = S.Stk.peek().atIndex(Idx); - if (!CheckInit(S, OpPC, Ptr)) + if (!S.CheckInit(OpPC, Ptr)) return false; Ptr.initialize(); new (&Ptr.deref()) T(Value); @@ -707,13 +903,23 @@ bool InitElemPop(InterpState &S, CodePtr OpPC, uint32_t Idx) { const T &Value = S.Stk.pop(); const Pointer &Ptr = S.Stk.pop().atIndex(Idx); - if (!CheckInit(S, OpPC, Ptr)) + if (!S.CheckInit(OpPC, Ptr)) return false; Ptr.initialize(); new (&Ptr.deref()) T(Value); return true; } +inline bool Initialise(InterpState &S, CodePtr OpPC) { + S.Stk.pop().initialize(); + return true; +} + +inline bool Activate(InterpState &S, CodePtr OpPC) { + S.Stk.pop().activate(); + return true; +} + //===----------------------------------------------------------------------===// // AddOffset, SubOffset //===----------------------------------------------------------------------===// @@ -722,9 +928,9 @@ // Fetch the pointer and the offset. const T &Offset = S.Stk.pop(); const Pointer &Ptr = S.Stk.pop(); - if (!CheckNull(S, OpPC, Ptr, CSK_ArrayIndex)) + if (!S.CheckNull(OpPC, Ptr, CSK_ArrayIndex)) return false; - if (!CheckRange(S, OpPC, Ptr, CSK_ArrayToPointer)) + if (!S.CheckRange(OpPC, Ptr, CSK_ArrayToPointer)) return false; // Get a version of the index comparable to the type. @@ -736,7 +942,7 @@ return true; } // Arrays of unknown bounds cannot have pointers into them. - if (!CheckArray(S, OpPC, Ptr)) + if (!S.CheckArray(OpPC, Ptr)) return false; // Compute the largest index into the array. @@ -748,9 +954,8 @@ APSInt APOffset(Offset.toAPSInt().extend(Bits + 2), false); APSInt APIndex(Index.toAPSInt().extend(Bits + 2), false); APSInt NewIndex = Add ? (APIndex + APOffset) : (APIndex - APOffset); - S.CCEDiag(S.Current->getSource(OpPC), diag::note_constexpr_array_index) - << NewIndex - << /*array*/ static_cast(!Ptr.inArray()) + S.CCEDiag(S.getSource(OpPC), diag::note_constexpr_array_index) + << NewIndex << /*array*/ static_cast(!Ptr.inArray()) << static_cast(MaxIndex); return false; }; @@ -786,7 +991,6 @@ return OffsetHelper(S, OpPC); } - //===----------------------------------------------------------------------===// // Destroy //===----------------------------------------------------------------------===// @@ -807,6 +1011,304 @@ return true; } +template +bool CastFP(InterpState &S, CodePtr OpPC, uint32_t Bits) { + using T = typename PrimConv::T; + using U = typename PrimConv::T; + S.Stk.push(U::from(S.Stk.pop(), Bits)); + return true; +} + +template +bool CastRealFP(InterpState &S, CodePtr OpPC, const fltSemantics *Sema) { + using T = typename PrimConv::T; + S.Stk.push(Real::from(S.Stk.pop(), *Sema)); + return true; +} + +template +bool RealToIntHelper(InterpState &S, CodePtr OpPC, const APFloat &Value, + uintptr_t Bits) { + APSInt Result(Bits, !U::isSigned()); + bool Ignored; + + // Conver to int and check for overflow. + if (Value.convertToInteger(Result, llvm::APFloat::rmTowardZero, &Ignored) & + APFloat::opInvalidOp) { + const Expr *E = S.getSource(OpPC).asExpr(); + S.CCEDiag(E, diag::note_constexpr_overflow) << Value << E->getType(); + return S.noteUndefinedBehavior(); + } + + S.Stk.push(Result); + return true; +} + +template +bool CastRealFPToAlu(InterpState &S, CodePtr OpPC) { + using T = typename PrimConv::T; + using U = typename PrimConv::T; + return RealToIntHelper(S, OpPC, S.Stk.pop().toAPFloat(), U::bitWidth()); +} + +template +bool CastRealFPToAluFP(InterpState &S, CodePtr OpPC, uint32_t Bits) { + using T = typename PrimConv::T; + using U = typename PrimConv::T; + return RealToIntHelper(S, OpPC, S.Stk.pop().toAPFloat(), Bits); +} + +//===----------------------------------------------------------------------===// +// CastMemberToBase, CastMemberToDerived +//===----------------------------------------------------------------------===// + +inline bool CastMemberToBase(InterpState &S, CodePtr OpPC, + const CXXRecordDecl *R) { + if (S.Stk.peek().toBase(R)) + return true; + S.FFDiag(S.getSource(OpPC)); + return false; +} + +inline bool CastMemberToDerived(InterpState &S, CodePtr OpPC, + const CXXRecordDecl *R) { + if (S.Stk.peek().toDerived(R)) + return true; + S.FFDiag(S.getSource(OpPC)); + return false; +} + +//===----------------------------------------------------------------------===// +// PointerBitCast, DynCastWarning +//===----------------------------------------------------------------------===// + +inline bool PointerBitCast(InterpState &S, CodePtr OpPC) { + auto *E = S.getSource(OpPC).asExpr(); + if (E->getType()->isVoidPointerType()) + S.CCEDiag(E, diag::note_constexpr_invalid_cast) << 3 << E->getType(); + else + S.CCEDiag(E, diag::note_constexpr_invalid_cast) << 2; + + // TODO: invalidate pointer. + return false; +} + +inline bool ReinterpretCastWarning(InterpState &S, CodePtr OpPC) { + auto *E = S.getSource(OpPC).asExpr(); + S.CCEDiag(E, diag::note_constexpr_invalid_cast) << 0; + return true; +} + +inline bool DynamicCastWarning(InterpState &S, CodePtr OpPC) { + auto *E = S.getSource(OpPC).asExpr(); + S.CCEDiag(E, diag::note_constexpr_invalid_cast) << 1; + return true; +} + +//===----------------------------------------------------------------------===// +// CastToDerived +//===----------------------------------------------------------------------===// + +inline bool CastToDerived(InterpState &S, CodePtr OpPC, const Expr *E) { + QualType Ty = E->getType(); + if (const PointerType *PT = Ty->getAs()) + Ty = PT->getPointeeType(); + + Pointer Base = S.Stk.pop(); + if (!S.CheckNull(OpPC, Base, CSK_Derived)) + return false; + + if (auto *R = S.P.getOrCreateRecord(Ty->getAs()->getDecl())) { + // Traverse the chain of bases (actually derived classes). + while (Base.isBaseClass()) { + // Step to deriving class. + Base = Base.getBase(); + + // Found the derived class. + if (Base.getRecord() == R) { + S.Stk.push(std::move(Base)); + return true; + } + } + } + + // Emit the diagnostic. + S.CCEDiag(E, diag::note_constexpr_invalid_downcast) << Base.getType() << Ty; + return false; +} + +//===----------------------------------------------------------------------===// +// Inc, Dec +//===----------------------------------------------------------------------===// + +template +using IncDecFn = bool (*)(InterpState &, CodePtr, const T &, T &); + +template +bool Post(InterpState &S, CodePtr OpPC, AccessKinds AK, IncDecFn Op) { + // Bail if not C++14. + if (!S.getLangOpts().CPlusPlus14 && !S.keepEvaluatingAfterFailure()) { + S.FFDiag(S.getSource(OpPC)); + return false; + } + + // Get the pointer to the object to mutate. + const Pointer &Ptr = S.Stk.pop(); + if (!S.CheckLoad(OpPC, Ptr, AK) || !S.CheckStore(OpPC, Ptr, AK)) + return false; + + // Return the original value. + const T Arg = Ptr.deref(); + S.Stk.push(Arg); + + // Perform the operation. + return Op(S, OpPC, Arg, Ptr.deref()); +} + +template +bool Pre(InterpState &S, CodePtr OpPC, AccessKinds AK, IncDecFn Op) { + // Bail if not C++14. + if (!S.getLangOpts().CPlusPlus14 && !S.keepEvaluatingAfterFailure()) { + S.FFDiag(S.getSource(OpPC)); + return false; + } + + // Read the value to mutate. + Pointer Ptr = S.Stk.peek(); + if (!S.CheckLoad(OpPC, Ptr, AK) || !S.CheckStore(OpPC, Ptr, AK)) + return false; + + // Perform the operation. + const T Arg = Ptr.deref(); + return Op(S, OpPC, Arg, Ptr.deref()); +} + +template class OpAP> +bool IncDecInt(InterpState &S, CodePtr OpPC, const T &Arg, T &R) { + if (std::is_unsigned::value) { + // Unsigned increment/decrement is defined to wrap around. + R = OpAP()(Arg, T::from(1, Arg.bitWidth())); + return true; + } else { + // Signed increment/decrement is undefined, catch that. + if (!OpFW(Arg, &R)) + return true; + // Compute the actual value for the diagnostic. + const size_t Bits = R.bitWidth() + 1; + APSInt One(APInt(Bits, 1, Arg.isSigned()), !Arg.isSigned()); + APSInt Value = OpAP()(Arg.toAPSInt(Bits), One); + return S.reportOverflow(S.getSource(OpPC).asExpr(), Value); + } +} + +template