diff --git a/flang/include/flang/Evaluate/call.h b/flang/include/flang/Evaluate/call.h --- a/flang/include/flang/Evaluate/call.h +++ b/flang/include/flang/Evaluate/call.h @@ -199,6 +199,7 @@ std::optional GetType() const; int Rank() const; bool IsElemental() const; + bool IsPure() const; std::optional> LEN() const; llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const; diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -1007,17 +1007,25 @@ // Predicate: is a scalar expression suitable for naive scalar expansion // in the flattening of an array expression? // TODO: capture such scalar expansions in temporaries, flatten everything -struct UnexpandabilityFindingVisitor +class UnexpandabilityFindingVisitor : public AnyTraverse { +public: using Base = AnyTraverse; using Base::operator(); - UnexpandabilityFindingVisitor() : Base{*this} {} - template bool operator()(const FunctionRef &) { return true; } + explicit UnexpandabilityFindingVisitor(bool admitPureCall) + : Base{*this}, admitPureCall_{admitPureCall} {} + template bool operator()(const FunctionRef &procRef) { + return !admitPureCall_ || !procRef.proc().IsPure(); + } bool operator()(const CoarrayRef &) { return true; } + +private: + bool admitPureCall_{false}; }; -template bool IsExpandableScalar(const Expr &expr) { - return !UnexpandabilityFindingVisitor{}(expr); +template +bool IsExpandableScalar(const Expr &expr, bool admitPureCall = false) { + return !UnexpandabilityFindingVisitor{admitPureCall}(expr); } // Common handling for procedure pointer compatibility of left- and right-hand diff --git a/flang/lib/Evaluate/call.cpp b/flang/lib/Evaluate/call.cpp --- a/flang/lib/Evaluate/call.cpp +++ b/flang/lib/Evaluate/call.cpp @@ -145,6 +145,20 @@ return false; } +bool ProcedureDesignator::IsPure() const { + if (const Symbol * interface{GetInterfaceSymbol()}) { + return IsPureProcedure(*interface); + } else if (const Symbol * symbol{GetSymbol()}) { + return IsPureProcedure(*symbol); + } else if (const auto *intrinsic{std::get_if(&u)}) { + return intrinsic->characteristics.value().attrs.test( + characteristics::Procedure::Attr::Pure); + } else { + DIE("ProcedureDesignator::IsPure(): no case"); + } + return false; +} + const SpecificIntrinsic *ProcedureDesignator::GetSpecificIntrinsic() const { return std::get_if(&u); } diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp --- a/flang/lib/Semantics/expression.cpp +++ b/flang/lib/Semantics/expression.cpp @@ -1833,7 +1833,7 @@ "component", "value")}; if (checked && *checked && GetRank(*componentShape) > 0 && GetRank(*valueShape) == 0 && - !IsExpandableScalar(*converted)) { + !IsExpandableScalar(*converted, true /*admit PURE call*/)) { AttachDeclaration( Say(expr.source, "Scalar value cannot be expanded to shape of array component '%s'"_err_en_US,