Index: flang/lib/Evaluate/CMakeLists.txt =================================================================== --- flang/lib/Evaluate/CMakeLists.txt +++ flang/lib/Evaluate/CMakeLists.txt @@ -24,6 +24,7 @@ fold-integer.cpp fold-logical.cpp fold-real.cpp + fold-reduction.cpp formatting.cpp host.cpp initial-image.cpp Index: flang/lib/Evaluate/fold-implementation.h =================================================================== --- flang/lib/Evaluate/fold-implementation.h +++ flang/lib/Evaluate/fold-implementation.h @@ -492,7 +492,7 @@ // Build and return constant result if constexpr (TR::category == TypeCategory::Character) { auto len{static_cast( - results.size() ? results[0].length() : 0)}; + results.empty() ? 0 : results[0].length())}; return Expr{Constant{len, std::move(results), std::move(shape)}}; } else { return Expr{Constant{std::move(results), std::move(shape)}}; @@ -944,7 +944,7 @@ if (constantArgs.size() != funcRef.arguments().size()) { return Expr(std::move(funcRef)); } - CHECK(constantArgs.size() > 0); + CHECK(!constantArgs.empty()); Expr result{std::move(*constantArgs[0])}; for (std::size_t i{1}; i < constantArgs.size(); ++i) { Extremum extremum{order, result, Expr{std::move(*constantArgs[i])}}; @@ -1075,7 +1075,7 @@ Expr folded{Fold(context_, common::Clone(expr.value()))}; if (const auto *c{UnwrapConstantValue(folded)}) { // Copy elements in Fortran array element order - if (c->size() > 0) { + if (!c->empty()) { ConstantSubscripts index{c->lbounds()}; do { elements_.emplace_back(c->At(index)); @@ -1156,7 +1156,7 @@ std::optional> AsFlatArrayConstructor(const Expr &expr) { if (const auto *c{UnwrapConstantValue(expr)}) { ArrayConstructor result{expr}; - if (c->size() > 0) { + if (!c->empty()) { ConstantSubscripts at{c->lbounds()}; do { result.Push(Expr{Constant{c->At(at)}}); Index: flang/lib/Evaluate/fold-integer.cpp =================================================================== --- flang/lib/Evaluate/fold-integer.cpp +++ flang/lib/Evaluate/fold-integer.cpp @@ -174,21 +174,47 @@ return Expr{std::move(funcRef)}; } +// COUNT() +template +static Expr FoldCount(FoldingContext &context, FunctionRef &&ref) { + static_assert(T::category == TypeCategory::Integer); + ActualArguments &arg{ref.arguments()}; + if (const Constant *mask{arg.empty() + ? nullptr + : Folder{context}.Folding(arg[0])}) { + std::optional dim; + if (arg.size() > 1 && arg[1]) { + dim = CheckDIM(context, arg[1], mask->Rank()); + if (!dim) { + mask = nullptr; + } + } + if (mask) { + auto accumulator{[&](Scalar &element, const ConstantSubscripts &at) { + if (mask->At(at).IsTrue()) { + element = element.AddSigned(Scalar{1}).value; + } + }}; + return Expr{DoReduction(*mask, dim, Scalar{}, accumulator)}; + } + } + return Expr{std::move(ref)}; +} + // for IALL, IANY, & IPARITY template static Expr FoldBitReduction(FoldingContext &context, FunctionRef &&ref, Scalar (Scalar::*operation)(const Scalar &) const, Scalar identity) { static_assert(T::category == TypeCategory::Integer); - using Element = Scalar; std::optional dim; if (std::optional> array{ ProcessReductionArgs(context, ref.arguments(), dim, identity, /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { - auto accumulator{[&](Element &element, const ConstantSubscripts &at) { + auto accumulator{[&](Scalar &element, const ConstantSubscripts &at) { element = (element.*operation)(array->At(at)); }}; - return Expr{DoReduction(*array, dim, identity, accumulator)}; + return Expr{DoReduction(*array, dim, identity, accumulator)}; } return Expr{std::move(ref)}; } @@ -237,17 +263,7 @@ cx->u); } } else if (name == "count") { - if (!args[1]) { // TODO: COUNT(x,DIM=d) - if (const auto *constant{UnwrapConstantValue(args[0])}) { - std::int64_t result{0}; - for (const auto &element : constant->values()) { - if (element.IsTrue()) { - ++result; - } - } - return Expr{result}; - } - } + return FoldCount(context, std::move(funcRef)); } else if (name == "digits") { if (const auto *cx{UnwrapExpr>(args[0])}) { return Expr{std::visit( Index: flang/lib/Evaluate/fold-logical.cpp =================================================================== --- flang/lib/Evaluate/fold-logical.cpp +++ flang/lib/Evaluate/fold-logical.cpp @@ -26,7 +26,7 @@ auto accumulator{[&](Element &element, const ConstantSubscripts &at) { element = (element.*operation)(array->At(at)); }}; - return Expr{DoReduction(*array, dim, identity, accumulator)}; + return Expr{DoReduction(*array, dim, identity, accumulator)}; } return Expr{std::move(ref)}; } Index: flang/lib/Evaluate/fold-reduction.h =================================================================== --- flang/lib/Evaluate/fold-reduction.h +++ flang/lib/Evaluate/fold-reduction.h @@ -6,8 +6,7 @@ // //===----------------------------------------------------------------------===// -// TODO: ALL, ANY, COUNT, DOT_PRODUCT, FINDLOC, IALL, IANY, IPARITY, -// NORM2, MAXLOC, MINLOC, PARITY, PRODUCT, SUM +// TODO: DOT_PRODUCT, FINDLOC, NORM2, MAXLOC, MINLOC, PARITY #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_ #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_ @@ -16,6 +15,10 @@ namespace Fortran::evaluate { +// Folds & validates a DIM= actual argument. +std::optional CheckDIM( + FoldingContext &, std::optional &, int rank); + // Common preprocessing for reduction transformational intrinsic function // folding. If the intrinsic can have DIM= &/or MASK= arguments, extract // and check them. If a MASK= is present, apply it to the array data and @@ -35,18 +38,7 @@ return std::nullopt; } if (dimIndex && arg.size() >= *dimIndex + 1 && arg[*dimIndex]) { - if (auto *dimConst{ - Folder{context}.Folding(arg[*dimIndex])}) { - if (auto dimScalar{dimConst->GetScalarValue()}) { - dim.emplace(dimScalar->ToInt64()); - if (*dim < 1 || *dim > folded->Rank()) { - context.messages().Say( - "DIM=%jd is not valid for an array of rank %d"_err_en_US, - static_cast(*dim), folded->Rank()); - dim.reset(); - } - } - } + dim = CheckDIM(context, arg[*dimIndex], folded->Rank()); if (!dim) { return std::nullopt; } @@ -96,8 +88,8 @@ // Generalized reduction to an array of one dimension fewer (w/ DIM=) // or to a scalar (w/o DIM=). -template -static Constant DoReduction(const Constant &array, +template +static Constant DoReduction(const Constant &array, std::optional &dim, const Scalar &identity, ACCUMULATOR &accumulator) { ConstantSubscripts at{array.lbounds()}; @@ -154,7 +146,7 @@ element = array->At(at); } }}; - return Expr{DoReduction(*array, dim, identity, accumulator)}; + return Expr{DoReduction(*array, dim, identity, accumulator)}; } return Expr{std::move(ref)}; } @@ -187,7 +179,7 @@ context.messages().Say( "PRODUCT() of %s data overflowed"_en_US, T::AsFortran()); } else { - return Expr{DoReduction(*array, dim, identity, accumulator)}; + return Expr{DoReduction(*array, dim, identity, accumulator)}; } } return Expr{std::move(ref)}; @@ -226,7 +218,7 @@ context.messages().Say( "SUM() of %s data overflowed"_en_US, T::AsFortran()); } else { - return Expr{DoReduction(*array, dim, identity, accumulator)}; + return Expr{DoReduction(*array, dim, identity, accumulator)}; } } return Expr{std::move(ref)}; Index: flang/lib/Evaluate/fold-reduction.cpp =================================================================== --- /dev/null +++ flang/lib/Evaluate/fold-reduction.cpp @@ -0,0 +1,32 @@ +//===-- lib/Evaluate/fold-reduction.cpp -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "fold-reduction.h" + +namespace Fortran::evaluate { + +std::optional CheckDIM( + FoldingContext &context, std::optional &arg, int rank) { + if (arg) { + if (auto *dimConst{Folder{context}.Folding(arg)}) { + if (auto dimScalar{dimConst->GetScalarValue()}) { + auto dim{dimScalar->ToInt64()}; + if (dim >= 1 && dim <= rank) { + return {dim}; + } else { + context.messages().Say( + "DIM=%jd is not valid for an array of rank %d"_err_en_US, + static_cast(dim), rank); + } + } + } + } + return std::nullopt; +} + +} // namespace Fortran::evaluate Index: flang/test/Evaluate/folding29.f90 =================================================================== --- /dev/null +++ flang/test/Evaluate/folding29.f90 @@ -0,0 +1,11 @@ +! RUN: %python %S/test_folding.py %s %flang_fc1 +! Tests folding of COUNT() +module m + logical, parameter :: arr(3,4) = reshape([(modulo(j, 2) == 1, j = 1, size(arr))], shape(arr)) + logical, parameter :: test_1 = count([1, 2, 3, 2, 1] < [(j, j=1, 5)]) == 2 + logical, parameter :: test_2 = count(arr) == 6 + logical, parameter :: test_3 = all(count(arr, dim=1) == [2, 1, 2, 1]) + logical, parameter :: test_4 = all(count(arr, dim=2) == [2, 2, 2]) + logical, parameter :: test_5 = count(logical(arr, kind=1)) == 6 + logical, parameter :: test_6 = count(logical(arr, kind=2)) == 6 +end module