diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -107,6 +107,18 @@ return op; } +template +static Op createSimpleOp(Fortran::lower::FirOpBuilder &builder, + mlir::Location loc, + const SmallVectorImpl &operands, + const SmallVectorImpl &operandSegments) { + llvm::ArrayRef argTy; + Op op = builder.create(loc, argTy, operands); + op.setAttr(Op::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr(operandSegments)); + return op; +} + static void genACC(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { @@ -516,6 +528,128 @@ } } +static void +genACCExitDataOp(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClauseList &accClauseList) { + mlir::Value ifCond, async, waitDevnum; + SmallVector copyoutOperands, deleteOperands, detachOperands, + waitOperands; + + // Async and wait clause have optional values but can be present with + // no value as well. When there is no value, the op has an attribute to + // represent the clause. + bool addAsyncAttr = false; + bool addWaitAttr = false; + bool addFinalizeAttr = false; + + auto &firOpBuilder = converter.getFirOpBuilder(); + auto currentLocation = converter.getCurrentLocation(); + + // Lower clauses values mapped to operands. + // Keep track of each group of operands separatly as clauses can appear + // more than once. + for (const auto &clause : accClauseList.v) { + if (const auto *ifClause = + std::get_if(&clause.u)) { + Value cond = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); + ifCond = firOpBuilder.createConvert(currentLocation, + firOpBuilder.getI1Type(), cond); + } else if (const auto *asyncClause = + std::get_if(&clause.u)) { + const auto &asyncClauseValue = asyncClause->v; + if (asyncClauseValue) { // async has a value. + async = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*asyncClauseValue))); + } else { + addAsyncAttr = true; + } + } else if (const auto *waitClause = + std::get_if(&clause.u)) { + const auto &waitClauseValue = waitClause->v; + if (waitClauseValue) { // wait has a value. + const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; + const std::list &waitList = + std::get>(waitArg.t); + for (const Fortran::parser::ScalarIntExpr &value : waitList) { + Value v = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(value))); + waitOperands.push_back(v); + } + + const std::optional &waitDevnumValue = + std::get>(waitArg.t); + if (waitDevnumValue) + waitDevnum = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*waitDevnumValue))); + } else { + addWaitAttr = true; + } + } else if (const auto *copyoutClause = + std::get_if( + &clause.u)) { + const Fortran::parser::AccObjectListWithModifier &listWithModifier = + copyoutClause->v; + const Fortran::parser::AccObjectList &accObjectList = + std::get(listWithModifier.t); + genObjectList(accObjectList, converter, copyoutOperands); + } else if (const auto *deleteClause = + std::get_if(&clause.u)) { + genObjectList(deleteClause->v, converter, deleteOperands); + } else if (const auto *detachClause = + std::get_if(&clause.u)) { + genObjectList(detachClause->v, converter, detachOperands); + } else if (std::get_if(&clause.u)) { + addFinalizeAttr = true; + } + } + + // Prepare the operand segement size attribute and the operands value range. + SmallVector operands; + SmallVector operandSegments; + addOperand(operands, operandSegments, ifCond); + addOperand(operands, operandSegments, async); + addOperand(operands, operandSegments, waitDevnum); + addOperands(operands, operandSegments, waitOperands); + addOperands(operands, operandSegments, copyoutOperands); + addOperands(operands, operandSegments, deleteOperands); + addOperands(operands, operandSegments, detachOperands); + + auto exitDataOp = createSimpleOp( + firOpBuilder, currentLocation, operands, operandSegments); + + if (addAsyncAttr) + exitDataOp.asyncAttr(firOpBuilder.getUnitAttr()); + if (addWaitAttr) + exitDataOp.waitAttr(firOpBuilder.getUnitAttr()); + if (addFinalizeAttr) + exitDataOp.finalizeAttr(firOpBuilder.getUnitAttr()); +} + +static void +genACC(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenACCStandaloneConstruct &standaloneConstruct) { + const auto &standaloneDirective = + std::get(standaloneConstruct.t); + const auto &accClauseList = + std::get(standaloneConstruct.t); + + if (standaloneDirective.v == llvm::acc::Directive::ACCD_enter_data) { + TODO("OpenACC enter data directive not lowered yet!"); + } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_exit_data) { + genACCExitDataOp(converter, accClauseList); + } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_init) { + TODO("OpenACC init directive not lowered yet!"); + } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_shutdown) { + TODO("OpenACC shutdown directive not lowered yet!"); + } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_set) { + TODO("OpenACC set directive not lowered yet!"); + } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_update) { + TODO("OpenACC update directive not lowered yet!"); + } +} + void Fortran::lower::genOpenACCConstruct( Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -535,7 +669,7 @@ }, [&](const Fortran::parser::OpenACCStandaloneConstruct &standaloneConstruct) { - TODO("OpenACC Standalone construct not lowered yet!"); + genACC(converter, eval, standaloneConstruct); }, [&](const Fortran::parser::OpenACCRoutineConstruct &routineConstruct) {