diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -202,6 +202,49 @@ } } +static void +genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + std::string name; + const Fortran::parser::OmpCriticalDirective &cd = + std::get(criticalConstruct.t); + if (std::get>(cd.t).has_value()) { + name = + std::get>(cd.t).value().ToString(); + } + + uint64_t hint = 0; + const auto &clauseList = std::get(cd.t); + for (const Fortran::parser::OmpClause &clause : clauseList.v) + if (auto hintClause = + std::get_if(&clause.u)) { + const auto *expr = Fortran::semantics::GetExpr(hintClause->v); + hint = *Fortran::evaluate::ToInt64(*expr); + break; + } + + mlir::omp::CriticalOp criticalOp = [&]() { + if (name.empty()) { + return firOpBuilder.create(currentLocation, + FlatSymbolRefAttr()); + } else { + mlir::ModuleOp module = firOpBuilder.getModule(); + mlir::OpBuilder modBuilder(module.getBodyRegion()); + auto global = module.lookupSymbol(name); + if (!global) + global = modBuilder.create( + currentLocation, name, hint); + return firOpBuilder.create( + currentLocation, mlir::FlatSymbolRefAttr::get( + firOpBuilder.getContext(), global.sym_name())); + } + }(); + createBodyOfOp(criticalOp, firOpBuilder, currentLocation); +} + void Fortran::lower::genOpenMPConstruct( Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -239,7 +282,7 @@ }, [&](const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPCriticalConstruct"); + genOMP(converter, eval, criticalConstruct); }, }, ompConstruct.u); diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -293,8 +293,6 @@ bool Pre(const parser::OpenMPBlockConstruct &); void Post(const parser::OpenMPBlockConstruct &); - bool Pre(const parser::OmpCriticalDirective &x); - bool Pre(const parser::OmpEndCriticalDirective &x); void Post(const parser::OmpBeginBlockDirective &) { GetContext().withinConstruct = true; @@ -313,7 +311,7 @@ bool Pre(const parser::OpenMPSectionsConstruct &); void Post(const parser::OpenMPSectionsConstruct &) { PopContext(); } - bool Pre(const parser::OpenMPCriticalConstruct &); + bool Pre(const parser::OpenMPCriticalConstruct &critical); void Post(const parser::OpenMPCriticalConstruct &) { PopContext(); } bool Pre(const parser::OpenMPDeclareSimdConstruct &x) { @@ -1376,28 +1374,21 @@ return true; } -bool OmpAttributeVisitor::Pre(const parser::OmpCriticalDirective &x) { - const auto &name{std::get>(x.t)}; - if (name) { - ResolveOmpName(*name, Symbol::Flag::OmpCriticalLock); +bool OmpAttributeVisitor::Pre(const parser::OpenMPCriticalConstruct &x) { + const auto &beginCriticalDir{std::get(x.t)}; + const auto &endCriticalDir{std::get(x.t)}; + PushContext(beginCriticalDir.source, llvm::omp::Directive::OMPD_critical); + if (const auto &criticalName{ + std::get>(beginCriticalDir.t)}) { + ResolveOmpName(*criticalName, Symbol::Flag::OmpCriticalLock); } - return true; -} - -bool OmpAttributeVisitor::Pre(const parser::OmpEndCriticalDirective &x) { - const auto &name{std::get>(x.t)}; - if (name) { - ResolveOmpName(*name, Symbol::Flag::OmpCriticalLock); + if (const auto &endCriticalName{ + std::get>(endCriticalDir.t)}) { + ResolveOmpName(*endCriticalName, Symbol::Flag::OmpCriticalLock); } return true; } -bool OmpAttributeVisitor::Pre(const parser::OpenMPCriticalConstruct &x) { - const auto &criticalDir{std::get(x.t)}; - PushContext(criticalDir.source, llvm::omp::Directive::OMPD_critical); - return true; -} - bool OmpAttributeVisitor::Pre(const parser::OpenMPThreadprivate &x) { PushContext(x.source, llvm::omp::Directive::OMPD_threadprivate); const auto &list{std::get(x.t)}; @@ -1515,13 +1506,11 @@ AddToContextObjectWithDSA(*resolvedSymbol, ompFlag); } } - } else if (ompFlagsRequireNewSymbol.test(ompFlag)) { - const auto pair{GetContext().scope.try_emplace( - name.source, Attrs{}, ObjectEntityDetails{})}; + } else if (ompFlag == Symbol::Flag::OmpCriticalLock) { + const auto pair{ + GetContext().scope.try_emplace(name.source, Attrs{}, UnknownDetails{})}; CHECK(pair.second); name.symbol = &pair.first->second.get(); - } else { - DIE("OpenMP Name resolution failed"); } } diff --git a/flang/test/Lower/OpenMP/critical.f90 b/flang/test/Lower/OpenMP/critical.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/critical.f90 @@ -0,0 +1,41 @@ +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefix="FIRDialect" +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefix="LLVMDialect" +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | mlir-translate --mlir-to-llvmir | FileCheck %s --check-prefix="LLVMIR" + +subroutine omp_critical() + use omp_lib + integer :: x, y +!FIRDialect: omp.critical.declare @help hint(contended) +!LLVMDialect: omp.critical.declare @help hint(contended) +!FIRDialect: omp.critical(@help) +!LLVMDialect: omp.critical(@help) +!LLVMIR: call void @__kmpc_critical_with_hint({{.*}}, {{.*}}, {{.*}} @{{.*}}help.var, i32 2) +!$OMP CRITICAL(help) HINT(omp_lock_hint_contended) + x = x + y +!FIRDialect: omp.terminator +!LLVMDialect: omp.terminator +!LLVMIR: call void @__kmpc_end_critical({{.*}}, {{.*}}, {{.*}} @{{.*}}help.var) +!$OMP END CRITICAL(help) + +! Test that the same name can be used again +! Also test with the zero hint expression +!FIRDialect: omp.critical(@help) +!LLVMDialect: omp.critical(@help) +!LLVMIR: call void @__kmpc_critical_with_hint({{.*}}, {{.*}}, {{.*}} @{{.*}}help.var, i32 2) +!$OMP CRITICAL(help) HINT(omp_lock_hint_none) + x = x - y +!FIRDialect: omp.terminator +!LLVMDialect: omp.terminator +!LLVMIR: call void @__kmpc_end_critical({{.*}}, {{.*}}, {{.*}} @{{.*}}help.var) +!$OMP END CRITICAL(help) + +!FIRDialect: omp.critical +!LLVMDialect: omp.critical +!LLVMIR: call void @__kmpc_critical({{.*}}, {{.*}}, {{.*}} @{{.*}}_.var) +!$OMP CRITICAL + y = x + y +!FIRDialect: omp.terminator +!LLVMDialect: omp.terminator +!LLVMIR: call void @__kmpc_end_critical({{.*}}, {{.*}}, {{.*}} @{{.*}}_.var) +!$OMP END CRITICAL +end subroutine omp_critical