diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h --- a/flang/include/flang/Lower/OpenACC.h +++ b/flang/include/flang/Lower/OpenACC.h @@ -54,8 +54,8 @@ Fortran::semantics::SemanticsContext &, pft::Evaluation &, const parser::OpenACCConstruct &); void genOpenACCDeclarativeConstruct( - AbstractConverter &, pft::Evaluation &, - const parser::OpenACCDeclarativeConstruct &); + AbstractConverter &, Fortran::semantics::SemanticsContext &, + pft::Evaluation &, const parser::OpenACCDeclarativeConstruct &); /// Get a acc.private.recipe op for the given type or create it if it does not /// exist yet. diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -2229,7 +2229,8 @@ void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &accDecl) { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); - genOpenACCDeclarativeConstruct(*this, getEval(), accDecl); + genOpenACCDeclarativeConstruct(*this, bridge.getSemanticsContext(), + getEval(), accDecl); for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) genFIR(e); builder->restoreInsertionPoint(insertPt); diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -14,6 +14,7 @@ #include "flang/Common/idioms.h" #include "flang/Lower/Bridge.h" #include "flang/Lower/ConvertType.h" +#include "flang/Lower/Mangler.h" #include "flang/Lower/PFTBuilder.h" #include "flang/Lower/StatementContext.h" #include "flang/Lower/Support/Utils.h" @@ -2282,6 +2283,134 @@ waitOp.setAsyncAttr(firOpBuilder.getUnitAttr()); } +static void addDeclareAttr(fir::FirOpBuilder &builder, mlir::Operation *op, + mlir::acc::DataClause clause) { + op->setAttr(mlir::acc::getDeclareAttrName(), + mlir::acc::DeclareAttr::get(builder.getContext(), + mlir::acc::DataClauseAttr::get( + builder.getContext(), clause))); +} + +static void genGlobalCtors(Fortran::lower::AbstractConverter &converter, + mlir::OpBuilder &modBuilder, + const Fortran::parser::AccObjectList &accObjectList, + mlir::acc::DataClause clause) { + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + for (const auto &accObject : accObjectList.v) { + mlir::Location operandLocation = genOperandLocation(converter, accObject); + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::Designator &designator) { + if (const auto *name = + Fortran::semantics::getDesignatorNameIfDataRef( + designator)) { + std::string globalName = converter.mangleName(*name->symbol); + fir::GlobalOp globalOp = builder.getNamedGlobal(globalName); + if (!globalOp) + llvm::report_fatal_error("could not retrieve global symbol"); + + // Create the new global constructor op after the FIR global. + std::stringstream globalCtorName; + globalCtorName << globalName << "_acc_ctor"; + auto crtPos = builder.saveInsertionPoint(); + addDeclareAttr(builder, globalOp.getOperation(), clause); + modBuilder.setInsertionPointAfter(globalOp); + auto globalCtor = + modBuilder.create( + operandLocation, globalCtorName.str()); + builder.createBlock(&globalCtor.getRegion(), + globalCtor.getRegion().end(), {}, {}); + builder.setInsertionPointToEnd(&globalCtor.getRegion().back()); + + // Fill up the global constructor region. + fir::AddrOfOp addrOp = builder.create( + operandLocation, + fir::ReferenceType::get(globalOp.getType()), + globalOp.getSymbol()); + addDeclareAttr(builder, addrOp.getOperation(), clause); + std::stringstream asFortran; + asFortran << Fortran::lower::mangle::demangleName(globalName); + llvm::SmallVector bounds; + mlir::acc::CreateOp entry = + createDataEntryOp( + builder, operandLocation, addrOp.getResTy(), asFortran, + bounds, true, clause, addrOp.getResTy().getType()); + builder.create( + operandLocation, mlir::ValueRange{entry.getAccPtr()}); + builder.create(operandLocation); + builder.restoreInsertionPoint(crtPos); + + // TODO: global destructor. + } + }, + [&](const Fortran::parser::Name &name) { + TODO(operandLocation, "OpenACC Global Ctor from parser::Name"); + }}, + accObject.u); + } +} + +template +static void +genGlobalCtorsWithModifier(Fortran::lower::AbstractConverter &converter, + mlir::OpBuilder &modBuilder, const Clause *x, + Fortran::parser::AccDataModifier::Modifier mod, + const mlir::acc::DataClause clause, + const mlir::acc::DataClause clauseWithModifier) { + const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v; + const auto &accObjectList = + std::get(listWithModifier.t); + const auto &modifier = + std::get>( + listWithModifier.t); + mlir::acc::DataClause dataClause = + (modifier && (*modifier).v == mod) ? clauseWithModifier : clause; + genGlobalCtors(converter, modBuilder, accObjectList, dataClause); +} + +static void genACC(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenACCStandaloneDeclarativeConstruct + &declareConstruct) { + + const auto &declarativeDir = + std::get(declareConstruct.t); + const auto &accClauseList = + std::get(declareConstruct.t); + + if (declarativeDir.v == llvm::acc::Directive::ACCD_declare) { + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + llvm::SmallVector dataClauseOperands, copyEntryOperands, + copyoutEntryOperands, createEntryOperands; + Fortran::lower::StatementContext stmtCtx; + auto moduleOp = + builder.getBlock()->getParent()->getParentOfType(); + auto funcOp = + builder.getBlock()->getParent()->getParentOfType(); + if (funcOp) { + TODO(funcOp.getLoc(), "OpenACC declare in function/subroutine"); + } else if (moduleOp) { + mlir::OpBuilder modBuilder(moduleOp.getBodyRegion()); + for (const Fortran::parser::AccClause &clause : accClauseList.v) { + mlir::Location clauseLocation = converter.genLocation(clause.source); + if (const auto *createClause = + std::get_if(&clause.u)) { + genGlobalCtorsWithModifier( + converter, modBuilder, createClause, + Fortran::parser::AccDataModifier::Modifier::Zero, + mlir::acc::DataClause::acc_create, + mlir::acc::DataClause::acc_create_zero); + } else { + TODO(clauseLocation, "OpenACC declare clause"); + } + } + } + return; + } + llvm_unreachable("unsupported declarative directive"); +} + void Fortran::lower::genOpenACCConstruct( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semanticsContext, @@ -2321,6 +2450,7 @@ void Fortran::lower::genOpenACCDeclarativeConstruct( Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenACCDeclarativeConstruct &accDeclConstruct) { @@ -2328,8 +2458,8 @@ common::visitors{ [&](const Fortran::parser::OpenACCStandaloneDeclarativeConstruct &standaloneDeclarativeConstruct) { - TODO(converter.genLocation(standaloneDeclarativeConstruct.source), - "OpenACC Standalone Declarative construct not lowered yet!"); + genACC(converter, semanticsContext, eval, + standaloneDeclarativeConstruct); }, [&](const Fortran::parser::OpenACCRoutineConstruct &routineConstruct) { diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -122,6 +122,10 @@ bool Pre(const parser::OpenACCDeclarativeConstruct &); void Post(const parser::OpenACCDeclarativeConstruct &) { PopContext(); } + void Post(const parser::AccDeclarativeDirective &) { + GetContext().withinConstruct = true; + } + bool Pre(const parser::OpenACCRoutineConstruct &); bool Pre(const parser::AccBindClause &); void Post(const parser::OpenACCStandaloneDeclarativeConstruct &); diff --git a/flang/test/Lower/OpenACC/acc-declare.f90 b/flang/test/Lower/OpenACC/acc-declare.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-declare.f90 @@ -0,0 +1,27 @@ +! This test checks lowering of OpenACC declare directive. + +! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s + +module acc_declare_test + integer, parameter :: n = 100000 + real, dimension(n) :: data1, data2 + !$acc declare create(data1) create(zero: data2) +end module + +! CHECK-LABEL: fir.global @_QMacc_declare_testEdata1 {acc.declare = #acc.declare} : !fir.array<100000xf32> + +! CHECK-LABEL: acc.global_ctor @_QMacc_declare_testEdata1_acc_ctor { +! CHECK: %[[GLOBAL_ADDR:.*]] = fir.address_of(@_QMacc_declare_testEdata1) {acc.declare = #acc.declare} : !fir.ref> +! CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[GLOBAL_ADDR]] : !fir.ref>) -> !fir.ref> {name = "data1"} +! CHECK: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref>) +! CHECK: acc.terminator +! CHECK: } + +! CHECK-LABEL: fir.global @_QMacc_declare_testEdata2 {acc.declare = #acc.declare} : !fir.array<100000xf32> + +! CHECK-LABEL: acc.global_ctor @_QMacc_declare_testEdata2_acc_ctor { +! CHECK: %[[GLOBAL_ADDR:.*]] = fir.address_of(@_QMacc_declare_testEdata2) {acc.declare = #acc.declare} : !fir.ref> +! CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[GLOBAL_ADDR]] : !fir.ref>) -> !fir.ref> {dataClause = #acc, name = "data2"} +! CHECK: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref>) +! CHECK: acc.terminator +! CHECK: }