diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -1255,6 +1255,36 @@ } } +static bool checkForSingleVariableOnRHS( + const Fortran::parser::AssignmentStmt &assignmentStmt) { + // Check if the assignment statement has a single variable on the RHS + const Fortran::parser::Expr &expr{ + std::get(assignmentStmt.t)}; + const Fortran::common::Indirection *designator = + std::get_if>( + &expr.u); + const Fortran::parser::Name *name = + designator ? getDesignatorNameIfDataRef(designator->value()) : nullptr; + return name != nullptr; +} + +static bool +checkForSymbolMatch(const Fortran::parser::AssignmentStmt &assignmentStmt) { + // Check if the symbol on the LHS of the assignment statement is present in + // the RHS expression + const auto &var{std::get(assignmentStmt.t)}; + const auto &expr{std::get(assignmentStmt.t)}; + const auto *e{Fortran::semantics::GetExpr(expr)}; + const auto *v{Fortran::semantics::GetExpr(var)}; + const Fortran::semantics::Symbol &varSymbol = + Fortran::evaluate::GetSymbolVector(*v).front(); + for (const Fortran::semantics::Symbol &symbol : + Fortran::evaluate::GetSymbolVector(*e)) + if (varSymbol == symbol) + return true; + return false; +} + static void genOmpAtomicHintAndMemoryOrderClauses( Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpAtomicClauseList &clauseList, @@ -1293,6 +1323,73 @@ } } +static void genOmpAtomicCaptureStatement( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::Variable &assignmentStmtVariable, + const Fortran::parser::Expr &assignmentStmtExpr, + const Fortran::parser::OmpAtomicClauseList *leftHandClauseList, + const Fortran::parser::OmpAtomicClauseList *rightHandClauseList) { + // Generate `omp.atomic.read` operation for atomic assigment statements + auto &firOpBuilder = converter.getFirOpBuilder(); + auto currentLocation = converter.getCurrentLocation(); + // Get the address of atomic read operands. + + Fortran::lower::StatementContext stmtCtx; + mlir::Value from_address = fir::getBase(converter.genExprAddr( + *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); + mlir::Value to_address = fir::getBase(converter.genExprAddr( + *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); + // If no hint clause is specified, the effect is as if + // hint(omp_sync_hint_none) had been specified. + mlir::IntegerAttr hint = nullptr; + mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; + if (leftHandClauseList) + genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, + memory_order); + if (rightHandClauseList) + genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, + memory_order); + firOpBuilder.create(currentLocation, from_address, + to_address, hint, memory_order); +} + +static void genOmpAtomicWriteStatement( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::Variable &assignmentStmtVariable, + const Fortran::parser::Expr &assignmentStmtExpr, + const Fortran::parser::OmpAtomicClauseList *leftHandClauseList, + const Fortran::parser::OmpAtomicClauseList *rightHandClauseList, + mlir::Value *evaluatedExprValue = nullptr) { + // Generate `omp.atomic.write` operation for atomic assignment statements + auto &firOpBuilder = converter.getFirOpBuilder(); + auto currentLocation = converter.getCurrentLocation(); + // Get the value and address of atomic write operands. + Fortran::lower::StatementContext stmtCtx; + mlir::Value value; + if (!evaluatedExprValue) + value = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); + else + // A pre-computed expression evaluation is provided + value = *evaluatedExprValue; + mlir::Value address = fir::getBase(converter.genExprAddr( + *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); + // If no hint clause is specified, the effect is as if + // hint(omp_sync_hint_none) had been specified. + mlir::IntegerAttr hint = nullptr; + mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; + if (leftHandClauseList) + genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, + memory_order); + if (rightHandClauseList) + genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, + memory_order); + firOpBuilder.create(currentLocation, address, value, + hint, memory_order); +} + static void genOmpAtomicUpdateStatement( Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -1359,9 +1456,6 @@ genOmpAtomicWrite(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OmpAtomicWrite &atomicWrite) { - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); - // Get the value and address of atomic write operands. const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = std::get<2>(atomicWrite.t); const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = @@ -1370,29 +1464,14 @@ std::get(std::get<3>(atomicWrite.t).statement.t); const auto &assignmentStmtVariable = std::get( std::get<3>(atomicWrite.t).statement.t); - Fortran::lower::StatementContext stmtCtx; - mlir::Value value = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); - mlir::Value address = fir::getBase(converter.genExprAddr( - *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); - // If no hint clause is specified, the effect is as if - // hint(omp_sync_hint_none) had been specified. - mlir::IntegerAttr hint = nullptr; - mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; - genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, - memory_order); - genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, - memory_order); - firOpBuilder.create(currentLocation, address, value, - hint, memory_order); + genOmpAtomicWriteStatement(converter, eval, assignmentStmtVariable, + assignmentStmtExpr, &leftHandClauseList, + &rightHandClauseList); } - static void genOmpAtomicRead(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OmpAtomicRead &atomicRead) { - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); - // Get the address of atomic read operands. + const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = std::get<2>(atomicRead.t); const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = @@ -1401,21 +1480,9 @@ std::get(std::get<3>(atomicRead.t).statement.t); const auto &assignmentStmtVariable = std::get( std::get<3>(atomicRead.t).statement.t); - Fortran::lower::StatementContext stmtCtx; - mlir::Value from_address = fir::getBase(converter.genExprAddr( - *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); - mlir::Value to_address = fir::getBase(converter.genExprAddr( - *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); - // If no hint clause is specified, the effect is as if - // hint(omp_sync_hint_none) had been specified. - mlir::IntegerAttr hint = nullptr; - mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; - genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, - memory_order); - genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, - memory_order); - firOpBuilder.create(currentLocation, from_address, - to_address, hint, memory_order); + genOmpAtomicCaptureStatement(converter, eval, assignmentStmtVariable, + assignmentStmtExpr, &leftHandClauseList, + &rightHandClauseList); } static void @@ -1455,6 +1522,81 @@ assignmentStmtExpr, &atomicClauseList, nullptr); } +static void +genOmpAtomicCapture(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpAtomicCapture &atomicCapture) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + + mlir::IntegerAttr hint = nullptr; + mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; + const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = + std::get<2>(atomicCapture.t); + const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = + std::get<0>(atomicCapture.t); + genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, + memory_order); + genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, + memory_order); + + const Fortran::parser::AssignmentStmt &stmt1 = + std::get<3>(atomicCapture.t).v.statement; + const auto &stmt1Var{std::get(stmt1.t)}; + const auto &stmt1Expr{std::get(stmt1.t)}; + const Fortran::parser::AssignmentStmt &stmt2 = + std::get<4>(atomicCapture.t).v.statement; + const auto &stmt2Var{std::get(stmt2.t)}; + const auto &stmt2Expr{std::get(stmt2.t)}; + + // Pre-evaluate RHS expression to be used in `omp.atomic.write` since + // it is not desirable to have RHS expression evaluation inside + // `omp.atomic.capture` + Fortran::lower::StatementContext stmtCtx; + mlir::Value evaluatedExprValue = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(stmt2Expr), stmtCtx)); + + auto atomicCaptureOp = firOpBuilder.create( + currentLocation, hint, memory_order); + firOpBuilder.createBlock(&atomicCaptureOp.getRegion()); + mlir::Block &block = atomicCaptureOp.getRegion().back(); + firOpBuilder.setInsertionPointToStart(&block); + + if (checkForSingleVariableOnRHS(stmt1)) { + if (checkForSymbolMatch(stmt2)) { + // Atomic capture construct is of the form [capture-stmt, update-stmt] + genOmpAtomicCaptureStatement(converter, eval, stmt1Var, stmt1Expr, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr); + genOmpAtomicUpdateStatement(converter, eval, stmt2Var, stmt2Expr, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr); + } else { + // Atomic capture construct is of the form [capture-stmt, write-stmt] + genOmpAtomicCaptureStatement(converter, eval, stmt1Var, stmt1Expr, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr); + // Use pre-evaluated RHS expression value `evaluatedExprValue` + genOmpAtomicWriteStatement( + converter, eval, stmt2Var, stmt2Expr, /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr, &evaluatedExprValue); + } + } else { + // Atomic capture construct is of the form [update-stmt, capture-stmt] + firOpBuilder.setInsertionPointToEnd(&block); + genOmpAtomicCaptureStatement(converter, eval, stmt2Var, stmt2Expr, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr); + firOpBuilder.setInsertionPointToStart(&block); + genOmpAtomicUpdateStatement(converter, eval, stmt1Var, stmt1Expr, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr); + } + firOpBuilder.setInsertionPointToEnd(&block); + firOpBuilder.create(currentLocation); + firOpBuilder.setInsertionPointToStart(&block); +} + static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -1472,8 +1614,8 @@ [&](const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { genOmpAtomicUpdate(converter, eval, atomicUpdate); }, - [&](const auto &) { - TODO(converter.getCurrentLocation(), "Atomic capture"); + [&](const Fortran::parser::OmpAtomicCapture &atomicCapture) { + genOmpAtomicCapture(converter, eval, atomicCapture); }, }, atomicConstruct.u); diff --git a/flang/test/Lower/OpenMP/atomic-capture.f90 b/flang/test/Lower/OpenMP/atomic-capture.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/atomic-capture.f90 @@ -0,0 +1,91 @@ +! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks the lowering of atomic capture + +program OmpAtomicCapture + use omp_lib + integer :: x, y + +!CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"} +!CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"} +!CHECK: omp.atomic.capture memory_order(release) { +!CHECK: omp.atomic.read %[[X]] = %[[Y]] : !fir.ref +!CHECK: omp.atomic.update %[[Y]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[ARG]] : i32 +!CHECK: omp.yield(%[[result]] : i32) +!CHECK: } +!CHECK: } + + !$omp atomic capture release + x = y + y = x + y + !$omp end atomic + + +!CHECK: omp.atomic.capture hint(uncontended) { +!CHECK: omp.atomic.update %[[Y]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: %[[result:.*]] = arith.muli %[[temp]], %[[ARG]] : i32 +!CHECK: omp.yield(%[[result]] : i32) +!CHECK: } +!CHECK: omp.atomic.read %[[X]] = %[[Y]] : !fir.ref +!CHECK: } + + !$omp atomic hint(omp_sync_hint_uncontended) capture + y = x * y + x = y + !$omp end atomic + +!CHECK: %[[constant_20:.*]] = arith.constant 20 : i32 +!CHECK: %[[constant_8:.*]] = arith.constant 8 : i32 +!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: %[[result:.*]] = arith.subi %[[constant_8]], %[[temp]] : i32 +!CHECK: %[[result_noreassoc:.*]] = fir.no_reassoc %[[result]] : i32 +!CHECK: %[[result:.*]] = arith.addi %[[constant_20]], %[[result_noreassoc]] : i32 +!CHECK: omp.atomic.capture memory_order(acquire) hint(nonspeculative) { +!CHECK: omp.atomic.read %[[X]] = %[[Y]] : !fir.ref +!CHECK: omp.atomic.write %[[Y]] = %[[result]] : !fir.ref, i32 +!CHECK: } + + !$omp atomic hint(omp_lock_hint_nonspeculative) capture acquire + x = y + y = 2 * 10 + (8 - x) + !$omp end atomic + + +!CHECK: %[[constant_20:.*]] = arith.constant 20 : i32 +!CHECK: %[[constant_8:.*]] = arith.constant 8 : i32 +!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: %[[result:.*]] = arith.subi %[[constant_8]], %[[temp]] : i32 +!CHECK: %[[result_noreassoc:.*]] = fir.no_reassoc %[[result]] : i32 +!CHECK: %[[result:.*]] = arith.addi %[[constant_20]], %[[result_noreassoc]] : i32 +!CHECK: omp.atomic.capture { +!CHECK: omp.atomic.read %[[X]] = %[[Y]] : !fir.ref +!CHECK: omp.atomic.write %[[Y]] = %[[result]] : !fir.ref, i32 +!CHECK: } + + !$omp atomic capture + x = y + y = 2 * 10 + (8 - x) + !$omp end atomic +end program + + +!TODO: Currently giving expected three operations in omp.atomic.capture +!region (one terminator, and two atomic ops). Decide on an +!approach +subroutine pointers_in_atomic_capture() + integer, pointer :: a, b + integer, target :: c, d + a=>c + b=>d + + !!$omp atomic capture + a = a + b + b = a + !!$omp end atomic +end subroutine