diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -438,6 +438,25 @@ }); } +// Get the NonDeferredLenParams for \p exv. +static llvm::SmallVector +getNonDeferredLenParams(const fir::ExtendedValue &exv) { + return exv.match( + [&](const fir::CharArrayBoxValue &character) + -> llvm::SmallVector { return {character.getLen()}; }, + [&](const fir::CharBoxValue &character) + -> llvm::SmallVector { return {character.getLen()}; }, + [&](const fir::MutableBoxValue &box) -> llvm::SmallVector { + return {box.nonDeferredLenParams().begin(), + box.nonDeferredLenParams().end()}; + }, + [&](const fir::BoxValue &box) -> llvm::SmallVector { + return {box.getExplicitParameters().begin(), + box.getExplicitParameters().end()}; + }, + [&](const auto &) -> llvm::SmallVector { return {}; }); +} + static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval) { auto &firOpBuilder = converter.getFirOpBuilder(); @@ -732,9 +751,12 @@ mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand; mlir::UnitAttr nowaitAttr; - llvm::SmallVector useDevicePtrOperand, useDeviceAddrOperand, - mapOperands; + llvm::SmallVector mapOperands, devicePtrOperands, + deviceAddrOperands; llvm::SmallVector mapTypes; + llvm::SmallVector useDeviceTypes; + llvm::SmallVector useDeviceLocs; + SmallVector useDeviceSymbols; auto addMapClause = [&firOpBuilder, &converter, &mapOperands, &mapTypes](const auto &mapClause, @@ -797,14 +819,27 @@ mlir::Type checkType = mapOp.getType(); if (auto refType = checkType.dyn_cast()) checkType = refType.getElementType(); - if (checkType.isa()) - TODO(currentLocation, "OMPD_target_data MapOperand BoxType"); + if (auto boxType = checkType.dyn_cast_or_null()) + if (!boxType.getElementType().isa()) + TODO(currentLocation, "OMPD_target_data MapOperand BoxType"); mapOperands.push_back(mapOp); mapTypes.push_back(mapTypeAttr); } }; + auto addUseDeviceClause = [&](const auto &useDeviceClause, auto &operands) { + genObjectList(useDeviceClause, converter, operands); + for (auto &operand : operands) { + useDeviceTypes.push_back(operand.getType()); + useDeviceLocs.push_back(operand.getLoc()); + } + for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + useDeviceSymbols.push_back(sym); + } + }; + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { mlir::Location currentLocation = converter.genLocation(clause.source); if (const auto &ifClause = @@ -825,12 +860,6 @@ deviceOperand = fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx)); } - } else if (std::get_if( - &clause.u)) { - TODO(currentLocation, "OMPD_target Use Device Ptr"); - } else if (std::get_if( - &clause.u)) { - TODO(currentLocation, "OMPD_target Use Device Addr"); } else if (const auto &threadLmtClause = std::get_if( &clause.u)) { @@ -838,6 +867,14 @@ *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx)); } else if (std::get_if(&clause.u)) { nowaitAttr = firOpBuilder.getUnitAttr(); + } else if (const auto &devPtrClause = + std::get_if( + &clause.u)) { + addUseDeviceClause(devPtrClause->v, devicePtrOperands); + } else if (const auto &devAddrClause = + std::get_if( + &clause.u)) { + addUseDeviceClause(devAddrClause->v, deviceAddrOperands); } else if (const auto &mapClause = std::get_if(&clause.u)) { addMapClause(mapClause, currentLocation); @@ -859,9 +896,24 @@ createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList); } else if (directive == llvm::omp::Directive::OMPD_target_data) { auto dataOp = firOpBuilder.create( - currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand, - useDeviceAddrOperand, mapOperands, mapTypesArrayAttr); - createBodyOfOp(dataOp, converter, currentLocation, *eval, &opClauseList); + currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands, + deviceAddrOperands, mapOperands, mapTypesArrayAttr); + + auto ®ion = dataOp.getRegion(); + firOpBuilder.createBlock(®ion, {}, useDeviceTypes, useDeviceLocs); + firOpBuilder.setInsertionPointToEnd(®ion.front()); + firOpBuilder.create(currentLocation); + firOpBuilder.setInsertionPointToStart(®ion.front()); + + unsigned argIndex = 0; + for (auto *sym : useDeviceSymbols) { + const auto &arg = region.front().getArgument(argIndex); + mlir::Value val = fir::getBase(arg); + auto extVal = converter.getSymbolExtendedValue(*sym); + converter.bindSymbol( + *sym, fir::MutableBoxValue(val, getNonDeferredLenParams(extVal), {})); + argIndex++; + } } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) { firOpBuilder.create(currentLocation, ifClauseOperand, deviceOperand, nowaitAttr, @@ -1157,7 +1209,17 @@ continue; } else if (std::get_if(&clause.u)) { // Map clause is exclusive to Target Data directives. It is handled - // as part of the DataOp creation. + // as part of the TargetOp creation. + continue; + } else if (std::get_if( + &clause.u)) { + // UseDevicePtr clause is exclusive to Target Data directives. It is + // handled as part of the TargetOp creation. + continue; + } else if (std::get_if( + &clause.u)) { + // UseDeviceAddr clause is exclusive to Target Data directives. It is + // handled as part of the TargetOp creation. continue; } else if (std::get_if( &clause.u)) { diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90 --- a/flang/test/Lower/OpenMP/target.f90 +++ b/flang/test/Lower/OpenMP/target.f90 @@ -162,3 +162,39 @@ !$omp end target !CHECK: } end subroutine omp_target_thread_limit + +!=============================================================================== +! Target `use_device_ptr` clause +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_device_ptr() { +subroutine omp_target_device_ptr + integer, pointer :: a, b + !CHECK: omp.target_data map((tofrom -> %[[VAL_0:.*]] : !fir.ref>>)) use_device_ptr(%[[VAL_0]], {{.*}} : !fir.ref>>, !fir.ref>>) + !$omp target data map(tofrom: a) use_device_ptr(a, b) + !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref>>, %[[VAL_2:.*]]: !fir.ref>>): + !CHECK: {{.*}} = fir.load %[[VAL_1]] : !fir.ref>> + a = 10 + !CHECK: {{.*}} = fir.load %[[VAL_2]] : !fir.ref>> + b = 20 + !CHECK: omp.terminator + !$omp end target data + !CHECK: } + end subroutine omp_target_device_ptr + + !=============================================================================== + ! Target `use_device_addr` clause + !=============================================================================== + + !CHECK-LABEL: func.func @_QPomp_target_device_addr() { + subroutine omp_target_device_addr + integer, pointer :: a + !CHECK: omp.target_data map((tofrom -> %[[VAL_0:.*]] : !fir.ref>>)) use_device_addr(%[[VAL_0]] : !fir.ref>>) + !$omp target data map(tofrom: a) use_device_addr(a) + !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref>>): + !CHECK: {{.*}} = fir.load %[[VAL_1]] : !fir.ref>> + a = 10 + !CHECK: omp.terminator + !$omp end target data + !CHECK: } + end subroutine omp_target_device_addr