diff --git a/clang/lib/AST/Interp/ByteCodeEmitter.cpp b/clang/lib/AST/Interp/ByteCodeEmitter.cpp --- a/clang/lib/AST/Interp/ByteCodeEmitter.cpp +++ b/clang/lib/AST/Interp/ByteCodeEmitter.cpp @@ -57,12 +57,17 @@ Function *Func = P.createFunction(F, ParamOffset, std::move(ParamTypes), std::move(ParamDescriptors)); // Compile the function body. - if (!F->isConstexpr() || !visitFunc(F)) { - // Return a dummy function if compilation failed. - if (BailLocation) + if (!F->isConstexpr()) { + // Return a dummy function for non-constexpr. + return Func; + } else if (!visitFunc(F)) { + if (BailLocation) { + // If compiler bailed, return an error. return llvm::make_error(*BailLocation); - else + } else { + // Otherwise, return a dummy function which is not constexpr. return Func; + } } else { // Create scopes from descriptors. llvm::SmallVector Scopes; diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.h b/clang/lib/AST/Interp/ByteCodeExprGen.h --- a/clang/lib/AST/Interp/ByteCodeExprGen.h +++ b/clang/lib/AST/Interp/ByteCodeExprGen.h @@ -77,6 +77,10 @@ bool VisitUnaryMinus(const UnaryOperator *E); bool VisitCallExpr(const CallExpr *E); + // Fallback methods for nodes which are not yet implemented. + bool VisitStmt(const Stmt *E) { llvm_unreachable("not an expression"); } + bool VisitExpr(const Expr *E) { return this->bail(E); } + protected: bool visitExpr(const Expr *E) override; bool visitDecl(const VarDecl *VD) override; @@ -147,6 +151,8 @@ bool emitFunctionCall(const FunctionDecl *Callee, llvm::Optional T, const Expr *Call); + bool visitAssign(PrimType T, const BinaryOperator *BO); + enum class DerefKind { /// Value is read and pushed to stack. Read, diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp --- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -299,6 +299,8 @@ }; switch (BO->getOpcode()) { + case BO_Assign: + return visitAssign(*T, BO); case BO_EQ: return Discard(this->emitEQ(*LT, *T, BO)); case BO_NE: return Discard(this->emitNE(*LT, *T, BO)); case BO_LT: return Discard(this->emitLT(*LT, *T, BO)); @@ -368,8 +370,10 @@ return false; } } else { - consumeError(Func.takeError()); - return this->bail(E); + handleAllErrors(Func.takeError(), [this](ByteCodeGenError &Err) { + S.FFDiag(Err.getLoc(), diag::err_experimental_clang_interp_failed); + }); + return false; } return DiscardResult && T ? this->emitPop(*T, E) : true; } @@ -406,6 +410,31 @@ } } +template +bool ByteCodeExprGen::visitAssign(PrimType T, + const BinaryOperator *BO) { + return dereference( + BO->getLHS(), DerefKind::Write, + [this, BO](PrimType) { + // Generate a value to store - will be set. + return visit(BO->getRHS()); + }, + [this, BO](PrimType T) { + // Pointer on stack - compile RHS and assign to pointer. + if (!visit(BO->getRHS())) + return false; + + if (BO->getLHS()->refersToBitField()) { + return this->bail(BO); + } else { + if (DiscardResult) + return this->emitStorePop(T, BO); + else + return this->emitStore(T, BO); + } + }); +} + template bool ByteCodeExprGen::dereference( const Expr *LV, DerefKind AK, llvm::function_ref Direct, diff --git a/clang/lib/AST/Interp/ByteCodeStmtGen.h b/clang/lib/AST/Interp/ByteCodeStmtGen.h --- a/clang/lib/AST/Interp/ByteCodeStmtGen.h +++ b/clang/lib/AST/Interp/ByteCodeStmtGen.h @@ -60,8 +60,16 @@ bool visitStmt(const Stmt *S); bool visitCompoundStmt(const CompoundStmt *S); bool visitDeclStmt(const DeclStmt *DS); + bool visitForStmt(const ForStmt *FS); + bool visitWhileStmt(const WhileStmt *DS); + bool visitDoStmt(const DoStmt *DS); bool visitReturnStmt(const ReturnStmt *RS); bool visitIfStmt(const IfStmt *IS); + bool visitBreakStmt(const BreakStmt *BS); + bool visitContinueStmt(const ContinueStmt *CS); + bool visitSwitchStmt(const SwitchStmt *SS); + bool visitCaseStmt(const SwitchCase *CS); + bool visitCXXForRangeStmt(const CXXForRangeStmt *FS); /// Compiles a variable declaration. bool visitVarDecl(const VarDecl *VD); diff --git a/clang/lib/AST/Interp/ByteCodeStmtGen.cpp b/clang/lib/AST/Interp/ByteCodeStmtGen.cpp --- a/clang/lib/AST/Interp/ByteCodeStmtGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeStmtGen.cpp @@ -118,10 +118,27 @@ return visitCompoundStmt(cast(S)); case Stmt::DeclStmtClass: return visitDeclStmt(cast(S)); + case Stmt::ForStmtClass: + return visitForStmt(cast(S)); + case Stmt::WhileStmtClass: + return visitWhileStmt(cast(S)); + case Stmt::DoStmtClass: + return visitDoStmt(cast(S)); case Stmt::ReturnStmtClass: return visitReturnStmt(cast(S)); case Stmt::IfStmtClass: return visitIfStmt(cast(S)); + case Stmt::BreakStmtClass: + return visitBreakStmt(cast(S)); + case Stmt::ContinueStmtClass: + return visitContinueStmt(cast(S)); + case Stmt::SwitchStmtClass: + return visitSwitchStmt(cast(S)); + case Stmt::CaseStmtClass: + case Stmt::DefaultStmtClass: + return visitCaseStmt(cast(S)); + case Stmt::CXXForRangeStmtClass: + return visitCXXForRangeStmt(cast(S)); case Stmt::NullStmtClass: return true; default: { @@ -136,9 +153,10 @@ bool ByteCodeStmtGen::visitCompoundStmt( const CompoundStmt *CompoundStmt) { BlockScope Scope(this); - for (auto *InnerStmt : CompoundStmt->body()) + for (auto *InnerStmt : CompoundStmt->body()) { if (!visitStmt(InnerStmt)) return false; + } return true; } @@ -161,6 +179,114 @@ return true; } +template +bool ByteCodeStmtGen::visitForStmt(const ForStmt *FS) { + // Compile the initialisation statement in an outer scope. + BlockScope OuterScope(this); + if (auto *Init = FS->getInit()) + if (!visitStmt(Init)) + return false; + + LabelTy LabelStart = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + + // Compile the condition, body and increment in the loop scope. + this->emitLabel(LabelStart); + { + BlockScope InnerScope(this); + + if (auto *Cond = FS->getCond()) { + if (auto *CondDecl = FS->getConditionVariableDeclStmt()) + if (!visitDeclStmt(CondDecl)) + return false; + + if (!this->visit(Cond)) + return false; + + if (!this->jumpFalse(LabelEnd)) + return false; + } + + if (auto *Body = FS->getBody()) { + LabelTy LabelSkip = this->getLabel(); + LoopScope FlowScope(this, LabelEnd, LabelSkip); + if (!visitStmt(Body)) + return false; + this->emitLabel(LabelSkip); + } + + if (auto *Inc = FS->getInc()) { + ExprScope IncScope(this); + if (!this->discard(Inc)) + return false; + } + if (!this->jump(LabelStart)) + return false; + } + this->emitLabel(LabelEnd); + return true; +} + +template +bool ByteCodeStmtGen::visitWhileStmt(const WhileStmt *WS) { + LabelTy LabelStart = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + + this->emitLabel(LabelStart); + { + BlockScope InnerScope(this); + if (auto *CondDecl = WS->getConditionVariableDeclStmt()) + if (!visitDeclStmt(CondDecl)) + return false; + + if (!this->visit(WS->getCond())) + return false; + + if (!this->jumpFalse(LabelEnd)) + return false; + + { + LoopScope FlowScope(this, LabelEnd, LabelStart); + if (!visitStmt(WS->getBody())) + return false; + } + if (!this->jump(LabelStart)) + return false; + } + this->emitLabel(LabelEnd); + + return true; +} + +template +bool ByteCodeStmtGen::visitDoStmt(const DoStmt *DS) { + LabelTy LabelStart = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + LabelTy LabelSkip = this->getLabel(); + + this->emitLabel(LabelStart); + { + { + LoopScope FlowScope(this, LabelEnd, LabelSkip); + if (!visitStmt(DS->getBody())) + return false; + this->emitLabel(LabelSkip); + } + + { + ExprScope CondScope(this); + if (!this->visitBool(DS->getCond())) + return false; + } + + if (!this->jumpTrue(LabelStart)) + return false; + } + this->emitLabel(LabelEnd); + + return true; +} + template bool ByteCodeStmtGen::visitReturnStmt(const ReturnStmt *RS) { if (const Expr *RE = RS->getRetValue()) { @@ -222,6 +348,167 @@ return true; } +template +bool ByteCodeStmtGen::visitBreakStmt(const BreakStmt *BS) { + if (!BreakLabel) + return this->bail(BS); + return this->jump(*BreakLabel); +} + +template +bool ByteCodeStmtGen::visitContinueStmt(const ContinueStmt *CS) { + if (!ContinueLabel) + return this->bail(CS); + return this->jump(*ContinueLabel); +} + +template +bool ByteCodeStmtGen::visitSwitchStmt(const SwitchStmt *SS) { + BlockScope InnerScope(this); + + if (Optional T = this->classify(SS->getCond()->getType())) { + // The condition is stored in a local and fetched for every test. + unsigned Off = this->allocateLocalPrimitive(SS->getCond(), *T, + /*isConst=*/true); + + // Compile the condition in its own scope. + { + ExprScope CondScope(this); + if (const Stmt *CondInit = SS->getInit()) + if (!visitStmt(SS->getInit())) + return false; + + if (const DeclStmt *CondDecl = SS->getConditionVariableDeclStmt()) + if (!visitDeclStmt(CondDecl)) + return false; + + if (!this->visit(SS->getCond())) + return false; + + if (!this->emitSetLocal(*T, Off, SS->getCond())) + return false; + } + + LabelTy LabelEnd = this->getLabel(); + + // Generate code to inspect all case labels, jumping to the matched one. + const DefaultStmt *Default = nullptr; + CaseMap Labels; + for (auto *SC = SS->getSwitchCaseList(); SC; SC = SC->getNextSwitchCase()) { + LabelTy Label = this->getLabel(); + Labels.insert({SC, Label}); + + if (auto *DS = dyn_cast(SC)) { + Default = DS; + continue; + } + + if (auto *CS = dyn_cast(SC)) { + if (!this->emitGetLocal(*T, Off, CS)) + return false; + if (!this->visit(CS->getLHS())) + return false; + + if (auto *RHS = CS->getRHS()) { + if (!this->visit(CS->getRHS())) + return false; + if (!this->emitInRange(*T, CS)) + return false; + } else { + if (!this->emitEQ(*T, PT_Bool, CS)) + return false; + } + + if (!this->jumpTrue(Label)) + return false; + continue; + } + + return this->bail(SS); + } + + // If a case wasn't matched, jump to default or skip the body. + if (!this->jump(Default ? Labels[Default] : LabelEnd)) + return false; + OptLabelTy DefaultLabel = Default ? Labels[Default] : OptLabelTy{}; + + // Compile the body, using labels defined previously. + SwitchScope LabelScope(this, std::move(Labels), LabelEnd, + DefaultLabel); + if (!visitStmt(SS->getBody())) + return false; + this->emitLabel(LabelEnd); + return true; + } else { + return this->bail(SS); + } +} + +template +bool ByteCodeStmtGen::visitCaseStmt(const SwitchCase *CS) { + auto It = CaseLabels.find(CS); + if (It == CaseLabels.end()) + return this->bail(CS); + + this->emitLabel(It->second); + return visitStmt(CS->getSubStmt()); +} + +template +bool ByteCodeStmtGen::visitCXXForRangeStmt(const CXXForRangeStmt *FS) { + BlockScope Scope(this); + + // Emit the optional init-statement. + if (auto *Init = FS->getInit()) { + if (!visitStmt(Init)) + return false; + } + + // Initialise the __range variable. + if (!visitStmt(FS->getRangeStmt())) + return false; + + // Create the __begin and __end iterators. + if (!visitStmt(FS->getBeginStmt()) || !visitStmt(FS->getEndStmt())) + return false; + + LabelTy LabelStart = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + + this->emitLabel(LabelStart); + { + // Lower the condition. + if (!this->visitBool(FS->getCond())) + return false; + if (!this->jumpFalse(LabelEnd)) + return false; + + // Lower the loop var and body, marking labels for continue/break. + { + BlockScope InnerScope(this); + if (!visitStmt(FS->getLoopVarStmt())) + return false; + + LabelTy LabelSkip = this->getLabel(); + { + LoopScope FlowScope(this, LabelEnd, LabelSkip); + + if (!visitStmt(FS->getBody())) + return false; + } + this->emitLabel(LabelSkip); + } + + // Increment: ++__begin + if (!visitStmt(FS->getInc())) + return false; + if (!this->jump(LabelStart)) + return false; + } + this->emitLabel(LabelEnd); + return true; +} + template bool ByteCodeStmtGen::visitVarDecl(const VarDecl *VD) { auto DT = VD->getType(); diff --git a/clang/lib/AST/Interp/Context.h b/clang/lib/AST/Interp/Context.h --- a/clang/lib/AST/Interp/Context.h +++ b/clang/lib/AST/Interp/Context.h @@ -70,13 +70,6 @@ /// Classifies an expression. llvm::Optional classify(QualType T); -private: - /// Runs a function. - bool Run(State &Parent, Function *Func, APValue &Result); - - /// Checks a result fromt the interpreter. - bool Check(State &Parent, llvm::Expected &&R); - private: /// Current compilation context. ASTContext &Ctx; diff --git a/clang/lib/AST/Interp/Context.cpp b/clang/lib/AST/Interp/Context.cpp --- a/clang/lib/AST/Interp/Context.cpp +++ b/clang/lib/AST/Interp/Context.cpp @@ -27,6 +27,8 @@ Context::~Context() {} bool Context::isPotentialConstantExpr(State &Parent, const FunctionDecl *FD) { + // Try to compile the function. This either produces an error message (if this + // is the first attempt to compile) or returns a dummy function with no body. Function *Func = P->getFunction(FD); if (!Func) { ByteCodeStmtGen C(*this, *P, Parent); @@ -40,22 +42,43 @@ } } + // If function has no body, it is definitely not constexpr. if (!Func->isConstexpr()) return false; - APValue Dummy; - return Run(Parent, Func, Dummy); + // Run the function in a dummy context. + APValue DummyResult; + InterpState State(Parent, *P, Stk, *this); + State.Current = new InterpFrame(State, Func, nullptr, {}, {}); + if (Interpret(State, DummyResult)) + return true; + Stk.clear(); + return false; } bool Context::evaluateAsRValue(State &Parent, const Expr *E, APValue &Result) { ByteCodeExprGen C(*this, *P, Parent, Stk, Result); - return Check(Parent, C.interpretExpr(E)); + if (auto Flag = C.interpretExpr(E)) { + return *Flag; + } else { + handleAllErrors(Flag.takeError(), [&Parent](ByteCodeGenError &Err) { + Parent.FFDiag(Err.getLoc(), diag::err_experimental_clang_interp_failed); + }); + return false; + } } bool Context::evaluateAsInitializer(State &Parent, const VarDecl *VD, APValue &Result) { ByteCodeExprGen C(*this, *P, Parent, Stk, Result); - return Check(Parent, C.interpretDecl(VD)); + if (auto Flag = C.interpretDecl(VD)) { + return *Flag; + } else { + handleAllErrors(Flag.takeError(), [&Parent](ByteCodeGenError &Err) { + Parent.FFDiag(Err.getLoc(), diag::err_experimental_clang_interp_failed); + }); + return false; + } } const LangOptions &Context::getLangOpts() const { return Ctx.getLangOpts(); } @@ -115,21 +138,3 @@ unsigned Context::getCharBit() const { return Ctx.getTargetInfo().getCharWidth(); } - -bool Context::Run(State &Parent, Function *Func, APValue &Result) { - InterpState State(Parent, *P, Stk, *this); - State.Current = new InterpFrame(State, Func, nullptr, {}, {}); - if (Interpret(State, Result)) - return true; - Stk.clear(); - return false; -} - -bool Context::Check(State &Parent, llvm::Expected &&Flag) { - if (Flag) - return *Flag; - handleAllErrors(Flag.takeError(), [&Parent](ByteCodeGenError &Err) { - Parent.FFDiag(Err.getLoc(), diag::err_experimental_clang_interp_failed); - }); - return false; -} diff --git a/clang/lib/AST/Interp/EvalEmitter.cpp b/clang/lib/AST/Interp/EvalEmitter.cpp --- a/clang/lib/AST/Interp/EvalEmitter.cpp +++ b/clang/lib/AST/Interp/EvalEmitter.cpp @@ -177,6 +177,8 @@ return false; if (S.checkingPotentialConstantExpression()) return false; + if (!F->isConstexpr()) + return false; S.Current = new InterpFrame(S, F, S.Current, OpPC, std::move(This)); return Interpret(S, Result); } diff --git a/clang/lib/AST/Interp/InterpLoop.cpp b/clang/lib/AST/Interp/InterpLoop.cpp --- a/clang/lib/AST/Interp/InterpLoop.cpp +++ b/clang/lib/AST/Interp/InterpLoop.cpp @@ -108,6 +108,8 @@ return false; if (S.checkingPotentialConstantExpression()) return false; + if (!F->isConstexpr()) + return false; // Adjust the state. S.CallStackDepth++; diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td --- a/clang/lib/AST/Interp/Opcodes.td +++ b/clang/lib/AST/Interp/Opcodes.td @@ -75,7 +75,11 @@ } def FPTypeClass : TypeClass { - let Types = IntTypeClass.Types; + let Types = []; +} + +def IntFPTypeClass : TypeClass { + let Types = !listconcat(IntTypeClass.Types, FPTypeClass.Types); } def AluFPRealFPTypeClass : TypeClass { @@ -334,6 +338,16 @@ def GT : ComparisonOpcode; def GE : ComparisonOpcode; +//===----------------------------------------------------------------------===// +// Range test. +//===----------------------------------------------------------------------===// + +// [Real, Real, Real] -> [Bool] +def InRange : Opcode { + let Types = [IntFPTypeClass]; + let HasGroup = 1; +} + //===----------------------------------------------------------------------===// // Stack management. //===----------------------------------------------------------------------===// diff --git a/clang/lib/AST/Interp/Opcodes/Comparison.h b/clang/lib/AST/Interp/Opcodes/Comparison.h --- a/clang/lib/AST/Interp/Opcodes/Comparison.h +++ b/clang/lib/AST/Interp/Opcodes/Comparison.h @@ -148,5 +148,16 @@ return false; } +template +bool InRange(InterpState &S, CodePtr OpPC) { + using T = typename PrimConv::T; + const T &RHS = S.Stk.pop(); + const T &LHS = S.Stk.pop(); + const T &Value = S.Stk.pop(); + + S.Stk.push(LHS <= Value && Value <= RHS); + return true; +} + #endif