diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -1351,7 +1351,15 @@ localSymbols.pushScope(); Fortran::lower::genOpenMPConstruct(*this, getEval(), omp); - for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) + // If loop is part of an OpenMP Construct then the OpenMP dialect + // workshare loop operation has already been created. Only the + // body needs to be created here and the do_loop can be skipped. + Fortran::lower::pft::Evaluation *curEval = + std::get_if(&omp.u) + ? &getEval().getFirstNestedEvaluation() + : &getEval(); + + for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) genFIR(e); localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); @@ -2243,14 +2251,7 @@ void genFIR(const Fortran::parser::IfStmt &) {} // nop void genFIR(const Fortran::parser::IfThenStmt &) {} // nop void genFIR(const Fortran::parser::NonLabelDoStmt &) {} // nop - - void genFIR(const Fortran::parser::OmpEndLoopDirective &) { - TODO(toLocation(), "OmpEndLoopDirective lowering"); - } - - void genFIR(const Fortran::parser::NamelistStmt &) { - TODO(toLocation(), "NamelistStmt lowering"); - } + void genFIR(const Fortran::parser::OmpEndLoopDirective &) {} // nop /// Generate FIR for the Evaluation `eval`. void genFIR(Fortran::lower::pft::Evaluation &eval, diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -113,13 +113,31 @@ createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc, const Fortran::parser::OmpClauseList *clauses = nullptr, + const Fortran::semantics::Symbol *arg = nullptr, bool outerCombined = false) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - firOpBuilder.createBlock(&op.getRegion()); + // If an argument for the region is provided then create the block with that + // argument. Also update the symbol's address with the mlir argument value. + // e.g. For loops the argument is the induction variable. And all further + // uses of the induction variable should use this mlir value. + if (arg) { + firOpBuilder.createBlock(&op.getRegion(), {}, {converter.genType(*arg)}, + {loc}); + converter.bindSymbol(*arg, op.getRegion().front().getArgument(0)); + } else { + firOpBuilder.createBlock(&op.getRegion()); + } auto &block = op.getRegion().back(); firOpBuilder.setInsertionPointToStart(&block); - // Ensure the block is well-formed. - firOpBuilder.create(loc); + + // Insert the terminator. + if constexpr (std::is_same_v) { + mlir::ValueRange results; + firOpBuilder.create(loc, results); + } else { + firOpBuilder.create(loc); + } + // Reset the insertion point to the start of the first block. firOpBuilder.setInsertionPointToStart(&block); // Handle privatization. Do not privatize if this is the outer operation. @@ -315,7 +333,7 @@ allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr, procBindKindAttr); createBodyOfOp(parallelOp, converter, currentLocation, - &opClauseList, /*isCombined=*/false); + &opClauseList); } else if (blockDirective.v == llvm::omp::OMPD_master) { auto masterOp = firOpBuilder.create(currentLocation, argTy); @@ -333,6 +351,122 @@ } } +static void genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { + + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + llvm::SmallVector lowerBound, upperBound, step, linearVars, + linearStepVars, reductionVars; + mlir::Value scheduleChunkClauseOperand; + mlir::Attribute scheduleClauseOperand, collapseClauseOperand, + noWaitClauseOperand, orderedClauseOperand, orderClauseOperand; + const auto &wsLoopOpClauseList = std::get( + std::get(loopConstruct.t).t); + if (llvm::omp::OMPD_do != + std::get( + std::get(loopConstruct.t).t) + .v) { + TODO(converter.getCurrentLocation(), "Combined worksharing loop construct"); + } + + Fortran::lower::pft::Evaluation *doConstructEval = + &eval.getFirstNestedEvaluation(); + + Fortran::lower::pft::Evaluation *doLoop = + &doConstructEval->getFirstNestedEvaluation(); + auto *doStmt = doLoop->getIf(); + assert(doStmt && "Expected do loop to be in the nested evaluation"); + const auto &loopControl = + std::get>(doStmt->t); + const Fortran::parser::LoopControl::Bounds *bounds = + std::get_if(&loopControl->u); + assert(bounds && "Expected bounds for worksharing do loop"); + Fortran::semantics::Symbol *iv = nullptr; + Fortran::lower::StatementContext stmtCtx; + lowerBound.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); + upperBound.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); + if (bounds->step) { + step.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); + } else { // If `step` is not present, assume it as `1`. + step.push_back(firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getIntegerType(32), 1)); + } + iv = bounds->name.thing.symbol; + + // FIXME: Add support for following clauses: + // 1. linear + // 2. order + // 3. collapse + // 4. schedule (with chunk) + auto wsLoopOp = firOpBuilder.create( + currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars, + reductionVars, /*reductions=*/nullptr, + scheduleClauseOperand.dyn_cast_or_null(), + scheduleChunkClauseOperand, /*schedule_modifiers=*/nullptr, + /*simd_modifier=*/nullptr, + collapseClauseOperand.dyn_cast_or_null(), + noWaitClauseOperand.dyn_cast_or_null(), + orderedClauseOperand.dyn_cast_or_null(), + orderClauseOperand.dyn_cast_or_null(), + /*inclusive=*/firOpBuilder.getUnitAttr()); + + // Handle attribute based clauses. + for (const Fortran::parser::OmpClause &clause : wsLoopOpClauseList.v) { + if (const auto &scheduleClause = + std::get_if(&clause.u)) { + mlir::MLIRContext *context = firOpBuilder.getContext(); + const auto &scheduleType = scheduleClause->v; + const auto &scheduleKind = + std::get( + scheduleType.t); + switch (scheduleKind) { + case Fortran::parser::OmpScheduleClause::ScheduleType::Static: + wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( + context, omp::ClauseScheduleKind::Static)); + break; + case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic: + wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( + context, omp::ClauseScheduleKind::Dynamic)); + break; + case Fortran::parser::OmpScheduleClause::ScheduleType::Guided: + wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( + context, omp::ClauseScheduleKind::Guided)); + break; + case Fortran::parser::OmpScheduleClause::ScheduleType::Auto: + wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( + context, omp::ClauseScheduleKind::Auto)); + break; + case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime: + wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( + context, omp::ClauseScheduleKind::Runtime)); + break; + } + } + } + // In FORTRAN `nowait` clause occur at the end of `omp do` directive. + // i.e + // !$omp do + // <...> + // !$omp end do nowait + if (const auto &endClauseList = + std::get>( + loopConstruct.t)) { + const auto &clauseList = + std::get((*endClauseList).t); + for (const Fortran::parser::OmpClause &clause : clauseList.v) + if (std::get_if(&clause.u)) + wsLoopOp.nowaitAttr(firOpBuilder.getUnitAttr()); + } + + createBodyOfOp(wsLoopOp, converter, currentLocation, + &wsLoopOpClauseList, iv); +} + static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -612,7 +746,7 @@ genOMP(converter, eval, sectionConstruct); }, [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPLoopConstruct"); + genOMP(converter, eval, loopConstruct); }, [&](const Fortran::parser::OpenMPDeclarativeAllocate &execAllocConstruct) { diff --git a/flang/test/Lower/OpenMP/omp-wsloop.f90 b/flang/test/Lower/OpenMP/omp-wsloop.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/omp-wsloop.f90 @@ -0,0 +1,63 @@ +! This test checks lowering of OpenMP DO Directive (Worksharing). + +! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s --check-prefixes="FIRDialect,OMPDialect" + +!FIRDialect: func @_QPsimple_loop() +subroutine simple_loop + integer :: i + ! OMPDialect: omp.parallel + !$OMP PARALLEL + ! FIRDialect: %[[WS_LB:.*]] = arith.constant 1 : i32 + ! FIRDialect: %[[WS_UB:.*]] = arith.constant 9 : i32 + ! FIRDialect: %[[WS_STEP:.*]] = arith.constant 1 : i32 + ! OMPDialect: omp.wsloop for (%[[I:.*]]) : i32 = (%[[WS_LB]]) to (%[[WS_UB]]) inclusive step (%[[WS_STEP]]) + !$OMP DO + do i=1, 9 + ! FIRDialect: fir.call @_FortranAioOutputInteger32({{.*}}, %[[I]]) : (!fir.ref, i32) -> i1 + print*, i + end do + ! OMPDialect: omp.yield + !$OMP END DO + ! OMPDialect: omp.terminator + !$OMP END PARALLEL +end subroutine + +!FIRDialect: func @_QPsimple_loop_with_step() +subroutine simple_loop_with_step + integer :: i + ! OMPDialect: omp.parallel + !$OMP PARALLEL + ! FIRDialect: %[[WS_LB:.*]] = arith.constant 1 : i32 + ! FIRDialect: %[[WS_UB:.*]] = arith.constant 9 : i32 + ! FIRDialect: %[[WS_STEP:.*]] = arith.constant 2 : i32 + ! OMPDialect: omp.wsloop for (%[[I:.*]]) : i32 = (%[[WS_LB]]) to (%[[WS_UB]]) inclusive step (%[[WS_STEP]]) + !$OMP DO + do i=1, 9, 2 + ! FIRDialect: fir.call @_FortranAioOutputInteger32({{.*}}, %[[I]]) : (!fir.ref, i32) -> i1 + print*, i + end do + ! OMPDialect: omp.yield + !$OMP END DO + ! OMPDialect: omp.terminator + !$OMP END PARALLEL +end subroutine + +!FIRDialect: func @_QPloop_with_schedule_nowait() +subroutine loop_with_schedule_nowait + integer :: i + ! OMPDialect: omp.parallel + !$OMP PARALLEL + ! FIRDialect: %[[WS_LB:.*]] = arith.constant 1 : i32 + ! FIRDialect: %[[WS_UB:.*]] = arith.constant 9 : i32 + ! FIRDialect: %[[WS_STEP:.*]] = arith.constant 1 : i32 + ! OMPDialect: omp.wsloop schedule(runtime) nowait for (%[[I:.*]]) : i32 = (%[[WS_LB]]) to (%[[WS_UB]]) inclusive step (%[[WS_STEP]]) + !$OMP DO SCHEDULE(runtime) + do i=1, 9 + ! FIRDialect: fir.call @_FortranAioOutputInteger32({{.*}}, %[[I]]) : (!fir.ref, i32) -> i1 + print*, i + end do + ! OMPDialect: omp.yield + !$OMP END DO NOWAIT + ! OMPDialect: omp.terminator + !$OMP END PARALLEL +end subroutine