diff --git a/clang/lib/AST/Interp/ByteCodeStmtGen.h b/clang/lib/AST/Interp/ByteCodeStmtGen.h --- a/clang/lib/AST/Interp/ByteCodeStmtGen.h +++ b/clang/lib/AST/Interp/ByteCodeStmtGen.h @@ -62,6 +62,9 @@ bool visitForStmt(const ForStmt *S); bool visitBreakStmt(const BreakStmt *S); bool visitContinueStmt(const ContinueStmt *S); + bool visitSwitchStmt(const SwitchStmt *S); + bool visitCaseStmt(const CaseStmt *S); + bool visitDefaultStmt(const DefaultStmt *S); /// Type of the expression returned by the function. std::optional ReturnType; diff --git a/clang/lib/AST/Interp/ByteCodeStmtGen.cpp b/clang/lib/AST/Interp/ByteCodeStmtGen.cpp --- a/clang/lib/AST/Interp/ByteCodeStmtGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeStmtGen.cpp @@ -182,6 +182,12 @@ return visitBreakStmt(cast(S)); case Stmt::ContinueStmtClass: return visitContinueStmt(cast(S)); + case Stmt::SwitchStmtClass: + return visitSwitchStmt(cast(S)); + case Stmt::CaseStmtClass: + return visitCaseStmt(cast(S)); + case Stmt::DefaultStmtClass: + return visitDefaultStmt(cast(S)); case Stmt::NullStmtClass: return true; default: { @@ -391,6 +397,84 @@ return this->jump(*ContinueLabel); } +template +bool ByteCodeStmtGen::visitSwitchStmt(const SwitchStmt *S) { + const Expr *Cond = S->getCond(); + PrimType CondT = this->classifyPrim(Cond->getType()); + + LabelTy EndLabel = this->getLabel(); + OptLabelTy DefaultLabel = std::nullopt; + unsigned CondVar = this->allocateLocalPrimitive(Cond, CondT, true, false); + + if (const auto *CondInit = S->getInit()) + if (!visitStmt(CondInit)) + return false; + + // Initialize condition variable. + if (!this->visit(Cond)) + return false; + if (!this->emitSetLocal(CondT, CondVar, S)) + return false; + + CaseMap CaseLabels; + // Create labels and comparison ops for all case statements. + for (const SwitchCase *SC = S->getSwitchCaseList(); SC; + SC = SC->getNextSwitchCase()) { + if (const auto *CS = dyn_cast(SC)) { + // FIXME: Implement ranges. + if (CS->caseStmtIsGNURange()) + return false; + CaseLabels[SC] = this->getLabel(); + + const Expr *Value = CS->getLHS(); + PrimType ValueT = this->classifyPrim(Value->getType()); + + // Compare the case statment's value to the switch condition. + if (!this->emitGetLocal(CondT, CondVar, CS)) + return false; + if (!this->visit(Value)) + return false; + + // Compare and jump to the case label. + if (!this->emitEQ(ValueT, S)) + return false; + if (!this->jumpTrue(CaseLabels[CS])) + return false; + } else { + assert(!DefaultLabel); + DefaultLabel = this->getLabel(); + } + } + + // If none of the conditions above were true, fall through to the default + // statement or jump after the switch statement. + if (DefaultLabel) { + if (!this->jump(*DefaultLabel)) + return false; + } else { + if (!this->jump(EndLabel)) + return false; + } + + SwitchScope SS(this, std::move(CaseLabels), EndLabel, DefaultLabel); + if (!this->visitStmt(S->getBody())) + return false; + this->emitLabel(EndLabel); + return true; +} + +template +bool ByteCodeStmtGen::visitCaseStmt(const CaseStmt *S) { + this->emitLabel(CaseLabels[S]); + return this->visitStmt(S->getSubStmt()); +} + +template +bool ByteCodeStmtGen::visitDefaultStmt(const DefaultStmt *S) { + this->emitLabel(*DefaultLabel); + return this->visitStmt(S->getSubStmt()); +} + namespace clang { namespace interp { diff --git a/clang/test/AST/Interp/switch.cpp b/clang/test/AST/Interp/switch.cpp new file mode 100644 --- /dev/null +++ b/clang/test/AST/Interp/switch.cpp @@ -0,0 +1,94 @@ +// RUN: %clang_cc1 -fexperimental-new-constant-interpreter -verify %s +// RUN: %clang_cc1 -verify=ref %s + +constexpr bool isEven(int a) { + bool v = false; + switch(a) { + case 2: return true; + case 4: return true; + case 6: return true; + + case 8: + case 10: + case 12: + case 14: + case 16: + return true; + case 18: + v = true; + break; + + default: + switch(a) { + case 1: + break; + case 3: + return false; + default: + break; + } + } + + return v; +} +static_assert(isEven(2), ""); +static_assert(isEven(8), ""); +static_assert(isEven(10), ""); +static_assert(isEven(18), ""); +static_assert(!isEven(1), ""); +static_assert(!isEven(3), ""); + + +constexpr int withInit() { + switch(int a = 2; a) { + case 1: return -1; + case 2: return 2; + } + return -1; +} +static_assert(withInit() == 2, ""); + +constexpr int FT(int a) { + int m = 0; + switch(a) { + case 4: m++; + case 3: m++; + case 2: m++; + case 1: m++; + return m; + } + + return -1; +} +static_assert(FT(1) == 1, ""); +static_assert(FT(4) == 4, ""); +static_assert(FT(5) == -1, ""); + + +constexpr int good() { return 1; } +constexpr int test(int val) { + switch (val) { + case good(): return 100; + default: return -1; + } + return 0; +} +static_assert(test(1) == 100, ""); + +constexpr int bad(int val) { return val / 0; } // expected-warning {{division by zero}} \ + // ref-warning {{division by zero}} +constexpr int another_test(int val) { // expected-note {{declared here}} \ + // ref-note {{declared here}} + switch (val) { + case bad(val): return 100; // expected-error {{case value is not a constant expression}} \ + // expected-note {{cannot be used in a constant expression}} \ + // ref-error {{case value is not a constant expression}} \ + // ref-note {{cannot be used in a constant expression}} + default: return -1; + } + return 0; +} +static_assert(another_test(1) == 100, ""); // expected-error {{static assertion failed}} \ + // expected-note {{evaluates to}} \ + // ref-error {{static assertion failed}} \ + // ref-note {{evaluates to}}