diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -404,6 +404,19 @@ } } +static mlir::Value +getIfClauseOperand(Fortran::lower::AbstractConverter &converter, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClause::If *ifClause) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + auto &expr = std::get(ifClause->v.t); + mlir::Value ifVal = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); + return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), + ifVal); +} + static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter, std::size_t loopVarTypeSize) { // OpenMP runtime requires 32-bit or 64-bit loop variables. @@ -547,6 +560,130 @@ } } +static void +createTargetDataOp(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpClauseList &opClauseList, + const llvm::omp::Directive &directive) { + Fortran::lower::StatementContext stmtCtx; + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + mlir::Value ifClauseOperand, deviceOperand; + mlir::UnitAttr nowaitAttr; + llvm::SmallVector useDevicePtrOperand, useDeviceAddrOperand, + mapOperands; + llvm::SmallVector mapTypes; + + auto addMapClause = [&firOpBuilder, &converter, &mapOperands, + &mapTypes](const auto &mapClause) { + auto mapType = std::get( + std::get>(mapClause->v.t) + ->t); + llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + switch (mapType) { + case Fortran::parser::OmpMapType::Type::To: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + break; + case Fortran::parser::OmpMapType::Type::From: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + break; + case Fortran::parser::OmpMapType::Type::Tofrom: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + break; + case Fortran::parser::OmpMapType::Type::Alloc: + case Fortran::parser::OmpMapType::Type::Release: + // alloc and release is the default map_type for the Target Data Ops, i.e. + // if no bits for map_type is supplied then alloc/release is implicitly + // assumed based on the target directive. Default value for Target Data + // and Enter Data is alloc and for Exit Data it is release. + break; + case Fortran::parser::OmpMapType::Type::Delete: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; + } + if (std::get>( + std::get>(mapClause->v.t) + ->t) + .has_value()) + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + + // TODO: Add support MapTypeModifiers close, mapper, present, iterator + + mlir::IntegerAttr mapTypeAttr = firOpBuilder.getIntegerAttr( + firOpBuilder.getI64Type(), + static_cast< + std::underlying_type_t>( + mapTypeBits)); + + llvm::SmallVector mapOperand; + genObjectList(std::get(mapClause->v.t), + converter, mapOperand); + + for (mlir::Value mapOp : mapOperand) { + mapOperands.push_back(mapOp); + mapTypes.push_back(mapTypeAttr); + } + }; + + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { + mlir::Location currentLocation = converter.genLocation(clause.source); + if (const auto &ifClause = + std::get_if(&clause.u)) { + ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause); + } else if (const auto &deviceClause = + std::get_if(&clause.u)) { + if (auto deviceModifier = std::get< + std::optional>( + deviceClause->v.t)) { + if (deviceModifier == + Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) { + TODO(currentLocation, "OMPD_target Device Modifier Ancestor"); + } + } + if (const auto *deviceExpr = Fortran::semantics::GetExpr( + std::get(deviceClause->v.t))) { + deviceOperand = + fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx)); + } + } else if (std::get_if( + &clause.u)) { + TODO(currentLocation, "OMPD_target Use Device Ptr"); + } else if (std::get_if( + &clause.u)) { + TODO(currentLocation, "OMPD_target Use Device Addr"); + } else if (std::get_if(&clause.u)) { + nowaitAttr = firOpBuilder.getUnitAttr(); + } else if (const auto &mapClause = + std::get_if(&clause.u)) { + addMapClause(mapClause); + } else { + TODO(currentLocation, "OMPD_target unhandled clause"); + } + } + + llvm::SmallVector mapTypesAttr(mapTypes.begin(), + mapTypes.end()); + mlir::ArrayAttr mapTypesArrayAttr = + ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr); + mlir::Location currentLocation = converter.getCurrentLocation(); + + if (directive == llvm::omp::Directive::OMPD_target_data) { + firOpBuilder.create( + currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand, + useDeviceAddrOperand, mapOperands, mapTypesArrayAttr); + } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) { + firOpBuilder.create(currentLocation, ifClauseOperand, + deviceOperand, nowaitAttr, + mapOperands, mapTypesArrayAttr); + } else if (directive == llvm::omp::Directive::OMPD_target_exit_data) { + firOpBuilder.create(currentLocation, ifClauseOperand, + deviceOperand, nowaitAttr, mapOperands, + mapTypesArrayAttr); + } else { + TODO(currentLocation, "OMPD_target directive unknown"); + } +} + static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSimpleStandaloneConstruct @@ -554,25 +691,27 @@ const auto &directive = std::get( simpleStandaloneConstruct.t); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + const Fortran::parser::OmpClauseList &opClauseList = + std::get(simpleStandaloneConstruct.t); + switch (directive.v) { default: break; case llvm::omp::Directive::OMPD_barrier: - converter.getFirOpBuilder().create( - converter.getCurrentLocation()); + firOpBuilder.create(converter.getCurrentLocation()); break; case llvm::omp::Directive::OMPD_taskwait: - converter.getFirOpBuilder().create( - converter.getCurrentLocation()); + firOpBuilder.create(converter.getCurrentLocation()); break; case llvm::omp::Directive::OMPD_taskyield: - converter.getFirOpBuilder().create( - converter.getCurrentLocation()); + firOpBuilder.create(converter.getCurrentLocation()); break; + case llvm::omp::Directive::OMPD_target_data: case llvm::omp::Directive::OMPD_target_enter_data: - TODO(converter.getCurrentLocation(), "OMPD_target_enter_data"); case llvm::omp::Directive::OMPD_target_exit_data: - TODO(converter.getCurrentLocation(), "OMPD_target_exit_data"); + createTargetDataOp(converter, opClauseList, directive.v); + break; case llvm::omp::Directive::OMPD_target_update: TODO(converter.getCurrentLocation(), "OMPD_target_update"); case llvm::omp::Directive::OMPD_ordered: @@ -669,19 +808,6 @@ return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind); } -static mlir::Value -getIfClauseOperand(Fortran::lower::AbstractConverter &converter, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClause::If *ifClause) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location currentLocation = converter.getCurrentLocation(); - auto &expr = std::get(ifClause->v.t); - mlir::Value ifVal = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); - return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), - ifVal); -} - /* When parallel is used in a combined construct, then use this function to * create the parallel operation. It handles the parallel specific clauses * and leaves the rest for handling at the inner operations. diff --git a/flang/test/Lower/OpenMP/target_data.f90 b/flang/test/Lower/OpenMP/target_data.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/target_data.f90 @@ -0,0 +1,105 @@ +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +!=============================================================================== +! Target_Enter Simple +!=============================================================================== + + +!CHECK-LABEL: func.func @_QPomp_target_enter_simple() { +subroutine omp_target_enter_simple + integer :: a(1024) + !CHECK: omp.target_enter_data map((to -> {{.*}} : !fir.ref>)) + !$omp target enter data map(to: a) +end subroutine omp_target_enter_simple + +!=============================================================================== +! Target_Enter Map types +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_enter_mt() { +subroutine omp_target_enter_mt + integer :: a(1024) + integer :: b(1024) + integer :: c(1024) + integer :: d(1024) + !CHECK: omp.target_enter_data map((to -> {{.*}} : !fir.ref>), (to -> {{.*}} : !fir.ref>), (always, alloc -> {{.*}} : !fir.ref>), (to -> {{.*}} : !fir.ref>)) + !$omp target enter data map(to: a, b) map(always, alloc: c) map(to: d) +end subroutine omp_target_enter_mt + +!=============================================================================== +! `Nowait` clause +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_enter_nowait() { +subroutine omp_target_enter_nowait + integer :: a(1024) + !CHECK: omp.target_enter_data nowait map((to -> {{.*}} : !fir.ref>)) + !$omp target enter data map(to: a) nowait +end subroutine omp_target_enter_nowait + +!=============================================================================== +! `if` clause +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_enter_if() { +subroutine omp_target_enter_if + integer :: a(1024) + integer :: i + i = 5 + !CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1:.*]] : !fir.ref + !CHECK: %[[VAL_4:.*]] = arith.constant 10 : i32 + !CHECK: %[[VAL_5:.*]] = arith.cmpi slt, %[[VAL_3:.*]], %[[VAL_4:.*]] : i32 + !CHECK: omp.target_enter_data if(%[[VAL_5:.*]] : i1) map((to -> {{.*}} : !fir.ref>)) + !$omp target enter data if(i<10) map(to: a) +end subroutine omp_target_enter_if + +!=============================================================================== +! `device` clause +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_enter_device() { +subroutine omp_target_enter_device + integer :: a(1024) + !CHECK: %[[VAL_1:.*]] = arith.constant 2 : i32 + !CHECK: omp.target_enter_data device(%[[VAL_1:.*]] : i32) map((to -> {{.*}} : !fir.ref>)) + !$omp target enter data map(to: a) device(2) +end subroutine omp_target_enter_device + +!=============================================================================== +! Target_Exit Simple +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_exit_simple() { +subroutine omp_target_exit_simple + integer :: a(1024) + !CHECK: omp.target_exit_data map((from -> {{.*}} : !fir.ref>)) + !$omp target exit data map(from: a) +end subroutine omp_target_exit_simple + +!=============================================================================== +! Target_Exit Map types +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_exit_mt() { +subroutine omp_target_exit_mt + integer :: a(1024) + integer :: b(1024) + integer :: c(1024) + integer :: d(1024) + integer :: e(1024) + !CHECK: omp.target_exit_data map((from -> {{.*}} : !fir.ref>), (from -> {{.*}} : !fir.ref>), (release -> {{.*}} : !fir.ref>), (always, delete -> {{.*}} : !fir.ref>), (from -> {{.*}} : !fir.ref>)) + !$omp target exit data map(from: a,b) map(release: c) map(always, delete: d) map(from: e) +end subroutine omp_target_exit_mt + +!=============================================================================== +! `device` clause +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_exit_device() { +subroutine omp_target_exit_device + integer :: a(1024) + integer :: d + !CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_1:.*]] : !fir.ref + !CHECK: omp.target_exit_data device(%[[VAL_2:.*]] : i32) map((from -> {{.*}} : !fir.ref>)) + !$omp target exit data map(from: a) device(d) +end subroutine omp_target_exit_device