diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -21,6 +21,7 @@ struct OpenMPDeclarativeConstruct; struct OmpEndLoopDirective; struct OmpClauseList; +struct AssignmentStmt; } // namespace parser namespace lower { @@ -40,6 +41,9 @@ void genThreadprivateOp(AbstractConverter &, const pft::Variable &); void genOpenMPReduction(AbstractConverter &, const Fortran::parser::OmpClauseList &clauseList); +void genOpenMPReduction(AbstractConverter &, + const Fortran::parser::AssignmentStmt &stmt); +bool isOpenMPReduction(const Fortran::parser::AssignmentStmt &stmt); } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -2469,7 +2469,16 @@ } void genFIR(const Fortran::parser::AssignmentStmt &stmt) { - genAssignment(*stmt.typedAssignment->v); + if (Fortran::lower::isOpenMPReduction(stmt)) + // TODO: Implement the genOpenMPReduction function. The current + // prototype uses the parse tree AssignmentStmt node, but regular + // FIR lowering uses the typedAssignment associated with it. It + // will probably be better to swich to the regular FIR lowering. To + // begin with, we will only need to implement a subset of the + // functionality in genAssignment. + Fortran::lower::genOpenMPReduction(*this, stmt); + else + genAssignment(*stmt.typedAssignment->v); } void genFIR(const Fortran::parser::SyncAllStmt &stmt) { diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -1669,3 +1669,65 @@ } } } + +void Fortran::lower::genOpenMPReduction( + Fortran::lower::AbstractConverter &, + const Fortran::parser::AssignmentStmt &stmt) {} + +bool Fortran::lower::isOpenMPReduction( + const Fortran::parser::AssignmentStmt &stmt) { + auto getVarSym = [](const Fortran::parser::Designator &d) { + Fortran::semantics::Symbol *varSym = nullptr; + if (const auto *dataRef = std::get_if(&d.u)) { + if (const auto *name = std::get_if(&dataRef->u)) { + varSym = name->symbol; + } + } + return varSym; + }; + + // TODO : Find the parent operation which has the reduction clause interface + // and check whether the provided expression type and variable matches any of + // the reduction symbol and reduction variable. + auto isMatchingReduction = [firOpBuilder](expr-type, var) { + }; + + bool isReductionAssign = false; + const auto &var{std::get(stmt.t)}; + Fortran::semantics::Symbol *varSym = nullptr; + if (auto varDesignatorIndirection = std::get_if< + Fortran::common::Indirection>(&var.u)) { + varSym = getVarSym(varDesignatorIndirection->value()); + if (varSym && varSym->test(Fortran::semantics::Symbol::Flag::OmpReduction)) + isReductionAssign = true; + } + + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + bool reductionVarInExpr = false; + bool matchingReductionInOp = false; + const auto &expr{std::get(stmt.t)}; + Fortran::common::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::Expr::Add &a) { + Fortran::semantics::Symbol *lhsSym = nullptr; + if (auto lhsVarIndirection = std::get_if< + Fortran::common::Indirection>( + &std::get<0>(a.t).value().u)) + lhsSym = getVarSym(lhsVarIndirection->value()); + if (lhsSym == varSym) + reductionVarInExpr = true; + Fortran::semantics::Symbol *rhsSym = nullptr; + if (auto rhsVarIndirection = std::get_if< + Fortran::common::Indirection>( + &std::get<1>(a.t).value().u)) + rhsSym = getVarSym(rhsVarIndirection->value()); + if (rhsSym == varSym) + reductionVarInExpr = true; + + matchingReductionInOp = isMatchingReduction(a, varSym); + }, + [&](const auto &x) {}, + }, + expr.u); + return isReductionAssign && reductionVarInExpr && matchingReductionInOp; +}