diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -1491,6 +1491,7 @@ addOperands(operands, operandSegments, deviceTypeOperands); addOperands(operands, operandSegments, hostOperands); addOperands(operands, operandSegments, deviceOperands); + operandSegments.push_back(0); // temporary for dataClauseOperands. mlir::acc::UpdateOp updateOp = createSimpleOp( firOpBuilder, currentLocation, operands, operandSegments); diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -84,6 +84,9 @@ def OpenACC_FirstPrivateClause : I64EnumAttrCase<"acc_firstprivate", 14>; def OpenACC_IsDevicePtrClause : I64EnumAttrCase<"acc_deviceptr", 15>; def OpenACC_GetDevicePtrClause : I64EnumAttrCase<"acc_getdeviceptr", 16>; +def OpenACC_UpdateHost : I64EnumAttrCase<"acc_update_host", 17>; +def OpenACC_UpdateSelf : I64EnumAttrCase<"acc_update_self", 18>; +def OpenACC_UpdateDevice : I64EnumAttrCase<"acc_update_device", 19>; def OpenACC_DataClauseEnum : I64EnumAttr<"DataClause", "data clauses supported by OpenACC", @@ -92,7 +95,8 @@ OpenACC_CreateClause, OpenACC_CreateZeroClause, OpenACC_DeleteClause, OpenACC_AttachClause, OpenACC_DetachClause, OpenACC_NoCreateClause, OpenACC_PrivateClause, OpenACC_FirstPrivateClause, - OpenACC_IsDevicePtrClause, OpenACC_GetDevicePtrClause + OpenACC_IsDevicePtrClause, OpenACC_GetDevicePtrClause, OpenACC_UpdateHost, + OpenACC_UpdateSelf, OpenACC_UpdateDevice, ]> { let cppNamespace = "::mlir::acc"; } @@ -286,6 +290,14 @@ let summary = "Gets device address from host address if it exists on device."; } +//===----------------------------------------------------------------------===// +// 2.14.4 device clause +//===----------------------------------------------------------------------===// +def OpenACC_UpdateDeviceOp : OpenACC_DataEntryOp<"update_device", + "mlir::acc::DataClause::acc_update_device"> { + let summary = "Represents acc update device semantics."; +} + // Data exit operation does not refer to OpenACC spec terminology, but to // terminology used in this dialect. It refers to data operations that will appear // after data or compute region. It will be used as the base of acc dialect @@ -361,6 +373,20 @@ let summary = "Represents acc detach semantics - reverse of attach."; } +//===----------------------------------------------------------------------===// +// 2.14.4 host clause +//===----------------------------------------------------------------------===// +def OpenACC_UpdateHostOp : OpenACC_DataExitOp<"update_host", + "mlir::acc::DataClause::acc_update_host"> { + let summary = "Represents acc update host semantics."; + let extraClassDeclaration = [{ + /// Check if this is an acc update self. + bool isSelf() { + return getDataClause() == acc::DataClause::acc_update_self; + } + }]; +} + //===----------------------------------------------------------------------===// // 2.5.1 parallel Construct //===----------------------------------------------------------------------===// @@ -1005,6 +1031,7 @@ Variadic:$deviceTypeOperands, Variadic:$hostOperands, Variadic:$deviceOperands, + Variadic:$dataClauseOperands, UnitAttr:$ifPresent); let extraClassDeclaration = [{ @@ -1025,6 +1052,7 @@ | `wait` `(` $waitOperands `:` type($waitOperands) `)` | `host` `(` $hostOperands `:` type($hostOperands) `)` | `device` `(` $deviceOperands `:` type($deviceOperands) `)` + | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` ) attr-dict-with-keyword }]; diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -206,6 +206,33 @@ return success(); } +//===----------------------------------------------------------------------===// +// HostOp +//===----------------------------------------------------------------------===// +LogicalResult acc::UpdateHostOp::verify() { + // Test for all clauses this operation can be decomposed from: + if (getDataClause() != acc::DataClause::acc_update_host && + getDataClause() != acc::DataClause::acc_update_self) + return emitError( + "data clause associated with host operation must match its intent" + " or specify original clause this operation was decomposed from"); + if (!getVarPtr() || !getAccPtr()) + return emitError("must have both host and device pointers"); + return success(); +} + +//===----------------------------------------------------------------------===// +// DeviceOp +//===----------------------------------------------------------------------===// +LogicalResult acc::UpdateDeviceOp::verify() { + // Test for all clauses this operation can be decomposed from: + if (getDataClause() != acc::DataClause::acc_update_device) + return emitError( + "data clause associated with device operation must match its intent" + " or specify original clause this operation was decomposed from"); + return success(); +} + template static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions = 1) { @@ -595,7 +622,8 @@ LogicalResult acc::UpdateOp::verify() { // At least one of host or device should have a value. - if (getHostOperands().empty() && getDeviceOperands().empty()) + if (getHostOperands().empty() && getDeviceOperands().empty() && + getDataClauseOperands().empty()) return emitError( "at least one value must be present in hostOperands or deviceOperands"); @@ -616,7 +644,8 @@ } unsigned UpdateOp::getNumDataOperands() { - return getHostOperands().size() + getDeviceOperands().size(); + return getHostOperands().size() + getDeviceOperands().size() + + getDataClauseOperands().size(); } Value UpdateOp::getDataOperand(unsigned i) { diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -1116,3 +1116,15 @@ // CHECK: [[DEVPTR:%.*]] = acc.getdeviceptr varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {dataClause = 4 : i64} // CHECK-NEXT: acc.exit_data dataOperands([[DEVPTR]] : memref<10xf32>) // CHECK-NEXT: acc.copyout accPtr([[DEVPTR]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {structured = false} + +// ----- + +func.func @host_device_ops(%a: memref<10xf32>) -> () { + %devptr = acc.getdeviceptr varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = 16} + acc.update_host accPtr(%devptr : memref<10xf32>) to varPtr(%a : memref<10xf32>) {structured = false} + acc.update dataOperands(%devptr : memref<10xf32>) + + %accPtr = acc.update_device varPtr(%a : memref<10xf32>) -> memref<10xf32> + acc.update dataOperands(%accPtr : memref<10xf32>) + return +}