diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -61,11 +61,66 @@ return sym; } -template -static void privatizeSymbol( - Op &op, Fortran::lower::AbstractConverter &converter, +class dataSharingProcessor { + bool hasLastPrivateOp; + mlir::OpBuilder::InsertPoint lastPrivIP; + mlir::OpBuilder::InsertPoint insPt; + // Symbols in private, firstprivate, and/or lastprivate clauses. + llvm::SetVector privatizedSymbols; + llvm::SetVector defaultSymbols; + llvm::SetVector symbolsInNestedRegions; + llvm::SetVector symbolsInParentRegions; + mlir::Operation *op; + Fortran::lower::AbstractConverter &converter; + fir::FirOpBuilder &firOpBuilder; + const Fortran::parser::OmpClauseList &opClauseList; + Fortran::lower::pft::Evaluation &eval; + + void privatizeSymbol( + const Fortran::semantics::Symbol *sym, + [[maybe_unused]] mlir::OpBuilder::InsertPoint *lastPrivIP = nullptr); + bool needBarrier(); + void collectSymbols(Fortran::semantics::Symbol::Flag flag); + void collectOmpObjectListSymbol( + const Fortran::parser::OmpObjectList &ompObjectList, + llvm::SetVector &symbolSet); + void collectSymbolsForPrivatization(); + void insertBarrier(); + void collectDefaultSymbols(); + void privatize(); + void defaultPrivatize(); + void insertLastPrivateCompare(mlir::Operation *op); + +public: + dataSharingProcessor(mlir::Operation *op, + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpClauseList &opClauseList, + Fortran::lower::pft::Evaluation &eval) + : hasLastPrivateOp(false), op(op), converter(converter), + firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList), + eval(eval) {} + bool process(); +}; + +bool dataSharingProcessor::process() { + insPt = firOpBuilder.saveInsertionPoint(); + collectSymbolsForPrivatization(); + insertLastPrivateCompare(op); + if (mlir::isa(op)) + firOpBuilder.setInsertionPointToStart(&op->getRegion(0).back()); + else + firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); + privatize(); + collectDefaultSymbols(); + defaultPrivatize(); + insertBarrier(); + firOpBuilder.restoreInsertionPoint(insPt); + return hasLastPrivateOp; +} + +void dataSharingProcessor::privatizeSymbol( const Fortran::semantics::Symbol *sym, - [[maybe_unused]] mlir::OpBuilder::InsertPoint *lastPrivIP = nullptr) { + [[maybe_unused]] mlir::OpBuilder::InsertPoint *lastPrivIP) { // Privatization for symbols which are pre-determined (like loop index // variables) happen separately, for everything else privatize here. if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined)) @@ -78,7 +133,7 @@ mlir::OpBuilder::InsertPoint firstPrivIP, insPt; if (mlir::isa(op)) { insPt = firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPointToStart(&op.getRegion().front()); + firOpBuilder.setInsertionPointToStart(&op->getRegion(0).front()); firstPrivIP = firOpBuilder.saveInsertionPoint(); } converter.copyHostAssociateVar(*sym, &firstPrivIP); @@ -89,26 +144,17 @@ converter.copyHostAssociateVar(*sym, lastPrivIP); } -template -static bool privatizeVars(Op &op, Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpClauseList &opClauseList, - Fortran::lower::pft::Evaluation &eval) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto insPt = firOpBuilder.saveInsertionPoint(); - // Symbols in private, firstprivate, and/or lastprivate clauses. - llvm::SetVector privatizedSymbols; - auto collectOmpObjectListSymbol = - [&](const Fortran::parser::OmpObjectList &ompObjectList, - llvm::SetVector &symbolSet) { - for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - symbolSet.insert(sym); - } - }; - // We need just one ICmpOp for multiple LastPrivate clauses. - mlir::arith::CmpIOp cmpOp; - mlir::OpBuilder::InsertPoint lastPrivIP; - bool hasLastPrivateOp = false; +void dataSharingProcessor::collectOmpObjectListSymbol( + const Fortran::parser::OmpObjectList &ompObjectList, + llvm::SetVector &symbolSet) { + for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + symbolSet.insert(sym); + } +} + +void dataSharingProcessor::collectSymbolsForPrivatization() { + bool hasCollapse = false; for (const Fortran::parser::OmpClause &clause : opClauseList.v) { if (const auto &privateClause = std::get_if(&clause.u)) { @@ -120,8 +166,48 @@ } else if (const auto &lastPrivateClause = std::get_if( &clause.u)) { + collectOmpObjectListSymbol(lastPrivateClause->v, privatizedSymbols); + hasLastPrivateOp = true; + } else if (const auto &collapseClause = + std::get_if( + &clause.u)) { + hasCollapse = true; + } + } + + if (hasCollapse && hasLastPrivateOp) + TODO(converter.getCurrentLocation(), "Collapse clause with lastprivate"); +} + +bool dataSharingProcessor ::needBarrier() { + for (auto sym : privatizedSymbols) { + if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate) && + sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate)) + return true; + } + return false; +} + +void dataSharingProcessor ::insertBarrier() { + // Emit implicit barrier to synchronize threads and avoid data races on + // initialization of firstprivate variables and post-update of lastprivate + // variables. + // FIXME: Emit barrier for lastprivate clause when 'sections' directive has + // 'nowait' clause. Otherwise, emit barrier when 'sections' directive has + // both firstprivate and lastprivate clause. + // Emit implicit barrier for linear clause. Maybe on somewhere else. + if (needBarrier()) + firOpBuilder.create(converter.getCurrentLocation()); +} + +void dataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) { + mlir::arith::CmpIOp cmpOp; + bool cmpCreated = false; + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { + if (const auto &lastPrivateClause = + std::get_if(&clause.u)) { // TODO: Add lastprivate support for simd construct - if (std::is_same_v) { + if (mlir::isa(op)) { if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) { // For `omp.sections`, lastprivatized variables occur in // lexically final `omp.section` operation. The following FIR @@ -146,11 +232,10 @@ // lastprivate FIR can reside. Later canonicalizations // will optimize away this operation. - omp::SectionOp *sectionOp = dyn_cast(&op); mlir::scf::IfOp ifOp = firOpBuilder.create( - sectionOp->getLoc(), + op->getLoc(), firOpBuilder.createIntegerConstant( - sectionOp->getLoc(), firOpBuilder.getIntegerType(1), 0x1), + op->getLoc(), firOpBuilder.getIntegerType(1), 0x1), /*else*/ false); firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front()); @@ -181,86 +266,63 @@ firOpBuilder.setInsertionPoint(ifOp); insPt = firOpBuilder.saveInsertionPoint(); } - } else if (std::is_same_v) { - omp::WsLoopOp *wsLoopOp = dyn_cast(&op); - mlir::Operation *lastOper = - wsLoopOp->getRegion().back().getTerminator(); + } else if (mlir::isa(op)) { + mlir::Operation *lastOper = op->getRegion(0).back().getTerminator(); firOpBuilder.setInsertionPoint(lastOper); - // Our goal here is to introduce the following control flow - // just before exiting the worksharing loop. - // Say our wsloop is as follows: + // Update the original variable just before exiting the worksharing + // loop. Conversion as follows: // - // omp.wsloop { - // ... - // store - // omp.yield - // } + // omp.wsloop { + // omp.wsloop { ... + // ... store + // store ===> %cmp = llvm.icmp "eq" %iv %ub + // omp.yield scf.if %cmp { + // } ^%lpv_update_blk: + // } + // omp.yield + // } // - // We want to convert it to the following: - // - // omp.wsloop { - // ... - // store - // %cmp = llvm.icmp "eq" %iv %ub - // scf.if %cmp { - // ^%lpv_update_blk: - // } - // omp.yield - // } - - // TODO: The following will not work when there is collapse present. - // Have to modify this in future. - for (const Fortran::parser::OmpClause &clause : opClauseList.v) - if (const auto &collapseClause = - std::get_if(&clause.u)) - TODO(converter.getCurrentLocation(), - "Collapse clause with lastprivate"); + // Only generate the compare once in presence of multiple LastPrivate // clauses. - if (!hasLastPrivateOp) { + if (!cmpCreated) { cmpOp = firOpBuilder.create( - wsLoopOp->getLoc(), mlir::arith::CmpIPredicate::eq, - wsLoopOp->getRegion().front().getArguments()[0], - wsLoopOp->getUpperBound()[0]); + op->getLoc(), mlir::arith::CmpIPredicate::eq, + op->getRegion(0).front().getArguments()[0], + mlir::dyn_cast(op).getUpperBound()[0]); } mlir::scf::IfOp ifOp = firOpBuilder.create( - wsLoopOp->getLoc(), cmpOp, /*else*/ false); + op->getLoc(), cmpOp, /*else*/ false); firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front()); lastPrivIP = firOpBuilder.saveInsertionPoint(); } else { TODO(converter.getCurrentLocation(), - "lastprivate clause in constructs other than worksharing-loop"); + "lastprivate clause in constructs other than " + "simd/worksharing-loop"); } - collectOmpObjectListSymbol(lastPrivateClause->v, privatizedSymbols); - hasLastPrivateOp = true; } } +} - // Symbols in regions with default(private/firstprivate) clause. - // FIXME: Collect the symbols with private/firstprivate flag in the region of - // the construct with default(private/firstprivate) clause excluding the - // symbols with the same private/firstprivate flag in the inner nested - // regions. - llvm::SetVector defaultSymbols; - llvm::SetVector symbolsInNestedRegions; - llvm::SetVector symbolsInParentRegions; - auto collectSymbols = [&](Fortran::semantics::Symbol::Flag flag) { - converter.collectSymbolSet(eval, defaultSymbols, flag, - /*collectSymbols=*/true, - /*collectHostAssociatedSymbols=*/true); - for (auto &e : eval.getNestedEvaluations()) { - if (e.hasNestedEvaluations()) - converter.collectSymbolSet(e, symbolsInNestedRegions, flag, - /*collectSymbols=*/true, - /*collectHostAssociatedSymbols=*/false); - else - converter.collectSymbolSet(e, symbolsInParentRegions, flag, - /*collectSymbols=*/false, - /*collectHostAssociatedSymbols=*/true); - } - }; +void dataSharingProcessor::collectSymbols( + Fortran::semantics::Symbol::Flag flag) { + converter.collectSymbolSet(eval, defaultSymbols, flag, + /*collectSymbols=*/true, + /*collectHostAssociatedSymbols=*/true); + for (auto &e : eval.getNestedEvaluations()) { + if (e.hasNestedEvaluations()) + converter.collectSymbolSet(e, symbolsInNestedRegions, flag, + /*collectSymbols=*/true, + /*collectHostAssociatedSymbols=*/false); + else + converter.collectSymbolSet(e, symbolsInParentRegions, flag, + /*collectSymbols=*/false, + /*collectHostAssociatedSymbols=*/true); + } +} +void dataSharingProcessor::collectDefaultSymbols() { for (const Fortran::parser::OmpClause &clause : opClauseList.v) { if (const auto &defaultClause = std::get_if(&clause.u)) { @@ -272,37 +334,19 @@ collectSymbols(Fortran::semantics::Symbol::Flag::OmpFirstPrivate); } } +} - bool needBarrier = false; - if (mlir::isa(op)) - firOpBuilder.setInsertionPointToStart(&op.getRegion().back()); - else - firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); - for (auto sym : privatizedSymbols) { - privatizeSymbol(op, converter, sym, &lastPrivIP); - if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate) && - sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate)) - needBarrier = true; - } +void dataSharingProcessor::privatize() { + for (auto sym : privatizedSymbols) + privatizeSymbol(sym, &lastPrivIP); +} +void dataSharingProcessor::defaultPrivatize() { for (auto sym : defaultSymbols) if (!symbolsInNestedRegions.contains(sym) && !symbolsInParentRegions.contains(sym) && !privatizedSymbols.contains(sym)) - privatizeSymbol(op, converter, sym); - - // Emit implicit barrier to synchronize threads and avoid data races on - // initialization of firstprivate variables and post-update of lastprivate - // variables. - // FIXME: Emit barrier for lastprivate clause when 'sections' directive has - // 'nowait' clause. Otherwise, emit barrier when 'sections' directive has - // both firstprivate and lastprivate clause. - // Emit implicit barrier for linear clause. Maybe on somewhere else. - if (needBarrier) - firOpBuilder.create(converter.getCurrentLocation()); - - firOpBuilder.restoreInsertionPoint(insPt); - return hasLastPrivateOp; + privatizeSymbol(sym); } /// The COMMON block is a global structure. \p commonValue is the base address @@ -604,7 +648,8 @@ // Handle privatization. Do not privatize if this is the outer operation. if (clauses && !outerCombined) { - bool lastPrivateOp = privatizeVars(op, converter, *clauses, eval); + dataSharingProcessor dsp(op, converter, *clauses, eval); + bool lastPrivateOp = dsp.process(); // LastPrivatization, due to introduction of // new control flow, changes the insertion point, // thus restore it.