diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -13,10 +13,14 @@ #ifndef FORTRAN_LOWER_OPENMP_H #define FORTRAN_LOWER_OPENMP_H +#include + namespace Fortran { namespace parser { struct OpenMPConstruct; struct OpenMPDeclarativeConstruct; +struct OmpEndLoopDirective; +struct OmpClauseList; } // namespace parser namespace lower { @@ -31,6 +35,7 @@ const parser::OpenMPConstruct &); void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &, const parser::OpenMPDeclarativeConstruct &); +int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList); } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -1401,15 +1401,29 @@ void genFIR(const Fortran::parser::OpenMPConstruct &omp) { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); localSymbols.pushScope(); - Fortran::lower::genOpenMPConstruct(*this, getEval(), omp); + genOpenMPConstruct(*this, getEval(), omp); + + const Fortran::parser::OpenMPLoopConstruct *ompLoop = + std::get_if(&omp.u); // If loop is part of an OpenMP Construct then the OpenMP dialect // workshare loop operation has already been created. Only the // body needs to be created here and the do_loop can be skipped. - Fortran::lower::pft::Evaluation *curEval = - std::get_if(&omp.u) - ? &getEval().getFirstNestedEvaluation() - : &getEval(); + // Skip the number of collapsed loops, which is 1 when there is a + // no collapse requested. + + Fortran::lower::pft::Evaluation *curEval = &getEval(); + if (ompLoop) { + const auto &wsLoopOpClauseList = std::get( + std::get(ompLoop->t).t); + int64_t collapseValue = + Fortran::lower::getCollapseValue(wsLoopOpClauseList); + + curEval = &curEval->getFirstNestedEvaluation(); + for (int64_t i = 1; i < collapseValue; i++) { + curEval = &*std::next(curEval->getNestedEvaluations().begin()); + } + } for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) genFIR(e); diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -25,6 +25,18 @@ using namespace mlir; +int64_t Fortran::lower::getCollapseValue( + const Fortran::parser::OmpClauseList &clauseList) { + for (const auto &clause : clauseList.v) { + if (const auto &collapseClause = + std::get_if(&clause.u)) { + const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); + return Fortran::evaluate::ToInt64(*expr).value(); + } + } + return 1; +} + static const Fortran::parser::Name * getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) { const auto *dataRef = std::get_if(&designator.u); @@ -108,22 +120,42 @@ } } +/// Create the body (block) for an OpenMP Operation. +/// +/// \param [in] op - the operation the body belongs to. +/// \param [inout] converter - converter to use for the clauses. +/// \param [in] loc - location in source code. +/// \oaran [in] clauses - list of clauses to process. +/// \param [in] args - block arguments (induction variable[s]) for the +//// region. +/// \param [in] outerCombined - is this an outer operation - prevents +/// privatization. template static void createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc, const Fortran::parser::OmpClauseList *clauses = nullptr, - const Fortran::semantics::Symbol *arg = nullptr, + const SmallVector &args = {}, bool outerCombined = false) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - // If an argument for the region is provided then create the block with that - // argument. Also update the symbol's address with the mlir argument value. - // e.g. For loops the argument is the induction variable. And all further + // If arguments for the region are provided then create the block with those + // arguments. Also update the symbol's address with the mlir argument values. + // e.g. For loops the arguments are the induction variable. And all further // uses of the induction variable should use this mlir value. - if (arg) { - firOpBuilder.createBlock(&op.getRegion(), {}, {converter.genType(*arg)}, - {loc}); - converter.bindSymbol(*arg, op.getRegion().front().getArgument(0)); + if (args.size()) { + SmallVector tiv; + SmallVector locs; + int argIndex = 0; + for (auto &arg : args) { + tiv.push_back(converter.genType(*arg)); + locs.push_back(loc); + } + firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs); + for (auto &arg : args) { + fir::ExtendedValue exval = op.getRegion().front().getArgument(argIndex); + converter.bindSymbol(*arg, exval); + argIndex++; + } } else { firOpBuilder.createBlock(&op.getRegion()); } @@ -394,38 +426,45 @@ TODO(converter.getCurrentLocation(), "Combined worksharing loop construct"); } - Fortran::lower::pft::Evaluation *doConstructEval = - &eval.getFirstNestedEvaluation(); - - Fortran::lower::pft::Evaluation *doLoop = - &doConstructEval->getFirstNestedEvaluation(); - auto *doStmt = doLoop->getIf(); - assert(doStmt && "Expected do loop to be in the nested evaluation"); - const auto &loopControl = - std::get>(doStmt->t); - const Fortran::parser::LoopControl::Bounds *bounds = - std::get_if(&loopControl->u); - assert(bounds && "Expected bounds for worksharing do loop"); - Fortran::semantics::Symbol *iv = nullptr; - Fortran::lower::StatementContext stmtCtx; - lowerBound.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); - upperBound.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); - if (bounds->step) { - step.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); - } else { // If `step` is not present, assume it as `1`. - step.push_back(firOpBuilder.createIntegerConstant( - currentLocation, firOpBuilder.getIntegerType(32), 1)); - } - iv = bounds->name.thing.symbol; + int64_t collapseValue = Fortran::lower::getCollapseValue(wsLoopOpClauseList); + + // Collect the loops to collapse. + auto *doConstructEval = &eval.getFirstNestedEvaluation(); + + SmallVector iv; + do { + auto *doLoop = &doConstructEval->getFirstNestedEvaluation(); + auto *doStmt = doLoop->getIf(); + assert(doStmt && "Expected do loop to be in the nested evaluation"); + const auto &loopControl = + std::get>(doStmt->t); + const Fortran::parser::LoopControl::Bounds *bounds = + std::get_if(&loopControl->u); + if (bounds) { + Fortran::lower::StatementContext stmtCtx; + lowerBound.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); + upperBound.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); + if (bounds->step) { + step.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); + } else { // If `step` is not present, assume it as `1`. + step.push_back(firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getIntegerType(32), 1)); + } + iv.push_back(bounds->name.thing.symbol); + } + + collapseValue--; + doConstructEval = + &*std::next(doConstructEval->getNestedEvaluations().begin()); + } while (collapseValue > 0); // FIXME: Add support for following clauses: // 1. linear // 2. order - // 3. collapse - // 4. schedule (with chunk) + // 3. schedule (with chunk) auto wsLoopOp = firOpBuilder.create( currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars, reductionVars, /*reductions=*/nullptr, @@ -440,8 +479,15 @@ // Handle attribute based clauses. for (const Fortran::parser::OmpClause &clause : wsLoopOpClauseList.v) { - if (const auto &scheduleClause = - std::get_if(&clause.u)) { + if (const auto &collapseClause = + std::get_if(&clause.u)) { + const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); + const std::optional collapseValue = + Fortran::evaluate::ToInt64(*expr); + wsLoopOp.collapse_valAttr(firOpBuilder.getI64IntegerAttr(*collapseValue)); + } else if (const auto &scheduleClause = + std::get_if( + &clause.u)) { mlir::MLIRContext *context = firOpBuilder.getContext(); const auto &scheduleType = scheduleClause->v; const auto &scheduleKind = diff --git a/flang/test/Lower/OpenMP/omp-wsloop-collapse.f90 b/flang/test/Lower/OpenMP/omp-wsloop-collapse.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/omp-wsloop-collapse.f90 @@ -0,0 +1,57 @@ +! This test checks lowering of OpenMP DO Directive(Worksharing) with collapse. + +! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s + +program wsloop_collapse + integer :: i, j, k + integer :: a, b, c + integer :: x +! CHECK: %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFEa"} +! CHECK: %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFEb"} +! CHECK: %[[VAL_2:.*]] = fir.alloca i32 {bindc_name = "c", uniq_name = "_QFEc"} +! CHECK: %[[VAL_3:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"} +! CHECK: %[[VAL_4:.*]] = fir.alloca i32 {bindc_name = "j", uniq_name = "_QFEj"} +! CHECK: %[[VAL_5:.*]] = fir.alloca i32 {bindc_name = "k", uniq_name = "_QFEk"} +! CHECK: %[[VAL_6:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"} + a=3 +! CHECK: %[[VAL_7:.*]] = arith.constant 3 : i32 +! CHECK: fir.store %[[VAL_7]] to %[[VAL_0]] : !fir.ref + b=2 +! CHECK: %[[VAL_8:.*]] = arith.constant 2 : i32 +! CHECK: fir.store %[[VAL_8]] to %[[VAL_1]] : !fir.ref + c=5 +! CHECK: %[[VAL_9:.*]] = arith.constant 5 : i32 +! CHECK: fir.store %[[VAL_9]] to %[[VAL_2]] : !fir.ref + x=0 +! CHECK: %[[VAL_10:.*]] = arith.constant 0 : i32 +! CHECK: fir.store %[[VAL_10]] to %[[VAL_6]] : !fir.ref + + !$omp do collapse(3) +! CHECK: %[[VAL_20:.*]] = arith.constant 1 : i32 +! CHECK: %[[VAL_21:.*]] = fir.load %[[VAL_0]] : !fir.ref +! CHECK: %[[VAL_22:.*]] = arith.constant 1 : i32 +! CHECK: %[[VAL_23:.*]] = arith.constant 1 : i32 +! CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_1]] : !fir.ref +! CHECK: %[[VAL_25:.*]] = arith.constant 1 : i32 +! CHECK: %[[VAL_26:.*]] = arith.constant 1 : i32 +! CHECK: %[[VAL_27:.*]] = fir.load %[[VAL_2]] : !fir.ref +! CHECK: %[[VAL_28:.*]] = arith.constant 1 : i32 + do i = 1, a + do j= 1, b + do k = 1, c +! CHECK: omp.wsloop collapse(3) for (%[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]]) : i32 = (%[[VAL_20]], %[[VAL_23]], %[[VAL_26]]) to (%[[VAL_21]], %[[VAL_24]], %[[VAL_27]]) inclusive step (%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]) { +! CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_6]] : !fir.ref +! CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_9]] : i32 +! CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : i32 +! CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_11]] : i32 +! CHECK: fir.store %[[VAL_15]] to %[[VAL_6]] : !fir.ref +! CHECK: omp.yield +! CHECK: } + x = x + i + j + k + end do + end do + end do + !$omp end do +! CHECK: return +! CHECK: } +end program wsloop_collapse