diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -318,4 +318,44 @@ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } +//===----------------------------------------------------------------------===// +// 2.14.4. Update Directive +//===----------------------------------------------------------------------===// + +def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> { + let summary = "update operation"; + + let description = [{ + The "acc.udpate" operation represents the OpenACC update executable + directive. + As host and self clauses are synonyms, any operands for host and self are + add to $hostOperands. + + Example: + + ```mlir + acc.update device(%d1 : memref<10xf32>) attributes {async} + ``` + }]; + + let arguments = (ins Optional:$asyncOperand, + Optional:$waitDevnum, + Variadic:$waitOperands, + UnitAttr:$async, + UnitAttr:$wait, + Variadic:$hostOperands, + Variadic:$deviceOperands, + UnitAttr:$ifPresent); + + let assemblyFormat = [{ + ( `async` `(` $asyncOperand^ `:` type($asyncOperand) `)` )? + ( `wait_devnum` `(` $waitDevnum^ `:` type($waitDevnum) `)` )? + ( `wait` `(` $waitOperands^ `:` type($waitOperands) `)` )? + ( `host` `(` $hostOperands^ `:` type($hostOperands) `)` )? + ( `device` `(` $deviceOperands^ `:` type($deviceOperands) `)` )? + attr-dict-with-keyword + }]; +} + + #endif // OPENACC_OPS diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -645,6 +645,33 @@ if (dataOp.getOperands().size() == 0 && !dataOp.defaultAttr()) return dataOp.emitError("at least one operand or the default attribute " "must appear on the data operation"); + return success(); +} + +//===----------------------------------------------------------------------===// +// UpdateOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(acc::UpdateOp updateOp) { + // At least one of host or device should have a value. + if (updateOp.hostOperands().size() == 0 && + updateOp.deviceOperands().size() == 0) + return updateOp.emitError("at least one value must be present in" + " hostOperands or deviceOperands"); + + // The async attribute represent the async clause without value. Therefore the + // attribute and operand cannot appear at the same time. + if (updateOp.asyncOperand() && updateOp.async()) + return updateOp.emitError("async attribute cannot appear with " + " asyncOperand"); + + // The wait attribute represent the wait clause without values. Therefore the + // attribute and operands cannot appear at the same time. + if (updateOp.waitOperands().size() > 0 && updateOp.wait()) + return updateOp.emitError("wait attribute cannot appear with waitOperands"); + + if (updateOp.waitDevnum() && updateOp.waitOperands().size() == 0) + return updateOp.emitError("wait_devnum cannot appear without waitOperands"); return success(); } diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -75,3 +75,26 @@ } // ----- +// expected-error@+1 {{at least one value must be present in hostOperands or deviceOperands}} +acc.update + +// ----- + +%cst = constant 1 : index +%value = alloc() : memref<10xf32> +// expected-error@+1 {{wait_devnum cannot appear without waitOperands}} +acc.update wait_devnum(%cst: index) host(%value: memref<10xf32>) + +// ----- + +%cst = constant 1 : index +%value = alloc() : memref<10xf32> +// expected-error@+1 {{async attribute cannot appear with asyncOperand}} +acc.update async(%cst: index) host(%value: memref<10xf32>) attributes {async} + +// ----- + +%cst = constant 1 : index +%value = alloc() : memref<10xf32> +// expected-error@+1 {{wait attribute cannot appear with waitOperands}} +acc.update wait(%cst: index) host(%value: memref<10xf32>) attributes {wait} diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -524,3 +524,33 @@ // CHECK-NEXT: } attributes {defaultAttr = "present"} // CHECK: acc.data { // CHECK-NEXT: } attributes {defaultAttr = "none"} + +// ----- + +func @testupdateop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { + %i64Value = constant 1 : i64 + %i32Value = constant 1 : i32 + %idxValue = constant 1 : index + acc.update async(%i64Value: i64) host(%a: memref<10xf32>) + acc.update async(%i32Value: i32) host(%a: memref<10xf32>) + acc.update async(%idxValue: index) host(%a: memref<10xf32>) + acc.update wait_devnum(%i64Value: i64) wait(%i32Value, %idxValue : i32, index) host(%a: memref<10xf32>) + acc.update host(%a: memref<10xf32>) device(%b, %c : memref<10xf32>, memref<10x10xf32>) + acc.update host(%a: memref<10xf32>) device(%b, %c : memref<10xf32>, memref<10x10xf32>) attributes {async} + acc.update host(%a: memref<10xf32>) device(%b, %c : memref<10xf32>, memref<10x10xf32>) attributes {wait} + acc.update host(%a: memref<10xf32>) device(%b, %c : memref<10xf32>, memref<10x10xf32>) attributes {ifPresent} + return +} + +// CHECK: func @testupdateop([[ARGA:%.*]]: memref<10xf32>, [[ARGB:%.*]]: memref<10xf32>, [[ARGC:%.*]]: memref<10x10xf32>) { +// CHECK: [[I64VALUE:%.*]] = constant 1 : i64 +// CHECK: [[I32VALUE:%.*]] = constant 1 : i32 +// CHECK: [[IDXVALUE:%.*]] = constant 1 : index +// CHECK: acc.update async([[I64VALUE]] : i64) host([[ARGA]] : memref<10xf32>) +// CHECK: acc.update async([[I32VALUE]] : i32) host([[ARGA]] : memref<10xf32>) +// CHECK: acc.update async([[IDXVALUE]] : index) host([[ARGA]] : memref<10xf32>) +// CHECK: acc.update wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) host([[ARGA]] : memref<10xf32>) +// CHECK: acc.update host([[ARGA]] : memref<10xf32>) device([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) +// CHECK: acc.update host([[ARGA]] : memref<10xf32>) device([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) attributes {async} +// CHECK: acc.update host([[ARGA]] : memref<10xf32>) device([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) attributes {wait} +// CHECK: acc.update host([[ARGA]] : memref<10xf32>) device([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) attributes {ifPresent}