diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -732,8 +732,8 @@ mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand; mlir::UnitAttr nowaitAttr; - llvm::SmallVector useDevicePtrOperand, useDeviceAddrOperand, - mapOperands; + llvm::SmallVector mapOperands, devicePtrDeviceOperands, + devicePtrHostOperands, deviceAddrDeviceOperands, deviceAddrHostOperands; llvm::SmallVector mapTypes; auto addMapClause = [&firOpBuilder, &converter, &mapOperands, @@ -797,13 +797,24 @@ mlir::Type checkType = mapOp.getType(); if (auto refType = checkType.dyn_cast()) checkType = refType.getElementType(); - if (checkType.isa()) - TODO(currentLocation, "OMPD_target_data MapOperand BoxType"); + if (auto boxType = checkType.dyn_cast_or_null()) + if (!boxType.getElementType().isa()) + TODO(currentLocation, "OMPD_target_data MapOperand BoxType"); mapOperands.push_back(mapOp); mapTypes.push_back(mapTypeAttr); } }; + SmallVector dev_ptr_args; + + auto addUseDeviceClause = [&](const auto &useDeviceClause, + auto &deviceOperands, auto &hostOperands) { + genObjectList(useDeviceClause, converter, hostOperands); + for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + dev_ptr_args.push_back(sym); + } + }; for (const Fortran::parser::OmpClause &clause : opClauseList.v) { mlir::Location currentLocation = converter.genLocation(clause.source); @@ -825,12 +836,6 @@ deviceOperand = fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx)); } - } else if (std::get_if( - &clause.u)) { - TODO(currentLocation, "OMPD_target Use Device Ptr"); - } else if (std::get_if( - &clause.u)) { - TODO(currentLocation, "OMPD_target Use Device Addr"); } else if (const auto &threadLmtClause = std::get_if( &clause.u)) { @@ -838,6 +843,16 @@ *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx)); } else if (std::get_if(&clause.u)) { nowaitAttr = firOpBuilder.getUnitAttr(); + } else if (const auto &devPtrClause = + std::get_if( + &clause.u)) { + addUseDeviceClause(devPtrClause->v, devicePtrDeviceOperands, + devicePtrHostOperands); + } else if (const auto &devAddrClause = + std::get_if( + &clause.u)) { + // addUseDeviceClause(devAddrClause->v, deviceAddrDeviceOperands, + // deviceAddrHostOperands); } else if (const auto &mapClause = std::get_if(&clause.u)) { addMapClause(mapClause, currentLocation); @@ -859,9 +874,15 @@ createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList); } else if (directive == llvm::omp::Directive::OMPD_target_data) { auto dataOp = firOpBuilder.create( - currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand, - useDeviceAddrOperand, mapOperands, mapTypesArrayAttr); - createBodyOfOp(dataOp, converter, currentLocation, *eval, &opClauseList); + currentLocation, ifClauseOperand, deviceOperand, mapOperands, + mapTypesArrayAttr, devicePtrDeviceOperands, devicePtrHostOperands, + deviceAddrDeviceOperands, deviceAddrHostOperands); + createBodyOfOp(dataOp, converter, currentLocation, *eval, &opClauseList, + dev_ptr_args); + // auto ®ion = dataOp.getRegion(); + // for(auto &arg : region.getArguments()){ + // dataOp.getUseDevicePtrDeviceMutable().append(arg); + // } } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) { firOpBuilder.create(currentLocation, ifClauseOperand, deviceOperand, nowaitAttr, @@ -1157,7 +1178,17 @@ continue; } else if (std::get_if(&clause.u)) { // Map clause is exclusive to Target Data directives. It is handled - // as part of the DataOp creation. + // as part of the TargetOp creation. + continue; + } else if (std::get_if( + &clause.u)) { + // UseDevicePtr clause is exclusive to Target Data directives. It is + // handled as part of the TargetOp creation. + continue; + } else if (std::get_if( + &clause.u)) { + // UseDeviceAddr clause is exclusive to Target Data directives. It is + // handled as part of the TargetOp creation. continue; } else if (std::get_if( &clause.u)) { diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90 --- a/flang/test/Lower/OpenMP/target.f90 +++ b/flang/test/Lower/OpenMP/target.f90 @@ -162,3 +162,34 @@ !$omp end target !CHECK: } end subroutine omp_target_thread_limit + +!=============================================================================== +! Target `use_device_ptr` clause +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_device_ptr() { +subroutine omp_target_device_ptr + integer, pointer :: a, b + !CHECK: omp.target_data map((tofrom -> %[[VAL_0]] : !fir.ref>>)) use_device_ptr(%[[VAL_0:.*]], %[[VAL_1:.*]] : !fir.ref>>, !fir.ref>>) + !$omp target data map(tofrom: a) use_device_ptr(a, b) + a = 10 + b = 20 + !CHECK: omp.terminator + !$omp end target data + !CHECK: } + end subroutine omp_target_device_ptr + + !=============================================================================== + ! Target `use_device_addr` clause + !=============================================================================== + + !CHECK-LABEL: func.func @_QPomp_target_device_addr() { + subroutine omp_target_device_addr + integer, pointer :: a + !CHECK: omp.target_data map((tofrom -> %[[VAL_0]] : !fir.ref>>)) use_device_addr(%[[VAL_0:.*]] : !fir.ref>>) + !$omp target data map(tofrom: a) use_device_addr(a) + a = 10 + !CHECK: omp.terminator + !$omp end target data + !CHECK: } + end subroutine omp_target_device_addr diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -978,19 +978,21 @@ let arguments = (ins Optional:$if_expr, Optional:$device, - Variadic:$use_device_ptr, - Variadic:$use_device_addr, Variadic:$map_operands, - I64ArrayAttr:$map_types); + I64ArrayAttr:$map_types, + Variadic:$use_device_ptr_device, + Variadic:$use_device_ptr_host, + Variadic:$use_device_addr_device, + Variadic:$use_device_addr_host); let regions = (region AnyRegion:$region); let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` - | `device` `(` $device `:` type($device) `)` - | `use_device_ptr` `(` $use_device_ptr `:` type($use_device_ptr) `)` - | `use_device_addr` `(` $use_device_addr `:` type($use_device_addr) `)`) + | `device` `(` $device `:` type($device) `)`) `map` `(` custom($map_operands, type($map_operands), $map_types) `)` + oilist(`use_device_ptr` `(` custom($use_device_ptr_device, $use_device_ptr_host, type($use_device_ptr_device), type($use_device_ptr_host)) `)` + | `use_device_addr` `(` custom($use_device_addr_device, $use_device_addr_host, type($use_device_addr_device), type($use_device_addr_host)) `)`) $region attr-dict }]; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -827,8 +827,65 @@ return success(); } +/// Parses a UseDevice Clause. +static ParseResult parseUseDeviceClause( + OpAsmParser &parser, + SmallVectorImpl &device_operands, + SmallVectorImpl &host_operands, + SmallVectorImpl &device_types, SmallVectorImpl &host_types) { + OpAsmParser::UnresolvedOperand device_addr_arg, host_addr_arg; + Type device_addr_type; + + auto parseUseDevice = [&]() -> ParseResult { + if (parser.parseLParen() || parser.parseOperand(device_addr_arg) || + parser.parseArrow() || parser.parseOperand(host_addr_arg) || + parser.parseColon() || parser.parseType(device_addr_type) || + parser.parseRParen()) + return failure(); + device_operands.push_back(device_addr_arg); + host_operands.push_back(host_addr_arg); + device_types.push_back(device_addr_type); + host_types.push_back(device_addr_type); + return success(); + }; + + if (parser.parseCommaSeparatedList(parseUseDevice)) + return failure(); + return success(); +} + +/// Prints a UseDevice Clause. +static void printUseDeviceClause(OpAsmPrinter &p, Operation *op, + OperandRange device_operands, + OperandRange host_operands, + TypeRange device_types, TypeRange host_types) { + assert(device_operands.size() == host_operands.size()); + for (unsigned i = 0, e = device_operands.size(); i < e; i++) { + p << '(' << device_operands[i] << " -> " << host_operands[i] << " : " + << device_types[i] << ')'; + if (i + 1 < e) + p << ", "; + } +} + +/// Verifies a UseDevice Clause. +static LogicalResult verifyUseDeviceClause(Operation *op, + OperandRange ptr_device, + OperandRange ptr_host, + OperandRange addr_device, + OperandRange addr_host) { + if (ptr_device.size() != ptr_host.size()) + return emitError(op->getLoc(), "use_device_ptr clause not well formed"); + if (addr_device.size() != addr_host.size()) + return emitError(op->getLoc(), "use_device_addr clause not well formed"); + return success(); +} + LogicalResult DataOp::verify() { - return verifyMapClause(*this, getMapOperands(), getMapTypes()); + verifyMapClause(*this, getMapOperands(), getMapTypes()); + verifyUseDeviceClause(*this, getUseDevicePtrDevice(), getUseDevicePtrHost(), + getUseDeviceAddrDevice(), getUseDeviceAddrDevice()); + return success(); } LogicalResult EnterDataOp::verify() { diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1416,8 +1416,8 @@ LogicalResult result = llvm::TypeSwitch(op) .Case([&](omp::DataOp dataOp) { - if (!dataOp.getUseDeviceAddr().empty() || - !dataOp.getUseDevicePtr().empty()) + if (!dataOp.getUseDeviceAddrDevice().empty() || + !dataOp.getUseDevicePtrDevice().empty()) return failure(); if (auto ifExprVar = dataOp.getIfExpr())