diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -315,4 +315,44 @@ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } +//===----------------------------------------------------------------------===// +// 2.14.4. Update Directive +//===----------------------------------------------------------------------===// + +def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> { + let summary = "update operation"; + + let description = [{ + The "acc.udpate" operation represents the OpenACC update executable + directive. + As host and self clauses are synonyms, any operands for host and self are + add to $hostOperands. + + Example: + + ```mlir + acc.update device(%d1 : memref<10xf32>) attributes {async} + ``` + }]; + + let arguments = (ins Optional:$asyncOperand, + Optional:$waitDevnum, + Variadic:$waitOperands, + UnitAttr:$async, + UnitAttr:$wait, + Variadic:$hostOperands, + Variadic:$deviceOperands, + UnitAttr:$ifPresent); + + let extraClassDeclaration = [{ + static StringRef getAsyncKeyword() { return "async"; } + static StringRef getWaitDevnumKeyword() { return "wait_devnum"; } + static StringRef getWaitKeyword() { return "wait"; } + static StringRef getHostKeyword() { return "host"; } + static StringRef getDeviceKeyword() { return "device"; } + static StringRef getIfPresentAttrName() { return "ifPresent"; } + }]; +} + + #endif // OPENACC_OPS diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -795,5 +795,116 @@ return success(); } +/// Parse acc.update operation +/// operation := `acc.update` (`async` `(` value-list `)`)? +/// (`wait_devnum` `(` value `)`)? (`wait` `(` value-list `)`)? +/// (`host` `(` value-list `)`)? (`device` `(` value-list `)`)? +/// attr-dict? +static ParseResult parseUpdateOp(OpAsmParser &parser, OperationState &result) { + Builder &builder = parser.getBuilder(); + SmallVector asyncOperands, waitOperands, + hostOperands, deviceOperands; + SmallVector asyncOperandTypes, waitOperandTypes, hostOperandTypes, + deviceOperandTypes; + OptionalParseResult async, waitDevnum; + + // async(value-list)? + async = + parseOptionalOperandAndType(parser, UpdateOp::getAsyncKeyword(), result); + if (async.hasValue() && failed(*async)) + return failure(); + + // wait_devnum(value)? + waitDevnum = parseOptionalOperandAndType( + parser, UpdateOp::getWaitDevnumKeyword(), result); + if (waitDevnum.hasValue() && failed(*waitDevnum)) + return failure(); + + // wait(value-list)? + if (failed(parseOperandList(parser, UpdateOp::getWaitKeyword(), waitOperands, + waitOperandTypes, result))) + return failure(); + + // host(value-list)? + if (failed(parseOperandList(parser, UpdateOp::getHostKeyword(), hostOperands, + hostOperandTypes, result))) + return failure(); + + // device(value-list)? + if (failed(parseOperandList(parser, UpdateOp::getDeviceKeyword(), + deviceOperands, deviceOperandTypes, result))) + return failure(); + + result.addAttribute(UpdateOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {static_cast(async.hasValue() ? 1 : 0), + static_cast(waitDevnum.hasValue() ? 1 : 0), + static_cast(waitOperands.size()), + static_cast(hostOperands.size()), + static_cast(deviceOperands.size())})); + + // Additional attributes + if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &printer, UpdateOp &op) { + printer << UpdateOp::getOperationName(); + + // async()? + if (Value asyncOperand = op.asyncOperand()) + printer << " " << UpdateOp::getAsyncKeyword() << "(" << asyncOperand << ": " + << asyncOperand.getType() << ")"; + + // wait_devnum()? + if (Value waitDevnum = op.waitDevnum()) + printer << " " << UpdateOp::getWaitDevnumKeyword() << "(" << waitDevnum + << ": " << waitDevnum.getType() << ")"; + + // wait()? + printOperandList(op.waitOperands(), UpdateOp::getWaitKeyword(), printer); + + // host()? + printOperandList(op.hostOperands(), UpdateOp::getHostKeyword(), printer); + + // device()? + printOperandList(op.deviceOperands(), UpdateOp::getDeviceKeyword(), printer); + + printer.printOptionalAttrDictWithKeyword( + op.getAttrs(), {UpdateOp::getOperandSegmentSizeAttr()}); +} + +static LogicalResult verify(acc::UpdateOp updateOp) { + // At least one of host or device should have a value. + if (updateOp.hostOperands().size() == 0 && + updateOp.deviceOperands().size() == 0) { + updateOp.emitError("At least one value must be present in hostOperands or" + " deviceOperands"); + return failure(); + } + + // The async attribute represent the async clause without value. Therefore the + // attribute and operand cannot appear at the same time. + if (updateOp.asyncOperand() && updateOp.async()) { + updateOp.emitError("async attribute cannot appear with async operand"); + } + + // The wait attribute represent the wait clause without values. Therefore the + // attribute and operands cannot appear at the same time. + if (updateOp.waitOperands().size() > 0 && updateOp.wait()) { + updateOp.emitError("wait attribute cannot appear with wait operands"); + } + + if (updateOp.waitDevnum() && updateOp.waitOperands().size() == 0) { + updateOp.emitError(UpdateOp::getWaitDevnumKeyword() + + " cannot appear without " + UpdateOp::getWaitKeyword() + + " operands"); + } + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -68,3 +68,27 @@ } attributes {auto_, seq} // ----- + +// expected-error@+1 {{At least one value must be present in hostOperands or deviceOperands}} +acc.update + +// ----- + +%cst = constant 1 : index +%value = alloc() : memref<10xf32> +// expected-error@+1 {{wait_devnum cannot appear without wait operands}} +acc.update wait_devnum(%cst: index) host(%value: memref<10xf32>) + +// ----- + +%cst = constant 1 : index +%value = alloc() : memref<10xf32> +// expected-error@+1 {{async attribute cannot appear with async operand}} +acc.update async(%cst: index) host(%value: memref<10xf32>) attributes {async} + +// ----- + +%cst = constant 1 : index +%value = alloc() : memref<10xf32> +// expected-error@+1 {{wait attribute cannot appear with wait operands}} +acc.update wait(%cst: index) host(%value: memref<10xf32>) attributes {wait} diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -510,3 +510,33 @@ // CHECK-NEXT: } // CHECK: acc.data present([[ARGA]]: memref<10xf32>) copyin([[ARGB]]: memref<10xf32>) copyout([[ARGC]]: memref<10x10xf32>) { // CHECK-NEXT: } + +// ----- + +func @testupdateop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { + %i64Value = constant 1 : i64 + %i32Value = constant 1 : i32 + %idxValue = constant 1 : index + acc.update async(%i64Value: i64) host(%a: memref<10xf32>) + acc.update async(%i32Value: i32) host(%a: memref<10xf32>) + acc.update async(%idxValue: index) host(%a: memref<10xf32>) + acc.update wait_devnum(%i64Value: i64) wait(%i32Value: i32, %idxValue: index) host(%a: memref<10xf32>) + acc.update host(%a: memref<10xf32>) device(%b: memref<10xf32>, %c: memref<10x10xf32>) + acc.update host(%a: memref<10xf32>) device(%b: memref<10xf32>, %c: memref<10x10xf32>) attributes {async} + acc.update host(%a: memref<10xf32>) device(%b: memref<10xf32>, %c: memref<10x10xf32>) attributes {wait} + acc.update host(%a: memref<10xf32>) device(%b: memref<10xf32>, %c: memref<10x10xf32>) attributes {ifPresent} + return +} + +// CHECK: func @testupdateop([[ARGA:%.*]]: memref<10xf32>, [[ARGB:%.*]]: memref<10xf32>, [[ARGC:%.*]]: memref<10x10xf32>) { +// CHECK: [[I64VALUE:%.*]] = constant 1 : i64 +// CHECK: [[I32VALUE:%.*]] = constant 1 : i32 +// CHECK: [[IDXVALUE:%.*]] = constant 1 : index +// CHECK: acc.update async([[I64VALUE]]: i64) host([[ARGA]]: memref<10xf32>) +// CHECK: acc.update async([[I32VALUE]]: i32) host([[ARGA]]: memref<10xf32>) +// CHECK: acc.update async([[IDXVALUE]]: index) host([[ARGA]]: memref<10xf32>) +// CHECK: acc.update wait_devnum([[I64VALUE]]: i64) wait([[I32VALUE]]: i32, [[IDXVALUE]]: index) host([[ARGA]]: memref<10xf32>) +// CHECK: acc.update host([[ARGA]]: memref<10xf32>) device([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) +// CHECK: acc.update host([[ARGA]]: memref<10xf32>) device([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) attributes {async} +// CHECK: acc.update host([[ARGA]]: memref<10xf32>) device([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) attributes {wait} +// CHECK: acc.update host([[ARGA]]: memref<10xf32>) device([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) attributes {ifPresent}