diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -723,6 +723,48 @@ } } +static void createBodyOfTargetOp( + Fortran::lower::AbstractConverter &converter, mlir::omp::DataOp &dataOp, + const llvm::SmallVector &useDeviceTypes, + const llvm::SmallVector &useDeviceLocs, + const SmallVector &useDeviceSymbols, + const mlir::Location ¤tLocation) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Region ®ion = dataOp.getRegion(); + + firOpBuilder.createBlock(®ion, {}, useDeviceTypes, useDeviceLocs); + firOpBuilder.create(currentLocation); + firOpBuilder.setInsertionPointToStart(®ion.front()); + + unsigned argIndex = 0; + for (auto *sym : useDeviceSymbols) { + const mlir::BlockArgument &arg = region.front().getArgument(argIndex); + mlir::Value val = fir::getBase(arg); + fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym); + if (auto refType = val.getType().dyn_cast()) { + if (fir::isa_builtin_cptr_type(refType.getElementType())) { + converter.bindSymbol(*sym, val); + } else { + extVal.match( + [&](const fir::MutableBoxValue &mbv) { + converter.bindSymbol( + *sym, + fir::MutableBoxValue( + val, fir::factory::getNonDeferredLenParams(extVal), {})); + }, + [&](const auto &) { + TODO(converter.getCurrentLocation(), + "use_device clause operand unsupported type"); + }); + } + } else { + TODO(converter.getCurrentLocation(), + "use_device clause operand unsupported type"); + } + argIndex++; + } +} + static void createTargetOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpClauseList &opClauseList, const llvm::omp::Directive &directive, @@ -732,13 +774,24 @@ mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand; mlir::UnitAttr nowaitAttr; - llvm::SmallVector useDevicePtrOperand, useDeviceAddrOperand, - mapOperands; + llvm::SmallVector mapOperands, devicePtrOperands, + deviceAddrOperands; llvm::SmallVector mapTypes; + llvm::SmallVector useDeviceTypes; + llvm::SmallVector useDeviceLocs; + SmallVector useDeviceSymbols; + + /// Check for unsupported map operand types. + auto checkType = [](auto currentLocation, mlir::Type type) { + if (auto refType = type.dyn_cast()) + type = refType.getElementType(); + if (auto boxType = type.dyn_cast_or_null()) + if (!boxType.getElementType().isa()) + TODO(currentLocation, "OMPD_target_data MapOperand BoxType"); + }; - auto addMapClause = [&firOpBuilder, &converter, &mapOperands, - &mapTypes](const auto &mapClause, - mlir::Location ¤tLocation) { + auto addMapClause = [&](const auto &mapClause, + mlir::Location ¤tLocation) { auto mapType = std::get( std::get>(mapClause->v.t) ->t); @@ -793,18 +846,25 @@ converter, mapOperand); for (mlir::Value mapOp : mapOperand) { - /// Check for unsupported map operand types. - mlir::Type checkType = mapOp.getType(); - if (auto refType = checkType.dyn_cast()) - checkType = refType.getElementType(); - if (checkType.isa()) - TODO(currentLocation, "OMPD_target_data MapOperand BoxType"); - + checkType(mapOp.getLoc(), mapOp.getType()); mapOperands.push_back(mapOp); mapTypes.push_back(mapTypeAttr); } }; + auto addUseDeviceClause = [&](const auto &useDeviceClause, auto &operands) { + genObjectList(useDeviceClause, converter, operands); + for (auto &operand : operands) { + checkType(operand.getLoc(), operand.getType()); + useDeviceTypes.push_back(operand.getType()); + useDeviceLocs.push_back(operand.getLoc()); + } + for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + useDeviceSymbols.push_back(sym); + } + }; + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { mlir::Location currentLocation = converter.genLocation(clause.source); if (const auto &ifClause = @@ -825,12 +885,6 @@ deviceOperand = fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx)); } - } else if (std::get_if( - &clause.u)) { - TODO(currentLocation, "OMPD_target Use Device Ptr"); - } else if (std::get_if( - &clause.u)) { - TODO(currentLocation, "OMPD_target Use Device Addr"); } else if (const auto &threadLmtClause = std::get_if( &clause.u)) { @@ -838,6 +892,14 @@ *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx)); } else if (std::get_if(&clause.u)) { nowaitAttr = firOpBuilder.getUnitAttr(); + } else if (const auto &devPtrClause = + std::get_if( + &clause.u)) { + addUseDeviceClause(devPtrClause->v, devicePtrOperands); + } else if (const auto &devAddrClause = + std::get_if( + &clause.u)) { + addUseDeviceClause(devAddrClause->v, deviceAddrOperands); } else if (const auto &mapClause = std::get_if(&clause.u)) { addMapClause(mapClause, currentLocation); @@ -859,9 +921,10 @@ createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList); } else if (directive == llvm::omp::Directive::OMPD_target_data) { auto dataOp = firOpBuilder.create( - currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand, - useDeviceAddrOperand, mapOperands, mapTypesArrayAttr); - createBodyOfOp(dataOp, converter, currentLocation, *eval, &opClauseList); + currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands, + deviceAddrOperands, mapOperands, mapTypesArrayAttr); + createBodyOfTargetOp(converter, dataOp, useDeviceTypes, useDeviceLocs, + useDeviceSymbols, currentLocation); } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) { firOpBuilder.create(currentLocation, ifClauseOperand, deviceOperand, nowaitAttr, @@ -1157,7 +1220,17 @@ continue; } else if (std::get_if(&clause.u)) { // Map clause is exclusive to Target Data directives. It is handled - // as part of the DataOp creation. + // as part of the TargetOp creation. + continue; + } else if (std::get_if( + &clause.u)) { + // UseDevicePtr clause is exclusive to Target Data directives. It is + // handled as part of the TargetOp creation. + continue; + } else if (std::get_if( + &clause.u)) { + // UseDeviceAddr clause is exclusive to Target Data directives. It is + // handled as part of the TargetOp creation. continue; } else if (std::get_if( &clause.u)) { diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90 --- a/flang/test/Lower/OpenMP/target.f90 +++ b/flang/test/Lower/OpenMP/target.f90 @@ -162,3 +162,39 @@ !$omp end target !CHECK: } end subroutine omp_target_thread_limit + +!=============================================================================== +! Target `use_device_ptr` clause +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_device_ptr() { +subroutine omp_target_device_ptr + use iso_c_binding, only : c_ptr, c_loc + type(c_ptr) :: a + integer, target :: b + !CHECK: omp.target_data map((tofrom -> %[[VAL_0:.*]] : !fir.ref>)) use_device_ptr(%[[VAL_0]] : !fir.ref>) + !$omp target data map(tofrom: a) use_device_ptr(a) + !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref>): + !CHECK: {{.*}} = fir.coordinate_of %[[VAL_1:.*]], {{.*}} : (!fir.ref>, !fir.field) -> !fir.ref + a = c_loc(b) + !CHECK: omp.terminator + !$omp end target data + !CHECK: } +end subroutine omp_target_device_ptr + + !=============================================================================== + ! Target `use_device_addr` clause + !=============================================================================== + + !CHECK-LABEL: func.func @_QPomp_target_device_addr() { + subroutine omp_target_device_addr + integer, pointer :: a + !CHECK: omp.target_data map((tofrom -> %[[VAL_0:.*]] : !fir.ref>>)) use_device_addr(%[[VAL_0]] : !fir.ref>>) + !$omp target data map(tofrom: a) use_device_addr(a) + !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref>>): + !CHECK: {{.*}} = fir.load %[[VAL_1]] : !fir.ref>> + a = 10 + !CHECK: omp.terminator + !$omp end target data + !CHECK: } +end subroutine omp_target_device_addr