diff --git a/flang/include/flang/Lower/IO.h b/flang/include/flang/Lower/IO.h --- a/flang/include/flang/Lower/IO.h +++ b/flang/include/flang/Lower/IO.h @@ -19,6 +19,8 @@ namespace Fortran { namespace parser { +struct CloseStmt; +struct OpenStmt; struct ReadStmt; struct PrintStmt; struct WriteStmt; @@ -28,10 +30,16 @@ class AbstractConverter; +/// Generate IO call(s) for CLOSE; return the IOSTAT code +mlir::Value genCloseStatement(AbstractConverter &, const parser::CloseStmt &); + /// Generate IO call(s) for READ; return the IOSTAT code mlir::Value genReadStatement(AbstractConverter &converter, const parser::ReadStmt &stmt); +/// Generate IO call(s) for OPEN; return the IOSTAT code +mlir::Value genOpenStatement(AbstractConverter &, const parser::OpenStmt &); + /// Generate IO call(s) for PRINT void genPrintStatement(AbstractConverter &converter, const parser::PrintStmt &stmt); diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -815,7 +815,8 @@ } void genFIR(const Fortran::parser::CloseStmt &stmt) { - TODO(toLocation(), "CloseStmt lowering"); + mlir::Value iostat = genCloseStatement(*this, stmt); + genIoConditionBranches(getEval(), stmt.v, iostat); } void genFIR(const Fortran::parser::EndfileStmt &stmt) { @@ -831,7 +832,8 @@ } void genFIR(const Fortran::parser::OpenStmt &stmt) { - TODO(toLocation(), "OpenStmt lowering"); + mlir::Value iostat = genOpenStatement(*this, stmt); + genIoConditionBranches(getEval(), stmt.v, iostat); } void genFIR(const Fortran::parser::PrintStmt &stmt) { diff --git a/flang/lib/Lower/IO.cpp b/flang/lib/Lower/IO.cpp --- a/flang/lib/Lower/IO.cpp +++ b/flang/lib/Lower/IO.cpp @@ -1059,6 +1059,20 @@ return false; } +template +static bool hasMem(const A &stmt) { + return hasX(stmt.v); +} + +/// Get the sought expression from the specifier list. +template +static const Fortran::lower::SomeExpr *getExpr(const A &stmt) { + for (const auto &spec : stmt.v) + if (auto *f = std::get_if(&spec.u)) + return Fortran::semantics::GetExpr(f->v); + llvm::report_fatal_error("must have a file unit"); +} + /// For each specifier, build the appropriate call, threading the cookie. template static void threadSpecs(Fortran::lower::AbstractConverter &converter, @@ -1469,6 +1483,77 @@ loc, builder.getIntegerAttr(ty, Fortran::runtime::io::DefaultUnit)); } +//===----------------------------------------------------------------------===// +// Generators for each IO statement type. +//===----------------------------------------------------------------------===// + +template +static mlir::Value genBasicIOStmt(Fortran::lower::AbstractConverter &converter, + const S &stmt) { + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + Fortran::lower::StatementContext stmtCtx; + mlir::Location loc = converter.getCurrentLocation(); + mlir::FuncOp beginFunc = getIORuntimeFunc(loc, builder); + mlir::FunctionType beginFuncTy = beginFunc.getType(); + mlir::Value unit = fir::getBase(converter.genExprValue( + getExpr(stmt), stmtCtx, loc)); + mlir::Value un = builder.createConvert(loc, beginFuncTy.getInput(0), unit); + mlir::Value file = locToFilename(converter, loc, beginFuncTy.getInput(1)); + mlir::Value line = locToLineNo(converter, loc, beginFuncTy.getInput(2)); + auto call = builder.create(loc, beginFunc, + mlir::ValueRange{un, file, line}); + mlir::Value cookie = call.getResult(0); + ConditionSpecInfo csi; + genConditionHandlerCall(converter, loc, cookie, stmt.v, csi); + mlir::Value ok; + auto insertPt = builder.saveInsertionPoint(); + threadSpecs(converter, loc, cookie, stmt.v, csi.hasErrorConditionSpec(), ok); + builder.restoreInsertionPoint(insertPt); + return genEndIO(converter, converter.getCurrentLocation(), cookie, csi, + stmtCtx); +} + +mlir::Value +Fortran::lower::genOpenStatement(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OpenStmt &stmt) { + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + Fortran::lower::StatementContext stmtCtx; + mlir::FuncOp beginFunc; + llvm::SmallVector beginArgs; + mlir::Location loc = converter.getCurrentLocation(); + if (hasMem(stmt)) { + beginFunc = getIORuntimeFunc(loc, builder); + mlir::FunctionType beginFuncTy = beginFunc.getType(); + mlir::Value unit = fir::getBase(converter.genExprValue( + getExpr(stmt), stmtCtx, loc)); + beginArgs.push_back( + builder.createConvert(loc, beginFuncTy.getInput(0), unit)); + beginArgs.push_back(locToFilename(converter, loc, beginFuncTy.getInput(1))); + beginArgs.push_back(locToLineNo(converter, loc, beginFuncTy.getInput(2))); + } else { + assert(hasMem(stmt)); + beginFunc = getIORuntimeFunc(loc, builder); + mlir::FunctionType beginFuncTy = beginFunc.getType(); + beginArgs.push_back(locToFilename(converter, loc, beginFuncTy.getInput(0))); + beginArgs.push_back(locToLineNo(converter, loc, beginFuncTy.getInput(1))); + } + auto cookie = + builder.create(loc, beginFunc, beginArgs).getResult(0); + ConditionSpecInfo csi; + genConditionHandlerCall(converter, loc, cookie, stmt.v, csi); + mlir::Value ok; + auto insertPt = builder.saveInsertionPoint(); + threadSpecs(converter, loc, cookie, stmt.v, csi.hasErrorConditionSpec(), ok); + builder.restoreInsertionPoint(insertPt); + return genEndIO(converter, loc, cookie, csi, stmtCtx); +} + +mlir::Value +Fortran::lower::genCloseStatement(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::CloseStmt &stmt) { + return genBasicIOStmt(converter, stmt); +} + //===----------------------------------------------------------------------===// // Data transfer statements. // diff --git a/flang/test/Lower/io-statement-1.f90 b/flang/test/Lower/io-statement-1.f90 --- a/flang/test/Lower/io-statement-1.f90 +++ b/flang/test/Lower/io-statement-1.f90 @@ -6,6 +6,12 @@ real :: a(100) ! CHECK-LABEL: _QQmain + ! CHECK: call {{.*}}BeginOpenUnit + ! CHECK-DAG: call {{.*}}SetFile + ! CHECK-DAG: call {{.*}}SetAccess + ! CHECK: call {{.*}}EndIoStatement + open(8, file="foo", access="sequential") + ! CHECK: call {{.*}}BeginExternalListInput ! CHECK: call {{.*}}InputInteger ! CHECK: call {{.*}}InputReal32 @@ -18,6 +24,10 @@ ! CHECK: call {{.*}}EndIoStatement write (8,*) i, f + ! CHECK: call {{.*}}BeginClose + ! CHECK: call {{.*}}EndIoStatement + close(8) + ! CHECK: call {{.*}}BeginExternalListOutput ! CHECK: call {{.*}}OutputAscii ! CHECK: call {{.*}}EndIoStatement