diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -127,6 +127,124 @@ privatizeVars(converter, *clauses); } +struct OpenMPMLIRClauses { + mlir::Value ifExpr, numThreadsExpr; + mlir::omp::ClauseProcBindKindAttr procBindKind; + SmallVector allocatorOperands, allocateOperands; + mlir::UnitAttr nowait; + uint64_t hint; + + OpenMPMLIRClauses(Fortran::lower::AbstractConverter &converter) + : converter(converter) {} + + void genIfClause(const Fortran::parser::OmpClause::If &ifExpr, + mlir::Value &mlirIfExpr) { + auto &expr = std::get(ifExpr.v.t); + mlirIfExpr = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); + } + + void genNumThreadsClause( + const Fortran::parser::OmpClause::NumThreads &numThreadsExpr, + mlir::Value &mlirNumThreadsExpr) { + mlirNumThreadsExpr = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numThreadsExpr.v), stmtCtx)); + } + + void + genProcBindClause(const Fortran::parser::OmpClause::ProcBind &procBindKind, + mlir::omp::ClauseProcBindKindAttr &mlirProcBindKind) { + omp::ClauseProcBindKind pbKind; + switch (procBindKind.v.v) { + case Fortran::parser::OmpProcBindClause::Type::Master: + pbKind = omp::ClauseProcBindKind::Master; + break; + case Fortran::parser::OmpProcBindClause::Type::Close: + pbKind = omp::ClauseProcBindKind::Close; + break; + case Fortran::parser::OmpProcBindClause::Type::Spread: + pbKind = omp::ClauseProcBindKind::Spread; + break; + case Fortran::parser::OmpProcBindClause::Type::Primary: + pbKind = omp::ClauseProcBindKind::Primary; + break; + } + mlirProcBindKind = omp::ClauseProcBindKindAttr::get( + converter.getFirOpBuilder().getContext(), pbKind); + } + + void genAllocateClause( + const Fortran::parser::OmpClause::Allocate &ompAllocateClause, + SmallVector &mlirAllocatorOperands, + SmallVector &mlirAllocateOperands) { + auto &firOpBuilder = converter.getFirOpBuilder(); + auto currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; + + mlir::Value allocatorOperand; + const Fortran::parser::OmpObjectList &ompObjectList = + std::get(ompAllocateClause.v.t); + const auto &allocatorValue = + std::get>( + ompAllocateClause.v.t); + // Check if allocate clause has allocator specified. If so, add it + // to list of allocators, otherwise, add default allocator to + // list of allocators. + if (allocatorValue) { + allocatorOperand = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(allocatorValue->v), stmtCtx)); + mlirAllocatorOperands.insert(mlirAllocatorOperands.end(), + ompObjectList.v.size(), allocatorOperand); + } else { + allocatorOperand = firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getI32Type(), 1); + mlirAllocatorOperands.insert(mlirAllocatorOperands.end(), + ompObjectList.v.size(), allocatorOperand); + } + genObjectList(ompObjectList, converter, mlirAllocateOperands); + } + + void collect(const Fortran::parser::OmpClauseList &clauseList) { + for (const Fortran::parser::OmpClause &clause : clauseList.v) { + std::visit( + Fortran::common::visitors{ + [this](const Fortran::parser::OmpClause::If &ifExpr) { + genIfClause(ifExpr, this->ifExpr); + }, + [this](const Fortran::parser::OmpClause::NumThreads + &numThreadsExpr) { + genNumThreadsClause(numThreadsExpr, this->numThreadsExpr); + }, + [this](const Fortran::parser::OmpClause::ProcBind &procBindKind) { + genProcBindClause(procBindKind, this->procBindKind); + }, + [this]( + const Fortran::parser::OmpClause::Allocate &allocateClause) { + genAllocateClause(allocateClause, this->allocatorOperands, + this->allocateOperands); + }, + [this](const Fortran::parser::OmpClause::Nowait &) { + this->nowait = converter.getFirOpBuilder().getUnitAttr(); + }, + [this](const Fortran::parser::OmpClause::Hint &hint) { + const auto *expr = Fortran::semantics::GetExpr(hint.v); + this->hint = *Fortran::evaluate::ToInt64(*expr); + }, + [](const Fortran::parser::OmpClause::Private &) {}, + [](const Fortran::parser::OmpClause::Firstprivate &) {}, + [](const Fortran::parser::OmpClause::Threads &) {}, + [this](const auto &) { + TODO(converter.getCurrentLocation(), "Unhandled clause"); + }}, + clause.u); + } + } + +private: + Fortran::lower::StatementContext stmtCtx; + Fortran::lower::AbstractConverter &converter; +}; + static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSimpleStandaloneConstruct @@ -160,38 +278,6 @@ } } -static void -genAllocateClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpAllocateClause &ompAllocateClause, - SmallVector &allocatorOperands, - SmallVector &allocateOperands) { - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); - Fortran::lower::StatementContext stmtCtx; - - mlir::Value allocatorOperand; - const Fortran::parser::OmpObjectList &ompObjectList = - std::get(ompAllocateClause.t); - const auto &allocatorValue = - std::get>( - ompAllocateClause.t); - // Check if allocate clause has allocator specified. If so, add it - // to list of allocators, otherwise, add default allocator to - // list of allocators. - if (allocatorValue) { - allocatorOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(allocatorValue->v), stmtCtx)); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); - } else { - allocatorOperand = firOpBuilder.createIntegerConstant( - currentLocation, firOpBuilder.getI32Type(), 1); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); - } - genObjectList(ompObjectList, converter, allocateOperands); -} - static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -243,75 +329,22 @@ Fortran::lower::StatementContext stmtCtx; llvm::ArrayRef argTy; - mlir::Value ifClauseOperand, numThreadsClauseOperand; - mlir::omp::ClauseProcBindKindAttr procBindKindAttr; - SmallVector allocateOperands, allocatorOperands; - mlir::UnitAttr nowaitAttr; const auto &opClauseList = std::get(beginBlockDirective.t); - for (const auto &clause : opClauseList.v) { - if (const auto &ifClause = - std::get_if(&clause.u)) { - auto &expr = std::get(ifClause->v.t); - ifClauseOperand = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); - } else if (const auto &numThreadsClause = - std::get_if( - &clause.u)) { - // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`. - numThreadsClauseOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); - } else if (const auto &procBindClause = - std::get_if( - &clause.u)) { - omp::ClauseProcBindKind pbKind; - switch (procBindClause->v.v) { - case Fortran::parser::OmpProcBindClause::Type::Master: - pbKind = omp::ClauseProcBindKind::Master; - break; - case Fortran::parser::OmpProcBindClause::Type::Close: - pbKind = omp::ClauseProcBindKind::Close; - break; - case Fortran::parser::OmpProcBindClause::Type::Spread: - pbKind = omp::ClauseProcBindKind::Spread; - break; - case Fortran::parser::OmpProcBindClause::Type::Primary: - pbKind = omp::ClauseProcBindKind::Primary; - break; - } - procBindKindAttr = - omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind); - } else if (const auto &allocateClause = - std::get_if( - &clause.u)) { - genAllocateClause(converter, allocateClause->v, allocatorOperands, - allocateOperands); - } else if (std::get_if(&clause.u) || - std::get_if( - &clause.u)) { - // Privatisation clauses are handled elsewhere. - continue; - } else if (std::get_if(&clause.u)) { - // Nothing needs to be done for threads clause. - continue; - } else { - TODO(currentLocation, "OpenMP Block construct clauses"); - } - } - for (const auto &clause : - std::get(endBlockDirective.t).v) { - if (std::get_if(&clause.u)) - nowaitAttr = firOpBuilder.getUnitAttr(); - } + OpenMPMLIRClauses clauses(converter); + clauses.collect(opClauseList); + clauses.collect( + std::get(endBlockDirective.t)); if (blockDirective.v == llvm::omp::OMPD_parallel) { // Create and insert the operation. auto parallelOp = firOpBuilder.create( - currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, - allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), - /*reductions=*/nullptr, procBindKindAttr); + currentLocation, argTy, clauses.ifExpr, clauses.numThreadsExpr, + clauses.allocateOperands, clauses.allocatorOperands, + /*reduction_vars=*/ValueRange(), + /*reductions=*/nullptr, clauses.procBindKind); createBodyOfOp(parallelOp, converter, currentLocation, &opClauseList, /*isCombined=*/false); } else if (blockDirective.v == llvm::omp::OMPD_master) { @@ -320,7 +353,8 @@ createBodyOfOp(masterOp, converter, currentLocation); } else if (blockDirective.v == llvm::omp::OMPD_single) { auto singleOp = firOpBuilder.create( - currentLocation, allocateOperands, allocatorOperands, nowaitAttr); + currentLocation, clauses.allocateOperands, clauses.allocatorOperands, + clauses.nowait); createBodyOfOp(singleOp, converter, currentLocation); } else if (blockDirective.v == llvm::omp::OMPD_ordered) { auto orderedOp = firOpBuilder.create( @@ -345,15 +379,9 @@ std::get>(cd.t).value().ToString(); } - uint64_t hint = 0; const auto &clauseList = std::get(cd.t); - for (const Fortran::parser::OmpClause &clause : clauseList.v) - if (auto hintClause = - std::get_if(&clause.u)) { - const auto *expr = Fortran::semantics::GetExpr(hintClause->v); - hint = *Fortran::evaluate::ToInt64(*expr); - break; - } + OpenMPMLIRClauses clauses(converter); + clauses.collect(clauseList); mlir::omp::CriticalOp criticalOp = [&]() { if (name.empty()) { @@ -365,7 +393,7 @@ auto global = module.lookupSymbol(name); if (!global) global = modBuilder.create( - currentLocation, name, hint); + currentLocation, name, clauses.hint); return firOpBuilder.create( currentLocation, mlir::FlatSymbolRefAttr::get( firOpBuilder.getContext(), global.sym_name())); @@ -393,35 +421,15 @@ const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); - SmallVector reductionVars, allocateOperands, allocatorOperands; - mlir::UnitAttr noWaitClauseOperand; const auto §ionsClauseList = std::get( std::get(sectionsConstruct.t) .t); - for (const Fortran::parser::OmpClause &clause : sectionsClauseList.v) { - - // Reduction Clause - if (std::get_if(&clause.u)) { - TODO(currentLocation, "OMPC_Reduction"); - - // Allocate clause - } else if (const auto &allocateClause = - std::get_if( - &clause.u)) { - genAllocateClause(converter, allocateClause->v, allocatorOperands, - allocateOperands); - } - } - const auto &endSectionsClauseList = - std::get(sectionsConstruct.t); - const auto &clauseList = - std::get(endSectionsClauseList.t); - for (const auto &clause : clauseList.v) { - // Nowait clause - if (std::get_if(&clause.u)) { - noWaitClauseOperand = firOpBuilder.getUnitAttr(); - } - } + const auto &endSectionsClauseList = std::get( + std::get(sectionsConstruct.t) + .t); + OpenMPMLIRClauses clauses(converter); + clauses.collect(sectionsClauseList); + clauses.collect(endSectionsClauseList); llvm::omp::Directive dir = std::get( @@ -434,7 +442,8 @@ if (dir == llvm::omp::Directive::OMPD_parallel_sections) { auto parallelOp = firOpBuilder.create( currentLocation, /*if_expr_var*/ nullptr, /*num_threads_var*/ nullptr, - allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), + clauses.allocateOperands, clauses.allocatorOperands, + /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr, /*proc_bind_val*/ nullptr); createBodyOfOp(parallelOp, converter, currentLocation); auto sectionsOp = firOpBuilder.create( @@ -446,8 +455,8 @@ // Sections Construct } else if (dir == llvm::omp::Directive::OMPD_sections) { auto sectionsOp = firOpBuilder.create( - currentLocation, reductionVars, /*reductions = */ nullptr, - allocateOperands, allocatorOperands, noWaitClauseOperand); + currentLocation, ValueRange(), /*reductions = */ nullptr, + clauses.allocateOperands, clauses.allocatorOperands, clauses.nowait); createBodyOfOp(sectionsOp, converter, currentLocation); } }