diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h --- a/flang/include/flang/Lower/AbstractConverter.h +++ b/flang/include/flang/Lower/AbstractConverter.h @@ -57,6 +57,7 @@ using SomeExpr = Fortran::evaluate::Expr; using SymbolRef = Fortran::common::Reference; +class StatementContext; //===----------------------------------------------------------------------===// // AbstractConverter interface @@ -79,26 +80,33 @@ //===--------------------------------------------------------------------===// /// Generate the address of the location holding the expression, someExpr. - virtual fir::ExtendedValue genExprAddr(const SomeExpr &, + virtual fir::ExtendedValue genExprAddr(const SomeExpr &, StatementContext &, mlir::Location *loc = nullptr) = 0; /// Generate the address of the location holding the expression, someExpr - fir::ExtendedValue genExprAddr(const SomeExpr *someExpr, mlir::Location loc) { - return genExprAddr(*someExpr, &loc); + fir::ExtendedValue genExprAddr(const SomeExpr *someExpr, + StatementContext &stmtCtx, + mlir::Location loc) { + return genExprAddr(*someExpr, stmtCtx, &loc); } /// Generate the computations of the expression to produce a value - virtual fir::ExtendedValue genExprValue(const SomeExpr &, + virtual fir::ExtendedValue genExprValue(const SomeExpr &, StatementContext &, mlir::Location *loc = nullptr) = 0; /// Generate the computations of the expression, someExpr, to produce a value fir::ExtendedValue genExprValue(const SomeExpr *someExpr, + StatementContext &stmtCtx, mlir::Location loc) { - return genExprValue(*someExpr, &loc); + return genExprValue(*someExpr, stmtCtx, &loc); } /// Get FoldingContext that is required for some expression /// analysis. virtual Fortran::evaluate::FoldingContext &getFoldingContext() = 0; + /// Host associated variables are grouped as a tuple. This returns that value, + /// which is itself a reference. Use bindTuple() to set this value. + virtual mlir::Value hostAssocTupleValue() = 0; + //===--------------------------------------------------------------------===// // Types //===--------------------------------------------------------------------===// diff --git a/flang/include/flang/Lower/CallInterface.h b/flang/include/flang/Lower/CallInterface.h --- a/flang/include/flang/Lower/CallInterface.h +++ b/flang/include/flang/Lower/CallInterface.h @@ -52,10 +52,16 @@ /// inside the input vector for the CallOp (caller side. It will be up to the /// CallInterface user to produce the mlir::Value that will go in this input /// vector). +class CallerInterface; class CalleeInterface; template struct PassedEntityTypes {}; template <> +struct PassedEntityTypes { + using FortranEntity = const Fortran::evaluate::ActualArgument *; + using FirValue = int; +}; +template <> struct PassedEntityTypes { using FortranEntity = std::optional>; @@ -165,6 +171,15 @@ nullptr; }; + /// Return the mlir::FuncOp. Note that front block is added by this + /// utility if callee side. + mlir::FuncOp getFuncOp() const { return func; } + /// Number of MLIR inputs/outputs of the created FuncOp. + std::size_t getNumFIRArguments() const { return inputs.size(); } + std::size_t getNumFIRResults() const { return outputs.size(); } + /// Return the MLIR output types. + llvm::SmallVector getResultType() const; + /// Return a container of Symbol/ActualArgument* and how they must /// be plugged with the mlir::FuncOp. llvm::ArrayRef getPassedArguments() const { @@ -182,6 +197,21 @@ determineInterface(bool isImplicit, const Fortran::evaluate::characteristics::Procedure &); + /// Does the caller need to allocate storage for the result ? + bool callerAllocateResult() const { + return mustPassResult() || mustSaveResult(); + } + + /// Is the Fortran result passed as an extra MLIR argument ? + bool mustPassResult() const { return passedResult.has_value(); } + /// Must the MLIR result be saved with a fir.save_result ? + bool mustSaveResult() const { return saveResult; } + + /// Can the associated procedure be called via an implicit interface? + bool canBeCalledViaImplicitInterface() const { + return characteristic && characteristic->CanBeCalledViaImplicitInterface(); + } + protected: CallInterface(Fortran::lower::AbstractConverter &c) : converter{c} {} /// CRTP handle. @@ -199,6 +229,7 @@ mlir::FuncOp func; llvm::SmallVector passedArguments; std::optional passedResult; + bool saveResult = false; Fortran::lower::AbstractConverter &converter; /// Store characteristic once created, it is required for further information @@ -207,6 +238,102 @@ std::nullopt; }; +//===----------------------------------------------------------------------===// +// Caller side interface +//===----------------------------------------------------------------------===// + +/// The CallerInterface provides the helpers needed by CallInterface +/// (getting the characteristic...) and a safe way for the user to +/// place the mlir::Value arguments into the input vector +/// once they are lowered. +class CallerInterface : public CallInterface { +public: + CallerInterface(const Fortran::evaluate::ProcedureRef &p, + Fortran::lower::AbstractConverter &c) + : CallInterface{c}, procRef{p} { + declare(); + mapPassedEntities(); + actualInputs.resize(getNumFIRArguments()); + } + + using ExprVisitor = std::function)>; + + /// CRTP callbacks + bool hasAlternateReturns() const; + std::string getMangledName() const; + mlir::Location getCalleeLocation() const; + Fortran::evaluate::characteristics::Procedure characterize() const; + + const Fortran::evaluate::ProcedureRef &getCallDescription() const { + return procRef; + } + + bool isMainProgram() const { return false; } + + /// Returns true if this is a call to a procedure pointer of a dummy + /// procedure. + bool isIndirectCall() const; + + /// Return the procedure symbol if this is a call to a user defined + /// procedure. + const Fortran::semantics::Symbol *getProcedureSymbol() const; + + /// Helpers to place the lowered arguments at the right place once they + /// have been lowered. + void placeInput(const PassedEntity &passedEntity, mlir::Value arg); + void placeAddressAndLengthInput(const PassedEntity &passedEntity, + mlir::Value addr, mlir::Value len); + + /// If this is a call to a procedure pointer or dummy, returns the related + /// symbol. Nullptr otherwise. + const Fortran::semantics::Symbol *getIfIndirectCallSymbol() const; + + /// Get the input vector once it is complete. + llvm::ArrayRef getInputs() const { + if (!verifyActualInputs()) + llvm::report_fatal_error("lowered arguments are incomplete"); + return actualInputs; + } + + /// Does the caller must map function interface symbols in order to evaluate + /// the result specification expressions (extents and lengths) ? If needed, + /// this mapping must be done after argument lowering, and before the call + /// itself. + bool mustMapInterfaceSymbols() const; + + /// Walk the result non-deferred extent specification expressions. + void walkResultExtents(ExprVisitor) const; + + /// Walk the result non-deferred length specification expressions. + void walkResultLengths(ExprVisitor) const; + + /// Get the mlir::Value that is passed as argument \p sym of the function + /// being called. The arguments must have been placed before calling this + /// function. + mlir::Value getArgumentValue(const semantics::Symbol &sym) const; + + /// Returns the symbol for the result in the explicit interface. If this is + /// called on an intrinsic or function without explicit interface, this will + /// crash. + const Fortran::semantics::Symbol &getResultSymbol() const; + + /// If some storage needs to be allocated for the result, + /// returns the storage type. + mlir::Type getResultStorageType() const; + + // Copy of base implementation. + static constexpr bool hasHostAssociated() { return false; } + mlir::Type getHostAssociatedTy() const { + llvm_unreachable("getting host associated type in CallerInterface"); + } + +private: + /// Check that the input vector is complete. + bool verifyActualInputs() const; + const Fortran::evaluate::ProcedureRef &procRef; + llvm::SmallVector actualInputs; +}; + //===----------------------------------------------------------------------===// // Callee side interface //===----------------------------------------------------------------------===// diff --git a/flang/include/flang/Lower/ConvertExpr.h b/flang/include/flang/Lower/ConvertExpr.h --- a/flang/include/flang/Lower/ConvertExpr.h +++ b/flang/include/flang/Lower/ConvertExpr.h @@ -34,6 +34,7 @@ namespace Fortran::lower { class AbstractConverter; +class StatementContext; class SymMap; using SomeExpr = Fortran::evaluate::Expr; @@ -41,13 +42,24 @@ fir::ExtendedValue createSomeExtendedExpression(mlir::Location loc, AbstractConverter &converter, const SomeExpr &expr, - SymMap &symMap); + SymMap &symMap, + StatementContext &stmtCtx); /// Create an extended expression address. fir::ExtendedValue createSomeExtendedAddress(mlir::Location loc, AbstractConverter &converter, const SomeExpr &expr, - SymMap &symMap); + SymMap &symMap, + StatementContext &stmtCtx); + +/// Lower a subroutine call. This handles both elemental and non elemental +/// subroutines. \p isUserDefAssignment must be set if this is called in the +/// context of a user defined assignment. For subroutines with alternate +/// returns, the returned value indicates which label the code should jump to. +/// The returned value is null otherwise. +mlir::Value createSubroutineCall(AbstractConverter &converter, + const evaluate::ProcedureRef &call, + SymMap &symMap, StatementContext &stmtCtx); // Attribute for an alloca that is a trivial adaptor for converting a value to // pass-by-ref semantics for a VALUE parameter. The optimizer may be able to diff --git a/flang/include/flang/Lower/ConvertVariable.h b/flang/include/flang/Lower/ConvertVariable.h --- a/flang/include/flang/Lower/ConvertVariable.h +++ b/flang/include/flang/Lower/ConvertVariable.h @@ -19,6 +19,7 @@ namespace Fortran ::lower { class AbstractConverter; +class CallerInterface; class SymMap; namespace pft { struct Variable; @@ -31,5 +32,12 @@ void instantiateVariable(AbstractConverter &, const pft::Variable &var, SymMap &symMap); +/// Instantiate the variables that appear in the specification expressions +/// of the result of a function call. The instantiated variables are added +/// to \p symMap. +void mapCallInterfaceSymbols(AbstractConverter &, + const Fortran::lower::CallerInterface &caller, + SymMap &symMap); + } // namespace Fortran::lower #endif // FORTRAN_LOWER_CONVERT_VARIABLE_H diff --git a/flang/include/flang/Lower/StatementContext.h b/flang/include/flang/Lower/StatementContext.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Lower/StatementContext.h @@ -0,0 +1,85 @@ +//===-- StatementContext.h --------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_LOWER_STATEMENTCONTEXT_H +#define FORTRAN_LOWER_STATEMENTCONTEXT_H + +#include + +namespace Fortran::lower { + +/// When lowering a statement, temporaries for intermediate results may be +/// allocated on the heap. A StatementContext enables their deallocation +/// either explicitly with finalize() calls, or implicitly at the end of +/// the context. A context may prohibit temporary allocation. Otherwise, +/// an initial "outer" context scope may have nested context scopes, which +/// must make explicit subscope finalize() calls. +class StatementContext { +public: + explicit StatementContext(bool cleanupProhibited = false) { + if (cleanupProhibited) + return; + cufs.push_back({}); + } + + ~StatementContext() { + if (!cufs.empty()) + finalize(/*popScope=*/true); + assert(cufs.empty() && "invalid StatementContext destructor call"); + } + + using CleanupFunction = std::function; + + /// Push a context subscope. + void pushScope() { + assert(!cufs.empty() && "invalid pushScope statement context"); + cufs.push_back({}); + } + + /// Append a cleanup function to the "list" of cleanup functions. + void attachCleanup(CleanupFunction cuf) { + assert(!cufs.empty() && "invalid attachCleanup statement context"); + if (cufs.back()) { + CleanupFunction oldCleanup = *cufs.back(); + cufs.back() = [=]() { + cuf(); + oldCleanup(); + }; + } else { + cufs.back() = cuf; + } + } + + /// Make cleanup calls. Pop or reset the stack top list. + void finalize(bool popScope = false) { + assert(!cufs.empty() && "invalid finalize statement context"); + if (cufs.back()) + (*cufs.back())(); + if (popScope) + cufs.pop_back(); + else + cufs.back().reset(); + } + +private: + // A statement context should never be copied or moved. + StatementContext(const StatementContext &) = delete; + StatementContext &operator=(const StatementContext &) = delete; + StatementContext(StatementContext &&) = delete; + + // Stack of cleanup function "lists" (nested cleanup function calls). + llvm::SmallVector> cufs; +}; + +} // namespace Fortran::lower + +#endif // FORTRAN_LOWER_STATEMENTCONTEXT_H diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -63,14 +63,15 @@ getKindMap().getIntegerBitsize(getKindMap().defaultIntegerKind())); } - /// The LHS and RHS are not always in agreement in terms of - /// type. In some cases, the disagreement is between COMPLEX and other scalar - /// types. In that case, the conversion must insert/extract out of a COMPLEX - /// value to have the proper semantics and be strongly typed. For e.g for - /// converting an integer/real to a complex, the real part is filled using - /// the integer/real after type conversion and the imaginary part is zero. + /// The LHS and RHS are not always in agreement in terms of type. In some + /// cases, the disagreement is between COMPLEX and other scalar types. In that + /// case, the conversion must insert (extract) out of a COMPLEX value to have + /// the proper semantics and be strongly typed. E.g., converting an integer + /// (real) to a complex, the real part is filled using the integer (real) + /// after type conversion and the imaginary part is zero. mlir::Value convertWithSemantics(mlir::Location loc, mlir::Type toTy, - mlir::Value val); + mlir::Value val, + bool allowCharacterConversion = false); /// Get the entry block of the current Function mlir::Block *getEntryBlock() { return &getFunction().front(); } @@ -97,9 +98,18 @@ return getI64Type(); } + /// Wrap `str` to a SymbolRefAttr. + mlir::SymbolRefAttr getSymbolRefAttr(llvm::StringRef str) { + return mlir::SymbolRefAttr::get(getContext(), str); + } + /// Get the mlir real type that implements fortran REAL(kind). mlir::Type getRealType(int kind); + fir::BoxProcType getBoxProcType(mlir::FunctionType funcTy) { + return fir::BoxProcType::get(getContext(), funcTy); + } + /// Create a null constant memory reference of type \p ptrType. /// If \p ptrType is not provided, !fir.ref type will be used. mlir::Value createNullConstant(mlir::Location loc, mlir::Type ptrType = {}); @@ -213,6 +223,14 @@ static mlir::FuncOp getNamedFunction(mlir::ModuleOp module, llvm::StringRef name); + /// Get a function by symbol name. The result will be null if there is no + /// function with the given symbol in the module. + mlir::FuncOp getNamedFunction(mlir::SymbolRefAttr symbol) { + return getNamedFunction(getModule(), symbol); + } + static mlir::FuncOp getNamedFunction(mlir::ModuleOp module, + mlir::SymbolRefAttr symbol); + fir::GlobalOp getNamedGlobal(llvm::StringRef name) { return getNamedGlobal(getModule(), name); } @@ -382,6 +400,14 @@ mlir::Location loc, const fir::ExtendedValue &box); +/// Read a fir::BoxValue into an fir::UnboxValue, a fir::ArrayBoxValue or a +/// fir::CharArrayBoxValue. This should only be called if the fir::BoxValue is +/// known to be contiguous given the context (or if the resulting address will +/// not be used). If the value is polymorphic, its dynamic type will be lost. +/// This must not be used on unlimited polymorphic and assumed rank entities. +fir::ExtendedValue readBoxValue(fir::FirOpBuilder &builder, mlir::Location loc, + const fir::BoxValue &box); + //===----------------------------------------------------------------------===// // String literal helper helpers //===----------------------------------------------------------------------===// diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h --- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h +++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h @@ -77,6 +77,11 @@ /// Attribute to keep track of Fortran scoping information for a symbol. static constexpr llvm::StringRef getSymbolAttrName() { return "fir.sym_name"; } +/// Attribute to mark a function that takes a host associations argument. +static constexpr llvm::StringRef getHostAssocAttrName() { + return "fir.host_assoc"; +} + /// Tell if \p value is: /// - a function argument that has attribute \p attributeName /// - or, the result of fir.alloca/fir.allocamem op that has attribute \p @@ -87,6 +92,11 @@ /// previous cases. bool valueHasFirAttribute(mlir::Value value, llvm::StringRef attributeName); +/// Scan the arguments of a FuncOp to determine if any arguments have the +/// attribute `attr` placed on them. This can be used to determine if the +/// function has any host associations, for example. +bool anyFuncArgsHaveAttr(mlir::FuncOp func, llvm::StringRef attr); + } // namespace fir #endif // FORTRAN_OPTIMIZER_DIALECT_FIROPSSUPPORT_H diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -19,6 +19,7 @@ #include "flang/Lower/Mangler.h" #include "flang/Lower/PFTBuilder.h" #include "flang/Lower/Runtime.h" +#include "flang/Lower/StatementContext.h" #include "flang/Lower/SymbolMap.h" #include "flang/Lower/Todo.h" #include "flang/Optimizer/Support/FIRContext.h" @@ -77,15 +78,17 @@ } fir::ExtendedValue genExprAddr(const Fortran::lower::SomeExpr &expr, + Fortran::lower::StatementContext &context, mlir::Location *loc = nullptr) override final { return createSomeExtendedAddress(loc ? *loc : toLocation(), *this, expr, - localSymbols); + localSymbols, context); } fir::ExtendedValue genExprValue(const Fortran::lower::SomeExpr &expr, + Fortran::lower::StatementContext &context, mlir::Location *loc = nullptr) override final { return createSomeExtendedExpression(loc ? *loc : toLocation(), *this, expr, - localSymbols); + localSymbols, context); } Fortran::evaluate::FoldingContext &getFoldingContext() override final { @@ -224,6 +227,7 @@ {builder->getRegion()}); // remove dead code delete builder; builder = nullptr; + hostAssocTuple = mlir::Value{}; localSymbols.clear(); } @@ -357,6 +361,8 @@ lowerFunc(f); // internal procedure } + mlir::Value hostAssocTupleValue() override final { return hostAssocTuple; } + private: FirConverter() = delete; FirConverter(const FirConverter &) = delete; @@ -476,8 +482,8 @@ } void genAssignment(const Fortran::evaluate::Assignment &assign) { + Fortran::lower::StatementContext stmtCtx; mlir::Location loc = toLocation(); - std::visit( Fortran::common::visitors{ // [1] Plain old assignment. @@ -512,15 +518,16 @@ const bool isNumericScalar = isNumericScalarCategory(lhsType->category()); fir::ExtendedValue rhs = isNumericScalar - ? genExprValue(assign.rhs) - : genExprAddr(assign.rhs); + ? genExprValue(assign.rhs, stmtCtx) + : genExprAddr(assign.rhs, stmtCtx); if (isNumericScalar) { // Fortran 2018 10.2.1.3 p8 and p9 // Conversions should have been inserted by semantic analysis, // but they can be incorrect between the rhs and lhs. Correct // that here. - mlir::Value addr = fir::getBase(genExprAddr(assign.lhs)); + mlir::Value addr = + fir::getBase(genExprAddr(assign.lhs, stmtCtx)); mlir::Value val = fir::getBase(rhs); // A function with multiple entry points returning different // types tags all result variables with one of the largest @@ -568,8 +575,16 @@ assign.u); } + /// Lowering of CALL statement void genFIR(const Fortran::parser::CallStmt &stmt) { - TODO(toLocation(), "CallStmt lowering"); + Fortran::lower::StatementContext stmtCtx; + setCurrentPosition(stmt.v.source); + assert(stmt.typedCall && "Call was not analyzed"); + // Call statement lowering shares code with function call lowering. + mlir::Value res = Fortran::lower::createSubroutineCall( + *this, *stmt.typedCall, localSymbols, stmtCtx); + if (!res) + return; // "Normal" subroutine call. } void genFIR(const Fortran::parser::ComputedGotoStmt &stmt) { @@ -999,6 +1014,9 @@ Fortran::lower::pft::Evaluation *evalPtr = nullptr; Fortran::lower::SymMap localSymbols; Fortran::parser::CharBlock currentPosition; + + /// Tuple of host assoicated variables. + mlir::Value hostAssocTuple; }; } // namespace diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp --- a/flang/lib/Lower/CallInterface.cpp +++ b/flang/lib/Lower/CallInterface.cpp @@ -30,6 +30,261 @@ return bindName ? *bindName : Fortran::lower::mangle::mangleName(symbol); } +//===----------------------------------------------------------------------===// +// Caller side interface implementation +//===----------------------------------------------------------------------===// + +bool Fortran::lower::CallerInterface::hasAlternateReturns() const { + return procRef.hasAlternateReturns(); +} + +std::string Fortran::lower::CallerInterface::getMangledName() const { + const Fortran::evaluate::ProcedureDesignator &proc = procRef.proc(); + if (const Fortran::semantics::Symbol *symbol = proc.GetSymbol()) + return ::getMangledName(symbol->GetUltimate()); + assert(proc.GetSpecificIntrinsic() && + "expected intrinsic procedure in designator"); + return proc.GetName(); +} + +const Fortran::semantics::Symbol * +Fortran::lower::CallerInterface::getProcedureSymbol() const { + return procRef.proc().GetSymbol(); +} + +bool Fortran::lower::CallerInterface::isIndirectCall() const { + if (const Fortran::semantics::Symbol *symbol = procRef.proc().GetSymbol()) + return Fortran::semantics::IsPointer(*symbol) || + Fortran::semantics::IsDummy(*symbol); + return false; +} + +const Fortran::semantics::Symbol * +Fortran::lower::CallerInterface::getIfIndirectCallSymbol() const { + if (const Fortran::semantics::Symbol *symbol = procRef.proc().GetSymbol()) + if (Fortran::semantics::IsPointer(*symbol) || + Fortran::semantics::IsDummy(*symbol)) + return symbol; + return nullptr; +} + +mlir::Location Fortran::lower::CallerInterface::getCalleeLocation() const { + const Fortran::evaluate::ProcedureDesignator &proc = procRef.proc(); + // FIXME: If the callee is defined in the same file but after the current + // unit we cannot get its location here and the funcOp is created at the + // wrong location (i.e, the caller location). + if (const Fortran::semantics::Symbol *symbol = proc.GetSymbol()) + return converter.genLocation(symbol->name()); + // Use current location for intrinsics. + return converter.getCurrentLocation(); +} + +// Get dummy argument characteristic for a procedure with implicit interface +// from the actual argument characteristic. The actual argument may not be a F77 +// entity. The attribute must be dropped and the shape, if any, must be made +// explicit. +static Fortran::evaluate::characteristics::DummyDataObject +asImplicitArg(Fortran::evaluate::characteristics::DummyDataObject &&dummy) { + Fortran::evaluate::Shape shape = + dummy.type.attrs().none() ? dummy.type.shape() + : Fortran::evaluate::Shape(dummy.type.Rank()); + return Fortran::evaluate::characteristics::DummyDataObject( + Fortran::evaluate::characteristics::TypeAndShape(dummy.type.type(), + std::move(shape))); +} + +static Fortran::evaluate::characteristics::DummyArgument +asImplicitArg(Fortran::evaluate::characteristics::DummyArgument &&dummy) { + return std::visit( + Fortran::common::visitors{ + [&](Fortran::evaluate::characteristics::DummyDataObject &obj) { + return Fortran::evaluate::characteristics::DummyArgument( + std::move(dummy.name), asImplicitArg(std::move(obj))); + }, + [&](Fortran::evaluate::characteristics::DummyProcedure &proc) { + return Fortran::evaluate::characteristics::DummyArgument( + std::move(dummy.name), std::move(proc)); + }, + [](Fortran::evaluate::characteristics::AlternateReturn &x) { + return Fortran::evaluate::characteristics::DummyArgument( + std::move(x)); + }}, + dummy.u); +} + +Fortran::evaluate::characteristics::Procedure +Fortran::lower::CallerInterface::characterize() const { + Fortran::evaluate::FoldingContext &foldingContext = + converter.getFoldingContext(); + std::optional characteristic = + Fortran::evaluate::characteristics::Procedure::Characterize( + procRef.proc(), foldingContext); + assert(characteristic && "Failed to get characteristic from procRef"); + // The characteristic may not contain the argument characteristic if the + // ProcedureDesignator has no interface. + if (!characteristic->HasExplicitInterface()) { + for (const std::optional &arg : + procRef.arguments()) { + if (arg.value().isAlternateReturn()) { + characteristic->dummyArguments.emplace_back( + Fortran::evaluate::characteristics::AlternateReturn{}); + } else { + // Argument cannot be optional with implicit interface + const Fortran::lower::SomeExpr *expr = arg.value().UnwrapExpr(); + assert( + expr && + "argument in call with implicit interface cannot be assumed type"); + std::optional + argCharacteristic = + Fortran::evaluate::characteristics::DummyArgument::FromActual( + "actual", *expr, foldingContext); + assert(argCharacteristic && + "failed to characterize argument in implicit call"); + characteristic->dummyArguments.emplace_back( + asImplicitArg(std::move(*argCharacteristic))); + } + } + } + return *characteristic; +} + +void Fortran::lower::CallerInterface::placeInput( + const PassedEntity &passedEntity, mlir::Value arg) { + assert(static_cast(actualInputs.size()) > passedEntity.firArgument && + passedEntity.firArgument >= 0 && + passedEntity.passBy != CallInterface::PassEntityBy::AddressAndLength && + "bad arg position"); + actualInputs[passedEntity.firArgument] = arg; +} + +void Fortran::lower::CallerInterface::placeAddressAndLengthInput( + const PassedEntity &passedEntity, mlir::Value addr, mlir::Value len) { + assert(static_cast(actualInputs.size()) > passedEntity.firArgument && + static_cast(actualInputs.size()) > passedEntity.firLength && + passedEntity.firArgument >= 0 && passedEntity.firLength >= 0 && + passedEntity.passBy == CallInterface::PassEntityBy::AddressAndLength && + "bad arg position"); + actualInputs[passedEntity.firArgument] = addr; + actualInputs[passedEntity.firLength] = len; +} + +bool Fortran::lower::CallerInterface::verifyActualInputs() const { + if (getNumFIRArguments() != actualInputs.size()) + return false; + for (mlir::Value arg : actualInputs) { + if (!arg) + return false; + } + return true; +} + +void Fortran::lower::CallerInterface::walkResultLengths( + ExprVisitor visitor) const { + assert(characteristic && "characteristic was not computed"); + const Fortran::evaluate::characteristics::FunctionResult &result = + characteristic->functionResult.value(); + const Fortran::evaluate::characteristics::TypeAndShape *typeAndShape = + result.GetTypeAndShape(); + assert(typeAndShape && "no result type"); + Fortran::evaluate::DynamicType dynamicType = typeAndShape->type(); + // Visit result length specification expressions that are explicit. + if (dynamicType.category() == Fortran::common::TypeCategory::Character) { + if (std::optional length = + dynamicType.GetCharLength()) + visitor(toEvExpr(*length)); + } else if (dynamicType.category() == common::TypeCategory::Derived) { + const Fortran::semantics::DerivedTypeSpec &derivedTypeSpec = + dynamicType.GetDerivedTypeSpec(); + if (Fortran::semantics::CountLenParameters(derivedTypeSpec) > 0) + TODO(converter.getCurrentLocation(), + "function result with derived type length parameters"); + } +} + +// Compute extent expr from shapeSpec of an explicit shape. +// TODO: Allow evaluate shape analysis to work in a mode where it disregards +// the non-constant aspects when building the shape to avoid having this here. +static Fortran::evaluate::ExtentExpr +getExtentExpr(const Fortran::semantics::ShapeSpec &shapeSpec) { + const auto &ubound = shapeSpec.ubound().GetExplicit(); + const auto &lbound = shapeSpec.lbound().GetExplicit(); + assert(lbound && ubound && "shape must be explicit"); + return Fortran::common::Clone(*ubound) - Fortran::common::Clone(*lbound) + + Fortran::evaluate::ExtentExpr{1}; +} + +void Fortran::lower::CallerInterface::walkResultExtents( + ExprVisitor visitor) const { + // Walk directly the result symbol shape (the characteristic shape may contain + // descriptor inquiries to it that would fail to lower on the caller side). + const Fortran::semantics::Symbol *interfaceSymbol = + procRef.proc().GetInterfaceSymbol(); + if (interfaceSymbol) { + const Fortran::semantics::Symbol &result = + interfaceSymbol->get().result(); + if (const auto *objectDetails = + result.detailsIf()) + if (objectDetails->shape().IsExplicitShape()) + for (const Fortran::semantics::ShapeSpec &shapeSpec : + objectDetails->shape()) + visitor(Fortran::evaluate::AsGenericExpr(getExtentExpr(shapeSpec))); + } else { + if (procRef.Rank() != 0) + fir::emitFatalError( + converter.getCurrentLocation(), + "only scalar functions may not have an interface symbol"); + } +} + +bool Fortran::lower::CallerInterface::mustMapInterfaceSymbols() const { + assert(characteristic && "characteristic was not computed"); + const std::optional + &result = characteristic->functionResult; + if (!result || result->CanBeReturnedViaImplicitInterface() || + !procRef.proc().GetInterfaceSymbol()) + return false; + bool allResultSpecExprConstant = true; + auto visitor = [&](const Fortran::lower::SomeExpr &e) { + allResultSpecExprConstant &= Fortran::evaluate::IsConstantExpr(e); + }; + walkResultLengths(visitor); + walkResultExtents(visitor); + return !allResultSpecExprConstant; +} + +mlir::Value Fortran::lower::CallerInterface::getArgumentValue( + const semantics::Symbol &sym) const { + mlir::Location loc = converter.getCurrentLocation(); + const Fortran::semantics::Symbol *iface = procRef.proc().GetInterfaceSymbol(); + if (!iface) + fir::emitFatalError( + loc, "mapping actual and dummy arguments requires an interface"); + const std::vector &dummies = + iface->get().dummyArgs(); + auto it = std::find(dummies.begin(), dummies.end(), &sym); + if (it == dummies.end()) + fir::emitFatalError(loc, "symbol is not a dummy in this call"); + FirValue mlirArgIndex = passedArguments[it - dummies.begin()].firArgument; + return actualInputs[mlirArgIndex]; +} + +mlir::Type Fortran::lower::CallerInterface::getResultStorageType() const { + if (passedResult) + return fir::dyn_cast_ptrEleTy(inputs[passedResult->firArgument].type); + assert(saveResult && !outputs.empty()); + return outputs[0].type; +} + +const Fortran::semantics::Symbol & +Fortran::lower::CallerInterface::getResultSymbol() const { + mlir::Location loc = converter.getCurrentLocation(); + const Fortran::semantics::Symbol *iface = procRef.proc().GetInterfaceSymbol(); + if (!iface) + fir::emitFatalError( + loc, "mapping actual and dummy arguments requires an interface"); + return iface->get().result(); +} + //===----------------------------------------------------------------------===// // Callee side interface implementation //===----------------------------------------------------------------------===// @@ -162,6 +417,12 @@ passedEntity.firArgument = firValue; } +/// Helpers to access ActualArgument/Symbols +static const Fortran::evaluate::ActualArguments & +getEntityContainer(const Fortran::evaluate::ProcedureRef &proc) { + return proc.arguments(); +} + static const std::vector & getEntityContainer(Fortran::lower::pft::FunctionLikeUnit &funit) { return funit.getSubprogramSymbol() @@ -169,6 +430,13 @@ .dummyArgs(); } +static const Fortran::evaluate::ActualArgument *getDataObjectEntity( + const std::optional &arg) { + if (arg) + return &*arg; + return nullptr; +} + static const Fortran::semantics::Symbol & getDataObjectEntity(const Fortran::semantics::Symbol *arg) { assert(arg && "expect symbol for data object entity"); @@ -400,6 +668,26 @@ mlir::MLIRContext &mlirContext; }; +template +bool Fortran::lower::CallInterface::PassedEntity::isOptional() const { + if (!characteristics) + return false; + return characteristics->IsOptional(); +} +template +bool Fortran::lower::CallInterface::PassedEntity::mayBeModifiedByCall() + const { + if (!characteristics) + return true; + return characteristics->GetIntent() != Fortran::common::Intent::In; +} +template +bool Fortran::lower::CallInterface::PassedEntity::mayBeReadByCall() const { + if (!characteristics) + return true; + return characteristics->GetIntent() != Fortran::common::Intent::Out; +} + template void Fortran::lower::CallInterface::determineInterface( bool isImplicit, @@ -424,3 +712,4 @@ } template class Fortran::lower::CallInterface; +template class Fortran::lower::CallInterface; diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -12,14 +12,17 @@ #include "flang/Lower/ConvertExpr.h" #include "flang/Evaluate/fold.h" -#include "flang/Evaluate/real.h" #include "flang/Evaluate/traverse.h" #include "flang/Lower/AbstractConverter.h" +#include "flang/Lower/CallInterface.h" #include "flang/Lower/ConvertType.h" +#include "flang/Lower/ConvertVariable.h" #include "flang/Lower/IntrinsicCall.h" +#include "flang/Lower/StatementContext.h" #include "flang/Lower/SymbolMap.h" #include "flang/Lower/Todo.h" #include "flang/Optimizer/Builder/Complex.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Semantics/expression.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" @@ -67,6 +70,25 @@ return fir::substBase(exv, temp); } +/// Is this a variable wrapped in parentheses? +template +static bool isParenthesizedVariable(const A &) { + return false; +} +template +static bool isParenthesizedVariable(const Fortran::evaluate::Expr &expr) { + using ExprVariant = decltype(Fortran::evaluate::Expr::u); + using Parentheses = Fortran::evaluate::Parentheses; + if constexpr (Fortran::common::HasMember) { + if (const auto *parentheses = std::get_if(&expr.u)) + return Fortran::evaluate::IsVariable(parentheses->left()); + return false; + } else { + return std::visit([&](const auto &x) { return isParenthesizedVariable(x); }, + expr.u); + } +} + /// Generate a load of a value from an address. Beware that this will lose /// any dynamic type information for polymorphic entities (note that unlimited /// polymorphic cannot be loaded and must not be provided here). @@ -103,6 +125,22 @@ return true; return false; } + +/// If \p arg is the address of a function with a denoted host-association tuple +/// argument, then return the host-associations tuple value of the current +/// procedure. Otherwise, return nullptr. +static mlir::Value +argumentHostAssocs(Fortran::lower::AbstractConverter &converter, + mlir::Value arg) { + if (auto addr = mlir::dyn_cast_or_null(arg.getDefiningOp())) { + auto &builder = converter.getFirOpBuilder(); + if (auto funcOp = builder.getNamedFunction(addr.getSymbol())) + if (fir::anyFuncArgsHaveAttr(funcOp, fir::getHostAssocAttrName())) + return converter.hostAssocTupleValue(); + } + return {}; +} + namespace { /// Lowering of Fortran::evaluate::Expr expressions @@ -112,9 +150,29 @@ explicit ScalarExprLowering(mlir::Location loc, Fortran::lower::AbstractConverter &converter, - Fortran::lower::SymMap &symMap) + Fortran::lower::SymMap &symMap, + Fortran::lower::StatementContext &stmtCtx) : location{loc}, converter{converter}, - builder{converter.getFirOpBuilder()}, symMap{symMap} {} + builder{converter.getFirOpBuilder()}, stmtCtx{stmtCtx}, symMap{symMap} { + } + + ExtValue genExtAddr(const Fortran::lower::SomeExpr &expr) { + return gen(expr); + } + + /// Lower `expr` to be passed as a fir.box argument. Do not create a temp + /// for the expr if it is a variable that can be described as a fir.box. + ExtValue genBoxArg(const Fortran::lower::SomeExpr &expr) { + bool saveUseBoxArg = useBoxArg; + useBoxArg = true; + ExtValue result = gen(expr); + useBoxArg = saveUseBoxArg; + return result; + } + + ExtValue genExtValue(const Fortran::lower::SomeExpr &expr) { + return genval(expr); + } mlir::Location getLoc() { return location; } @@ -516,6 +574,501 @@ TODO(getLoc(), "gen FunctionRef"); } + /// helper to detect statement functions + static bool + isStatementFunctionCall(const Fortran::evaluate::ProcedureRef &procRef) { + if (const Fortran::semantics::Symbol *symbol = procRef.proc().GetSymbol()) + if (const auto *details = + symbol->detailsIf()) + return details->stmtFunction().has_value(); + return false; + } + + /// Helper to package a Value and its properties into an ExtendedValue. + static ExtValue toExtendedValue(mlir::Location loc, mlir::Value base, + llvm::ArrayRef extents, + llvm::ArrayRef lengths) { + mlir::Type type = base.getType(); + if (type.isa()) + return fir::BoxValue(base, /*lbounds=*/{}, lengths, extents); + type = fir::unwrapRefType(type); + if (type.isa()) + return fir::MutableBoxValue(base, lengths, /*mutableProperties*/ {}); + if (auto seqTy = type.dyn_cast()) { + if (seqTy.getDimension() != extents.size()) + fir::emitFatalError(loc, "incorrect number of extents for array"); + if (seqTy.getEleTy().isa()) { + if (lengths.empty()) + fir::emitFatalError(loc, "missing length for character"); + assert(lengths.size() == 1); + return fir::CharArrayBoxValue(base, lengths[0], extents); + } + return fir::ArrayBoxValue(base, extents); + } + if (type.isa()) { + if (lengths.empty()) + fir::emitFatalError(loc, "missing length for character"); + assert(lengths.size() == 1); + return fir::CharBoxValue(base, lengths[0]); + } + return base; + } + + // Find the argument that corresponds to the host associations. + // Verify some assumptions about how the signature was built here. + [[maybe_unused]] static unsigned findHostAssocTuplePos(mlir::FuncOp fn) { + // Scan the argument list from last to first as the host associations are + // appended for now. + for (unsigned i = fn.getNumArguments(); i > 0; --i) + if (fn.getArgAttr(i - 1, fir::getHostAssocAttrName())) { + // Host assoc tuple must be last argument (for now). + assert(i == fn.getNumArguments() && "tuple must be last"); + return i - 1; + } + llvm_unreachable("anyFuncArgsHaveAttr failed"); + } + + /// Lower a non-elemental procedure reference and read allocatable and pointer + /// results into normal values. + ExtValue genProcedureRef(const Fortran::evaluate::ProcedureRef &procRef, + llvm::Optional resultType) { + ExtValue res = genRawProcedureRef(procRef, resultType); + return res; + } + + /// Given a call site for which the arguments were already lowered, generate + /// the call and return the result. This function deals with explicit result + /// allocation and lowering if needed. It also deals with passing the host + /// link to internal procedures. + ExtValue genCallOpAndResult(Fortran::lower::CallerInterface &caller, + mlir::FunctionType callSiteType, + llvm::Optional resultType) { + mlir::Location loc = getLoc(); + using PassBy = Fortran::lower::CallerInterface::PassEntityBy; + // Handle cases where caller must allocate the result or a fir.box for it. + bool mustPopSymMap = false; + if (caller.mustMapInterfaceSymbols()) { + symMap.pushScope(); + mustPopSymMap = true; + Fortran::lower::mapCallInterfaceSymbols(converter, caller, symMap); + } + // If this is an indirect call, retrieve the function address. Also retrieve + // the result length if this is a character function (note that this length + // will be used only if there is no explicit length in the local interface). + mlir::Value funcPointer; + mlir::Value charFuncPointerLength; + if (caller.getIfIndirectCallSymbol()) { + TODO(loc, "genCallOpAndResult indirect call"); + } + + mlir::IndexType idxTy = builder.getIndexType(); + auto lowerSpecExpr = [&](const auto &expr) -> mlir::Value { + return builder.createConvert( + loc, idxTy, fir::getBase(converter.genExprValue(expr, stmtCtx))); + }; + llvm::SmallVector resultLengths; + auto allocatedResult = [&]() -> llvm::Optional { + llvm::SmallVector extents; + llvm::SmallVector lengths; + if (!caller.callerAllocateResult()) + return {}; + mlir::Type type = caller.getResultStorageType(); + if (type.isa()) + caller.walkResultExtents([&](const Fortran::lower::SomeExpr &e) { + extents.emplace_back(lowerSpecExpr(e)); + }); + caller.walkResultLengths([&](const Fortran::lower::SomeExpr &e) { + lengths.emplace_back(lowerSpecExpr(e)); + }); + + // Result length parameters should not be provided to box storage + // allocation and save_results, but they are still useful information to + // keep in the ExtendedValue if non-deferred. + if (!type.isa()) { + if (fir::isa_char(fir::unwrapSequenceType(type)) && lengths.empty()) { + // Calling an assumed length function. This is only possible if this + // is a call to a character dummy procedure. + if (!charFuncPointerLength) + fir::emitFatalError(loc, "failed to retrieve character function " + "length while calling it"); + lengths.push_back(charFuncPointerLength); + } + resultLengths = lengths; + } + + if (!extents.empty() || !lengths.empty()) { + TODO(loc, "genCallOpResult extents and length"); + } + mlir::Value temp = + builder.createTemporary(loc, type, ".result", extents, resultLengths); + return toExtendedValue(loc, temp, extents, lengths); + }(); + + if (mustPopSymMap) + symMap.popScope(); + + // Place allocated result or prepare the fir.save_result arguments. + mlir::Value arrayResultShape; + if (allocatedResult) { + if (std::optional::PassedEntity> + resultArg = caller.getPassedResult()) { + if (resultArg->passBy == PassBy::AddressAndLength) + caller.placeAddressAndLengthInput(*resultArg, + fir::getBase(*allocatedResult), + fir::getLen(*allocatedResult)); + else if (resultArg->passBy == PassBy::BaseAddress) + caller.placeInput(*resultArg, fir::getBase(*allocatedResult)); + else + fir::emitFatalError( + loc, "only expect character scalar result to be passed by ref"); + } else { + assert(caller.mustSaveResult()); + arrayResultShape = allocatedResult->match( + [&](const fir::CharArrayBoxValue &) { + return builder.createShape(loc, *allocatedResult); + }, + [&](const fir::ArrayBoxValue &) { + return builder.createShape(loc, *allocatedResult); + }, + [&](const auto &) { return mlir::Value{}; }); + } + } + + // In older Fortran, procedure argument types are inferred. This may lead + // different view of what the function signature is in different locations. + // Casts are inserted as needed below to accommodate this. + + // The mlir::FuncOp type prevails, unless it has a different number of + // arguments which can happen in legal program if it was passed as a dummy + // procedure argument earlier with no further type information. + mlir::SymbolRefAttr funcSymbolAttr; + bool addHostAssociations = false; + if (!funcPointer) { + mlir::FunctionType funcOpType = caller.getFuncOp().getType(); + mlir::SymbolRefAttr symbolAttr = + builder.getSymbolRefAttr(caller.getMangledName()); + if (callSiteType.getNumResults() == funcOpType.getNumResults() && + callSiteType.getNumInputs() + 1 == funcOpType.getNumInputs() && + fir::anyFuncArgsHaveAttr(caller.getFuncOp(), + fir::getHostAssocAttrName())) { + // The number of arguments is off by one, and we're lowering a function + // with host associations. Modify call to include host associations + // argument by appending the value at the end of the operands. + assert(funcOpType.getInput(findHostAssocTuplePos(caller.getFuncOp())) == + converter.hostAssocTupleValue().getType()); + addHostAssociations = true; + } + if (!addHostAssociations && + (callSiteType.getNumResults() != funcOpType.getNumResults() || + callSiteType.getNumInputs() != funcOpType.getNumInputs())) { + // Deal with argument number mismatch by making a function pointer so + // that function type cast can be inserted. Do not emit a warning here + // because this can happen in legal program if the function is not + // defined here and it was first passed as an argument without any more + // information. + funcPointer = + builder.create(loc, funcOpType, symbolAttr); + } else if (callSiteType.getResults() != funcOpType.getResults()) { + // Implicit interface result type mismatch are not standard Fortran, but + // some compilers are not complaining about it. The front end is not + // protecting lowering from this currently. Support this with a + // discouraging warning. + LLVM_DEBUG(mlir::emitWarning( + loc, "a return type mismatch is not standard compliant and may " + "lead to undefined behavior.")); + // Cast the actual function to the current caller implicit type because + // that is the behavior we would get if we could not see the definition. + funcPointer = + builder.create(loc, funcOpType, symbolAttr); + } else { + funcSymbolAttr = symbolAttr; + } + } + + mlir::FunctionType funcType = + funcPointer ? callSiteType : caller.getFuncOp().getType(); + llvm::SmallVector operands; + // First operand of indirect call is the function pointer. Cast it to + // required function type for the call to handle procedures that have a + // compatible interface in Fortran, but that have different signatures in + // FIR. + if (funcPointer) { + operands.push_back( + funcPointer.getType().isa() + ? builder.create(loc, funcType, funcPointer) + : builder.createConvert(loc, funcType, funcPointer)); + } + + // Deal with potential mismatches in arguments types. Passing an array to a + // scalar argument should for instance be tolerated here. + bool callingImplicitInterface = caller.canBeCalledViaImplicitInterface(); + for (auto [fst, snd] : + llvm::zip(caller.getInputs(), funcType.getInputs())) { + // When passing arguments to a procedure that can be called an implicit + // interface, allow character actual arguments to be passed to dummy + // arguments of any type and vice versa + mlir::Value cast; + auto *context = builder.getContext(); + if (snd.isa() && + fst.getType().isa()) { + auto funcTy = mlir::FunctionType::get(context, llvm::None, llvm::None); + auto boxProcTy = builder.getBoxProcType(funcTy); + if (mlir::Value host = argumentHostAssocs(converter, fst)) { + cast = builder.create( + loc, boxProcTy, llvm::ArrayRef{fst, host}); + } else { + cast = builder.create(loc, boxProcTy, fst); + } + } else { + cast = builder.convertWithSemantics(loc, snd, fst, + callingImplicitInterface); + } + operands.push_back(cast); + } + + // Add host associations as necessary. + if (addHostAssociations) + operands.push_back(converter.hostAssocTupleValue()); + + auto call = builder.create(loc, funcType.getResults(), + funcSymbolAttr, operands); + + if (caller.mustSaveResult()) + builder.create( + loc, call.getResult(0), fir::getBase(allocatedResult.getValue()), + arrayResultShape, resultLengths); + + if (allocatedResult) { + allocatedResult->match( + [&](const fir::MutableBoxValue &box) { + if (box.isAllocatable()) { + TODO(loc, "allocatedResult for allocatable"); + } + }, + [](const auto &) {}); + return *allocatedResult; + } + + if (!resultType.hasValue()) + return mlir::Value{}; // subroutine call + // For now, Fortran return values are implemented with a single MLIR + // function return value. + assert(call.getNumResults() == 1 && + "Expected exactly one result in FUNCTION call"); + return call.getResult(0); + } + + /// Like genExtAddr, but ensure the address returned is a temporary even if \p + /// expr is variable inside parentheses. + ExtValue genTempExtAddr(const Fortran::lower::SomeExpr &expr) { + // In general, genExtAddr might not create a temp for variable inside + // parentheses to avoid creating array temporary in sub-expressions. It only + // ensures the sub-expression is not re-associated with other parts of the + // expression. In the call semantics, there is a difference between expr and + // variable (see R1524). For expressions, a variable storage must not be + // argument associated since it could be modified inside the call, or the + // variable could also be modified by other means during the call. + if (!isParenthesizedVariable(expr)) + return genExtAddr(expr); + mlir::Location loc = getLoc(); + if (expr.Rank() > 0) + TODO(loc, "genTempExtAddr array"); + return genExtValue(expr).match( + [&](const fir::CharBoxValue &boxChar) -> ExtValue { + TODO(loc, "genTempExtAddr CharBoxValue"); + }, + [&](const fir::UnboxedValue &v) -> ExtValue { + mlir::Type type = v.getType(); + mlir::Value value = v; + if (fir::isa_ref_type(type)) + value = builder.create(loc, value); + mlir::Value temp = builder.createTemporary(loc, value.getType()); + builder.create(loc, value, temp); + return temp; + }, + [&](const fir::BoxValue &x) -> ExtValue { + // Derived type scalar that may be polymorphic. + assert(!x.hasRank() && x.isDerived()); + if (x.isDerivedWithLengthParameters()) + fir::emitFatalError( + loc, "making temps for derived type with length parameters"); + // TODO: polymorphic aspects should be kept but for now the temp + // created always has the declared type. + mlir::Value var = + fir::getBase(fir::factory::readBoxValue(builder, loc, x)); + auto value = builder.create(loc, var); + mlir::Value temp = builder.createTemporary(loc, value.getType()); + builder.create(loc, value, temp); + return temp; + }, + [&](const auto &) -> ExtValue { + fir::emitFatalError(loc, "expr is not a scalar value"); + }); + } + + /// Helper structure to track potential copy-in of non contiguous variable + /// argument into a contiguous temp. It is used to deallocate the temp that + /// may have been created as well as to the copy-out from the temp to the + /// variable after the call. + struct CopyOutPair { + ExtValue var; + ExtValue temp; + // Flag to indicate if the argument may have been modified by the + // callee, in which case it must be copied-out to the variable. + bool argMayBeModifiedByCall; + // Optional boolean value that, if present and false, prevents + // the copy-out and temp deallocation. + llvm::Optional restrictCopyAndFreeAtRuntime; + }; + using CopyOutPairs = llvm::SmallVector; + + /// Helper to read any fir::BoxValue into other fir::ExtendedValue categories + /// not based on fir.box. + /// This will lose any non contiguous stride information and dynamic type and + /// should only be called if \p exv is known to be contiguous or if its base + /// address will be replaced by a contiguous one. If \p exv is not a + /// fir::BoxValue, this is a no-op. + ExtValue readIfBoxValue(const ExtValue &exv) { + if (const auto *box = exv.getBoxOf()) + return fir::factory::readBoxValue(builder, getLoc(), *box); + return exv; + } + + /// Lower a non-elemental procedure reference. + ExtValue genRawProcedureRef(const Fortran::evaluate::ProcedureRef &procRef, + llvm::Optional resultType) { + mlir::Location loc = getLoc(); + if (isElementalProcWithArrayArgs(procRef)) + fir::emitFatalError(loc, "trying to lower elemental procedure with array " + "arguments as normal procedure"); + if (const Fortran::evaluate::SpecificIntrinsic *intrinsic = + procRef.proc().GetSpecificIntrinsic()) + return genIntrinsicRef(procRef, *intrinsic, resultType); + + if (isStatementFunctionCall(procRef)) + TODO(loc, "Lower statement function call"); + + Fortran::lower::CallerInterface caller(procRef, converter); + using PassBy = Fortran::lower::CallerInterface::PassEntityBy; + + llvm::SmallVector mutableModifiedByCall; + // List of where temp must be copied into var after the call. + CopyOutPairs copyOutPairs; + + mlir::FunctionType callSiteType = caller.genFunctionType(); + + // Lower the actual arguments and map the lowered values to the dummy + // arguments. + for (const Fortran::lower::CallInterface< + Fortran::lower::CallerInterface>::PassedEntity &arg : + caller.getPassedArguments()) { + const auto *actual = arg.entity; + mlir::Type argTy = callSiteType.getInput(arg.firArgument); + if (!actual) { + // Optional dummy argument for which there is no actual argument. + caller.placeInput(arg, builder.create(loc, argTy)); + continue; + } + const auto *expr = actual->UnwrapExpr(); + if (!expr) + TODO(loc, "assumed type actual argument lowering"); + + if (arg.passBy == PassBy::Value) { + ExtValue argVal = genval(*expr); + if (!fir::isUnboxedValue(argVal)) + fir::emitFatalError( + loc, "internal error: passing non trivial value by value"); + caller.placeInput(arg, fir::getBase(argVal)); + continue; + } + + if (arg.passBy == PassBy::MutableBox) { + TODO(loc, "arg passby MutableBox"); + } + const bool actualArgIsVariable = Fortran::evaluate::IsVariable(*expr); + if (arg.passBy == PassBy::BaseAddress || arg.passBy == PassBy::BoxChar) { + auto argAddr = [&]() -> ExtValue { + ExtValue baseAddr; + if (actualArgIsVariable && arg.isOptional()) { + if (Fortran::evaluate::IsAllocatableOrPointerObject( + *expr, converter.getFoldingContext())) { + TODO(loc, "Allocatable or pointer argument"); + } + if (const Fortran::semantics::Symbol *wholeSymbol = + Fortran::evaluate::UnwrapWholeSymbolOrComponentDataRef( + *expr)) + if (Fortran::semantics::IsOptional(*wholeSymbol)) { + TODO(loc, "procedureref optional arg"); + } + // Fall through: The actual argument can safely be + // copied-in/copied-out without any care if needed. + } + if (actualArgIsVariable && expr->Rank() > 0) { + TODO(loc, "procedureref arrays"); + } + // Actual argument is a non optional/non pointer/non allocatable + // scalar. + if (actualArgIsVariable) + return genExtAddr(*expr); + // Actual argument is not a variable. Make sure a variable address is + // not passed. + return genTempExtAddr(*expr); + }(); + // Scalar and contiguous expressions may be lowered to a fir.box, + // either to account for potential polymorphism, or because lowering + // did not account for some contiguity hints. + // Here, polymorphism does not matter (an entity of the declared type + // is passed, not one of the dynamic type), and the expr is known to + // be simply contiguous, so it is safe to unbox it and pass the + // address without making a copy. + argAddr = readIfBoxValue(argAddr); + + if (arg.passBy == PassBy::BaseAddress) { + caller.placeInput(arg, fir::getBase(argAddr)); + } else { + TODO(loc, "procedureref PassBy::BoxChar"); + } + } else if (arg.passBy == PassBy::Box) { + // Before lowering to an address, handle the allocatable/pointer actual + // argument to optional fir.box dummy. It is legal to pass + // unallocated/disassociated entity to an optional. In this case, an + // absent fir.box must be created instead of a fir.box with a null value + // (Fortran 2018 15.5.2.12 point 1). + if (arg.isOptional() && Fortran::evaluate::IsAllocatableOrPointerObject( + *expr, converter.getFoldingContext())) { + TODO(loc, "optional allocatable or pointer argument"); + } else { + // Make sure a variable address is only passed if the expression is + // actually a variable. + mlir::Value box = + actualArgIsVariable + ? builder.createBox(loc, genBoxArg(*expr)) + : builder.createBox(getLoc(), genTempExtAddr(*expr)); + caller.placeInput(arg, box); + } + } else if (arg.passBy == PassBy::AddressAndLength) { + ExtValue argRef = genExtAddr(*expr); + caller.placeAddressAndLengthInput(arg, fir::getBase(argRef), + fir::getLen(argRef)); + } else if (arg.passBy == PassBy::CharProcTuple) { + TODO(loc, "procedureref CharProcTuple"); + } else { + TODO(loc, "pass by value in non elemental function call"); + } + } + + ExtValue result = genCallOpAndResult(caller, callSiteType, resultType); + + // // Copy-out temps that were created for non contiguous variable arguments + // if + // // needed. + // for (const auto ©OutPair : copyOutPairs) + // genCopyOut(copyOutPair); + + return result; + } + template ExtValue genval(const Fortran::evaluate::FunctionRef &funcRef) { ExtValue result = genFunctionRef(funcRef); @@ -525,7 +1078,10 @@ } ExtValue genval(const Fortran::evaluate::ProcedureRef &procRef) { - TODO(getLoc(), "genval ProcedureRef"); + llvm::Optional resTy; + if (procRef.hasAlternateReturns()) + resTy = builder.getIndexType(); + return genProcedureRef(procRef, resTy); } /// Generate a call to an intrinsic function. @@ -586,28 +1142,6 @@ TODO(getLoc(), "genval Expr arrays"); } - /// Lower a non-elemental procedure reference. - // TODO: Handle read allocatable and pointer results. - ExtValue genProcedureRef(const Fortran::evaluate::ProcedureRef &procRef, - llvm::Optional resultType) { - ExtValue res = genRawProcedureRef(procRef, resultType); - return res; - } - - /// Lower a non-elemental procedure reference. - ExtValue genRawProcedureRef(const Fortran::evaluate::ProcedureRef &procRef, - llvm::Optional resultType) { - mlir::Location loc = getLoc(); - if (isElementalProcWithArrayArgs(procRef)) - fir::emitFatalError(loc, "trying to lower elemental procedure with array " - "arguments as normal procedure"); - if (const Fortran::evaluate::SpecificIntrinsic *intrinsic = - procRef.proc().GetSpecificIntrinsic()) - return genIntrinsicRef(procRef, *intrinsic, resultType); - - return {}; - } - /// Helper to detect Transformational function reference. template bool isTransformationalRef(const T &) { @@ -679,20 +1213,35 @@ mlir::Location location; Fortran::lower::AbstractConverter &converter; fir::FirOpBuilder &builder; + Fortran::lower::StatementContext &stmtCtx; Fortran::lower::SymMap &symMap; + bool useBoxArg = false; // expression lowered as argument }; } // namespace fir::ExtendedValue Fortran::lower::createSomeExtendedExpression( mlir::Location loc, Fortran::lower::AbstractConverter &converter, - const Fortran::lower::SomeExpr &expr, Fortran::lower::SymMap &symMap) { + const Fortran::lower::SomeExpr &expr, Fortran::lower::SymMap &symMap, + Fortran::lower::StatementContext &stmtCtx) { LLVM_DEBUG(expr.AsFortran(llvm::dbgs() << "expr: ") << '\n'); - return ScalarExprLowering{loc, converter, symMap}.genval(expr); + return ScalarExprLowering{loc, converter, symMap, stmtCtx}.genval(expr); } fir::ExtendedValue Fortran::lower::createSomeExtendedAddress( mlir::Location loc, Fortran::lower::AbstractConverter &converter, - const Fortran::lower::SomeExpr &expr, Fortran::lower::SymMap &symMap) { + const Fortran::lower::SomeExpr &expr, Fortran::lower::SymMap &symMap, + Fortran::lower::StatementContext &stmtCtx) { LLVM_DEBUG(expr.AsFortran(llvm::dbgs() << "address: ") << '\n'); - return ScalarExprLowering{loc, converter, symMap}.gen(expr); + return ScalarExprLowering{loc, converter, symMap, stmtCtx}.gen(expr); +} + +mlir::Value Fortran::lower::createSubroutineCall( + AbstractConverter &converter, const evaluate::ProcedureRef &call, + SymMap &symMap, StatementContext &stmtCtx) { + mlir::Location loc = converter.getCurrentLocation(); + + // Simple subroutine call, with potential alternate return. + auto res = Fortran::lower::createSomeExtendedExpression( + loc, converter, toEvExpr(call), symMap, stmtCtx); + return fir::getBase(res); } diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp --- a/flang/lib/Lower/ConvertVariable.cpp +++ b/flang/lib/Lower/ConvertVariable.cpp @@ -107,3 +107,50 @@ instantiateLocal(converter, var, symMap); } } + +void Fortran::lower::mapCallInterfaceSymbols( + AbstractConverter &converter, const Fortran::lower::CallerInterface &caller, + SymMap &symMap) { + const Fortran::semantics::Symbol &result = caller.getResultSymbol(); + for (Fortran::lower::pft::Variable var : + Fortran::lower::pft::buildFuncResultDependencyList(result)) { + if (var.isAggregateStore()) { + instantiateVariable(converter, var, symMap); + } else { + const Fortran::semantics::Symbol &sym = var.getSymbol(); + const auto *hostDetails = + sym.detailsIf(); + if (hostDetails && !var.isModuleVariable()) { + // The callee is an internal procedure `A` whose result properties + // depend on host variables. The caller may be the host, or another + // internal procedure `B` contained in the same host. In the first + // case, the host symbol is obviously mapped, in the second case, it + // must also be mapped because + // HostAssociations::internalProcedureBindings that was called when + // lowering `B` will have mapped all host symbols of captured variables + // to the tuple argument containing the composite of all host associated + // variables, whether or not the host symbol is actually referred to in + // `B`. Hence it is possible to simply lookup the variable associated to + // the host symbol without having to go back to the tuple argument. + Fortran::lower::SymbolBox hostValue = + symMap.lookupSymbol(hostDetails->symbol()); + assert(hostValue && "callee host symbol must be mapped on caller side"); + symMap.addSymbol(sym, hostValue.toExtendedValue()); + // The SymbolBox associated to the host symbols is complete, skip + // instantiateVariable that would try to allocate a new storage. + continue; + } + if (Fortran::semantics::IsDummy(sym) && sym.owner() == result.owner()) { + // Get the argument for the dummy argument symbols of the current call. + symMap.addSymbol(sym, caller.getArgumentValue(sym)); + // All the properties of the dummy variable may not come from the actual + // argument, let instantiateVariable handle this. + } + // If this is neither a host associated or dummy symbol, it must be a + // module or common block variable to satisfy specification expression + // requirements in 10.1.11, instantiateVariable will get its address and + // properties. + instantiateVariable(converter, var, symMap); + } + } +} diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -14,6 +14,7 @@ #include "flang/Common/idioms.h" #include "flang/Lower/Bridge.h" #include "flang/Lower/PFTBuilder.h" +#include "flang/Lower/StatementContext.h" #include "flang/Lower/Todo.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/FIRBuilder.h" @@ -120,7 +121,7 @@ static void genACC(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { - + Fortran::lower::StatementContext stmtCtx; const auto &beginLoopDirective = std::get(loopConstruct.t); const auto &loopDirective = @@ -151,7 +152,7 @@ std::get>( x.t)) { gangNum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(gangNumValue.value()))); + *Fortran::semantics::GetExpr(gangNumValue.value()), stmtCtx)); } if (const auto &gangStaticValue = std::get>(x.t)) { @@ -159,8 +160,8 @@ std::get>( gangStaticValue.value().t); if (expr) { - gangStatic = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(*expr))); + gangStatic = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*expr), stmtCtx)); } else { // * was passed as value and will be represented as a -1 constant // integer. @@ -176,7 +177,7 @@ &clause.u)) { if (workerClause->v) { workerNum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*workerClause->v))); + *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)); } executionMapping |= mlir::acc::OpenACCExecMapping::WORKER; } else if (const auto *vectorClause = @@ -184,7 +185,7 @@ &clause.u)) { if (vectorClause->v) { vectorLength = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*vectorClause->v))); + *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)); } executionMapping |= mlir::acc::OpenACCExecMapping::VECTOR; } else if (const auto *tileClause = @@ -195,8 +196,8 @@ std::get>( accTileExpr.t); if (expr) { - tileOperands.push_back(fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(*expr)))); + tileOperands.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*expr), stmtCtx))); } else { // * was passed as value and will be represented as a -1 constant // integer. @@ -281,6 +282,7 @@ auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; // Lower clauses values mapped to operands. // Keep track of each group of operands separatly as clauses can appear @@ -291,7 +293,7 @@ const auto &asyncClauseValue = asyncClause->v; if (asyncClauseValue) { // async has a value. async = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue))); + *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); } else { addAsyncAttr = true; } @@ -303,8 +305,8 @@ const std::list &waitList = std::get>(waitArg.t); for (const Fortran::parser::ScalarIntExpr &value : waitList) { - Value v = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(value))); + Value v = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(value), stmtCtx)); waitOperands.push_back(v); } } else { @@ -314,21 +316,21 @@ std::get_if( &clause.u)) { numGangs = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numGangsClause->v))); + *Fortran::semantics::GetExpr(numGangsClause->v), stmtCtx)); } else if (const auto *numWorkersClause = std::get_if( &clause.u)) { numWorkers = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numWorkersClause->v))); + *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)); } else if (const auto *vectorLengthClause = std::get_if( &clause.u)) { vectorLength = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(vectorLengthClause->v))); + *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)); } else if (const auto *ifClause = std::get_if(&clause.u)) { - Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); + Value cond = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); ifCond = firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), cond); } else if (const auto *selfClause = @@ -339,7 +341,7 @@ &accSelfClause.u)) { if (*optCondition) { Value cond = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*optCondition))); + *Fortran::semantics::GetExpr(*optCondition), stmtCtx)); selfCond = firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), cond); } else { @@ -442,6 +444,7 @@ auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; // Lower clauses values mapped to operands. // Keep track of each group of operands separatly as clauses can appear @@ -449,8 +452,8 @@ for (const auto &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); + Value cond = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); ifCond = firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), cond); } else if (const auto *copyClause = @@ -546,6 +549,7 @@ auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; // Lower clauses values mapped to operands. // Keep track of each group of operands separatly as clauses can appear @@ -553,8 +557,8 @@ for (const auto &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - mlir::Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); + mlir::Value cond = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); ifCond = firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), cond); } else if (const auto *asyncClause = @@ -562,7 +566,7 @@ const auto &asyncClauseValue = asyncClause->v; if (asyncClauseValue) { // async has a value. async = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue))); + *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); } else { addAsyncAttr = true; } @@ -574,8 +578,8 @@ const std::list &waitList = std::get>(waitArg.t); for (const Fortran::parser::ScalarIntExpr &value : waitList) { - mlir::Value v = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(value))); + mlir::Value v = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(value), stmtCtx)); waitOperands.push_back(v); } @@ -583,7 +587,7 @@ std::get>(waitArg.t); if (waitDevnumValue) waitDevnum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*waitDevnumValue))); + *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)); } else { addWaitAttr = true; } @@ -646,6 +650,7 @@ auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; // Lower clauses values mapped to operands. // Keep track of each group of operands separatly as clauses can appear @@ -653,8 +658,8 @@ for (const auto &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); + Value cond = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); ifCond = firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), cond); } else if (const auto *asyncClause = @@ -662,7 +667,7 @@ const auto &asyncClauseValue = asyncClause->v; if (asyncClauseValue) { // async has a value. async = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue))); + *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); } else { addAsyncAttr = true; } @@ -674,8 +679,8 @@ const std::list &waitList = std::get>(waitArg.t); for (const Fortran::parser::ScalarIntExpr &value : waitList) { - Value v = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(value))); + Value v = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(value), stmtCtx)); waitOperands.push_back(v); } @@ -683,7 +688,7 @@ std::get>(waitArg.t); if (waitDevnumValue) waitDevnum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*waitDevnumValue))); + *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)); } else { addWaitAttr = true; } @@ -737,6 +742,7 @@ auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; // Lower clauses values mapped to operands. // Keep track of each group of operands separatly as clauses can appear @@ -744,15 +750,15 @@ for (const auto &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - mlir::Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); + mlir::Value cond = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); ifCond = firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), cond); } else if (const auto *deviceNumClause = std::get_if( &clause.u)) { deviceNum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(deviceNumClause->v))); + *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx)); } else if (const auto *deviceTypeClause = std::get_if( &clause.u)) { @@ -761,7 +767,7 @@ if (deviceTypeValue) { for (const auto &scalarIntExpr : *deviceTypeValue) { mlir::Value expr = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(scalarIntExpr))); + *Fortran::semantics::GetExpr(scalarIntExpr), stmtCtx)); deviceTypeOperands.push_back(expr); } } else { @@ -800,6 +806,7 @@ auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; // Lower clauses values mapped to operands. // Keep track of each group of operands separatly as clauses can appear @@ -807,8 +814,8 @@ for (const auto &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - mlir::Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); + mlir::Value cond = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); ifCond = firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), cond); } else if (const auto *asyncClause = @@ -816,7 +823,7 @@ const auto &asyncClauseValue = asyncClause->v; if (asyncClauseValue) { // async has a value. async = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue))); + *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); } else { addAsyncAttr = true; } @@ -828,8 +835,8 @@ const std::list &waitList = std::get>(waitArg.t); for (const Fortran::parser::ScalarIntExpr &value : waitList) { - mlir::Value v = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(value))); + mlir::Value v = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(value), stmtCtx)); waitOperands.push_back(v); } @@ -837,7 +844,7 @@ std::get>(waitArg.t); if (waitDevnumValue) waitDevnum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*waitDevnumValue))); + *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)); } else { addWaitAttr = true; } @@ -849,7 +856,7 @@ if (deviceTypeValue) { for (const auto &scalarIntExpr : *deviceTypeValue) { mlir::Value expr = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(scalarIntExpr))); + *Fortran::semantics::GetExpr(scalarIntExpr), stmtCtx)); deviceTypeOperands.push_back(expr); } } else { @@ -935,6 +942,7 @@ auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; if (waitArgument) { // wait has a value. const Fortran::parser::AccWaitArgument &waitArg = *waitArgument; @@ -942,7 +950,7 @@ std::get>(waitArg.t); for (const Fortran::parser::ScalarIntExpr &value : waitList) { mlir::Value v = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(value))); + converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx)); waitOperands.push_back(v); } @@ -950,7 +958,7 @@ std::get>(waitArg.t); if (waitDevnumValue) waitDevnum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*waitDevnumValue))); + *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)); } // Lower clauses values mapped to operands. @@ -959,8 +967,8 @@ for (const auto &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - mlir::Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); + mlir::Value cond = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); ifCond = firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), cond); } else if (const auto *asyncClause = @@ -968,7 +976,7 @@ const auto &asyncClauseValue = asyncClause->v; if (asyncClauseValue) { // async has a value. async = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue))); + *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); } else { addAsyncAttr = true; } diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -14,6 +14,7 @@ #include "flang/Common/idioms.h" #include "flang/Lower/Bridge.h" #include "flang/Lower/PFTBuilder.h" +#include "flang/Lower/StatementContext.h" #include "flang/Lower/Todo.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/FIRBuilder.h" @@ -139,6 +140,7 @@ auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; llvm::ArrayRef argTy; if (blockDirective.v == llvm::omp::OMPD_parallel) { @@ -152,14 +154,14 @@ std::get_if(&clause.u)) { auto &expr = std::get(ifClause->v.t); - ifClauseOperand = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(expr))); + ifClauseOperand = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(expr), stmtCtx)); } else if (const auto &numThreadsClause = std::get_if( &clause.u)) { // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`. numThreadsClauseOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numThreadsClause->v))); + *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); } // TODO: Handle private, firstprivate, shared and copyin } diff --git a/flang/lib/Lower/Runtime.cpp b/flang/lib/Lower/Runtime.cpp --- a/flang/lib/Lower/Runtime.cpp +++ b/flang/lib/Lower/Runtime.cpp @@ -8,6 +8,7 @@ #include "flang/Lower/Runtime.h" #include "flang/Lower/Bridge.h" +#include "flang/Lower/StatementContext.h" #include "flang/Lower/Todo.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" @@ -38,13 +39,15 @@ const Fortran::parser::StopStmt &stmt) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); mlir::Location loc = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; llvm::SmallVector operands; mlir::FuncOp callee; mlir::FunctionType calleeType; // First operand is stop code (zero if absent) if (const auto &code = std::get>(stmt.t)) { - auto expr = converter.genExprValue(*Fortran::semantics::GetExpr(*code)); + auto expr = + converter.genExprValue(*Fortran::semantics::GetExpr(*code), stmtCtx); LLVM_DEBUG(llvm::dbgs() << "stop expression: "; expr.dump(); llvm::dbgs() << '\n'); expr.match( @@ -88,7 +91,7 @@ std::get>(stmt.t)) { const SomeExpr *expr = Fortran::semantics::GetExpr(*quiet); assert(expr && "failed getting typed expression"); - mlir::Value q = fir::getBase(converter.genExprValue(*expr)); + mlir::Value q = fir::getBase(converter.genExprValue(*expr, stmtCtx)); operands.push_back( builder.createConvert(loc, calleeType.getInput(operands.size()), q)); } else { diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -38,6 +38,11 @@ return modOp.lookupSymbol(name); } +mlir::FuncOp fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp, + mlir::SymbolRefAttr symbol) { + return modOp.lookupSymbol(symbol); +} + fir::GlobalOp fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp, llvm::StringRef name) { return modOp.lookupSymbol(name); @@ -258,9 +263,10 @@ return glob; } -mlir::Value fir::FirOpBuilder::convertWithSemantics(mlir::Location loc, - mlir::Type toTy, - mlir::Value val) { +mlir::Value +fir::FirOpBuilder::convertWithSemantics(mlir::Location loc, mlir::Type toTy, + mlir::Value val, + bool allowCharacterConversion) { assert(toTy && "store location must be typed"); auto fromTy = val.getType(); if (fromTy == toTy) @@ -282,6 +288,35 @@ auto rp = helper.extractComplexPart(val, /*isImagPart=*/false); return createConvert(loc, toTy, rp); } + if (allowCharacterConversion) { + if (fromTy.isa()) { + // Extract the address of the character string and pass it + fir::factory::CharacterExprHelper charHelper{*this, loc}; + std::pair unboxchar = + charHelper.createUnboxChar(val); + return createConvert(loc, toTy, unboxchar.first); + } + if (auto boxType = toTy.dyn_cast()) { + // Extract the address of the actual argument and create a boxed + // character value with an undefined length + // TODO: We should really calculate the total size of the actual + // argument in characters and use it as the length of the string + auto refType = getRefType(boxType.getEleTy()); + mlir::Value charBase = createConvert(loc, refType, val); + mlir::Value unknownLen = create(loc, getIndexType()); + fir::factory::CharacterExprHelper charHelper{*this, loc}; + return charHelper.createEmboxChar(charBase, unknownLen); + } + } + if (fir::isa_ref_type(toTy) && fir::isa_box_type(fromTy)) { + // Call is expecting a raw data pointer, not a box. Get the data pointer out + // of the box and pass that. + assert((fir::unwrapRefType(toTy) == + fir::unwrapRefType(fir::unwrapPassByRefType(fromTy)) && + "element types expected to match")); + return create(loc, toTy, val); + } + return createConvert(loc, toTy, val); } @@ -523,6 +558,29 @@ [&](const auto &) -> llvm::SmallVector { return {}; }); } +fir::ExtendedValue fir::factory::readBoxValue(fir::FirOpBuilder &builder, + mlir::Location loc, + const fir::BoxValue &box) { + assert(!box.isUnlimitedPolymorphic() && !box.hasAssumedRank() && + "cannot read unlimited polymorphic or assumed rank fir.box"); + auto addr = + builder.create(loc, box.getMemTy(), box.getAddr()); + if (box.isCharacter()) { + auto len = fir::factory::readCharLen(builder, loc, box); + if (box.rank() == 0) + return fir::CharBoxValue(addr, len); + return fir::CharArrayBoxValue(addr, len, + fir::factory::readExtents(builder, loc, box), + box.getLBounds()); + } + if (box.isDerivedWithLengthParameters()) + TODO(loc, "read fir.box with length parameters"); + if (box.rank() == 0) + return addr; + return fir::ArrayBoxValue(addr, fir::factory::readExtents(builder, loc, box), + box.getLBounds()); +} + std::string fir::factory::uniqueCGIdent(llvm::StringRef prefix, llvm::StringRef name) { // For "long" identifiers use a hash value diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -3291,6 +3291,13 @@ return false; } +bool fir::anyFuncArgsHaveAttr(mlir::FuncOp func, llvm::StringRef attr) { + for (unsigned i = 0, end = func.getNumArguments(); i < end; ++i) + if (func.getArgAttr(i, attr)) + return true; + return false; +} + mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) { for (auto i = path.begin(), end = path.end(); eleTy && i < end;) { eleTy = llvm::TypeSwitch(eleTy) diff --git a/flang/test/Lower/basic-call.f90 b/flang/test/Lower/basic-call.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/basic-call.f90 @@ -0,0 +1,49 @@ +! RUN: bbc %s -o "-" -emit-fir | FileCheck %s + +subroutine sub1() +end +! CHECK-LABEL: func @_QPsub1() + +subroutine sub2() + call sub1() +end + +! CHECK-LABEL: func @_QPsub2() +! CHECK: fir.call @_QPsub1() : () -> () + +subroutine sub3(a, b) + integer :: a + real :: b +end + +! CHECK-LABEL: func @_QPsub3( +! CHECK-SAME: %{{.*}}: !fir.ref {fir.bindc_name = "a"}, +! CHECK-SAME: %{{.*}}: !fir.ref {fir.bindc_name = "b"}) + +subroutine sub4() + call sub3(2, 3.0) +end + +! CHECK-LABEL: func @_QPsub4() { +! CHECK-DAG: %[[REAL_VALUE:.*]] = fir.alloca f32 {adapt.valuebyref} +! CHECK-DAG: %[[INT_VALUE:.*]] = fir.alloca i32 {adapt.valuebyref} +! CHECK: %[[C2:.*]] = arith.constant 2 : i32 +! CHECK: fir.store %[[C2]] to %[[INT_VALUE]] : !fir.ref +! CHECK: %[[C3:.*]] = arith.constant 3.000000e+00 : f32 +! CHECK: fir.store %[[C3]] to %[[REAL_VALUE]] : !fir.ref +! CHECK: fir.call @_QPsub3(%[[INT_VALUE]], %[[REAL_VALUE]]) : (!fir.ref, !fir.ref) -> () + +subroutine call_fct1() + real :: a, b, c + c = fct1(a, b) +end + +! CHECK-LABEL: func @_QPcall_fct1() +! CHECK: %[[A:.*]] = fir.alloca f32 {bindc_name = "a", uniq_name = "_QFcall_fct1Ea"} +! CHECK: %[[B:.*]] = fir.alloca f32 {bindc_name = "b", uniq_name = "_QFcall_fct1Eb"} +! CHECK: %[[C:.*]] = fir.alloca f32 {bindc_name = "c", uniq_name = "_QFcall_fct1Ec"} +! CHECK: %[[RES:.*]] = fir.call @_QPfct1(%[[A]], %[[B]]) : (!fir.ref, !fir.ref) -> f32 +! CHECK: fir.store %[[RES]] to %[[C]] : !fir.ref +! CHECK: return + +! CHECK: func private @_QPfct1(!fir.ref, !fir.ref) -> f32