diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h --- a/flang/include/flang/Lower/AbstractConverter.h +++ b/flang/include/flang/Lower/AbstractConverter.h @@ -16,6 +16,7 @@ #include "flang/Common/Fortran.h" #include "flang/Lower/PFTDefs.h" #include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Semantics/symbol.h" #include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/ArrayRef.h" @@ -99,6 +100,12 @@ virtual void copyHostAssociateVar(const Fortran::semantics::Symbol &sym) = 0; + /// Collect the set of symbols in an eval region + virtual void collectSymbolSet( + pft::Evaluation &eval, + llvm::SetVector &symbolSet, + Fortran::semantics::Symbol::Flag flag) = 0; + //===--------------------------------------------------------------------===// // Expressions //===--------------------------------------------------------------------===// diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h --- a/flang/include/flang/Lower/PFTBuilder.h +++ b/flang/include/flang/Lower/PFTBuilder.h @@ -778,6 +778,11 @@ void visitAllSymbols(const FunctionLikeUnit &funit, std::function callBack); +/// Call the provided \p callBack on all symbols that are referenced inside \p +/// eval region. +void visitAllSymbols(const Evaluation &eval, + std::function callBack); + } // namespace Fortran::lower::pft namespace Fortran::lower { diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -457,6 +457,17 @@ // Utility methods //===--------------------------------------------------------------------===// + void collectSymbolSet( + Fortran::lower::pft::Evaluation &eval, + llvm::SetVector &symbolSet, + Fortran::semantics::Symbol::Flag flag) override final { + auto addToList = [&](const Fortran::semantics::Symbol &sym) { + if (sym.test(flag)) + symbolSet.insert(&sym); + }; + Fortran::lower::pft::visitAllSymbols(eval, addToList); + } + mlir::Location getCurrentLocation() override final { return toLocation(); } /// Generate a dummy location. diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -60,21 +60,71 @@ } } +static void genDefaultClause( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpDefaultClause::Type &ompDefaultClauseType) { + llvm::SetVector symbols; + if (ompDefaultClauseType == Fortran::parser::OmpDefaultClause::Type::Private) + converter.collectSymbolSet(eval, symbols, + Fortran::semantics::Symbol::Flag::OmpPrivate); + else if (ompDefaultClauseType == + Fortran::parser::OmpDefaultClause::Type::Firstprivate) + converter.collectSymbolSet( + eval, symbols, Fortran::semantics::Symbol::Flag::OmpFirstPrivate); + else + return; + for (const Fortran::semantics::Symbol *sym : symbols) { + if (ompDefaultClauseType == + Fortran::parser::OmpDefaultClause::Type::Private) { + bool success = converter.createHostAssociateVarClone(*sym); + (void)success; + assert(success && "Privatization failed due to existing binding"); + } else if (ompDefaultClauseType == + Fortran::parser::OmpDefaultClause::Type::Firstprivate) { + converter.copyHostAssociateVar(*sym); + } + } +} + static void privatizeVars(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpClauseList &opClauseList) { + const Fortran::parser::OmpClauseList &opClauseList, + Fortran::lower::pft::Evaluation &eval) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); auto insPt = firOpBuilder.saveInsertionPoint(); firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); + Fortran::parser::OmpDefaultClause::Type ompDefaultClauseType; for (const Fortran::parser::OmpClause &clause : opClauseList.v) { - if (const auto &privateClause = - std::get_if(&clause.u)) { - createPrivateVarSyms(converter, privateClause); - } else if (const auto &firstPrivateClause = + if (const auto &defaultClause = + std::get_if(&clause.u)) { + const auto &ompDefaultClause{defaultClause->v}; + ompDefaultClauseType = ompDefaultClause.v; + } + } + + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { + if (const auto &privateClause{ + std::get_if(&clause.u)}) { + if (ompDefaultClauseType != + Fortran::parser::OmpDefaultClause::Type:: + Private) // if default clause is 'private', symbols with + // OmpPrivate shall be dealt with while dealing with + // default clause; hence skip privitization here. + createPrivateVarSyms(converter, privateClause); + + } else if (const auto &firstPrivateClause{ std::get_if( - &clause.u)) { - createPrivateVarSyms(converter, firstPrivateClause); + &clause.u)}) { + if (ompDefaultClauseType != + Fortran::parser::OmpDefaultClause::Type:: + Firstprivate) // if default clause is 'firstprivate', symbols with + // OmpFirstPrivate shall be dealt with while dealing + // with default clause; hence skip privitization + // here + createPrivateVarSyms(converter, firstPrivateClause); } } + genDefaultClause(converter, eval, ompDefaultClauseType); firOpBuilder.restoreInsertionPoint(insPt); } @@ -111,7 +161,7 @@ template static void createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter, - mlir::Location &loc, + mlir::Location &loc, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OmpClauseList *clauses = nullptr, bool outerCombined = false) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -124,7 +174,7 @@ firOpBuilder.setInsertionPointToStart(&block); // Handle privatization. Do not privatize if this is the outer operation. if (clauses && !outerCombined) - privatizeVars(converter, *clauses); + privatizeVars(converter, *clauses, eval); } static void genOMP(Fortran::lower::AbstractConverter &converter, @@ -295,6 +345,19 @@ } else if (std::get_if(&clause.u)) { // Nothing needs to be done for threads clause. continue; + } else if (const auto &defaultClause = + std::get_if( + &clause.u)) { + // TODO: Remove this once the else TODO "OpenMP Block construct clauses" + // is removed Default clause for block constructs is handled while + // creating body of operation. Hence no additional handling required here + continue; + } else if (const auto &sharedClause = + std::get_if(&clause.u)) { + // TODO: Remove this once the else TODO "OpenMP Block construct clauses" + // is removed Shared clause for block constructs is handled required to + // decide what privitization to do, in presence of OmpDefault clause. + continue; } else { TODO(currentLocation, "OpenMP Block construct clauses"); } @@ -313,19 +376,21 @@ allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr, procBindKindAttr); createBodyOfOp(parallelOp, converter, currentLocation, - &opClauseList, /*isCombined=*/false); + eval, &opClauseList, + /*isCombined=*/false); } else if (blockDirective.v == llvm::omp::OMPD_master) { auto masterOp = firOpBuilder.create(currentLocation, argTy); - createBodyOfOp(masterOp, converter, currentLocation); + createBodyOfOp(masterOp, converter, currentLocation, eval); } else if (blockDirective.v == llvm::omp::OMPD_single) { auto singleOp = firOpBuilder.create( currentLocation, allocateOperands, allocatorOperands, nowaitAttr); - createBodyOfOp(singleOp, converter, currentLocation); + createBodyOfOp(singleOp, converter, currentLocation, eval); } else if (blockDirective.v == llvm::omp::OMPD_ordered) { auto orderedOp = firOpBuilder.create( currentLocation, /*simd=*/nullptr); - createBodyOfOp(orderedOp, converter, currentLocation); + createBodyOfOp(orderedOp, converter, currentLocation, + eval); } else { TODO(converter.getCurrentLocation(), "Unhandled block directive"); } @@ -371,7 +436,7 @@ firOpBuilder.getContext(), global.sym_name())); } }(); - createBodyOfOp(criticalOp, converter, currentLocation); + createBodyOfOp(criticalOp, converter, currentLocation, eval); } static void @@ -383,7 +448,7 @@ auto currentLocation = converter.getCurrentLocation(); mlir::omp::SectionOp sectionOp = firOpBuilder.create(currentLocation); - createBodyOfOp(sectionOp, converter, currentLocation); + createBodyOfOp(sectionOp, converter, currentLocation, eval); } // TODO: Add support for reduction @@ -436,19 +501,20 @@ currentLocation, /*if_expr_var*/ nullptr, /*num_threads_var*/ nullptr, allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr, /*proc_bind_val*/ nullptr); - createBodyOfOp(parallelOp, converter, currentLocation); + createBodyOfOp(parallelOp, converter, currentLocation, eval); auto sectionsOp = firOpBuilder.create( currentLocation, /*reduction_vars*/ ValueRange(), /*reductions=*/nullptr, /*allocate_vars*/ ValueRange(), /*allocators_vars*/ ValueRange(), /*nowait=*/nullptr); - createBodyOfOp(sectionsOp, converter, currentLocation); + createBodyOfOp(sectionsOp, converter, currentLocation, eval); // Sections Construct } else if (dir == llvm::omp::Directive::OMPD_sections) { auto sectionsOp = firOpBuilder.create( currentLocation, reductionVars, /*reductions = */ nullptr, allocateOperands, allocatorOperands, noWaitClauseOperand); - createBodyOfOp(sectionsOp, converter, currentLocation); + createBodyOfOp(sectionsOp, converter, currentLocation, + eval); } } diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp --- a/flang/lib/Lower/PFTBuilder.cpp +++ b/flang/lib/Lower/PFTBuilder.cpp @@ -1813,3 +1813,12 @@ parser::Walk(functionParserNode, visitor); }); } + +void Fortran::lower::pft::visitAllSymbols( + const Fortran::lower::pft::Evaluation &eval, + const std::function callBack) { + SymbolVisitor visitor{callBack}; + eval.visit([&](const auto &functionParserNode) { + parser::Walk(functionParserNode, visitor); + }); +} diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -1485,6 +1485,21 @@ } } } + if (GetContext().defaultDSA == semantics::Symbol::Flag::OmpPrivate && + !HasDataSharingAttributeObject(*name.symbol)) { + name.symbol = DeclarePrivateAccessEntity( + *name.symbol, semantics::Symbol::Flag::OmpPrivate, currScope()); + AddToContextObjectWithDSA( + *name.symbol, semantics::Symbol::Flag::OmpPrivate); + + } else if (GetContext().defaultDSA == + semantics::Symbol::Flag::OmpFirstPrivate && + !HasDataSharingAttributeObject(*name.symbol)) { + name.symbol = DeclarePrivateAccessEntity( + *name.symbol, semantics::Symbol::Flag::OmpFirstPrivate, currScope()); + AddToContextObjectWithDSA( + *name.symbol, semantics::Symbol::Flag::OmpPrivate); + } } // within OpenMP construct } @@ -1585,6 +1600,8 @@ CheckMultipleAppearances(*name, *symbol, ompFlag); } if (privateDataSharingAttributeFlags.test(ompFlag)) { + AddDataSharingAttributeObject( + common::Reference(*symbol)); CheckObjectInNamelist(*name, *symbol, ompFlag); } diff --git a/flang/test/Lower/OpenMP/default-clause.f90 b/flang/test/Lower/OpenMP/default-clause.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/default-clause.f90 @@ -0,0 +1,61 @@ +! This test checks lowering of OpenMP parallel directive +! with `DEFAULT` clause present. + +! RUN: bbc -fopenmp -emit-fir %s -o - | \ +! RUN: FileCheck %s --check-prefix=FIRDialect + +!FIRDialect: func @_QQmain() { +!FIRDialect: %[[W:.*]] = fir.alloca i32 {bindc_name = "w", uniq_name = "_QFEw"} +!FIRDialect: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"} +!FIRDialect: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"} +!FIRDialect: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFEz"} +!FIRDialect: omp.parallel { +!FIRDialect: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFEx"} +!FIRDialect: %{{.*}} = fir.load %[[X]] : !fir.ref +!FIRDialect: fir.store %{{.*}} to %[[PRIVATE_X]] : !fir.ref +!FIRDialect: %[[PRIVATE_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFEy"} +!FIRDialect: %[[PRIVATE_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFEw"} +!FIRDialect: %{{.*}} = arith.constant 2 : i32 +!FIRDialect: %{{.*}} = fir.load %[[PRIVATE_Y]] : !fir.ref +!FIRDialect: %{{.*}} = arith.muli %{{.*}}, %{{.*}} : i32 +!FIRDialect: fir.store %{{.*}} to %[[PRIVATE_X]] : !fir.ref +!FIRDialect: %{{.*}} = fir.load %[[PRIVATE_W]] : !fir.ref +!FIRDialect: %{{.*}} = arith.constant 45 : i32 +!FIRDialect: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i32 +!FIRDialect: fir.store %{{.*}} to %[[Z]] : !fir.ref +!FIRDialect: omp.terminator +!FIRDialect: } +!FIRDialect: omp.parallel { +!FIRDialect: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFEx"} +!FIRDialect: %[[PRIVATE_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFEy"} +!FIRDialect: %{{.*}} = fir.load %[[Y]] : !fir.ref +!FIRDialect: fir.store %{{.*}} to %[[PRIVATE_Y]] : !fir.ref +!FIRDialect: %[[PRIVATE_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFEw"} +!FIRDialect: %{{.*}} = fir.load %[[W]] : !fir.ref +!FIRDialect: fir.store %{{.*}} to %[[PRIVATE_W]] : !fir.ref +!FIRDialect: %{{.*}} = arith.constant 2 : i32 +!FIRDialect: %{{.*}} = fir.load %[[PRIVATE_Y]] : !fir.ref +!FIRDialect: %{{.*}} = arith.muli %{{.*}}, %{{.*}} : i32 +!FIRDialect: fir.store %{{.*}} to %[[PRIVATE_X]] : !fir.ref +!FIRDialect: %{{.*}} = fir.load %[[PRIVATE_W]] : !fir.ref +!FIRDialect: %{{.*}} = arith.constant 45 : i32 +!FIRDialect: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i32 +!FIRDialect: fir.store %{{.*}} to %[[Z]] : !fir.ref +!FIRDialect: omp.terminator +!FIRDialect: } +!FIRDialect: return +!FIRDialect: } + +program default_clause_lowering + integer :: x, y, z, w + + !$omp parallel default(private) firstprivate(x) shared(z) + x = y * 2 + z = w + 45 + !$omp end parallel + + !$omp parallel default(firstprivate) private(x) shared(z) + x = y * 2 + z = w + 45 + !$omp end parallel +end program default_clause_lowering