diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h --- a/flang/include/flang/Parser/parse-tree.h +++ b/flang/include/flang/Parser/parse-tree.h @@ -3308,7 +3308,7 @@ // 2.5 proc-bind-clause -> PROC_BIND (MASTER | CLOSE | SPREAD) struct OmpProcBindClause { - ENUM_CLASS(Type, Close, Master, Spread) + ENUM_CLASS(Type, Close, Master, Spread, Primary) WRAPPER_CLASS_BOILERPLATE(OmpProcBindClause, Type); }; diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -179,75 +179,79 @@ auto currentLocation = converter.getCurrentLocation(); Fortran::lower::StatementContext stmtCtx; llvm::ArrayRef argTy; - if (blockDirective.v == llvm::omp::OMPD_parallel) { - - mlir::Value ifClauseOperand, numThreadsClauseOperand; - Attribute procBindClauseOperand; + mlir::Value ifClauseOperand, numThreadsClauseOperand; + mlir::omp::ClauseProcBindKindAttr procBindKindAttr; + SmallVector allocateOperands, allocatorOperands; + mlir::UnitAttr nowaitAttr; - const auto ¶llelOpClauseList = - std::get(beginBlockDirective.t); - for (const auto &clause : parallelOpClauseList.v) { - if (const auto &ifClause = - std::get_if(&clause.u)) { - auto &expr = - std::get(ifClause->v.t); - ifClauseOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(expr), stmtCtx)); - } else if (const auto &numThreadsClause = - std::get_if( - &clause.u)) { - // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`. - numThreadsClauseOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); + for (const auto &clause : + std::get(beginBlockDirective.t).v) { + if (const auto &ifClause = + std::get_if(&clause.u)) { + auto &expr = std::get(ifClause->v.t); + ifClauseOperand = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); + } else if (const auto &numThreadsClause = + std::get_if( + &clause.u)) { + // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`. + numThreadsClauseOperand = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); + } else if (const auto &procBindClause = + std::get_if( + &clause.u)) { + omp::ClauseProcBindKind pbKind; + switch (procBindClause->v.v) { + case Fortran::parser::OmpProcBindClause::Type::Master: + pbKind = omp::ClauseProcBindKind::Master; + break; + case Fortran::parser::OmpProcBindClause::Type::Close: + pbKind = omp::ClauseProcBindKind::Close; + break; + case Fortran::parser::OmpProcBindClause::Type::Spread: + pbKind = omp::ClauseProcBindKind::Spread; + break; + case Fortran::parser::OmpProcBindClause::Type::Primary: + pbKind = omp::ClauseProcBindKind::Primary; + break; } - // TODO: Handle private, firstprivate, shared and copyin + procBindKindAttr = + omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind); + } else if (const auto &allocateClause = + std::get_if( + &clause.u)) { + genAllocateClause(converter, allocateClause->v, allocatorOperands, + allocateOperands); + } else if (const auto &privateClause = + std::get_if( + &clause.u)) { + // TODO: Handle private. This cannot be a hard TODO because testing for + // allocate clause requires private variables. + } else { + TODO(currentLocation, "OpenMP Block construct clauses"); } + } + + for (const auto &clause : + std::get(endBlockDirective.t).v) { + if (std::get_if(&clause.u)) + nowaitAttr = firOpBuilder.getUnitAttr(); + } + + if (blockDirective.v == llvm::omp::OMPD_parallel) { // Create and insert the operation. auto parallelOp = firOpBuilder.create( currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, - /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(), - /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr, - procBindClauseOperand.dyn_cast_or_null()); - // Handle attribute based clauses. - for (const auto &clause : parallelOpClauseList.v) { - // TODO: Handle default clause - if (const auto &procBindClause = - std::get_if(&clause.u)) { - const auto &ompProcBindClause{procBindClause->v}; - omp::ClauseProcBindKind pbKind; - switch (ompProcBindClause.v) { - case Fortran::parser::OmpProcBindClause::Type::Master: - pbKind = omp::ClauseProcBindKind::Master; - break; - case Fortran::parser::OmpProcBindClause::Type::Close: - pbKind = omp::ClauseProcBindKind::Close; - break; - case Fortran::parser::OmpProcBindClause::Type::Spread: - pbKind = omp::ClauseProcBindKind::Spread; - break; - } - parallelOp.proc_bind_valAttr(omp::ClauseProcBindKindAttr::get( - firOpBuilder.getContext(), pbKind)); - } - } + allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), + /*reductions=*/nullptr, procBindKindAttr); createBodyOfOp(parallelOp, firOpBuilder, currentLocation); } else if (blockDirective.v == llvm::omp::OMPD_master) { auto masterOp = firOpBuilder.create(currentLocation, argTy); createBodyOfOp(masterOp, firOpBuilder, currentLocation); - - // Single Construct } else if (blockDirective.v == llvm::omp::OMPD_single) { - mlir::UnitAttr nowaitAttr; - for (const auto &clause : - std::get(endBlockDirective.t).v) { - if (std::get_if(&clause.u)) - nowaitAttr = firOpBuilder.getUnitAttr(); - // TODO: Handle allocate clause (D122302) - } auto singleOp = firOpBuilder.create( - currentLocation, /*allocate_vars=*/ValueRange(), - /*allocators_vars=*/ValueRange(), nowaitAttr); + currentLocation, allocateOperands, allocatorOperands, nowaitAttr); createBodyOfOp(singleOp, firOpBuilder, currentLocation); } } diff --git a/flang/test/Lower/OpenMP/parallel.f90 b/flang/test/Lower/OpenMP/parallel.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/parallel.f90 @@ -0,0 +1,163 @@ +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes="FIRDialect,OMPDialect" +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="OMPDialect" + +!FIRDialect-LABEL: func @_QPparallel_simple +subroutine parallel_simple() + !OMPDialect: omp.parallel +!$omp parallel + !FIRDialect: fir.call + call f1() +!$omp end parallel +end subroutine parallel_simple + +!=============================================================================== +! `if` clause +!=============================================================================== + +!FIRDialect-LABEL: func @_QPparallel_if +subroutine parallel_if(alpha) + integer, intent(in) :: alpha + + !OMPDialect: omp.parallel if(%{{.*}} : i1) { + !$omp parallel if(alpha .le. 0) + !FIRDialect: fir.call + call f1() + !OMPDialect: omp.terminator + !$omp end parallel + + !OMPDialect: omp.parallel if(%{{.*}} : i1) { + !$omp parallel if(.false.) + !FIRDialect: fir.call + call f2() + !OMPDialect: omp.terminator + !$omp end parallel + + !OMPDialect: omp.parallel if(%{{.*}} : i1) { + !$omp parallel if(alpha .ge. 0) + !FIRDialect: fir.call + call f3() + !OMPDialect: omp.terminator + !$omp end parallel + + !OMPDialect: omp.parallel if(%{{.*}} : i1) { + !$omp parallel if(.true.) + !FIRDialect: fir.call + call f4() + !OMPDialect: omp.terminator + !$omp end parallel + +end subroutine parallel_if + +!=============================================================================== +! `num_threads` clause +!=============================================================================== + +!FIRDialect-LABEL: func @_QPparallel_numthreads +subroutine parallel_numthreads(num_threads) + integer, intent(inout) :: num_threads + + !OMPDialect: omp.parallel num_threads(%{{.*}}: i32) { + !$omp parallel num_threads(16) + !FIRDialect: fir.call + call f1() + !OMPDialect: omp.terminator + !$omp end parallel + + num_threads = 4 + + !OMPDialect: omp.parallel num_threads(%{{.*}} : i32) { + !$omp parallel num_threads(num_threads) + !FIRDialect: fir.call + call f2() + !OMPDialect: omp.terminator + !$omp end parallel + +end subroutine parallel_numthreads + +!=============================================================================== +! `proc_bind` clause +!=============================================================================== + +!FIRDialect-LABEL: func @_QPparallel_proc_bind +subroutine parallel_proc_bind() + + !OMPDialect: omp.parallel proc_bind(master) { + !$omp parallel proc_bind(master) + !FIRDialect: fir.call + call f1() + !OMPDialect: omp.terminator + !$omp end parallel + + !OMPDialect: omp.parallel proc_bind(close) { + !$omp parallel proc_bind(close) + !FIRDialect: fir.call + call f2() + !OMPDialect: omp.terminator + !$omp end parallel + + !OMPDialect: omp.parallel proc_bind(spread) { + !$omp parallel proc_bind(spread) + !FIRDialect: fir.call + call f3() + !OMPDialect: omp.terminator + !$omp end parallel + +end subroutine parallel_proc_bind + +!=============================================================================== +! `allocate` clause +!=============================================================================== + +!FIRDialect-LABEL: func @_QPparallel_allocate +subroutine parallel_allocate() + use omp_lib + integer :: x + !OMPDialect: omp.parallel allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref) { + !$omp parallel allocate(omp_high_bw_mem_alloc: x) private(x) + !FIRDialect: arith.addi + x = x + 12 + !OMPDialect: omp.terminator + !$omp end parallel +end subroutine parallel_allocate + +!=============================================================================== +! multiple clauses +!=============================================================================== + +!FIRDialect-LABEL: func @_QPparallel_multiple_clauses +subroutine parallel_multiple_clauses(alpha, num_threads) + use omp_lib + integer, intent(inout) :: alpha + integer, intent(in) :: num_threads + + !OMPDialect: omp.parallel if({{.*}} : i1) proc_bind(master) { + !$omp parallel if(alpha .le. 0) proc_bind(master) + !FIRDialect: fir.call + call f1() + !OMPDialect: omp.terminator + !$omp end parallel + + !OMPDialect: omp.parallel num_threads({{.*}} : i32) proc_bind(close) { + !$omp parallel proc_bind(close) num_threads(num_threads) + !FIRDialect: fir.call + call f2() + !OMPDialect: omp.terminator + !$omp end parallel + + !OMPDialect: omp.parallel if({{.*}} : i1) num_threads({{.*}} : i32) { + !$omp parallel num_threads(num_threads) if(alpha .le. 0) + !FIRDialect: fir.call + call f3() + !OMPDialect: omp.terminator + !$omp end parallel + + !OMPDialect: omp.parallel if({{.*}} : i1) num_threads({{.*}} : i32) allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref) { + !$omp parallel num_threads(num_threads) if(alpha .le. 0) allocate(omp_high_bw_mem_alloc: alpha) private(alpha) + !FIRDialect: fir.call + call f3() + !FIRDialect: arith.addi + alpha = alpha + 12 + !OMPDialect: omp.terminator + !$omp end parallel + +end subroutine parallel_multiple_clauses diff --git a/flang/test/Lower/OpenMP/single.f90 b/flang/test/Lower/OpenMP/single.f90 --- a/flang/test/Lower/OpenMP/single.f90 +++ b/flang/test/Lower/OpenMP/single.f90 @@ -44,3 +44,23 @@ !OMPDialect: omp.terminator !$omp end parallel end subroutine omp_single_nowait + +!=============================================================================== +! Single construct with allocate +!=============================================================================== + +!FIRDialect-LABEL: func @_QPsingle_allocate +subroutine single_allocate() + use omp_lib + integer :: x + !OMPDialect: omp.parallel { + !$omp parallel + !OMPDialect: omp.single allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref) { + !$omp single allocate(omp_high_bw_mem_alloc: x) private(x) + !FIRDialect: arith.addi + x = x + 12 + !OMPDialect: omp.terminator + !$omp end single + !OMPDialect: omp.terminator + !$omp end parallel +end subroutine single_allocate