diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -2068,14 +2068,13 @@ } } -static void -createBodyOfTargetOp(Fortran::lower::AbstractConverter &converter, - mlir::omp::DataOp &dataOp, - const llvm::SmallVector &useDeviceTypes, - const llvm::SmallVector &useDeviceLocs, - const llvm::SmallVector - &useDeviceSymbols, - const mlir::Location ¤tLocation) { +static void createBodyOfTargetDataOp( + Fortran::lower::AbstractConverter &converter, mlir::omp::DataOp &dataOp, + const llvm::SmallVector &useDeviceTypes, + const llvm::SmallVector &useDeviceLocs, + const llvm::SmallVector + &useDeviceSymbols, + const mlir::Location ¤tLocation) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::Region ®ion = dataOp.getRegion(); @@ -2112,15 +2111,150 @@ } } -static void createTargetOp(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpClauseList &opClauseList, - const llvm::omp::Directive &directive, - mlir::Location currentLocation, - Fortran::lower::pft::Evaluation *eval = nullptr) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); +template +static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + mlir::Location currentLocation, bool outerCombined, + const Fortran::parser::OmpClauseList *clauseList, + Args &&...args) { + auto op = converter.getFirOpBuilder().create( + currentLocation, std::forward(args)...); + createBodyOfOp(op, converter, currentLocation, eval, clauseList, + /*args=*/{}, outerCombined); + return op; +} + +static mlir::omp::MasterOp +genMasterOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + mlir::Location currentLocation) { + return genOpWithBody(converter, eval, currentLocation, + /*outerCombined=*/false, + /*clauseList=*/nullptr, + /*resultTypes=*/mlir::TypeRange()); +} + +static mlir::omp::OrderedRegionOp +genOrderedRegionOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + mlir::Location currentLocation) { + return genOpWithBody( + converter, eval, currentLocation, /*outerCombined=*/false, + /*clauseList=*/nullptr, /*simd=*/false); +} + +static mlir::omp::ParallelOp +genParallelOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + mlir::Location currentLocation, + const Fortran::parser::OmpClauseList &clauseList, + bool outerCombined = false) { Fortran::lower::StatementContext stmtCtx; - mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand; + mlir::Value ifClauseOperand, numThreadsClauseOperand; + mlir::omp::ClauseProcBindKindAttr procBindKindAttr; + llvm::SmallVector allocateOperands, allocatorOperands, + reductionVars; + llvm::SmallVector reductionDeclSymbols; + + // TODO: Handle the following clauses + // 1. default + ClauseProcessor cp(converter, clauseList); + cp.processIf(stmtCtx, + Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel, + ifClauseOperand); + cp.processNumThreads(stmtCtx, numThreadsClauseOperand); + cp.processProcBind(procBindKindAttr); + cp.processDefault(); + cp.processAllocate(allocatorOperands, allocateOperands); + cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols); + + return genOpWithBody( + converter, eval, currentLocation, outerCombined, &clauseList, + /*resultTypes=*/mlir::TypeRange(), ifClauseOperand, + numThreadsClauseOperand, allocateOperands, allocatorOperands, + reductionVars, + reductionDeclSymbols.empty() + ? nullptr + : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), + reductionDeclSymbols), + procBindKindAttr); +} + +static mlir::omp::SingleOp +genSingleOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + mlir::Location currentLocation, + const Fortran::parser::OmpClauseList &beginClauseList, + const Fortran::parser::OmpClauseList &endClauseList) { + llvm::SmallVector allocateOperands, allocatorOperands; mlir::UnitAttr nowaitAttr; + + ClauseProcessor(converter, beginClauseList) + .processAllocate(allocatorOperands, allocateOperands); + ClauseProcessor(converter, endClauseList).processNowait(nowaitAttr); + + return genOpWithBody( + converter, eval, currentLocation, /*outerCombined=*/false, + &beginClauseList, allocateOperands, allocatorOperands, nowaitAttr); +} + +static mlir::omp::TaskOp +genTaskOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, mlir::Location currentLocation, + const Fortran::parser::OmpClauseList &clauseList) { + Fortran::lower::StatementContext stmtCtx; + mlir::Value ifClauseOperand, finalClauseOperand, priorityClauseOperand; + mlir::UnitAttr untiedAttr, mergeableAttr; + llvm::SmallVector dependTypeOperands; + llvm::SmallVector allocateOperands, allocatorOperands, + dependOperands; + + ClauseProcessor cp(converter, clauseList); + cp.processIf(stmtCtx, + Fortran::parser::OmpIfClause::DirectiveNameModifier::Task, + ifClauseOperand); + cp.processAllocate(allocatorOperands, allocateOperands); + cp.processDefault(); + cp.processFinal(stmtCtx, finalClauseOperand); + cp.processUntied(untiedAttr); + cp.processMergeable(mergeableAttr); + cp.processPriority(stmtCtx, priorityClauseOperand); + cp.processDepend(dependTypeOperands, dependOperands); + + return genOpWithBody( + converter, eval, currentLocation, /*outerCombined=*/false, &clauseList, + ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr, + /*in_reduction_vars=*/mlir::ValueRange(), + /*in_reductions=*/nullptr, priorityClauseOperand, + dependTypeOperands.empty() + ? nullptr + : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), + dependTypeOperands), + dependOperands, allocateOperands, allocatorOperands); +} + +static mlir::omp::TaskGroupOp +genTaskGroupOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + mlir::Location currentLocation, + const Fortran::parser::OmpClauseList &clauseList) { + llvm::SmallVector allocateOperands, allocatorOperands; + // TODO: Add task_reduction support + ClauseProcessor(converter, clauseList) + .processAllocate(allocatorOperands, allocateOperands); + return genOpWithBody( + converter, eval, currentLocation, /*outerCombined=*/false, &clauseList, + /*task_reduction_vars=*/mlir::ValueRange(), + /*task_reductions=*/nullptr, allocateOperands, allocatorOperands); +} + +static mlir::omp::DataOp +genDataOp(Fortran::lower::AbstractConverter &converter, + mlir::Location currentLocation, + const Fortran::parser::OmpClauseList &clauseList) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + Fortran::lower::StatementContext stmtCtx; + mlir::Value ifClauseOperand, deviceOperand; llvm::SmallVector mapOperands, devicePtrOperands, deviceAddrOperands; llvm::SmallVector mapTypes; @@ -2128,79 +2262,106 @@ llvm::SmallVector useDeviceLocs; llvm::SmallVector useDeviceSymbols; + ClauseProcessor cp(converter, clauseList); + cp.processIf(stmtCtx, + Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData, + ifClauseOperand); + cp.processDevice(stmtCtx, deviceOperand); + cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs, + useDeviceSymbols); + cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs, + useDeviceSymbols); + cp.processMap(mapOperands, mapTypes); + + llvm::SmallVector mapTypesAttr(mapTypes.begin(), + mapTypes.end()); + mlir::ArrayAttr mapTypesArrayAttr = + mlir::ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr); + + auto dataOp = converter.getFirOpBuilder().create( + currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands, + deviceAddrOperands, mapOperands, mapTypesArrayAttr); + createBodyOfTargetDataOp(converter, dataOp, useDeviceTypes, useDeviceLocs, + useDeviceSymbols, currentLocation); + return dataOp; +} + +template +static OpTy +genEnterExitDataOp(Fortran::lower::AbstractConverter &converter, + mlir::Location currentLocation, + const Fortran::parser::OmpClauseList &clauseList) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + Fortran::lower::StatementContext stmtCtx; + mlir::Value ifClauseOperand, deviceOperand; + mlir::UnitAttr nowaitAttr; + llvm::SmallVector mapOperands; + llvm::SmallVector mapTypes; + Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName; - switch (directive) { - case llvm::omp::Directive::OMPD_target: - directiveName = Fortran::parser::OmpIfClause::DirectiveNameModifier::Target; - break; - case llvm::omp::Directive::OMPD_target_data: - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData; - break; - case llvm::omp::Directive::OMPD_target_enter_data: + if constexpr (std::is_same_v) { directiveName = Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData; - break; - case llvm::omp::Directive::OMPD_target_exit_data: + } else if constexpr (std::is_same_v) { directiveName = Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData; - break; - default: - TODO(currentLocation, "OMPD_target directive unknown"); - break; + } else { + return nullptr; } - ClauseProcessor cp(converter, opClauseList); + ClauseProcessor cp(converter, clauseList); cp.processIf(stmtCtx, directiveName, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); - cp.processThreadLimit(stmtCtx, threadLmtOperand); cp.processNowait(nowaitAttr); - cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs, - useDeviceSymbols); - cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs, - useDeviceSymbols); cp.processMap(mapOperands, mapTypes); - for (const Fortran::parser::OmpClause &clause : opClauseList.v) { - if (!std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u)) { - mlir::Location clauseLocation = converter.genLocation(clause.source); - TODO(clauseLocation, "OMPD_target unhandled clause"); - } - } + llvm::SmallVector mapTypesAttr(mapTypes.begin(), + mapTypes.end()); + mlir::ArrayAttr mapTypesArrayAttr = + mlir::ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr); + + return firOpBuilder.create(currentLocation, ifClauseOperand, + deviceOperand, nowaitAttr, mapOperands, + mapTypesArrayAttr); +} + +static mlir::omp::TargetOp +genTargetOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + mlir::Location currentLocation, + const Fortran::parser::OmpClauseList &clauseList, + bool outerCombined = false) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + Fortran::lower::StatementContext stmtCtx; + mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand; + mlir::UnitAttr nowaitAttr; + llvm::SmallVector mapOperands; + llvm::SmallVector mapTypes; + + ClauseProcessor cp(converter, clauseList); + cp.processIf(stmtCtx, + Fortran::parser::OmpIfClause::DirectiveNameModifier::Target, + ifClauseOperand); + cp.processDevice(stmtCtx, deviceOperand); + cp.processThreadLimit(stmtCtx, threadLimitOperand); + cp.processNowait(nowaitAttr); + cp.processMap(mapOperands, mapTypes); llvm::SmallVector mapTypesAttr(mapTypes.begin(), mapTypes.end()); mlir::ArrayAttr mapTypesArrayAttr = mlir::ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr); - if (directive == llvm::omp::Directive::OMPD_target) { - auto targetOp = firOpBuilder.create( - currentLocation, ifClauseOperand, deviceOperand, threadLmtOperand, - nowaitAttr, mapOperands, mapTypesArrayAttr); - createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList); - } else if (directive == llvm::omp::Directive::OMPD_target_data) { - auto dataOp = firOpBuilder.create( - currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands, - deviceAddrOperands, mapOperands, mapTypesArrayAttr); - createBodyOfTargetOp(converter, dataOp, useDeviceTypes, useDeviceLocs, - useDeviceSymbols, currentLocation); - } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) { - firOpBuilder.create( - currentLocation, ifClauseOperand, deviceOperand, nowaitAttr, - mapOperands, mapTypesArrayAttr); - } else if (directive == llvm::omp::Directive::OMPD_target_exit_data) { - firOpBuilder.create(currentLocation, ifClauseOperand, - deviceOperand, nowaitAttr, - mapOperands, mapTypesArrayAttr); - } + return genOpWithBody( + converter, eval, currentLocation, outerCombined, &clauseList, + ifClauseOperand, deviceOperand, threadLimitOperand, nowaitAttr, + mapOperands, mapTypesArrayAttr); } +//===----------------------------------------------------------------------===// +// genOMP() Code generation helper functions +//===----------------------------------------------------------------------===// + static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSimpleStandaloneConstruct @@ -2226,9 +2387,15 @@ firOpBuilder.create(currentLocation); break; case llvm::omp::Directive::OMPD_target_data: + genDataOp(converter, currentLocation, opClauseList); + break; case llvm::omp::Directive::OMPD_target_enter_data: + genEnterExitDataOp(converter, currentLocation, + opClauseList); + break; case llvm::omp::Directive::OMPD_target_exit_data: - createTargetOp(converter, opClauseList, directive.v, currentLocation); + genEnterExitDataOp(converter, currentLocation, + opClauseList); break; case llvm::omp::Directive::OMPD_target_update: TODO(currentLocation, "OMPD_target_update"); @@ -2273,89 +2440,6 @@ standaloneConstruct.u); } -/* When parallel is used in a combined construct, then use this function to - * create the parallel operation. It handles the parallel specific clauses - * and leaves the rest for handling at the inner operations. - */ -template -static void -createCombinedParallelOp(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Directive &directive) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location currentLocation = converter.getCurrentLocation(); - Fortran::lower::StatementContext stmtCtx; - llvm::ArrayRef argTy; - mlir::Value ifClauseOperand, numThreadsClauseOperand; - llvm::SmallVector allocatorOperands, allocateOperands; - mlir::omp::ClauseProcBindKindAttr procBindKindAttr; - const auto &opClauseList = - std::get(directive.t); - // TODO: Handle the following clauses - // 1. default - // Note: rest of the clauses are handled when the inner operation is created - ClauseProcessor cp(converter, opClauseList); - cp.processIf(stmtCtx, - Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel, - ifClauseOperand); - cp.processNumThreads(stmtCtx, numThreadsClauseOperand); - cp.processProcBind(procBindKindAttr); - - // Create and insert the operation. - auto parallelOp = firOpBuilder.create( - currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, - allocateOperands, allocatorOperands, - /*reduction_vars=*/mlir::ValueRange(), - /*reductions=*/nullptr, procBindKindAttr); - - createBodyOfOp(parallelOp, converter, currentLocation, - eval, &opClauseList, /*iv=*/{}, - /*isCombined=*/true); -} - -/* When target is used in a combined construct, then use this function to - * create the target operation. It handles the target specific clauses - * and leaves the rest for handling at the inner operations. - */ -template -static void createCombinedTargetOp(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Directive &directive) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location currentLocation = converter.getCurrentLocation(); - Fortran::lower::StatementContext stmtCtx; - mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand; - llvm::SmallVector mapOperands; - llvm::SmallVector mapTypes; - mlir::UnitAttr nowaitAttr; - const auto &opClauseList = - std::get(directive.t); - - // Note: rest of the clauses are handled when the inner operation is created - ClauseProcessor cp(converter, opClauseList); - cp.processIf(stmtCtx, - Fortran::parser::OmpIfClause::DirectiveNameModifier::Target, - ifClauseOperand); - cp.processDevice(stmtCtx, deviceOperand); - cp.processThreadLimit(stmtCtx, threadLimitOperand); - cp.processNowait(nowaitAttr); - cp.processMap(mapOperands, mapTypes); - - llvm::SmallVector mapTypesAttr(mapTypes.begin(), - mapTypes.end()); - mlir::ArrayAttr mapTypesArrayAttr = - mlir::ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr); - - // Create and insert the operation. - auto targetOp = firOpBuilder.create( - currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand, - nowaitAttr, mapOperands, mapTypesArrayAttr); - - createBodyOfOp(targetOp, converter, currentLocation, - eval, &opClauseList, - /*iv=*/{}, /*isCombined=*/true); -} - static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { @@ -2392,8 +2476,8 @@ if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet) .test(ompDirective)) { validDirective = true; - createCombinedTargetOp( - converter, eval, beginLoopDirective); + genTargetOp(converter, eval, currentLocation, loopOpClauseList, + /*outerCombined=*/true); } if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet) .test(ompDirective)) { @@ -2407,8 +2491,8 @@ if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet) .test(ompDirective)) { validDirective = true; - createCombinedParallelOp( - converter, eval, beginLoopDirective); + genParallelOp(converter, eval, currentLocation, loopOpClauseList, + /*outerCombined=*/true); } } if ((llvm::omp::allDoSet | llvm::omp::allSimdSet).test(ompDirective)) @@ -2513,67 +2597,16 @@ const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { const auto &beginBlockDirective = std::get(blockConstruct.t); - const auto &blockDirective = - std::get(beginBlockDirective.t); const auto &endBlockDirective = std::get(blockConstruct.t); - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location currentLocation = converter.genLocation(blockDirective.source); - - Fortran::lower::StatementContext stmtCtx; - llvm::ArrayRef argTy; - mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand, - priorityClauseOperand; - mlir::omp::ClauseProcBindKindAttr procBindKindAttr; - llvm::SmallVector allocateOperands, allocatorOperands, - dependOperands, reductionVars; - llvm::SmallVector dependTypeOperands, reductionDeclSymbols; - mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr; - - // Use placeholder value to avoid uninitialized `directiveName` compiler - // errors. The 'if clause' obtained won't be used for these directives. - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel; - switch (blockDirective.v) { - case llvm::omp::OMPD_parallel: - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel; - break; - case llvm::omp::OMPD_task: - directiveName = Fortran::parser::OmpIfClause::DirectiveNameModifier::Task; - break; - // Target-related 'if' clauses handled by createTargetOp(). - case llvm::omp::OMPD_target: - case llvm::omp::OMPD_target_data: - // These block directives do not accept an 'if' clause. - case llvm::omp::OMPD_master: - case llvm::omp::OMPD_single: - case llvm::omp::OMPD_ordered: - case llvm::omp::OMPD_taskgroup: - break; - default: - TODO(currentLocation, - "Unhandled block directive (" + - llvm::omp::getOpenMPDirectiveName(blockDirective.v) + ")"); - break; - } - - const auto &opClauseList = + const auto &directive = + std::get(beginBlockDirective.t); + const auto &beginClauseList = std::get(beginBlockDirective.t); - ClauseProcessor cp(converter, opClauseList); - cp.processIf(stmtCtx, directiveName, ifClauseOperand); - cp.processNumThreads(stmtCtx, numThreadsClauseOperand); - cp.processProcBind(procBindKindAttr); - cp.processAllocate(allocatorOperands, allocateOperands); - cp.processDefault(); - cp.processFinal(stmtCtx, finalClauseOperand); - cp.processUntied(untiedAttr); - cp.processMergeable(mergeableAttr); - cp.processPriority(stmtCtx, priorityClauseOperand); - cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols); - cp.processDepend(dependTypeOperands, dependOperands); + const auto &endClauseList = + std::get(endBlockDirective.t); - for (const Fortran::parser::OmpClause &clause : opClauseList.v) { + for (const Fortran::parser::OmpClause &clause : beginClauseList.v) { mlir::Location clauseLocation = converter.genLocation(clause.source); if (!std::get_if(&clause.u) && !std::get_if(&clause.u) && @@ -2586,17 +2619,11 @@ !std::get_if(&clause.u) && !std::get_if(&clause.u) && !std::get_if(&clause.u) && - // Privatisation and copyin clauses are handled elsewhere. !std::get_if(&clause.u) && !std::get_if(&clause.u) && !std::get_if(&clause.u) && - // Shared is the default behavior in the IR, so no handling is required. !std::get_if(&clause.u) && - // Nothing needs to be done for threads clause. !std::get_if(&clause.u) && - // Map, UseDevicePtr, UseDeviceAddr and ThreadLimit clauses are - // exclusive to Target directives. They are handled as part of the - // TargetOp creation. !std::get_if(&clause.u) && !std::get_if(&clause.u) && !std::get_if(&clause.u) && @@ -2605,67 +2632,74 @@ } } - ClauseProcessor(converter, - std::get(endBlockDirective.t)) - .processNowait(nowaitAttr); - for (const auto &clause : - std::get(endBlockDirective.t).v) { + for (const auto &clause : endClauseList.v) { mlir::Location clauseLocation = converter.genLocation(clause.source); if (!std::get_if(&clause.u)) TODO(clauseLocation, "OpenMP Block construct clause"); } - if (blockDirective.v == llvm::omp::OMPD_parallel) { - // Create and insert the operation. - auto parallelOp = firOpBuilder.create( - currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, - allocateOperands, allocatorOperands, reductionVars, - reductionDeclSymbols.empty() - ? nullptr - : mlir::ArrayAttr::get(firOpBuilder.getContext(), - reductionDeclSymbols), - procBindKindAttr); - createBodyOfOp(parallelOp, converter, - currentLocation, eval, &opClauseList); - } else if (blockDirective.v == llvm::omp::OMPD_master) { - auto masterOp = - firOpBuilder.create(currentLocation, argTy); - createBodyOfOp(masterOp, converter, currentLocation, - eval); - } else if (blockDirective.v == llvm::omp::OMPD_single) { - auto singleOp = firOpBuilder.create( - currentLocation, allocateOperands, allocatorOperands, nowaitAttr); - createBodyOfOp(singleOp, converter, currentLocation, - eval, &opClauseList); - } else if (blockDirective.v == llvm::omp::OMPD_ordered) { - auto orderedOp = firOpBuilder.create( - currentLocation, /*simd=*/false); - createBodyOfOp(orderedOp, converter, - currentLocation, eval); - } else if (blockDirective.v == llvm::omp::OMPD_task) { - auto taskOp = firOpBuilder.create( - currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr, - mergeableAttr, /*in_reduction_vars=*/mlir::ValueRange(), - /*in_reductions=*/nullptr, priorityClauseOperand, - dependTypeOperands.empty() - ? nullptr - : mlir::ArrayAttr::get(firOpBuilder.getContext(), - dependTypeOperands), - dependOperands, allocateOperands, allocatorOperands); - createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList); - } else if (blockDirective.v == llvm::omp::OMPD_taskgroup) { - // TODO: Add task_reduction support - auto taskGroupOp = firOpBuilder.create( - currentLocation, /*task_reduction_vars=*/mlir::ValueRange(), - /*task_reductions=*/nullptr, allocateOperands, allocatorOperands); - createBodyOfOp(taskGroupOp, converter, currentLocation, eval, - &opClauseList); - } else if (blockDirective.v == llvm::omp::OMPD_target) { - createTargetOp(converter, opClauseList, blockDirective.v, currentLocation, - &eval); - } else if (blockDirective.v == llvm::omp::OMPD_target_data) { - createTargetOp(converter, opClauseList, blockDirective.v, currentLocation, - &eval); + mlir::Location currentLocation = converter.genLocation(directive.source); + switch (directive.v) { + case llvm::omp::Directive::OMPD_master: + genMasterOp(converter, eval, currentLocation); + break; + case llvm::omp::Directive::OMPD_ordered: + genOrderedRegionOp(converter, eval, currentLocation); + break; + case llvm::omp::Directive::OMPD_parallel: + genParallelOp(converter, eval, currentLocation, beginClauseList); + break; + case llvm::omp::Directive::OMPD_single: + genSingleOp(converter, eval, currentLocation, beginClauseList, + endClauseList); + break; + case llvm::omp::Directive::OMPD_target: + genTargetOp(converter, eval, currentLocation, beginClauseList); + break; + case llvm::omp::Directive::OMPD_target_data: + genDataOp(converter, currentLocation, beginClauseList); + break; + case llvm::omp::Directive::OMPD_task: + genTaskOp(converter, eval, currentLocation, beginClauseList); + break; + case llvm::omp::Directive::OMPD_taskgroup: + genTaskGroupOp(converter, eval, currentLocation, beginClauseList); + break; + case llvm::omp::Directive::OMPD_teams: + TODO(currentLocation, "Teams construct"); + break; + case llvm::omp::Directive::OMPD_workshare: + TODO(currentLocation, "Workshare construct"); + break; + default: { + // Codegen for combined directives + bool combinedDirective = false; + if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet) + .test(directive.v)) { + genTargetOp(converter, eval, currentLocation, beginClauseList, + /*outerCombined=*/true); + combinedDirective = true; + } + if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet).test(directive.v)) { + TODO(currentLocation, "Teams construct"); + combinedDirective = true; + } + if ((llvm::omp::allParallelSet & llvm::omp::blockConstructSet).test(directive.v)) { + bool outerCombined = directive.v != llvm::omp::Directive::OMPD_target_parallel; + genParallelOp(converter, eval, currentLocation, beginClauseList, + outerCombined); + combinedDirective = true; + } + if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet).test(directive.v)) { + TODO(currentLocation, "Workshare construct"); + combinedDirective = true; + } + if (!combinedDirective) + TODO(currentLocation, "Unhandled block directive (" + + llvm::omp::getOpenMPDirectiveName(directive.v) + + ")"); + break; + } } } @@ -2739,51 +2773,44 @@ genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::Location currentLocation = converter.getCurrentLocation(); - llvm::SmallVector reductionVars, allocateOperands, - allocatorOperands; + llvm::SmallVector allocateOperands, allocatorOperands; mlir::UnitAttr nowaitClauseOperand; const auto &beginSectionsDirective = std::get(sectionsConstruct.t); const auto §ionsClauseList = std::get(beginSectionsDirective.t); + // Process clauses before optional omp.parallel, so that new variables are + // allocated outside of the parallel region ClauseProcessor cp(converter, sectionsClauseList); cp.processSectionsReduction(currentLocation); cp.processAllocate(allocatorOperands, allocateOperands); - const auto &endSectionsDirective = - std::get(sectionsConstruct.t); - const auto &endSectionsClauseList = - std::get(endSectionsDirective.t); - ClauseProcessor(converter, endSectionsClauseList) - .processNowait(nowaitClauseOperand); - llvm::omp::Directive dir = std::get(beginSectionsDirective.t) .v; - // Parallel Sections Construct + // Parallel wrapper of PARALLEL SECTIONS construct if (dir == llvm::omp::Directive::OMPD_parallel_sections) { - createCombinedParallelOp( - converter, eval, - std::get( - sectionsConstruct.t)); - auto sectionsOp = firOpBuilder.create( - currentLocation, /*reduction_vars*/ mlir::ValueRange(), - /*reductions=*/nullptr, allocateOperands, allocatorOperands, - /*nowait=*/nullptr); - createBodyOfOp(sectionsOp, converter, currentLocation, eval); - - // Sections Construct - } else if (dir == llvm::omp::Directive::OMPD_sections) { - auto sectionsOp = firOpBuilder.create( - currentLocation, reductionVars, /*reductions=*/nullptr, - allocateOperands, allocatorOperands, nowaitClauseOperand); - createBodyOfOp(sectionsOp, converter, - currentLocation, eval); - } + genParallelOp(converter, eval, currentLocation, sectionsClauseList, + /*outerCombined=*/true); + } else { + const auto &endSectionsDirective = + std::get(sectionsConstruct.t); + const auto &endSectionsClauseList = + std::get(endSectionsDirective.t); + ClauseProcessor(converter, endSectionsClauseList) + .processNowait(nowaitClauseOperand); + } + + // SECTIONS construct + genOpWithBody(converter, eval, currentLocation, + /*outerCombined=*/false, + /*clauseList=*/nullptr, + /*reduction_vars=*/mlir::ValueRange(), + /*reductions=*/nullptr, allocateOperands, + allocatorOperands, nowaitClauseOperand); } static bool checkForSingleVariableOnRHS( diff --git a/flang/test/Lower/OpenMP/parallel-sections.f90 b/flang/test/Lower/OpenMP/parallel-sections.f90 --- a/flang/test/Lower/OpenMP/parallel-sections.f90 +++ b/flang/test/Lower/OpenMP/parallel-sections.f90 @@ -38,12 +38,16 @@ subroutine omp_parallel_sections_allocate(x, y) use omp_lib integer, intent(inout) :: x, y - !FIRDialect: %[[allocator:.*]] = arith.constant 1 : i32 - !LLVMDialect: %[[allocator:.*]] = llvm.mlir.constant(1 : i32) : i32 - !OMPDialect: omp.parallel { + !FIRDialect: %[[allocator_1:.*]] = arith.constant 1 : i32 + !FIRDialect: %[[allocator_2:.*]] = arith.constant 1 : i32 + !LLVMDialect: %[[allocator_1:.*]] = llvm.mlir.constant(1 : i32) : i32 + !LLVMDialect: %[[allocator_2:.*]] = llvm.mlir.constant(1 : i32) : i32 + !OMPDialect: omp.parallel allocate( + !FIRDialect: %[[allocator_2]] : i32 -> %{{.*}} : !fir.ref) { + !LLVMDialect: %[[allocator_2]] : i32 -> %{{.*}} : !llvm.ptr) { !OMPDialect: omp.sections allocate( - !FIRDialect: %[[allocator]] : i32 -> %{{.*}} : !fir.ref) { - !LLVMDialect: %[[allocator]] : i32 -> %{{.*}} : !llvm.ptr) { + !FIRDialect: %[[allocator_1]] : i32 -> %{{.*}} : !fir.ref) { + !LLVMDialect: %[[allocator_1]] : i32 -> %{{.*}} : !llvm.ptr) { !$omp parallel sections allocate(omp_high_bw_mem_alloc: x) !OMPDialect: omp.section { !$omp section