diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h --- a/flang/include/flang/Parser/parse-tree.h +++ b/flang/include/flang/Parser/parse-tree.h @@ -3462,8 +3462,8 @@ // 2.12 if-clause -> IF ([ directive-name-modifier :] scalar-logical-expr) struct OmpIfClause { TUPLE_CLASS_BOILERPLATE(OmpIfClause); - ENUM_CLASS(DirectiveNameModifier, Parallel, Target, TargetEnterData, - TargetExitData, TargetData, TargetUpdate, Taskloop, Task) + ENUM_CLASS(DirectiveNameModifier, Parallel, Simd, Target, TargetData, + TargetEnterData, TargetExitData, TargetUpdate, Task, Taskloop, Teams) std::tuple, ScalarLogicalExpr> t; }; diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -509,8 +509,10 @@ bool processCopyin() const; bool processDepend(llvm::SmallVectorImpl &dependTypeOperands, llvm::SmallVectorImpl &dependOperands) const; - bool processIf(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; + bool + processIf(Fortran::lower::StatementContext &stmtCtx, + Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, + mlir::Value &result) const; bool processLink(llvm::SmallVectorImpl &result) const; bool processMap(llvm::SmallVectorImpl &mapOperands, @@ -1040,11 +1042,19 @@ pbKind); } -static mlir::Value -getIfClauseOperand(Fortran::lower::AbstractConverter &converter, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClause::If *ifClause, - mlir::Location clauseLocation) { +static mlir::Value getIfClauseOperand( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClause::If *ifClause, + Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, + mlir::Location clauseLocation) { + // Only consider the clause if it's intended for the given directive. + auto &directive = std::get< + std::optional>( + ifClause->v.t); + if (directive && directive.value() != directiveName) + return nullptr; + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); auto &expr = std::get(ifClause->v.t); mlir::Value ifVal = fir::getBase( @@ -1563,17 +1573,25 @@ }); } -bool ClauseProcessor::processIf(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { - return findRepeatableClause( +bool ClauseProcessor::processIf( + Fortran::lower::StatementContext &stmtCtx, + Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, + mlir::Value &result) const { + bool found = false; + findRepeatableClause( [&](const ClauseTy::If *ifClause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); - // TODO Consider DirectiveNameModifier of the `ifClause` to only search - // for an applicable 'if' clause. - result = - getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation); + mlir::Value operand = getIfClauseOperand(converter, stmtCtx, ifClause, + directiveName, clauseLocation); + // Assume that, at most, a single 'if' clause will be applicable to the + // given directive. + if (operand) { + result = operand; + found = true; + } }); + return found; } bool ClauseProcessor::processLink( @@ -2083,8 +2101,30 @@ llvm::SmallVector useDeviceLocs; llvm::SmallVector useDeviceSymbols; + Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName; + switch (directive) { + case llvm::omp::Directive::OMPD_target: + directiveName = Fortran::parser::OmpIfClause::DirectiveNameModifier::Target; + break; + case llvm::omp::Directive::OMPD_target_data: + directiveName = + Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData; + break; + case llvm::omp::Directive::OMPD_target_enter_data: + directiveName = + Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData; + break; + case llvm::omp::Directive::OMPD_target_exit_data: + directiveName = + Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData; + break; + default: + TODO(currentLocation, "OMPD_target directive unknown"); + break; + } + ClauseProcessor cp(converter, opClauseList); - cp.processIf(stmtCtx, ifClauseOperand); + cp.processIf(stmtCtx, directiveName, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); cp.processThreadLimit(stmtCtx, threadLmtOperand); cp.processNowait(nowaitAttr); @@ -2131,8 +2171,6 @@ firOpBuilder.create(currentLocation, ifClauseOperand, deviceOperand, nowaitAttr, mapOperands, mapTypesArrayAttr); - } else { - TODO(currentLocation, "OMPD_target directive unknown"); } } @@ -2230,7 +2268,9 @@ // 1. default // Note: rest of the clauses are handled when the inner operation is created ClauseProcessor cp(converter, opClauseList); - cp.processIf(stmtCtx, ifClauseOperand); + cp.processIf(stmtCtx, + Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel, + ifClauseOperand); cp.processNumThreads(stmtCtx, numThreadsClauseOperand); cp.processProcBind(procBindKindAttr); @@ -2289,7 +2329,9 @@ cp.processCollapse(currentLocation, eval, lowerBound, upperBound, step, iv, loopVarTypeSize); cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); - cp.processIf(stmtCtx, ifClauseOperand); + cp.processIf(stmtCtx, + Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd, + ifClauseOperand); cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols); cp.processSimdlen(simdlenClauseOperand); cp.processSafelen(safelenClauseOperand); @@ -2390,10 +2432,36 @@ llvm::SmallVector dependTypeOperands, reductionDeclSymbols; mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr; + // Use placeholder value to avoid uninitialized `directiveName` compiler + // errors. The 'if clause' obtained won't be used for these directives. + Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName = + Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel; + switch (blockDirective.v) { + case llvm::omp::OMPD_parallel: + directiveName = + Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel; + break; + case llvm::omp::OMPD_task: + directiveName = Fortran::parser::OmpIfClause::DirectiveNameModifier::Task; + break; + // Target-related 'if' clauses handled by createTargetOp(). + case llvm::omp::OMPD_target: + case llvm::omp::OMPD_target_data: + // These block directives do not accept an 'if' clause. + case llvm::omp::OMPD_master: + case llvm::omp::OMPD_single: + case llvm::omp::OMPD_ordered: + case llvm::omp::OMPD_taskgroup: + break; + default: + TODO(currentLocation, "Unhandled block directive"); + break; + } + const auto &opClauseList = std::get(beginBlockDirective.t); ClauseProcessor cp(converter, opClauseList); - cp.processIf(stmtCtx, ifClauseOperand); + cp.processIf(stmtCtx, directiveName, ifClauseOperand); cp.processNumThreads(stmtCtx, numThreadsClauseOperand); cp.processProcBind(procBindKindAttr); cp.processAllocate(allocatorOperands, allocateOperands); @@ -2491,8 +2559,6 @@ } else if (blockDirective.v == llvm::omp::OMPD_target_data) { createTargetOp(converter, opClauseList, blockDirective.v, currentLocation, &eval); - } else { - TODO(currentLocation, "Unhandled block directive"); } } diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -115,17 +115,19 @@ TYPE_PARSER(construct( maybe( ("PARALLEL" >> pure(OmpIfClause::DirectiveNameModifier::Parallel) || + "SIMD" >> pure(OmpIfClause::DirectiveNameModifier::Simd) || + "TARGET" >> pure(OmpIfClause::DirectiveNameModifier::Target) || + "TARGET DATA" >> + pure(OmpIfClause::DirectiveNameModifier::TargetData) || "TARGET ENTER DATA" >> pure(OmpIfClause::DirectiveNameModifier::TargetEnterData) || "TARGET EXIT DATA" >> pure(OmpIfClause::DirectiveNameModifier::TargetExitData) || - "TARGET DATA" >> - pure(OmpIfClause::DirectiveNameModifier::TargetData) || "TARGET UPDATE" >> pure(OmpIfClause::DirectiveNameModifier::TargetUpdate) || - "TARGET" >> pure(OmpIfClause::DirectiveNameModifier::Target) || "TASK"_id >> pure(OmpIfClause::DirectiveNameModifier::Task) || - "TASKLOOP" >> pure(OmpIfClause::DirectiveNameModifier::Taskloop)) / + "TASKLOOP" >> pure(OmpIfClause::DirectiveNameModifier::Taskloop) || + "TEAMS" >> pure(OmpIfClause::DirectiveNameModifier::Teams)) / ":"), scalarLogicalExpr)) diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -2333,17 +2333,19 @@ using dirNameModifier = parser::OmpIfClause::DirectiveNameModifier; static std::unordered_map dirNameModifierMap{{dirNameModifier::Parallel, llvm::omp::parallelSet}, + {dirNameModifier::Simd, llvm::omp::simdSet}, {dirNameModifier::Target, llvm::omp::targetSet}, + {dirNameModifier::TargetData, + {llvm::omp::Directive::OMPD_target_data}}, {dirNameModifier::TargetEnterData, {llvm::omp::Directive::OMPD_target_enter_data}}, {dirNameModifier::TargetExitData, {llvm::omp::Directive::OMPD_target_exit_data}}, - {dirNameModifier::TargetData, - {llvm::omp::Directive::OMPD_target_data}}, {dirNameModifier::TargetUpdate, {llvm::omp::Directive::OMPD_target_update}}, {dirNameModifier::Task, {llvm::omp::Directive::OMPD_task}}, - {dirNameModifier::Taskloop, llvm::omp::taskloopSet}}; + {dirNameModifier::Taskloop, llvm::omp::taskloopSet}, + {dirNameModifier::Teams, llvm::omp::teamSet}}; if (const auto &directiveName{ std::get>(x.v.t)}) { auto search{dirNameModifierMap.find(*directiveName)};