diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -271,13 +271,16 @@ //// region. /// \param [in] outerCombined - is this an outer operation - prevents /// privatization. +/// \param [in] expr - the expression whose evaluation's extended +//// value is required template static void createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OmpClauseList *clauses = nullptr, const SmallVector &args = {}, - bool outerCombined = false) { + bool outerCombined = false, + const Fortran::parser::Expr *expr = nullptr) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); // If an argument for the region is provided then create the block with that // argument. Also update the symbol's address with the mlir argument value. @@ -288,27 +291,39 @@ std::size_t loopVarTypeSize = 0; for (const Fortran::semantics::Symbol *arg : args) loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); - mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); + mlir::Type varType; + if constexpr (std::is_same_v) { + Fortran::lower::StatementContext statementCtx; + mlir::Value result = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*expr), statementCtx)); + varType = result.getType(); + } else { + varType = getLoopVarType(converter, loopVarTypeSize); + } SmallVector tiv; SmallVector locs; for (int i = 0; i < (int)args.size(); i++) { - tiv.push_back(loopVarType); + tiv.push_back(varType); locs.push_back(loc); } firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs); - int argIndex = 0; - // The argument is not currently in memory, so make a temporary for the - // argument, and store it there, then bind that location to the argument. - for (const Fortran::semantics::Symbol *arg : args) { - mlir::Value val = - fir::getBase(op.getRegion().front().getArgument(argIndex)); - mlir::Value temp = firOpBuilder.createTemporary( - loc, loopVarType, - llvm::ArrayRef{ - Fortran::lower::getAdaptToByRefAttr(firOpBuilder)}); - storeOp = firOpBuilder.create(loc, val, temp); - converter.bindSymbol(*arg, temp); - argIndex++; + if constexpr (!std::is_same_v) { + // No need to create a temporary for the argument in case of + // omp::AtomicUpdateOp + int argIndex = 0; + // The argument is not currently in memory, so make a temporary for the + // argument, and store it there, then bind that location to the argument. + for (const Fortran::semantics::Symbol *arg : args) { + mlir::Value val = + fir::getBase(op.getRegion().front().getArgument(argIndex)); + mlir::Value temp = firOpBuilder.createTemporary( + loc, varType, + llvm::ArrayRef{ + Fortran::lower::getAdaptToByRefAttr(firOpBuilder)}); + storeOp = firOpBuilder.create(loc, val, temp); + converter.bindSymbol(*arg, temp); + argIndex++; + } } } else { firOpBuilder.createBlock(&op.getRegion()); @@ -329,6 +344,11 @@ std::is_same_v) { mlir::ValueRange results; firOpBuilder.create(loc, results); + } else if constexpr (std::is_same_v) { + Fortran::lower::StatementContext statementCtx; + mlir::Value result = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*expr), statementCtx)); + firOpBuilder.create(loc, result); } else { firOpBuilder.create(loc); } @@ -1054,6 +1074,45 @@ } } +static void genOmpAtomicUpdateStatement( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::Variable &assignmentStmtVariable, + const Fortran::parser::Expr &assignmentStmtExpr, + const Fortran::parser::OmpAtomicClauseList *leftHandClauseList, + const Fortran::parser::OmpAtomicClauseList *rightHandClauseList) { + // Generate `omp.atomic.update` operation for atomic assignment statements + auto &firOpBuilder = converter.getFirOpBuilder(); + auto currentLocation = converter.getCurrentLocation(); + mlir::Value address; + SmallVector symbolVector; + Fortran::lower::StatementContext stmtCtx; + if (auto varDesignator = std::get_if< + Fortran::common::Indirection>( + &assignmentStmtVariable.u)) { + if (const auto *name = getDesignatorNameIfDataRef(varDesignator->value())) { + address = fir::getBase(converter.genExprAddr( + *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); + symbolVector.push_back(name->symbol); + } + } + // If no hint clause is specified, the effect is as if + // hint(omp_sync_hint_none) had been specified. + mlir::IntegerAttr hint = nullptr; + mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; + if (leftHandClauseList) + genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, + memory_order); + if (rightHandClauseList) + genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, + memory_order); + auto atomicUpdateOp = firOpBuilder.create( + currentLocation, address, hint, memory_order); + createBodyOfOp(atomicUpdateOp, converter, + currentLocation, eval, nullptr, + symbolVector, false, &assignmentStmtExpr); +} + static void genOmpAtomicWrite(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -1117,6 +1176,43 @@ to_address, hint, memory_order); } +static void +genOmpAtomicUpdate(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { + const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = + std::get<2>(atomicUpdate.t); + const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = + std::get<0>(atomicUpdate.t); + const auto &assignmentStmtExpr = + std::get(std::get<3>(atomicUpdate.t).statement.t); + const auto &assignmentStmtVariable = std::get( + std::get<3>(atomicUpdate.t).statement.t); + + genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable, + assignmentStmtExpr, &leftHandClauseList, + &rightHandClauseList); +} + +static void genOmpAtomic(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpAtomic &atomicConstruct) { + const Fortran::parser::OmpAtomicClauseList &atomicClauseList = + std::get(atomicConstruct.t); + const auto &assignmentStmtExpr = std::get( + std::get>( + atomicConstruct.t) + .statement.t); + const auto &assignmentStmtVariable = std::get( + std::get>( + atomicConstruct.t) + .statement.t); + // If atomic-clause is not present on the construct, the behaviour is as if + // the update clause is specified + genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable, + assignmentStmtExpr, &atomicClauseList, nullptr); +} + static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -1128,9 +1224,14 @@ [&](const Fortran::parser::OmpAtomicWrite &atomicWrite) { genOmpAtomicWrite(converter, eval, atomicWrite); }, + [&](const Fortran::parser::OmpAtomic &atomicConstruct) { + genOmpAtomic(converter, eval, atomicConstruct); + }, + [&](const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { + genOmpAtomicUpdate(converter, eval, atomicUpdate); + }, [&](const auto &) { - TODO(converter.getCurrentLocation(), - "Atomic update & capture"); + TODO(converter.getCurrentLocation(), "Atomic capture"); }, }, atomicConstruct.u); diff --git a/flang/test/Lower/OpenMP/atomic-update.f90 b/flang/test/Lower/OpenMP/atomic-update.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/atomic-update.f90 @@ -0,0 +1,135 @@ +! This test checks lowering of atomic and atomic update constructs +! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +program OmpAtomicUpdate + use omp_lib + integer :: x, y, z + integer, pointer :: a, b + integer, target :: c, d + a=>c + b=>d + +!CHECK: func.func @_QQmain() { +!CHECK: %[[A:.*]] = fir.alloca !fir.box> {bindc_name = "a", uniq_name = "_QFEa"} +!CHECK: %[[A_ADDR:.*]] = fir.alloca !fir.ptr {uniq_name = "_QFEa.addr"} +!CHECK: %{{.*}} = fir.zero_bits !fir.ptr +!CHECK: fir.store %{{.*}} to %[[A_ADDR]] : !fir.ref> +!CHECK: %[[B:.*]] = fir.alloca !fir.box> {bindc_name = "b", uniq_name = "_QFEb"} +!CHECK: %[[B_ADDR:.*]] = fir.alloca !fir.ptr {uniq_name = "_QFEb.addr"} +!CHECK: %{{.*}} = fir.zero_bits !fir.ptr +!CHECK: fir.store %{{.*}} to %[[B_ADDR]] : !fir.ref> +!CHECK: %[[C_ADDR:.*]] = fir.address_of(@_QFEc) : !fir.ref +!CHECK: %[[D_ADDR:.*]] = fir.address_of(@_QFEd) : !fir.ref +!CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"} +!CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"} +!CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFEz"} +!CHECK: %{{.*}} = fir.convert %[[C_ADDR]] : (!fir.ref) -> !fir.ptr +!CHECK: fir.store %{{.*}} to %[[A_ADDR]] : !fir.ref> +!CHECK: %{{.*}} = fir.convert %[[D_ADDR]] : (!fir.ref) -> !fir.ptr +!CHECK: fir.store {{.*}} to %[[B_ADDR]] : !fir.ref> +!CHECK: %[[LOADED_A:.*]] = fir.load %[[A_ADDR]] : !fir.ref> +!CHECK: omp.atomic.update %[[LOADED_A]] : !fir.ptr { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: %[[LOADED_A:.*]] = fir.load %[[A_ADDR]] : !fir.ref> +!CHECK: %{{.*}} = fir.load %[[LOADED_A]] : !fir.ptr +!CHECK: %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref> +!CHECK: %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr +!CHECK: %[[RESULT:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } + !$omp atomic update + a = a + b + +!CHECK: omp.atomic.update %[[Y]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: %[[LOADED_Y:.*]] = fir.load %[[Y]] : !fir.ref +!CHECK: {{.*}} = arith.constant 1 : i32 +!CHECK: %[[RESULT:.*]] = arith.addi %[[LOADED_Y]], {{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: omp.atomic.update %[[Z]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.muli %[[LOADED_X]], %[[LOADED_Z]] : i32 +!CHECK: omp.yield(%16 : i32) +!CHECK: } + !$omp atomic + y = y + 1 + !$omp atomic update + z = x * z + +!CHECK: omp.atomic.update memory_order(relaxed) hint(uncontended) %[[X]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: %{{.*}} = arith.constant 1 : i32 +!CHECK: %[[RESULT:.*]] = arith.subi %[[LOADED_X]], {{.*}} : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: omp.atomic.update memory_order(relaxed) %[[Y]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: %[[LOADED_Y:.*]] = fir.load %[[Y]] : !fir.ref +!CHECK: %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref +!CHECK: %{{.*}} = arith.cmpi sgt, %[[LOADED_X]], %[[LOADED_Y]] : i32 +!CHECK: %{{.*}} = arith.select %{{.*}}, %[[LOADED_X]], %[[LOADED_Y]] : i32 +!CHECK: %{{.*}} = arith.cmpi sgt, %{{.*}}, %[[LOADED_Z]] : i32 +!CHECK: %[[RESULT:.*]] = arith.select %{{.*}}, %{{.*}}, %[[LOADED_Z]] : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: omp.atomic.update memory_order(relaxed) hint(contended) %[[Z]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref +!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.addi %[[LOADED_Z]], %[[LOADED_X]] : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } + !$omp atomic relaxed update hint(omp_sync_hint_uncontended) + x = x - 1 + !$omp atomic update relaxed + y = max(x, y, z) + !$omp atomic relaxed hint(omp_sync_hint_contended) + z = z + x + +!CHECK: omp.atomic.update memory_order(release) hint(contended) %[[Z]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: %{{.*}} = arith.constant 10 : i32 +!CHECK: %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.muli {{.*}}, %[[LOADED_Z]] : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: omp.atomic.update memory_order(release) hint(speculative) %[[X]] : !fir.ref { +!CHECK: ^bb0(%arg0: i32): +!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.divsi %[[LOADED_X]], %[[LOADED_Z]] : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } + + !$omp atomic release update hint(omp_lock_hint_contended) + z = z * 10 + !$omp atomic hint(omp_lock_hint_speculative) update release + x = x / z + +!CHECK: omp.atomic.update memory_order(seq_cst) hint(nonspeculative) %[[Y]] : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): +!CHECK: %{{.*}} = arith.constant 10 : i32 +!CHECK: %[[LOADED_Y:.*]] = fir.load %[[Y]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.addi %{{.*}}, %[[LOADED_Y]] : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: omp.atomic.update memory_order(seq_cst) %[[Z]] : !fir.ref { +!CHECK: ^bb0(%arg0: i32): +!CHECK: %[[LOADED_Y:.*]] = fir.load %[[Y]] : !fir.ref +!CHECK: %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref +!CHECK: %[[RESULT:.*]] = arith.addi %[[LOADED_Y]], %[[LOADED_Z]] : i32 +!CHECK: omp.yield(%[[RESULT]] : i32) +!CHECK: } +!CHECK: return +!CHECK: } + !$omp atomic hint(omp_sync_hint_nonspeculative) seq_cst + y = 10 + y + !$omp atomic seq_cst update + z = y + z +end program OmpAtomicUpdate