diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h --- a/flang/lib/Evaluate/fold-reduction.h +++ b/flang/lib/Evaluate/fold-reduction.h @@ -15,7 +15,10 @@ namespace Fortran::evaluate { -// Fold and validate a DIM= argument. Returns false on error. +// Fold and validate a DIM= argument. Returns true (with &dim empty) +// when DIM= is not present or (with &dim set) when DIM= is present, constant, +// and valid. Returns false, possibly with an error message, when +// DIM= is present but either not constant or not valid. bool CheckReductionDIM(std::optional &dim, FoldingContext &, ActualArguments &, std::optional dimIndex, int rank); diff --git a/flang/lib/Evaluate/fold-reduction.cpp b/flang/lib/Evaluate/fold-reduction.cpp --- a/flang/lib/Evaluate/fold-reduction.cpp +++ b/flang/lib/Evaluate/fold-reduction.cpp @@ -11,23 +11,26 @@ namespace Fortran::evaluate { bool CheckReductionDIM(std::optional &dim, FoldingContext &context, ActualArguments &arg, std::optional dimIndex, int rank) { - if (dimIndex && static_cast(*dimIndex) < arg.size()) { - if (auto *dimConst{ - Folder{context}.Folding(arg[*dimIndex])}) { - if (auto dimScalar{dimConst->GetScalarValue()}) { - auto dimVal{dimScalar->ToInt64()}; - if (dimVal >= 1 && dimVal <= rank) { - dim = dimVal; - } else { - context.messages().Say( - "DIM=%jd is not valid for an array of rank %d"_err_en_US, - static_cast(dimVal), rank); - return false; - } + if (!dimIndex || static_cast(*dimIndex) >= arg.size() || + !arg[*dimIndex]) { + dim.reset(); + return true; // no DIM= argument + } + if (auto *dimConst{ + Folder{context}.Folding(arg[*dimIndex])}) { + if (auto dimScalar{dimConst->GetScalarValue()}) { + auto dimVal{dimScalar->ToInt64()}; + if (dimVal >= 1 && dimVal <= rank) { + dim = dimVal; + return true; // DIM= exists and is a valid constant + } else { + context.messages().Say( + "DIM=%jd is not valid for an array of rank %d"_err_en_US, + static_cast(dimVal), rank); } } } - return true; + return false; // DIM= bad or not scalar constant } Constant *GetReductionMASK(