diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -21,6 +21,7 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" using namespace mlir; @@ -63,6 +64,53 @@ firOpBuilder.setInsertionPointToStart(&block); } +struct MLIRClauses { + MLIRClauses(Fortran::lower::AbstractConverter &converter) + : converter(converter) {} + + mlir::Value ifVal, numThreadsVal; + omp::ClauseProcBindKindAttr pbKindAttr; + + void handleProcBindClause(const Fortran::parser::OmpClause::ProcBind &pbKind, + omp::ClauseProcBindKindAttr &pbKindAttr) { + std::string pbStr = + Fortran::parser::OmpProcBindClause::EnumToString(pbKind.v.v); + std::transform(pbStr.begin(), pbStr.end(), pbStr.begin(), llvm::toLower); + auto pbVal = mlir::omp::symbolizeClauseProcBindKind(pbStr); + if (pbVal) { + pbKindAttr = omp::ClauseProcBindKindAttr::get( + converter.getFirOpBuilder().getContext(), *pbVal); + } + } + + void collect(const Fortran::parser::OmpClauseList &clauseList) { + for (const Fortran::parser::OmpClause &clause : clauseList.v) { + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::OmpClause::If &ifClause) { + auto &expr = + std::get(ifClause.v.t); + ifVal = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(expr), stmtCtx)); + }, + [&](const Fortran::parser::OmpClause::NumThreads + &numThreadsClause) { + numThreadsVal = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numThreadsClause.v), stmtCtx)); + }, + [&](const Fortran::parser::OmpClause::ProcBind &pbKind) { + handleProcBindClause(pbKind, pbKindAttr); + }, + [](const auto &) {}}, + clause.u); + } + } + +private: + Fortran::lower::StatementContext stmtCtx; + Fortran::lower::AbstractConverter &converter; +}; + static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSimpleStandaloneConstruct @@ -142,58 +190,17 @@ auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); - Fortran::lower::StatementContext stmtCtx; llvm::ArrayRef argTy; - if (blockDirective.v == llvm::omp::OMPD_parallel) { - - mlir::Value ifClauseOperand, numThreadsClauseOperand; - Attribute procBindClauseOperand; + if (blockDirective.v == llvm::omp::OMPD_parallel) { const auto ¶llelOpClauseList = std::get(beginBlockDirective.t); - for (const auto &clause : parallelOpClauseList.v) { - if (const auto &ifClause = - std::get_if(&clause.u)) { - auto &expr = - std::get(ifClause->v.t); - ifClauseOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(expr), stmtCtx)); - } else if (const auto &numThreadsClause = - std::get_if( - &clause.u)) { - // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`. - numThreadsClauseOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); - } - // TODO: Handle private, firstprivate, shared and copyin - } + MLIRClauses mlirClauses(converter); + mlirClauses.collect(parallelOpClauseList); // Create and insert the operation. auto parallelOp = firOpBuilder.create( - currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, - ValueRange(), ValueRange(), - procBindClauseOperand.dyn_cast_or_null()); - // Handle attribute based clauses. - for (const auto &clause : parallelOpClauseList.v) { - // TODO: Handle default clause - if (const auto &procBindClause = - std::get_if(&clause.u)) { - const auto &ompProcBindClause{procBindClause->v}; - omp::ClauseProcBindKind pbKind; - switch (ompProcBindClause.v) { - case Fortran::parser::OmpProcBindClause::Type::Master: - pbKind = omp::ClauseProcBindKind::Master; - break; - case Fortran::parser::OmpProcBindClause::Type::Close: - pbKind = omp::ClauseProcBindKind::Close; - break; - case Fortran::parser::OmpProcBindClause::Type::Spread: - pbKind = omp::ClauseProcBindKind::Spread; - break; - } - parallelOp.proc_bind_valAttr(omp::ClauseProcBindKindAttr::get( - firOpBuilder.getContext(), pbKind)); - } - } + currentLocation, argTy, mlirClauses.ifVal, mlirClauses.numThreadsVal, + ValueRange(), ValueRange(), mlirClauses.pbKindAttr); createBodyOfOp(parallelOp, firOpBuilder, currentLocation); } else if (blockDirective.v == llvm::omp::OMPD_master) { auto masterOp = diff --git a/flang/test/Lower/OpenMP/parallel.f90 b/flang/test/Lower/OpenMP/parallel.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/parallel.f90 @@ -0,0 +1,132 @@ +! RUN: bbc -fopenmp -emit-fir -o - %s | FileCheck %s +! RUN: %flang_fc1 -fopenmp -emit-fir -o - %s | FileCheck %s + +subroutine parallel_simple() + ! CHECK: omp.parallel +!$omp parallel + ! CHECK: fir.call + call f1() +!$omp end parallel +end subroutine parallel_simple + +!=============================================================================== +! `if` clause +!=============================================================================== + +subroutine parallel_if(alpha) + integer, intent(in) :: alpha + + ! CHECK: omp.parallel if(%{{.*}} : i1) { + ! CHECK-NEXT: fir.call + ! CHECK-NEXT: omp.terminator +!$omp parallel if(alpha .le. 0) + call f1() +!$omp end parallel + + ! CHECK: omp.parallel if(%{{.*}} : i1) { + ! CHECK-NEXT: fir.call + ! CHECK-NEXT: omp.terminator +!$omp parallel if(.false.) + call f2() +!$omp end parallel + + ! CHECK: omp.parallel if(%{{.*}} : i1) { + ! CHECK-NEXT: fir.call + ! CHECK-NEXT: omp.terminator +!$omp parallel if(alpha .ge. 0) + call f3() +!$omp end parallel + + ! CHECK: omp.parallel if(%{{.*}} : i1) { + ! CHECK-NEXT: fir.call + ! CHECK-NEXT: omp.terminator +!$omp parallel if(.true.) + call f4() +!$omp end parallel + +!$ end subroutine parallel_if + +!=============================================================================== +! `num_threads` clause +!=============================================================================== + +subroutine parallel_numthreads(num_threads) + integer, intent(inout) :: num_threads + + ! CHECK: omp.parallel num_threads(%{{.*}}: i32) { + ! CHECK-NEXT: fir.call + ! CHECK-NEXT: omp.terminator +!$omp parallel num_threads(16) + call f1() +!$omp end parallel + + num_threads = 4 + + ! CHECK: omp.parallel num_threads(%{{.*}} : i32) { + ! CHECK-NEXT: fir.call + ! CHECK-NEXT: omp.terminator +!$omp parallel num_threads(num_threads) + call f2() +!$omp end parallel + +end subroutine parallel_numthreads + +!=============================================================================== +! `proc_bind` clause +!=============================================================================== + +subroutine parallel_proc_bind() + + ! CHECK: omp.parallel proc_bind(master) { + ! CHECK-NEXT: fir.call + ! CHECK-NEXT: omp.terminator +!$omp parallel proc_bind(master) + call f1() +!$omp end parallel + + ! CHECK: omp.parallel proc_bind(close) { + ! CHECK-NEXT: fir.call + ! CHECK-NEXT: omp.terminator +!$omp parallel proc_bind(close) + call f2() +!$omp end parallel + + ! CHECK: omp.parallel proc_bind(spread) { + ! CHECK-NEXT: fir.call + ! CHECK-NEXT: omp.terminator +!$omp parallel proc_bind(spread) + call f3() +!$omp end parallel + +end subroutine parallel_proc_bind + +!=============================================================================== +! multiple clauses +!=============================================================================== + +subroutine parallel_multiple_clauses(alpha, num_threads) + integer, intent(in) :: alpha + integer, intent(in) :: num_threads + + ! CHECK: omp.parallel if({{.*}} : i1) proc_bind(master) { + ! CHECK-NEXT: fir.call + ! CHECK-NEXT: omp.terminator +!$omp parallel if(alpha .le. 0) proc_bind(master) + call f1() +!$omp end parallel + + ! CHECK: omp.parallel num_threads({{.*}} : i32) proc_bind(close) { + ! CHECK-NEXT: fir.call + ! CHECK-NEXT: omp.terminator +!$omp parallel proc_bind(close) num_threads(num_threads) + call f2() +!$omp end parallel + + ! CHECK: omp.parallel if({{.*}} : i1) num_threads({{.*}} : i32) { + ! CHECK-NEXT: fir.call + ! CHECK-NEXT: omp.terminator +!$omp parallel num_threads(num_threads) if(alpha .le. 0) + call f3() +!$omp end parallel + +end subroutine parallel_multiple_clauses