diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -485,6 +485,8 @@ bool processHint(mlir::IntegerAttr &result) const; bool processMergeable(mlir::UnitAttr &result) const; bool processNowait(mlir::UnitAttr &result) const; + bool processNumTeams(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; bool processNumThreads(Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const; bool processOrdered(mlir::IntegerAttr &result) const; @@ -1336,6 +1338,16 @@ return markClauseOccurrence(result); } +bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const { + if (auto *numTeamsClause = findUniqueClause()) { + result = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numTeamsClause->v), stmtCtx)); + return true; + } + return false; +} + bool ClauseProcessor::processNumThreads( Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { if (auto *numThreadsClause = findUniqueClause()) { @@ -2335,6 +2347,39 @@ mapOperands, mapTypesArrayAttr); } +static mlir::omp::TeamsOp +genTeamsOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + mlir::Location currentLocation, + const Fortran::parser::OmpClauseList &clauseList, + bool outerCombined = false) { + Fortran::lower::StatementContext stmtCtx; + mlir::Value numTeamsClauseOperand, ifClauseOperand, threadLimitClauseOperand; + llvm::SmallVector allocateOperands, allocatorOperands, + reductionVars; + llvm::SmallVector reductionDeclSymbols; + + ClauseProcessor cp(converter, clauseList); + cp.processIf(stmtCtx, + Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams, + ifClauseOperand); + cp.processAllocate(allocatorOperands, allocateOperands); + cp.processDefault(); + cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols); + cp.processNumTeams(stmtCtx, numTeamsClauseOperand); + cp.processThreadLimit(stmtCtx, threadLimitClauseOperand); + + return genOpWithBody( + converter, eval, currentLocation, outerCombined, &clauseList, + /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand, + threadLimitClauseOperand, allocateOperands, allocatorOperands, + reductionVars, + reductionDeclSymbols.empty() + ? nullptr + : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), + reductionDeclSymbols)); +} + //===----------------------------------------------------------------------===// // genOMP() Code generation helper functions //===----------------------------------------------------------------------===// @@ -2535,7 +2580,8 @@ } if (teamsSet.test(ompDirective)) { validDirective = true; - TODO(currentLocation, "Teams construct"); + genTeamsOp(converter, eval, currentLocation, loopOpClauseList, + /*outerCombined=*/true); } if (distributeSet.test(ompDirective)) { validDirective = true; @@ -2677,7 +2723,6 @@ const auto &endClauseList = std::get(endBlockDirective.t); - Fortran::lower::StatementContext stmtCtx; mlir::Location currentLocation = converter.genLocation(directive.source); // Codegen for combined directives @@ -2688,7 +2733,8 @@ combinedDirective = true; } if (teamsSet.test(directive.v)) { - TODO(currentLocation, "Teams construct"); + genTeamsOp(converter, eval, currentLocation, beginClauseList, + /*outerCombined=*/false); combinedDirective = true; } if (parallelSet.test(directive.v)) { @@ -2731,6 +2777,8 @@ genTaskGroupOp(converter, eval, currentLocation, beginClauseList); break; case Directive::OMPD_teams: + genTeamsOp(converter, eval, currentLocation, beginClauseList); + break; case Directive::OMPD_workshare: default: TODO(currentLocation, "Unhandled block directive"); diff --git a/flang/test/Lower/OpenMP/if-clause.f90 b/flang/test/Lower/OpenMP/if-clause.f90 --- a/flang/test/Lower/OpenMP/if-clause.f90 +++ b/flang/test/Lower/OpenMP/if-clause.f90 @@ -13,7 +13,6 @@ ! - PARALLEL SECTIONS ! - PARALLEL WORKSHARE ! - TARGET PARALLEL - ! - TARGET TEAMS ! - TARGET TEAMS DISTRIBUTE ! - TARGET TEAMS DISTRIBUTE PARALLEL DO ! - TARGET TEAMS DISTRIBUTE PARALLEL DO SIMD @@ -21,7 +20,6 @@ ! - TARGET UPDATE ! - TASKLOOP ! - TASKLOOP SIMD - ! - TEAMS ! - TEAMS DISTRIBUTE ! - TEAMS DISTRIBUTE PARALLEL DO ! - TEAMS DISTRIBUTE PARALLEL DO SIMD @@ -416,6 +414,54 @@ end do !$omp end target simd + ! ---------------------------------------------------------------------------- + ! TARGET TEAMS + ! ---------------------------------------------------------------------------- + + ! CHECK: omp.target + ! CHECK-NOT: if({{.*}}) + ! CHECK-SAME: { + ! CHECK: omp.teams + ! CHECK-NOT: if({{.*}}) + ! CHECK-SAME: { + !$omp target teams + i = 1 + !$omp end target teams + + ! CHECK: omp.target + ! CHECK-SAME: if({{.*}}) + ! CHECK: omp.teams + ! CHECK-SAME: if({{.*}}) + !$omp target teams if(.true.) + i = 1 + !$omp end target teams + + ! CHECK: omp.target + ! CHECK-SAME: if({{.*}}) + ! CHECK: omp.teams + ! CHECK-SAME: if({{.*}}) + !$omp target teams if(target: .true.) if(teams: .false.) + i = 1 + !$omp end target teams + + ! CHECK: omp.target + ! CHECK-SAME: if({{.*}}) + ! CHECK: omp.teams + ! CHECK-NOT: if({{.*}}) + ! CHECK-SAME: { + !$omp target teams if(target: .true.) + i = 1 + !$omp end target teams + + ! CHECK: omp.target + ! CHECK-NOT: if({{.*}}) + ! CHECK-SAME: { + ! CHECK: omp.teams + ! CHECK-SAME: if({{.*}}) + !$omp target teams if(teams: .true.) + i = 1 + !$omp end target teams + ! ---------------------------------------------------------------------------- ! TASK ! ---------------------------------------------------------------------------- @@ -434,4 +480,26 @@ ! CHECK-SAME: if({{.*}}) !$omp task if(task: .true.) !$omp end task + + ! ---------------------------------------------------------------------------- + ! TEAMS + ! ---------------------------------------------------------------------------- + ! CHECK: omp.teams + ! CHECK-NOT: if({{.*}}) + ! CHECK-SAME: { + !$omp teams + i = 1 + !$omp end teams + + ! CHECK: omp.teams + ! CHECK-SAME: if({{.*}}) + !$omp teams if(.true.) + i = 1 + !$omp end teams + + ! CHECK: omp.teams + ! CHECK-SAME: if({{.*}}) + !$omp teams if(teams: .true.) + i = 1 + !$omp end teams end program main diff --git a/flang/test/Lower/OpenMP/teams-reduction-add.f90 b/flang/test/Lower/OpenMP/teams-reduction-add.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/teams-reduction-add.f90 @@ -0,0 +1,18 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! CHECK: omp.reduction.declare @[[REDUCTION_NAME:.*]] : i32 init + +! CHECK: func @_QPteams_reduction_add +subroutine teams_reduction_add() + ! CHECK: %[[I:.*]] = fir.alloca i32 + integer :: i + i = 0 + + ! CHECK: omp.teams + ! CHECK-SAME: reduction(@[[REDUCTION_NAME]] -> %[[I]] : !fir.ref) + !$omp teams reduction(+:i) + ! CHECK: omp.reduction %{{.*}}, %[[I]] : i32, !fir.ref + ! CHECK-NEXT: omp.terminator + i = i + 1 + !$omp end teams +end subroutine teams_reduction_add diff --git a/flang/test/Lower/OpenMP/teams.f90 b/flang/test/Lower/OpenMP/teams.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/teams.f90 @@ -0,0 +1,114 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPteams_simple +subroutine teams_simple() + ! CHECK: omp.teams + !$omp teams + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + !$omp end teams +end subroutine teams_simple + +!=============================================================================== +! `num_teams` clause +!=============================================================================== + +! CHECK-LABEL: func @_QPteams_numteams +subroutine teams_numteams(num_teams) + integer, intent(inout) :: num_teams + + ! CHECK: omp.teams + ! CHECK-SAME: num_teams( to %{{.*}}: i32) + !$omp teams num_teams(4) + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + !$omp end teams + + ! CHECK: omp.teams + ! CHECK-SAME: num_teams( to %{{.*}}: i32) + !$omp teams num_teams(num_teams) + ! CHECK: fir.call + call f2() + ! CHECK: omp.terminator + !$omp end teams + +end subroutine teams_numteams + +!=============================================================================== +! `if` clause +!=============================================================================== + +! CHECK-LABEL: func @_QPteams_if +subroutine teams_if(alpha) + integer, intent(in) :: alpha + logical :: condition + + ! CHECK: omp.teams + ! CHECK-SAME: if(%{{.*}}) + !$omp teams if(.false.) + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + !$omp end teams + + ! CHECK: omp.teams + ! CHECK-SAME: if(%{{.*}}) + !$omp teams if(alpha .le. 0) + ! CHECK: fir.call + call f2() + ! CHECK: omp.terminator + !$omp end teams + + ! CHECK: omp.teams + ! CHECK-SAME: if(%{{.*}}) + !$omp teams if(condition) + ! CHECK: fir.call + call f3() + ! CHECK: omp.terminator + !$omp end teams +end subroutine teams_if + +!=============================================================================== +! `thread_limit` clause +!=============================================================================== + +! CHECK-LABEL: func @_QPteams_threadlimit +subroutine teams_threadlimit(thread_limit) + integer, intent(inout) :: thread_limit + + ! CHECK: omp.teams + ! CHECK-SAME: thread_limit(%{{.*}}: i32) + !$omp teams thread_limit(4) + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + !$omp end teams + + ! CHECK: omp.teams + ! CHECK-SAME: thread_limit(%{{.*}}: i32) + !$omp teams thread_limit(thread_limit) + ! CHECK: fir.call + call f2() + ! CHECK: omp.terminator + !$omp end teams + +end subroutine teams_threadlimit + +!=============================================================================== +! `allocate` clause +!=============================================================================== + +! CHECK-LABEL: func @_QPteams_allocate +subroutine teams_allocate() + use omp_lib + integer :: x + ! CHECK: omp.teams + ! CHECK-SAME: allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref) + !$omp teams allocate(omp_high_bw_mem_alloc: x) private(x) + ! CHECK: arith.addi + x = x + 12 + ! CHECK: omp.terminator + !$omp end teams +end subroutine teams_allocate