diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -372,6 +372,14 @@ return op; } +static void addDeclareAttr(fir::FirOpBuilder &builder, mlir::Operation *op, + mlir::acc::DataClause clause) { + op->setAttr(mlir::acc::getDeclareAttrName(), + mlir::acc::DeclareAttr::get(builder.getContext(), + mlir::acc::DataClauseAttr::get( + builder.getContext(), clause))); +} + template static void genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, @@ -379,7 +387,8 @@ Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::StatementContext &stmtCtx, llvm::SmallVectorImpl &dataOperands, - mlir::acc::DataClause dataClause, bool structured) { + mlir::acc::DataClause dataClause, bool structured, + bool setDeclareAttr = false) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); for (const auto &accObject : objectList.v) { llvm::SmallVector bounds; @@ -392,6 +401,8 @@ bounds, structured, dataClause, baseAddr.getType()); dataOperands.push_back(op.getAccPtr()); + if (setDeclareAttr) + addDeclareAttr(builder, op.getVarPtr().getDefiningOp(), dataClause); } } @@ -2283,14 +2294,6 @@ waitOp.setAsyncAttr(firOpBuilder.getUnitAttr()); } -static void addDeclareAttr(fir::FirOpBuilder &builder, mlir::Operation *op, - mlir::acc::DataClause clause) { - op->setAttr(mlir::acc::getDeclareAttrName(), - mlir::acc::DeclareAttr::get(builder.getContext(), - mlir::acc::DataClauseAttr::get( - builder.getContext(), clause))); -} - template static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder, @@ -2391,6 +2394,34 @@ dataClause); } +static void +genDeclareInFunction(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + mlir::Location loc, + const Fortran::parser::AccClauseList &accClauseList) { + llvm::SmallVector dataClauseOperands, copyEntryOperands; + Fortran::lower::StatementContext stmtCtx; + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + for (const Fortran::parser::AccClause &clause : accClauseList.v) { + if (const auto *copyClause = + std::get_if(&clause.u)) { + auto crtDataStart = dataClauseOperands.size(); + + genDataOperandOperations( + copyClause->v, converter, semanticsContext, stmtCtx, + dataClauseOperands, mlir::acc::DataClause::acc_copy, + /*structured=*/true, /*setDeclareAttr=*/true); + copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart, + dataClauseOperands.end()); + } else { + mlir::Location clauseLocation = converter.genLocation(clause.source); + TODO(clauseLocation, "clause on declare directive"); + } + } + + builder.create(loc, dataClauseOperands); +} + static void genDeclareInModule(Fortran::lower::AbstractConverter &converter, mlir::ModuleOp &moduleOp, @@ -2439,6 +2470,8 @@ const auto &declarativeDir = std::get(declareConstruct.t); + mlir::Location directiveLocation = + converter.genLocation(declarativeDir.source); const auto &accClauseList = std::get(declareConstruct.t); @@ -2449,7 +2482,8 @@ auto funcOp = builder.getBlock()->getParent()->getParentOfType(); if (funcOp) - TODO(funcOp.getLoc(), "OpenACC declare in function/subroutine"); + genDeclareInFunction(converter, semanticsContext, directiveLocation, + accClauseList); else if (moduleOp) genDeclareInModule(converter, moduleOp, accClauseList); return; diff --git a/flang/test/Lower/OpenACC/acc-declare.f90 b/flang/test/Lower/OpenACC/acc-declare.f90 --- a/flang/test/Lower/OpenACC/acc-declare.f90 +++ b/flang/test/Lower/OpenACC/acc-declare.f90 @@ -85,3 +85,23 @@ ! CHECK: acc.declare_enter dataOperands(%[[LINK]] : !fir.ref>) ! CHECK: acc.terminator ! CHECK: } + +module acc_declare + contains + + subroutine acc_declare_copy() + integer :: a(100), i + !$acc declare copy(a) + + do i = 1, 100 + a(i) = i + end do + end subroutine + +! CHECK-LABEL: func.func @_QMacc_declarePacc_declare_copy() +! CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<100xi32> {acc.declare = #acc.declare, bindc_name = "a", uniq_name = "_QMacc_declareFacc_declare_copyEa"} +! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%c1 : index) startIdx(%c1 : index) +! CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%[[ALLOCA]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {dataClause = #acc, name = "a"} +! CHECK: acc.declare_enter dataOperands(%[[COPYIN]] : !fir.ref>) + +end module