diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -246,6 +246,75 @@ standaloneConstruct.u); } +/* When parallel is used in a combined construct, then use this function to + * create the parallel operation. It handles the parallel specific clauses + * and leaves the rest for handling at the inner operations. + * TODO: Refactor clause handling + */ +template +static void +createCombinedParallelOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Directive &directive) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; + llvm::ArrayRef argTy; + mlir::Value ifClauseOperand, numThreadsClauseOperand; + SmallVector allocatorOperands, allocateOperands; + mlir::omp::ClauseProcBindKindAttr procBindKindAttr; + const auto &opClauseList = + std::get(directive.t); + // TODO: Handle the following clauses + // 1. default + // 2. copyin + // Note: rest of the clauses are handled when the inner operation is created + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { + if (const auto &ifClause = + std::get_if(&clause.u)) { + auto &expr = std::get(ifClause->v.t); + mlir::Value ifVal = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); + ifClauseOperand = firOpBuilder.createConvert( + currentLocation, firOpBuilder.getI1Type(), ifVal); + } else if (const auto &numThreadsClause = + std::get_if( + &clause.u)) { + numThreadsClauseOperand = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); + } else if (const auto &procBindClause = + std::get_if( + &clause.u)) { + omp::ClauseProcBindKind pbKind; + switch (procBindClause->v.v) { + case Fortran::parser::OmpProcBindClause::Type::Master: + pbKind = omp::ClauseProcBindKind::Master; + break; + case Fortran::parser::OmpProcBindClause::Type::Close: + pbKind = omp::ClauseProcBindKind::Close; + break; + case Fortran::parser::OmpProcBindClause::Type::Spread: + pbKind = omp::ClauseProcBindKind::Spread; + break; + case Fortran::parser::OmpProcBindClause::Type::Primary: + pbKind = omp::ClauseProcBindKind::Primary; + break; + } + procBindKindAttr = + omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind); + } + } + // Create and insert the operation. + auto parallelOp = firOpBuilder.create( + currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, + allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), + /*reductions=*/nullptr, procBindKindAttr); + + createBodyOfOp(parallelOp, converter, currentLocation, + &opClauseList, /*iv=*/nullptr, + /*isCombined=*/true); +} + static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -385,12 +454,16 @@ mlir::Value scheduleChunkClauseOperand; mlir::Attribute scheduleClauseOperand, collapseClauseOperand, noWaitClauseOperand, orderedClauseOperand, orderClauseOperand; - const auto &wsLoopOpClauseList = std::get( - std::get(loopConstruct.t).t); - if (llvm::omp::OMPD_do != + + const auto ompDirective = std::get( std::get(loopConstruct.t).t) - .v) { + .v; + if (llvm::omp::OMPD_parallel_do == ompDirective) { + createCombinedParallelOp( + converter, eval, + std::get(loopConstruct.t)); + } else if (llvm::omp::OMPD_do != ompDirective) { TODO(converter.getCurrentLocation(), "Combined worksharing loop construct"); } @@ -438,6 +511,8 @@ orderClauseOperand.dyn_cast_or_null(), /*inclusive=*/firOpBuilder.getUnitAttr()); + const auto &wsLoopOpClauseList = std::get( + std::get(loopConstruct.t).t); // Handle attribute based clauses. for (const Fortran::parser::OmpClause &clause : wsLoopOpClauseList.v) { if (const auto &scheduleClause = diff --git a/flang/test/Lower/OpenMP/omp-parallel-wsloop.f90 b/flang/test/Lower/OpenMP/omp-parallel-wsloop.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/omp-parallel-wsloop.f90 @@ -0,0 +1,66 @@ +! This test checks lowering of OpenMP DO Directive (Worksharing). + +! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPsimple_parallel_do() +subroutine simple_parallel_do + integer :: i + ! CHECK: omp.parallel + ! CHECK: %[[WS_LB:.*]] = arith.constant 1 : i32 + ! CHECK: %[[WS_UB:.*]] = arith.constant 9 : i32 + ! CHECK: %[[WS_STEP:.*]] = arith.constant 1 : i32 + ! CHECK: omp.wsloop for (%[[I:.*]]) : i32 = (%[[WS_LB]]) to (%[[WS_UB]]) inclusive step (%[[WS_STEP]]) + !$OMP PARALLEL DO + do i=1, 9 + ! CHECK: fir.call @_FortranAioOutputInteger32({{.*}}, %[[I]]) : (!fir.ref, i32) -> i1 + print*, i + end do + ! CHECK: omp.yield + ! CHECK: omp.terminator + !$OMP END PARALLEL DO +end subroutine + +! CHECK-LABEL: func @_QPparallel_do_with_parallel_clauses +! CHECK-SAME: %[[COND_REF:.*]]: !fir.ref> {fir.bindc_name = "cond"}, %[[NT_REF:.*]]: !fir.ref {fir.bindc_name = "nt"} +subroutine parallel_do_with_parallel_clauses(cond, nt) + logical :: cond + integer :: nt + integer :: i + ! CHECK: %[[COND:.*]] = fir.load %[[COND_REF]] : !fir.ref> + ! CHECK: %[[COND_CVT:.*]] = fir.convert %[[COND]] : (!fir.logical<4>) -> i1 + ! CHECK: %[[NT:.*]] = fir.load %[[NT_REF]] : !fir.ref + ! CHECK: omp.parallel if(%[[COND_CVT]] : i1) num_threads(%[[NT]] : i32) proc_bind(close) + ! CHECK: %[[WS_LB:.*]] = arith.constant 1 : i32 + ! CHECK: %[[WS_UB:.*]] = arith.constant 9 : i32 + ! CHECK: %[[WS_STEP:.*]] = arith.constant 1 : i32 + ! CHECK: omp.wsloop for (%[[I:.*]]) : i32 = (%[[WS_LB]]) to (%[[WS_UB]]) inclusive step (%[[WS_STEP]]) + !$OMP PARALLEL DO IF(cond) NUM_THREADS(nt) PROC_BIND(close) + do i=1, 9 + ! CHECK: fir.call @_FortranAioOutputInteger32({{.*}}, %[[I]]) : (!fir.ref, i32) -> i1 + print*, i + end do + ! CHECK: omp.yield + ! CHECK: omp.terminator + !$OMP END PARALLEL DO +end subroutine + +! CHECK-LABEL: func @_QPparallel_do_with_clauses +! CHECK-SAME: %[[NT_REF:.*]]: !fir.ref {fir.bindc_name = "nt"} +subroutine parallel_do_with_clauses(nt) + integer :: nt + integer :: i + ! CHECK: %[[NT:.*]] = fir.load %[[NT_REF]] : !fir.ref + ! CHECK: omp.parallel num_threads(%[[NT]] : i32) + ! CHECK: %[[WS_LB:.*]] = arith.constant 1 : i32 + ! CHECK: %[[WS_UB:.*]] = arith.constant 9 : i32 + ! CHECK: %[[WS_STEP:.*]] = arith.constant 1 : i32 + ! CHECK: omp.wsloop schedule(dynamic) for (%[[I:.*]]) : i32 = (%[[WS_LB]]) to (%[[WS_UB]]) inclusive step (%[[WS_STEP]]) + !$OMP PARALLEL DO NUM_THREADS(nt) SCHEDULE(dynamic) + do i=1, 9 + ! CHECK: fir.call @_FortranAioOutputInteger32({{.*}}, %[[I]]) : (!fir.ref, i32) -> i1 + print*, i + end do + ! CHECK: omp.yield + ! CHECK: omp.terminator + !$OMP END PARALLEL DO +end subroutine