diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -26,6 +26,7 @@ #include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" using DeclareTargetCapturePair = @@ -3094,17 +3095,6 @@ if (rightHandClauseList) genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, memoryOrder); - auto atomicUpdateOp = firOpBuilder.create( - currentLocation, lhsAddr, hint, memoryOrder); - - //// Generate body of Atomic Update operation - // If an argument for the region is provided then create the block with that - // argument. Also update the symbol's address with the argument mlir value. - llvm::SmallVector varTys = {varType}; - llvm::SmallVector locs = {currentLocation}; - firOpBuilder.createBlock(&atomicUpdateOp.getRegion(), {}, varTys, locs); - mlir::Value val = - fir::getBase(atomicUpdateOp.getRegion().front().getArgument(0)); const auto *varDesignator = std::get_if>( &assignmentStmtVariable.u); @@ -3117,21 +3107,96 @@ "Array references as atomic update variable"); assert(name && name->symbol && "No symbol attached to atomic update variable"); - converter.bindSymbol(*name->symbol, val); - // Set the insert for the terminator operation to go at the end of the - // block. - mlir::Block &block = atomicUpdateOp.getRegion().back(); + if (Fortran::semantics::IsAllocatableOrPointer(name->symbol->GetUltimate())) + converter.bindSymbol(*name->symbol, lhsAddr); + + // Lowering is in two steps : + // subroutine sb + // integer :: a, b + // !$omp atomic update + // a = a + b + // end subroutine + // + // 1. Lower to scf.execute_region_op + // + // func.func @_QPsb() { + // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"} + // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"} + // %2 = scf.execute_region -> i32 { + // %3 = fir.load %0 : !fir.ref + // %4 = fir.load %1 : !fir.ref + // %5 = arith.addi %3, %4 : i32 + // scf.yield %5 : i32 + // } + // return + // } + auto tempOp = + firOpBuilder.create(currentLocation, varType); + firOpBuilder.createBlock(&tempOp.getRegion()); + mlir::Block &block = tempOp.getRegion().back(); firOpBuilder.setInsertionPointToEnd(&block); - Fortran::lower::StatementContext stmtCtx; mlir::Value rhsExpr = fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); mlir::Value convertResult = firOpBuilder.createConvert(currentLocation, varType, rhsExpr); // Insert the terminator: YieldOp. - firOpBuilder.create(currentLocation, convertResult); - // Reset the insert point to before the terminator. + firOpBuilder.create(currentLocation, convertResult); firOpBuilder.setInsertionPointToStart(&block); + + // 2. Create the omp.atomic.update Operation using the Operations in the + // temporary scf.execute_region Operation. + // + // func.func @_QPsb() { + // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"} + // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"} + // %2 = fir.load %1 : !fir.ref + // omp.atomic.update %0 : !fir.ref { + // ^bb0(%arg0: i32): + // %3 = fir.load %1 : !fir.ref + // %4 = arith.addi %arg0, %3 : i32 + // omp.yield(%3 : i32) + // } + // return + // } + mlir::Value updateVar = converter.getSymbolAddress(*name->symbol); + if (auto decl = updateVar.getDefiningOp()) + updateVar = decl.getBase(); + + firOpBuilder.setInsertionPointAfter(tempOp); + auto atomicUpdateOp = firOpBuilder.create( + currentLocation, updateVar, hint, memoryOrder); + + llvm::SmallVector varTys = {varType}; + llvm::SmallVector locs = {currentLocation}; + firOpBuilder.createBlock(&atomicUpdateOp.getRegion(), {}, varTys, locs); + mlir::Value val = + fir::getBase(atomicUpdateOp.getRegion().front().getArgument(0)); + + llvm::SmallVector ops; + for (mlir::Operation &op : tempOp.getRegion().getOps()) + ops.push_back(&op); + + // SCF Yield is converted to OMP Yield. All other operations are copied + for (mlir::Operation *op : ops) { + if (auto y = mlir::dyn_cast(op)) { + firOpBuilder.setInsertionPointToEnd(&atomicUpdateOp.getRegion().front()); + firOpBuilder.create(currentLocation, y.getResults()); + op->erase(); + } else { + op->remove(); + atomicUpdateOp.getRegion().front().push_back(op); + } + } + + // Remove the load and replace all uses of load with the block argument + for (mlir::Operation &op : atomicUpdateOp.getRegion().getOps()) { + fir::LoadOp y = mlir::dyn_cast(&op); + if (y && y.getMemref() == updateVar) + y.getRes().replaceAllUsesWith(val); + } + + tempOp.erase(); } static void diff --git a/flang/test/Lower/OpenMP/atomic-update-hlfir.f90 b/flang/test/Lower/OpenMP/atomic-update-hlfir.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/atomic-update-hlfir.f90 @@ -0,0 +1,23 @@ +! This test checks lowering of atomic and atomic update constructs with HLFIR +! RUN: bbc -hlfir -fopenmp -emit-hlfir %s -o - | FileCheck %s +! RUN: %flang_fc1 -flang-experimental-hlfir -emit-hlfir -fopenmp %s -o - | FileCheck %s + +subroutine sb + integer :: x, y + + !$omp atomic update + x = x + y +end subroutine + +!CHECK-LABEL: @_QPsb +!CHECK: %[[X_REF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsbEx"} +!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"} +!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: omp.atomic.update %[[X_DECL]]#0 : !fir.ref { +!CHECK: ^bb0(%[[ARG_X:.*]]: i32): +!CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref +!CHECK: %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32 +!CHECK: omp.yield(%[[X_UPDATE_VAL]] : i32) +!CHECK: } +!CHECK: return