diff --git a/flang/lib/Semantics/check-case.cpp b/flang/lib/Semantics/check-case.cpp index 3f1586dfd8fd..5aee11e106cf 100644 --- a/flang/lib/Semantics/check-case.cpp +++ b/flang/lib/Semantics/check-case.cpp @@ -1,252 +1,254 @@ //===-- lib/Semantics/check-case.cpp --------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "check-case.h" #include "flang/Common/idioms.h" #include "flang/Common/reference.h" +#include "flang/Common/template.h" #include "flang/Evaluate/fold.h" #include "flang/Evaluate/type.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/semantics.h" #include "flang/Semantics/tools.h" #include namespace Fortran::semantics { template class CaseValues { public: CaseValues(SemanticsContext &c, const evaluate::DynamicType &t) : context_{c}, caseExprType_{t} {} void Check(const std::list &cases) { for (const parser::CaseConstruct::Case &c : cases) { AddCase(c); } if (!hasErrors_) { cases_.sort(Comparator{}); if (!AreCasesDisjoint()) { // C1149 ReportConflictingCases(); } } } private: using Value = evaluate::Scalar; void AddCase(const parser::CaseConstruct::Case &c) { const auto &stmt{std::get>(c.t)}; const parser::CaseStmt &caseStmt{stmt.statement}; const auto &selector{std::get(caseStmt.t)}; std::visit( common::visitors{ [&](const std::list &ranges) { for (const auto &range : ranges) { auto pair{ComputeBounds(range)}; if (pair.first && pair.second && *pair.first > *pair.second) { context_.Say(stmt.source, "CASE has lower bound greater than upper bound"_en_US); } else { if constexpr (T::category == TypeCategory::Logical) { // C1148 if ((pair.first || pair.second) && (!pair.first || !pair.second || *pair.first != *pair.second)) { context_.Say(stmt.source, "CASE range is not allowed for LOGICAL"_err_en_US); } } cases_.emplace_back(stmt); cases_.back().lower = std::move(pair.first); cases_.back().upper = std::move(pair.second); } } }, [&](const parser::Default &) { cases_.emplace_front(stmt); }, }, selector.u); } std::optional GetValue(const parser::CaseValue &caseValue) { const parser::Expr &expr{caseValue.thing.thing.value()}; auto *x{expr.typedExpr.get()}; if (x && x->v) { // C1147 auto type{x->v->GetType()}; if (type && type->category() == caseExprType_.category() && (type->category() != TypeCategory::Character || type->kind() == caseExprType_.kind())) { x->v = evaluate::Fold(context_.foldingContext(), evaluate::ConvertToType(T::GetType(), std::move(*x->v))); if (x->v) { if (auto value{evaluate::GetScalarConstantValue(*x->v)}) { return *value; } } context_.Say( expr.source, "CASE value must be a constant scalar"_err_en_US); } else { std::string typeStr{type ? type->AsFortran() : "typeless"s}; context_.Say(expr.source, "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US, typeStr, caseExprType_.AsFortran()); } hasErrors_ = true; } return std::nullopt; } using PairOfValues = std::pair, std::optional>; PairOfValues ComputeBounds(const parser::CaseValueRange &range) { return std::visit(common::visitors{ [&](const parser::CaseValue &x) { auto value{GetValue(x)}; return PairOfValues{value, value}; }, [&](const parser::CaseValueRange::Range &x) { std::optional lo, hi; if (x.lower) { lo = GetValue(*x.lower); } if (x.upper) { hi = GetValue(*x.upper); } if ((x.lower && !lo) || (x.upper && !hi)) { return PairOfValues{}; // error case } return PairOfValues{std::move(lo), std::move(hi)}; }, }, range.u); } struct Case { explicit Case(const parser::Statement &s) : stmt{s} {} bool IsDefault() const { return !lower && !upper; } std::string AsFortran() const { std::string result; { llvm::raw_string_ostream bs{result}; if (lower) { evaluate::Constant{*lower}.AsFortran(bs << '('); if (!upper) { bs << ':'; } else if (*lower != *upper) { evaluate::Constant{*upper}.AsFortran(bs << ':'); } bs << ')'; } else if (upper) { evaluate::Constant{*upper}.AsFortran(bs << "(:") << ')'; } else { bs << "DEFAULT"; } } return result; } const parser::Statement &stmt; std::optional lower, upper; }; // Defines a comparator for use with std::list<>::sort(). // Returns true if and only if the highest value in range x is less // than the least value in range y. The DEFAULT case is arbitrarily // defined to be less than all others. When two ranges overlap, // neither is less than the other. struct Comparator { bool operator()(const Case &x, const Case &y) const { if (x.IsDefault()) { return !y.IsDefault(); } else { return x.upper && y.lower && *x.upper < *y.lower; } } }; bool AreCasesDisjoint() const { auto endIter{cases_.end()}; for (auto iter{cases_.begin()}; iter != endIter; ++iter) { auto next{iter}; if (++next != endIter && !Comparator{}(*iter, *next)) { return false; } } return true; } // This has quadratic time, but only runs in error cases void ReportConflictingCases() { for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) { parser::Message *msg{nullptr}; for (auto p{cases_.begin()}; p != cases_.end(); ++p) { if (p->stmt.source.begin() < iter->stmt.source.begin() && !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) { if (!msg) { msg = &context_.Say(iter->stmt.source, "CASE %s conflicts with previous cases"_err_en_US, iter->AsFortran()); } msg->Attach( p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran()); } } } } SemanticsContext &context_; const evaluate::DynamicType &caseExprType_; std::list cases_; bool hasErrors_{false}; }; +template struct TypeVisitor { + using Result = bool; + using Types = evaluate::CategoryTypes; + template Result Test() { + if (T::kind == exprType.kind()) { + CaseValues(context, exprType).Check(caseList); + return true; + } else { + return false; + } + } + SemanticsContext &context; + const evaluate::DynamicType &exprType; + const std::list &caseList; +}; + void CaseChecker::Enter(const parser::CaseConstruct &construct) { const auto &selectCaseStmt{ std::get>(construct.t)}; const auto &selectCase{selectCaseStmt.statement}; const auto &selectExpr{ std::get>(selectCase.t).thing}; const auto *x{GetExpr(selectExpr)}; if (!x) { return; // expression semantics failed } if (auto exprType{x->GetType()}) { const auto &caseList{ std::get>(construct.t)}; switch (exprType->category()) { case TypeCategory::Integer: - CaseValues>{context_, *exprType} - .Check(caseList); + common::SearchTypes( + TypeVisitor{context_, *exprType, caseList}); return; case TypeCategory::Logical: CaseValues>{context_, *exprType} .Check(caseList); return; case TypeCategory::Character: - switch (exprType->kind()) { - SWITCH_COVERS_ALL_CASES - case 1: - CaseValues>{ - context_, *exprType} - .Check(caseList); - return; - case 2: - CaseValues>{ - context_, *exprType} - .Check(caseList); - return; - case 4: - CaseValues>{ - context_, *exprType} - .Check(caseList); - return; - } + common::SearchTypes( + TypeVisitor{context_, *exprType, caseList}); + return; default: break; } } context_.Say(selectExpr.source, "SELECT CASE expression must be integer, logical, or character"_err_en_US); } } // namespace Fortran::semantics diff --git a/flang/test/Semantics/case01.f90 b/flang/test/Semantics/case01.f90 index fc85ee7091b6..e1965db573b6 100644 --- a/flang/test/Semantics/case01.f90 +++ b/flang/test/Semantics/case01.f90 @@ -1,165 +1,165 @@ ! RUN: %S/test_errors.sh %s %t %f18 ! Test SELECT CASE Constraints: C1145, C1146, C1147, C1148, C1149 program selectCaseProg implicit none ! local variable declaration character :: grade1 = 'B' integer :: grade2 = 3 logical :: grade3 = .false. real :: grade4 = 2.0 character (len = 10) :: name = 'test' logical, parameter :: grade5 = .false. CHARACTER(KIND=1), parameter :: ASCII_parm1 = 'a', ASCII_parm2='b' CHARACTER(KIND=2), parameter :: UCS16_parm = 'c' CHARACTER(KIND=4), parameter :: UCS32_parm ='d' type scores integer :: val end type type (scores) :: score = scores(25) type (scores), parameter :: score_val = scores(50) ! Valid Cases select case (grade1) case ('A') case ('B') case ('C') case default end select select case (grade2) case (1) case (2) case (3) case default end select select case (grade3) case (.true.) case (.false.) end select select case (name) case default case ('now') case ('test') end select ! C1145 !ERROR: SELECT CASE expression must be integer, logical, or character select case (grade4) case (1.0) case (2.0) case (3.0) case default end select !ERROR: SELECT CASE expression must be integer, logical, or character select case (score) case (score_val) case (scores(100)) end select ! C1146 select case (grade3) case default case (.true.) !ERROR: CASE DEFAULT conflicts with previous cases case default end select ! C1147 select case (grade2) !ERROR: CASE value has type 'CHARACTER(1)' which is not compatible with the SELECT CASE expression's type 'INTEGER(4)' case (:'Z') case default end select select case (grade1) !ERROR: CASE value has type 'INTEGER(4)' which is not compatible with the SELECT CASE expression's type 'CHARACTER(KIND=1,LEN=1_8)' case (:1) case default end select select case (grade3) case default case (.true.) !ERROR: CASE value has type 'INTEGER(4)' which is not compatible with the SELECT CASE expression's type 'LOGICAL(4)' case (3) end select select case (grade2) case default case (2 :) !ERROR: CASE value has type 'LOGICAL(4)' which is not compatible with the SELECT CASE expression's type 'INTEGER(4)' case (.true. :) !ERROR: CASE value has type 'REAL(4)' which is not compatible with the SELECT CASE expression's type 'INTEGER(4)' case (1.0) !ERROR: CASE value has type 'CHARACTER(1)' which is not compatible with the SELECT CASE expression's type 'INTEGER(4)' case ('wow') end select select case (ASCII_parm1) case (ASCII_parm2) !ERROR: CASE value has type 'CHARACTER(4)' which is not compatible with the SELECT CASE expression's type 'CHARACTER(1)' case (UCS32_parm) !ERROR: CASE value has type 'CHARACTER(2)' which is not compatible with the SELECT CASE expression's type 'CHARACTER(1)' case (UCS16_parm) !ERROR: CASE value has type 'CHARACTER(4)' which is not compatible with the SELECT CASE expression's type 'CHARACTER(1)' case (4_"ucs-32") !ERROR: CASE value has type 'CHARACTER(2)' which is not compatible with the SELECT CASE expression's type 'CHARACTER(1)' case (2_"ucs-16") case default end select ! C1148 select case (grade3) case default !ERROR: CASE range is not allowed for LOGICAL case (.true. :) end select ! C1149 select case (grade3) case (.true.) case (.false.) !ERROR: CASE (.true._1) conflicts with previous cases case (.true.) !ERROR: CASE (.false._1) conflicts with previous cases case (grade5) end select select case (grade2) case (51:50) ! warning case (100:) case (:30) case (40) case (90) case (91:99) - !ERROR: CASE (81_16:90_16) conflicts with previous cases + !ERROR: CASE (81_4:90_4) conflicts with previous cases case (81:90) - !ERROR: CASE (:80_16) conflicts with previous cases + !ERROR: CASE (:80_4) conflicts with previous cases case (:80) - !ERROR: CASE (200_16) conflicts with previous cases + !ERROR: CASE (200_4) conflicts with previous cases case (200) case default end select select case (name) case ('hello') case ('hey') !ERROR: CASE (:"hh") conflicts with previous cases case (:'hh') !ERROR: CASE (:"hd") conflicts with previous cases case (:'hd') case ( 'hu':) case ('hi':'ho') !ERROR: CASE ("hj") conflicts with previous cases case ('hj') !ERROR: CASE ("ha") conflicts with previous cases case ('ha') !ERROR: CASE ("hz") conflicts with previous cases case ('hz') case default end select end program