diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -732,9 +732,12 @@ mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand; mlir::UnitAttr nowaitAttr; - llvm::SmallVector useDevicePtrOperand, useDeviceAddrOperand, - mapOperands; + llvm::SmallVector mapOperands, devicePtrOperands, + deviceAddrOperands; llvm::SmallVector mapTypes; + llvm::SmallVector useDeviceTypes; + llvm::SmallVector useDeviceLocs; + SmallVector useDeviceSymbols; auto addMapClause = [&firOpBuilder, &converter, &mapOperands, &mapTypes](const auto &mapClause, @@ -797,14 +800,27 @@ mlir::Type checkType = mapOp.getType(); if (auto refType = checkType.dyn_cast()) checkType = refType.getElementType(); - if (checkType.isa()) - TODO(currentLocation, "OMPD_target_data MapOperand BoxType"); + if (auto boxType = checkType.dyn_cast_or_null()) + if (!boxType.getElementType().isa()) + TODO(currentLocation, "OMPD_target_data MapOperand BoxType"); mapOperands.push_back(mapOp); mapTypes.push_back(mapTypeAttr); } }; + auto addUseDeviceClause = [&](const auto &useDeviceClause, auto &operands) { + genObjectList(useDeviceClause, converter, operands); + for (auto &operand : operands) { + useDeviceTypes.push_back(operand.getType()); + useDeviceLocs.push_back(operand.getLoc()); + } + for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + useDeviceSymbols.push_back(sym); + } + }; + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { mlir::Location currentLocation = converter.genLocation(clause.source); if (const auto &ifClause = @@ -825,12 +841,6 @@ deviceOperand = fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx)); } - } else if (std::get_if( - &clause.u)) { - TODO(currentLocation, "OMPD_target Use Device Ptr"); - } else if (std::get_if( - &clause.u)) { - TODO(currentLocation, "OMPD_target Use Device Addr"); } else if (const auto &threadLmtClause = std::get_if( &clause.u)) { @@ -838,6 +848,14 @@ *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx)); } else if (std::get_if(&clause.u)) { nowaitAttr = firOpBuilder.getUnitAttr(); + } else if (const auto &devPtrClause = + std::get_if( + &clause.u)) { + addUseDeviceClause(devPtrClause->v, devicePtrOperands); + } else if (const auto &devAddrClause = + std::get_if( + &clause.u)) { + addUseDeviceClause(devAddrClause->v, deviceAddrOperands); } else if (const auto &mapClause = std::get_if(&clause.u)) { addMapClause(mapClause, currentLocation); @@ -859,9 +877,22 @@ createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList); } else if (directive == llvm::omp::Directive::OMPD_target_data) { auto dataOp = firOpBuilder.create( - currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand, - useDeviceAddrOperand, mapOperands, mapTypesArrayAttr); - createBodyOfOp(dataOp, converter, currentLocation, *eval, &opClauseList); + currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands, + deviceAddrOperands, mapOperands, mapTypesArrayAttr); + + auto ®ion = dataOp.getRegion(); + firOpBuilder.createBlock(®ion, {}, useDeviceTypes, useDeviceLocs); + firOpBuilder.setInsertionPointToEnd(®ion.front()); + firOpBuilder.create(currentLocation); + firOpBuilder.setInsertionPointToStart(®ion.front()); + + unsigned argIndex = 0; + for (auto *sym : useDeviceSymbols) { + const auto &arg = region.front().getArgument(argIndex); + mlir::Value val = fir::getBase(arg); + converter.bindSymbol(*sym, val); + argIndex++; + } } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) { firOpBuilder.create(currentLocation, ifClauseOperand, deviceOperand, nowaitAttr, @@ -1157,7 +1188,17 @@ continue; } else if (std::get_if(&clause.u)) { // Map clause is exclusive to Target Data directives. It is handled - // as part of the DataOp creation. + // as part of the TargetOp creation. + continue; + } else if (std::get_if( + &clause.u)) { + // UseDevicePtr clause is exclusive to Target Data directives. It is + // handled as part of the TargetOp creation. + continue; + } else if (std::get_if( + &clause.u)) { + // UseDeviceAddr clause is exclusive to Target Data directives. It is + // handled as part of the TargetOp creation. continue; } else if (std::get_if( &clause.u)) { diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90 --- a/flang/test/Lower/OpenMP/target.f90 +++ b/flang/test/Lower/OpenMP/target.f90 @@ -162,3 +162,34 @@ !$omp end target !CHECK: } end subroutine omp_target_thread_limit + +!=============================================================================== +! Target `use_device_ptr` clause +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_device_ptr() { +subroutine omp_target_device_ptr + integer, pointer :: a, b + !CHECK: omp.target_data map((tofrom -> %[[VAL_0]] : !fir.ref>>)) use_device_ptr(%[[VAL_0:.*]], %[[VAL_1:.*]] : !fir.ref>>, !fir.ref>>) + !$omp target data map(tofrom: a) use_device_ptr(a, b) + a = 10 + b = 20 + !CHECK: omp.terminator + !$omp end target data + !CHECK: } + end subroutine omp_target_device_ptr + + !=============================================================================== + ! Target `use_device_addr` clause + !=============================================================================== + + !CHECK-LABEL: func.func @_QPomp_target_device_addr() { + subroutine omp_target_device_addr + integer, pointer :: a + !CHECK: omp.target_data map((tofrom -> %[[VAL_0]] : !fir.ref>>)) use_device_addr(%[[VAL_0:.*]] : !fir.ref>>) + !$omp target data map(tofrom: a) use_device_addr(a) + a = 10 + !CHECK: omp.terminator + !$omp end target data + !CHECK: } + end subroutine omp_target_device_addr