diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -276,7 +276,8 @@ mlir::Location &loc, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OmpClauseList *clauses = nullptr, const SmallVector &args = {}, - bool outerCombined = false) { + bool outerCombined = false, + const Fortran::parser::Expr *expr = nullptr) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); // If an argument for the region is provided then create the block with that // argument. Also update the symbol's address with the mlir argument value. @@ -287,11 +288,21 @@ std::size_t loopVarTypeSize = 0; for (const Fortran::semantics::Symbol *arg : args) loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); - mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); + mlir::Type varType; + if constexpr (std::is_same_v) { + // In case of AtomicUpdate assignment statement, let LHS variable type = + // RHS expression type + Fortran::lower::StatementContext stmtCtx; + mlir::Value result = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(*expr), stmtCtx)); + varType = result.getType(); + } else { + varType = getLoopVarType(converter, loopVarTypeSize); + } SmallVector tiv; SmallVector locs; for (int i = 0; i < (int)args.size(); i++) { - tiv.push_back(loopVarType); + tiv.push_back(varType); locs.push_back(loc); } firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs); @@ -302,7 +313,7 @@ mlir::Value val = fir::getBase(op.getRegion().front().getArgument(argIndex)); mlir::Value temp = firOpBuilder.createTemporary( - loc, loopVarType, + loc, varType, llvm::ArrayRef{ Fortran::lower::getAdaptToByRefAttr(firOpBuilder)}); storeOp = firOpBuilder.create(loc, val, temp); @@ -327,6 +338,11 @@ if constexpr (std::is_same_v) { mlir::ValueRange results; firOpBuilder.create(loc, results); + } else if constexpr (std::is_same_v) { + Fortran::lower::StatementContext stmtCtx; + auto result = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(*expr), stmtCtx)); + firOpBuilder.create(loc, result); } else { firOpBuilder.create(loc); } @@ -967,45 +983,208 @@ } } } +static void genOmpAtomicCaptureStatement( + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::Variable &assignmentStmtVariable, + const Fortran::parser::Expr &assignmentStmtExpr, + const Fortran::parser::OmpAtomicClauseList *leftHandClauseList, + const Fortran::parser::OmpAtomicClauseList *rightHandClauseList) { + // Generate `omp.atomic.read` operation for atomic statements of the form `v = + // x` + auto &firOpBuilder = converter.getFirOpBuilder(); + auto currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; + // Get the address of atomic read operands. + mlir::Value from_address = fir::getBase(converter.genExprAddr( + *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); + mlir::Value to_address = fir::getBase(converter.genExprAddr( + *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); + // If no hint clause is specified, the effect is as if + // hint(omp_sync_hint_none) had been specified. + mlir::IntegerAttr hint = nullptr; + mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; + if (leftHandClauseList) { + genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, + memory_order); + } + if (rightHandClauseList) { + genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, + memory_order); + } + firOpBuilder.create(currentLocation, from_address, + to_address, hint, memory_order); +} -static void -genOmpAtomicWrite(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpAtomicWrite &atomicWrite) { +static void genOmpAtomicWriteStatement( + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::Variable &assignmentStmtVariable, + const Fortran::parser::Expr &assignmentStmtExpr, + const Fortran::parser::OmpAtomicClauseList *leftHandClauseList, + const Fortran::parser::OmpAtomicClauseList *rightHandClauseList) { + // Generate `omp.atomic.write` operation for atomic statements of the form `x + // = expr` auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); - // Get the value and address of atomic write operands. - const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = - std::get<2>(atomicWrite.t); - const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = - std::get<0>(atomicWrite.t); - const auto &assignmentStmtExpr = - std::get(std::get<3>(atomicWrite.t).statement.t); - const auto &assignmentStmtVariable = std::get( - std::get<3>(atomicWrite.t).statement.t); Fortran::lower::StatementContext stmtCtx; + // Get the address of atomic write operands. mlir::Value value = fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); mlir::Value address = fir::getBase(converter.genExprAddr( *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); + // If no hint clause is specified, the effect is as if // hint(omp_sync_hint_none) had been specified. mlir::IntegerAttr hint = nullptr; mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; - genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, - memory_order); - genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, - memory_order); + if (leftHandClauseList) { + genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, + memory_order); + } + if (rightHandClauseList) { + genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, + memory_order); + } firOpBuilder.create(currentLocation, address, value, hint, memory_order); } +static bool checkForAtomicCaptureStmt( + const Fortran::parser::AssignmentStmt &assignmentStmt) { + // Check if the atomic statement is of the structure `v = x` + // Rely on previous phases to ensure correct semantics of `v = x` + const auto &expr{std::get(assignmentStmt.t)}; + return std::visit( + Fortran::common::visitors{ + [&](const Fortran::common::Indirection + &designator) { + if (getDesignatorNameIfDataRef(designator.value())) + return true; // found a variable on the RHS of the atomic + // statement expression + return false; + }, + [&](const auto &) { return false; }, + }, + expr.u); +} + +template +bool isOmpAtomicUpdateStmtOperatorValid( + const T &node, const Fortran::parser::Variable &variable) { + using AllowedBinaryOperators = + std::variant; + using BinaryOperators = + std::variant; + + if constexpr (Fortran::common::HasMember) { + const auto &variableName{variable.GetSource().ToString()}; + const auto &exprLeft{std::get<0>(node.t)}; + const auto &exprRight{std::get<1>(node.t)}; + if ((exprLeft.value().source.ToString() != variableName) && + (exprRight.value().source.ToString() != variableName)) { + return false; + } + return Fortran::common::HasMember; + } + return false; +} + +static bool checkForAtomicUpdateStmt( + const Fortran::parser::AssignmentStmt &assignmentStmt) { + // Check if the atomic statement is of the structure `x = x operator expr` OR + // `x = expr operator x` OR `x = intrinsic_procedure_name(x, expr_list)` OR `x + // = intrinsic_procedure_name(expr_list, x)`. Rely on previous phases to + // ensure correct semantics of these assignment statements. + const auto &expr{std::get(assignmentStmt.t)}; + const auto &var{std::get(assignmentStmt.t)}; + return std::visit( + Fortran::common::visitors{ + [&](const Fortran::common::Indirection< + Fortran::parser::FunctionReference> &) { return true; }, + [&](const auto &x) { + return isOmpAtomicUpdateStmtOperatorValid(x, var); + }, + }, + expr.u); +} + +static void genOmpAtomicUpdateStatement( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::Variable &assignmentStmtVariable, + const Fortran::parser::Expr &assignmentStmtExpr, + const Fortran::parser::OmpAtomicClauseList *leftHandClauseList, + const Fortran::parser::OmpAtomicClauseList *rightHandClauseList) { + + // Generate `omp.atomic.update` operation atomic assignment statements of the + // form `x = x operator expr` OR `x = expr operator x` OR `x = + // intrinsic_procedure_name(x, expr_list)` OR `x = + // intrinsic_procedure_name(expr_list, x)`. + + auto &firOpBuilder = converter.getFirOpBuilder(); + auto currentLocation = converter.getCurrentLocation(); + mlir::Value address; + SmallVector symbolVector; + Fortran::lower::StatementContext stmtCtx; + if (auto varDesignator = std::get_if< + Fortran::common::Indirection>( + &assignmentStmtVariable.u)) { + if (const auto *name = getDesignatorNameIfDataRef(varDesignator->value())) { + address = fir::getBase(converter.genExprAddr( + *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); + symbolVector.push_back(name->symbol); + } + } + // If no hint clause is specified, the effect is as if + // hint(omp_sync_hint_none) had been specified. + mlir::IntegerAttr hint = nullptr; + mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; + if (leftHandClauseList) { + genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, + memory_order); + } + if (rightHandClauseList) { + genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, + memory_order); + } + auto atomicUpdateOp = firOpBuilder.create( + currentLocation, address, hint, memory_order); + createBodyOfOp(atomicUpdateOp, converter, + currentLocation, eval, nullptr, + symbolVector, false, &assignmentStmtExpr); +} + +static void +genOmpAtomicWrite(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpAtomicWrite &atomicWrite) { + const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = + std::get<2>(atomicWrite.t); + const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = + std::get<0>(atomicWrite.t); + const auto &assignmentStmtExpr = + std::get(std::get<3>(atomicWrite.t).statement.t); + const auto &assignmentStmtVariable = std::get( + std::get<3>(atomicWrite.t).statement.t); + genOmpAtomicWriteStatement(converter, assignmentStmtVariable, + assignmentStmtExpr, &leftHandClauseList, + &rightHandClauseList); +} + static void genOmpAtomicRead(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OmpAtomicRead &atomicRead) { - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); - // Get the address of atomic read operands. const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = std::get<2>(atomicRead.t); const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = @@ -1014,21 +1193,125 @@ std::get(std::get<3>(atomicRead.t).statement.t); const auto &assignmentStmtVariable = std::get( std::get<3>(atomicRead.t).statement.t); - Fortran::lower::StatementContext stmtCtx; - mlir::Value from_address = fir::getBase(converter.genExprAddr( - *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); - mlir::Value to_address = fir::getBase(converter.genExprAddr( - *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); - // If no hint clause is specified, the effect is as if - // hint(omp_sync_hint_none) had been specified. + genOmpAtomicCaptureStatement(converter, assignmentStmtVariable, + assignmentStmtExpr, &leftHandClauseList, + &rightHandClauseList); +} + +static void +genOmpAtomicUpdate(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { + const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = + std::get<2>(atomicUpdate.t); + const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = + std::get<0>(atomicUpdate.t); + const auto &assignmentStmtExpr = + std::get(std::get<3>(atomicUpdate.t).statement.t); + const auto &assignmentStmtVariable = std::get( + std::get<3>(atomicUpdate.t).statement.t); + + genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable, + assignmentStmtExpr, &leftHandClauseList, + &rightHandClauseList); +} + +static void genOmpAtomic(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpAtomic &atomicConstruct) { + const Fortran::parser::OmpAtomicClauseList &atomicClauseList = + std::get(atomicConstruct.t); + const auto &assignmentStmtExpr = std::get( + std::get>( + atomicConstruct.t) + .statement.t); + const auto &assignmentStmtVariable = std::get( + std::get>( + atomicConstruct.t) + .statement.t); + genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable, + assignmentStmtExpr, &atomicClauseList, nullptr); +} + +static void +genOmpAtomicCapture(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpAtomicCapture &atomicCapture) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); mlir::IntegerAttr hint = nullptr; mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; + const Fortran::parser::AssignmentStmt &stmt1 = + std::get<3>(atomicCapture.t).v.statement; + const Fortran::parser::AssignmentStmt &stmt2 = + std::get<4>(atomicCapture.t).v.statement; + const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = + std::get<2>(atomicCapture.t); + const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = + std::get<0>(atomicCapture.t); + genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, memory_order); genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, memory_order); - firOpBuilder.create(currentLocation, from_address, - to_address, hint, memory_order); + + auto atomicCaptureOp = firOpBuilder.create( + currentLocation, hint, memory_order); + firOpBuilder.createBlock(&atomicCaptureOp.getRegion()); + mlir::Block &block = atomicCaptureOp.getRegion().back(); + firOpBuilder.setInsertionPointToEnd(&block); + + firOpBuilder.create(currentLocation); + firOpBuilder.setInsertionPointToStart(&block); + if (checkForAtomicCaptureStmt(stmt1) && checkForAtomicUpdateStmt(stmt2)) { + // Atomic capture construct is of the form [capture-stmt, update-stmt] + const auto &assignmentStmt1Expr = std::get(stmt1.t); + const auto &assignmentStmt1Variable = + std::get(stmt1.t); + const auto &assignmentStmt2Expr = std::get(stmt2.t); + const auto &assignmentStmt2Variable = + std::get(stmt2.t); + genOmpAtomicCaptureStatement( + converter, assignmentStmt1Variable, assignmentStmt1Expr, + /*OmpAtomicClauseList =*/nullptr, /*OmpAtomicClauseList =*/nullptr); + genOmpAtomicUpdateStatement( + converter, eval, assignmentStmt2Variable, assignmentStmt2Expr, + /*OmpAtomicClauseList =*/nullptr, /*OmpAtomicClauseList =*/nullptr); + } else if (checkForAtomicUpdateStmt(stmt1) && + checkForAtomicCaptureStmt(stmt2)) { + // Atomic capture construct is of the form [update-stmt, capture-stmt] + const auto &assignmentStmt1Expr = std::get(stmt1.t); + const auto &assignmentStmt1Variable = + std::get(stmt1.t); + const auto &assignmentStmt2Expr = std::get(stmt2.t); + const auto &assignmentStmt2Variable = + std::get(stmt2.t); + // `omp.atomic.read` operation must be created outside the + // `omp.atomic.update` Hence, create these operations in a bottom-up manner: + // first create `omp.atomic.read` and then `omp.atomic.update` + genOmpAtomicCaptureStatement( + converter, assignmentStmt2Variable, assignmentStmt2Expr, + /*OmpAtomicClauseList =*/nullptr, /*OmpAtomicClauseList =*/nullptr); + firOpBuilder.setInsertionPointToStart( + &block); // insert `omp.atomic.update` "above" `omp.atomic.read` + genOmpAtomicUpdateStatement( + converter, eval, assignmentStmt1Variable, assignmentStmt1Expr, + /*OmpAtomicClauseList =*/nullptr, /*OmpAtomicClauseList =*/nullptr); + } else { + // Atomic capture construct is of the form [capture-stmt, write-stmt] + const auto &assignmentStmt1Expr = std::get(stmt1.t); + const auto &assignmentStmt1Variable = + std::get(stmt1.t); + const auto &assignmentStmt2Expr = std::get(stmt2.t); + const auto &assignmentStmt2Variable = + std::get(stmt2.t); + genOmpAtomicCaptureStatement( + converter, assignmentStmt1Variable, assignmentStmt1Expr, + /*OmpAtomicClauseList =*/nullptr, /*OmpAtomicClauseList =*/nullptr); + genOmpAtomicWriteStatement( + converter, assignmentStmt2Variable, assignmentStmt2Expr, + /*OmpAtomicClauseList =*/nullptr, /*OmpAtomicClauseList =*/nullptr); + } } static void @@ -1042,9 +1325,14 @@ [&](const Fortran::parser::OmpAtomicWrite &atomicWrite) { genOmpAtomicWrite(converter, eval, atomicWrite); }, - [&](const auto &) { - TODO(converter.getCurrentLocation(), - "Atomic update & capture"); + [&](const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { + genOmpAtomicUpdate(converter, eval, atomicUpdate); + }, + [&](const Fortran::parser::OmpAtomic &atomicConstruct) { + genOmpAtomic(converter, eval, atomicConstruct); + }, + [&](const Fortran::parser::OmpAtomicCapture &atomicCapture) { + genOmpAtomicCapture(converter, eval, atomicCapture); }, }, atomicConstruct.u); diff --git a/flang/test/Lower/OpenMP/atomic-capture.f90 b/flang/test/Lower/OpenMP/atomic-capture.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/atomic-capture.f90 @@ -0,0 +1,129 @@ +! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s +! RUN: flang-new -fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefix=FIRDialect +! TODO: Add support for pointers + +! This test checks the lowering of atomic capture construct + +!CHECK: %[[TEMP_1:.*]] = fir.alloca i32 {adapt.valuebyref} +!CHECK: %[[TEMP_2:.*]] = fir.alloca i32 {adapt.valuebyref} +!CHECK: %[[VAR_X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"} +!CHECK: %[[VAR_Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"} +!CHECK: omp.atomic.capture memory_order(release) { +!CHECK: omp.atomic.read %[[VAR_X]] = %[[VAR_Y]] : !fir.ref +!CHECK: omp.atomic.update %[[VAR_Y]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: fir.store %[[ARG]] to %[[TEMP_2]] : !fir.ref +!CHECK: %[[INTERMEDIATE_1:.*]] = fir.load %[[VAR_X]] : !fir.ref +!CHECK: %[[INTERMEDIATE_2:.*]] = fir.load %[[TEMP_2]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.addi %[[INTERMEDIATE_1]], %[[INTERMEDIATE_2]] : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: } +!CHECK: omp.atomic.capture hint(uncontended) { +!CHECK: omp.atomic.update %[[VAR_Y]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: fir.store %[[ARG]] to %[[TEMP_1]] : !fir.ref +!CHECK: %[[INTERMEDIATE_3:.*]] = fir.load %[[VAR_X]] : !fir.ref +!CHECK: %[[INTERMEDIATE_4:.*]] = fir.load %[[TEMP_1]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.muli %[[INTERMEDIATE_3]], %[[INTERMEDIATE_4]] : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: omp.atomic.read %[[VAR_X]] = %[[VAR_Y]] : !fir.ref +!CHECK: } +!CHECK: omp.atomic.capture memory_order(acquire) hint(nonspeculative) { +!CHECK: omp.atomic.read %[[VAR_X]] = %[[VAR_Y]] : !fir.ref +!CHECK: {{.*}} = arith.constant {{.*}} : i32 +!CHECK: {{.*}} = arith.constant {{.*}} : i32 +!CHECK: {{.*}} = fir.load %[[VAR_X]] : !fir.ref +!CHECK: {{.*}} = arith.subi {{.*}}, {{.*}} : i32 +!CHECK: {{.*}} = fir.no_reassoc {{.*}} : i32 +!CHECK: %[[INTERMEDIATE_5:.*]] = arith.addi {{.*}}, {{.*}} : i32 +!CHECK: omp.atomic.write %[[VAR_Y]] = %[[INTERMEDIATE_5]] : !fir.ref, i32 +!CHECK: } +!CHECK: omp.atomic.capture { +!CHECK: omp.atomic.read %[[VAR_X]] = %[[VAR_Y]] : !fir.ref +!CHECK: {{.*}} = arith.constant {{.*}} : i32 +!CHECK: {{.*}} = arith.constant {{.*}} : i32 +!CHECK: %4 = fir.load %[[VAR_X]] : !fir.ref +!CHECK: {{.*}} = arith.subi {{.*}}, {{.*}} : i32 +!CHECK: {{.*}} = fir.no_reassoc {{.*}} : i32 +!CHECK: %[[INTERMEDIATE_5:.*]] = arith.addi {{.*}}, {{.*}} : i32 +!CHECK: omp.atomic.write %[[VAR_Y]] = %[[INTERMEDIATE_5]] : !fir.ref, i32 +!CHECK: } +!CHECK: return +!CHECK: } + +!FIRDialect: %[[TEMP_1:.*]] = fir.alloca i32 {adapt.valuebyref} +!FIRDialect: %[[TEMP_2:.*]] = fir.alloca i32 {adapt.valuebyref} +!FIRDialect: %[[VAR_X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"} +!FIRDialect: %[[VAR_Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"} +!FIRDialect: omp.atomic.capture memory_order(release) { +!FIRDialect: omp.atomic.read %[[VAR_X]] = %[[VAR_Y]] : !fir.ref +!FIRDialect: omp.atomic.update %[[VAR_Y]] : !fir.ref { +!FIRDialect: ^bb0(%[[ARG:.*]]: i32): +!FIRDialect: fir.store %[[ARG]] to %[[TEMP_2]] : !fir.ref +!FIRDialect: %[[INTERMEDIATE_1:.*]] = fir.load %[[VAR_X]] : !fir.ref +!FIRDialect: %[[INTERMEDIATE_2:.*]] = fir.load %[[TEMP_2]] : !fir.ref +!FIRDialect: %[[RESULT:.*]] = arith.addi %[[INTERMEDIATE_1]], %[[INTERMEDIATE_2]] : i32 +!FIRDialect: omp.yield(%[[RESULT]] : i32) +!FIRDialect: } +!FIRDialect: } +!FIRDialect: omp.atomic.capture hint(uncontended) { +!FIRDialect: omp.atomic.update %[[VAR_Y]] : !fir.ref { +!FIRDialect: ^bb0(%[[ARG:.*]]: i32): +!FIRDialect: fir.store %[[ARG]] to %[[TEMP_1]] : !fir.ref +!FIRDialect: %[[INTERMEDIATE_3:.*]] = fir.load %[[VAR_X]] : !fir.ref +!FIRDialect: %[[INTERMEDIATE_4:.*]] = fir.load %[[TEMP_1]] : !fir.ref +!FIRDialect: %[[RESULT:.*]] = arith.muli %[[INTERMEDIATE_3]], %[[INTERMEDIATE_4]] : i32 +!FIRDialect: omp.yield(%[[RESULT]] : i32) +!FIRDialect: } +!FIRDialect: omp.atomic.read %[[VAR_X]] = %[[VAR_Y]] : !fir.ref +!FIRDialect: } +!FIRDialect: omp.atomic.capture memory_order(acquire) hint(nonspeculative) { +!FIRDialect: omp.atomic.read %[[VAR_X]] = %[[VAR_Y]] : !fir.ref +!FIRDialect: {{.*}} = arith.constant {{.*}} : i32 +!FIRDialect: {{.*}} = arith.constant {{.*}} : i32 +!FIRDialect: {{.*}} = fir.load %[[VAR_X]] : !fir.ref +!FIRDialect: {{.*}} = arith.subi {{.*}}, {{.*}} : i32 +!FIRDialect: {{.*}} = fir.no_reassoc {{.*}} : i32 +!FIRDialect: %[[INTERMEDIATE_5:.*]] = arith.addi {{.*}}, {{.*}} : i32 +!FIRDialect: omp.atomic.write %[[VAR_Y]] = %[[INTERMEDIATE_5]] : !fir.ref, i32 +!FIRDialect: } +!FIRDialect: omp.atomic.capture { +!FIRDialect: omp.atomic.read %[[VAR_X]] = %[[VAR_Y]] : !fir.ref +!FIRDialect: {{.*}} = arith.constant {{.*}} : i32 +!FIRDialect: {{.*}} = arith.constant {{.*}} : i32 +!FIRDialect: %4 = fir.load %[[VAR_X]] : !fir.ref +!FIRDialect: {{.*}} = arith.subi {{.*}}, {{.*}} : i32 +!FIRDialect: {{.*}} = fir.no_reassoc {{.*}} : i32 +!FIRDialect: %[[INTERMEDIATE_5:.*]] = arith.addi {{.*}}, {{.*}} : i32 +!FIRDialect: omp.atomic.write %[[VAR_Y]] = %[[INTERMEDIATE_5]] : !fir.ref, i32 +!FIRDialect: } +!FIRDialect: return +!FIRDialect: } + +program sample + use omp_lib + integer :: x, y + + !$omp atomic capture release + x = y + y = x + y + !$omp end atomic + + !$omp atomic hint(omp_sync_hint_uncontended) capture + y = x * y + x = y + !$omp end atomic + + !$omp atomic hint(omp_lock_hint_nonspeculative) capture acquire + x = y + y = 2 * 10 + (8 - x) + !$omp end atomic + + + !$omp atomic capture + x = y + y = 2 * 10 + (8 - x) + !$omp end atomic +end program diff --git a/flang/test/Lower/OpenMP/atomic-update.f90 b/flang/test/Lower/OpenMP/atomic-update.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/atomic-update.f90 @@ -0,0 +1,155 @@ +! This test checks lowering of atomic update construct +! RUN: bbc -fopenmp -emit-fir %s -o - | \ +! RUN: FileCheck %s + +program OmpAtomicUpdate + use omp_lib + integer :: x, y, z + integer, pointer :: a, b + integer, target :: c, d + a=>c + b=>d + +!CHECK: %[[TEMP_1:.*]] = fir.alloca i32 {adapt.valuebyref} +!CHECK: %[[TEMP_2:.*]] = fir.alloca i32 {adapt.valuebyref} +!CHECK: %[[TEMP_3:.*]] = fir.alloca i32 {adapt.valuebyref} +!CHECK: %[[TEMP_4:.*]] = fir.alloca i32 {adapt.valuebyref} +!CHECK: %[[TEMP_5:.*]] = fir.alloca i32 {adapt.valuebyref} +!CHECK: %[[TEMP_6:.*]] = fir.alloca i32 {adapt.valuebyref} +!CHECK: %[[TEMP_7:.*]] = fir.alloca i32 {adapt.valuebyref} +!CHECK: %[[TEMP_8:.*]] = fir.alloca i32 {adapt.valuebyref} +!CHECK: %[[TEMP_9:.*]] = fir.alloca i32 {adapt.valuebyref} +!CHECK: %[[TEMP_10:.*]] = fir.alloca i32 {adapt.valuebyref} +!CHECK: {{.*}} = fir.alloca !fir.box> {bindc_name = "a", uniq_name = "_QFEa"} +!CHECK: {{.*}} = fir.alloca !fir.ptr {uniq_name = "_QFEa.addr"} +!CHECK: {{.*}} = fir.zero_bits !fir.ptr +!CHECK: fir.store {{.*}} to {{.*}} : !fir.ref> +!CHECK: {{.*}} = fir.alloca !fir.box> {bindc_name = "b", uniq_name = "_QFEb"} +!CHECK: %[[b_ADDR:.*]] = fir.alloca !fir.ptr {uniq_name = "_QFEb.addr"} +!CHECK: {{.*}} = fir.zero_bits !fir.ptr +!CHECK: fir.store {{.*}} to {{.*}} : !fir.ref> +!CHECK: {{.*}} = fir.address_of(@_QFEc) : !fir.ref +!CHECK: {{.*}} = fir.address_of(@_QFEd) : !fir.ref +!CHECK: %[[VAR_X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"} +!CHECK: %[[VAR_Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"} +!CHECK: %[[VAR_Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFEz"} +!CHECK: {{.*}} = fir.convert {{.*}} : (!fir.ref) -> !fir.ptr +!CHECK: fir.store {{.*}} to {{.*}} : !fir.ref> +!CHECK: {{.*}} = fir.convert {{.*}} : (!fir.ref) -> !fir.ptr +!CHECK: fir.store {{.*}} to {{.*}} : !fir.ref> +!CHECK: %[[LOADED_a_ADDR:.*]] = fir.load {{.*}} : !fir.ref> + + +!CHECK: omp.atomic.update %[[LOADED_a_ADDR]] : !fir.ptr { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: fir.store %[[ARG]] to %[[TEMP_10]] : !fir.ref +!CHECK: {{.*}} = fir.load %[[TEMP_10]] : !fir.ref +!CHECK: {{.*}} = fir.load %[[b_ADDR]] : !fir.ref> +!CHECK: {{.*}} = fir.load {{.*}} : !fir.ptr +!CHECK: %[[RESULT:.*]] = arith.addi {{.*}}, {{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } + !$omp atomic update + a = a + b + + +!CHECK: omp.atomic.update %[[VAR_Y]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: fir.store %[[ARG]] to %[[TEMP_9]] : !fir.ref +!CHECK: {{.*}} = fir.load %[[TEMP_9]] : !fir.ref +!CHECK: {{.*}} = arith.constant 1 : i32 +!CHECK: %[[RESULT:.*]] = arith.addi %{{.*}}, {{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: omp.atomic.update %[[VAR_Z]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: fir.store %[[ARG]] to %[[TEMP_8]] : !fir.ref +!CHECK: {{.*}} = fir.load %[[VAR_X]] : !fir.ref +!CHECK: %{{.*}} = fir.load %[[TEMP_8]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.muli {{.*}}, {{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } + !$omp atomic + y = y + 1 + !$omp atomic update + z = x * z + +!CHECK: omp.atomic.update memory_order(relaxed) hint(uncontended) %[[VAR_X]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: fir.store %[[ARG]] to %[[TEMP_7]] : !fir.ref +!CHECK: %{{.*}} = fir.load %[[TEMP_7]] : !fir.ref +!CHECK: %{{.*}} = arith.constant 1 : i32 +!CHECK: %[[RESULT:.*]] = arith.subi {{.*}}, {{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK:} +!CHECK: omp.atomic.update memory_order(relaxed) %[[VAR_Y]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: fir.store %[[ARG]] to %[[TEMP_6]] : !fir.ref +!CHECK: {{.*}} = fir.load %[[VAR_X]] : !fir.ref +!CHECK: {{.*}} = fir.load %[[TEMP_6]] : !fir.ref +!CHECK: {{.*}} = fir.load %[[VAR_Z]] : !fir.ref +!CHECK: {{.*}} = arith.cmpi sgt, {{.*}}, {{.*}} : i32 +!CHECK: {{.*}} = arith.select {{.*}}, {{.*}}, {{.*}} : i32 +!CHECK: {{.*}} = arith.cmpi sgt, {{.*}}, {{.*}} : i32 +!CHECK: %[[RESULT:.*]] = arith.select {{.*}}, {{.*}}, {{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: omp.atomic.update memory_order(relaxed) hint(contended) %[[VAR_Z]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: fir.store %[[ARG]] to %[[TEMP_5]] : !fir.ref +!CHECK: %{{.*}} = fir.load %[[TEMP_5]] : !fir.ref +!CHECK: {{.*}} = fir.load %[[VAR_X]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.addi {{.*}}, {{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } + !$omp atomic relaxed update hint(omp_sync_hint_uncontended) + x = x - 1 + !$omp atomic update relaxed + y = max(x, y, z) + !$omp atomic relaxed hint(omp_sync_hint_contended) + z = z + x + +!CHECK: omp.atomic.update memory_order(release) hint(contended) %[[VAR_Z]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: fir.store %[[ARG]] to %[[TEMP_4]] : !fir.ref +!CHECK: {{.*}} = arith.constant 10 : i32 +!CHECK: {{.*}} = fir.load %[[TEMP_4]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.muli {{.*}}, {{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: omp.atomic.update memory_order(release) hint(speculative) %[[VAR_X]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: fir.store %[[ARG]] to %[[TEMP_3]] : !fir.ref +!CHECK: {{.*}} = fir.load %[[TEMP_3]] : !fir.ref +!CHECK: {{.*}} = fir.load %[[VAR_Z]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.divsi {{.*}}, {{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } + !$omp atomic release update hint(omp_lock_hint_contended) + z = z * 10 + !$omp atomic hint(omp_lock_hint_speculative) update release + x = x / z + +!CHECK: omp.atomic.update memory_order(seq_cst) hint(nonspeculative) %[[VAR_Y]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: fir.store %[[ARG]] to %[[TEMP_2]] : !fir.ref +!CHECK: {{.*}} = arith.constant 10 : i32 +!CHECK: {{.*}} = fir.load %[[TEMP_2]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.addi {{.*}}, {{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: omp.atomic.update memory_order(seq_cst) %[[VAR_Z]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: fir.store %[[ARG]] to %[[TEMP_1]] : !fir.ref +!CHECK: {{.*}} = fir.load %[[VAR_Y]] : !fir.ref +!CHECK: {{.*}} = fir.load %[[TEMP_1]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.addi {{.*}}, {{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: return +!CHECK: } + !$omp atomic hint(omp_sync_hint_nonspeculative) seq_cst + y = 10 + y + !$omp atomic seq_cst update + z = y + z +end program OmpAtomicUpdate diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -920,17 +920,36 @@ LogicalResult AtomicCaptureOp::verifyRegions() { Block::OpListType &ops = region().front().getOperations(); - if (ops.size() != 3) + int numberOfOmpOps{0}; + for (auto &op : ops) { + if (dyn_cast(op) || dyn_cast(op) || + dyn_cast(op)) + numberOfOmpOps++; + } + if (!(numberOfOmpOps == 2 && dyn_cast(ops.back()))) return emitError() << "expected three operations in omp.atomic.capture region (one " "terminator, and two atomic ops)"; + auto &firstOp = ops.front(); - auto &secondOp = *ops.getNextNode(firstOp); - auto firstReadStmt = dyn_cast(firstOp); auto firstUpdateStmt = dyn_cast(firstOp); + auto firstReadStmt = dyn_cast(firstOp); + auto &secondOp = *ops.getNextNode(firstOp); auto secondReadStmt = dyn_cast(secondOp); auto secondUpdateStmt = dyn_cast(secondOp); auto secondWriteStmt = dyn_cast(secondOp); + if (!secondWriteStmt && !secondUpdateStmt) { + // If second statement is neither `omp.atomic.write` nor + // `omp.atomic.update`, then the `omp.atomic.capture` structure is + // [capture-stmt, write-stmt] and `write-stmt` occurs (if it occurs at all!) + // as the second last statement of the block. Verify it thus + + for (auto &op : ops) { + secondWriteStmt = dyn_cast(op); + if (secondWriteStmt) + break; + } + } if (!((firstUpdateStmt && secondReadStmt) || (firstReadStmt && secondUpdateStmt) ||