diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -357,5 +357,37 @@ }]; } +//===----------------------------------------------------------------------===// +// 2.16.3. Wait Directive +//===----------------------------------------------------------------------===// + +def OpenACC_WaitOp : OpenACC_Op<"wait", [AttrSizedOperandSegments]> { + let summary = "wait operation"; + + let description = [{ + The "acc.wait" operation represents the OpenACC wait executable + directive. + + Example: + + ```mlir + acc.wait(%value1: index) + acc.wait() async(%async1: i32) + ``` + }]; + + let arguments = (ins Variadic:$waitOperands, + Optional:$asyncOperand, + Optional:$waitDevnum, + UnitAttr:$async, + Optional:$ifCond); + + let assemblyFormat = [{ + ( `(` $waitOperands^ `:` type($waitOperands) `)` )? + ( `async` `(` $asyncOperand^ `:` type($asyncOperand) `)` )? + ( `wait_devnum` `(` $waitDevnum^ `:` type($waitDevnum) `)` )? + ( `if` `(` $ifCond^ `)` )? attr-dict-with-keyword + }]; +} #endif // OPENACC_OPS diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -676,5 +676,21 @@ return success(); } +//===----------------------------------------------------------------------===// +// WaitOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(acc::WaitOp waitOp) { + // The async attribute represent the async clause without value. Therefore the + // attribute and operand cannot appear at the same time. + if (waitOp.asyncOperand() && waitOp.async()) + return waitOp.emitError("async attribute cannot appear with asyncOperand"); + + if (waitOp.waitDevnum() && waitOp.waitOperands().empty()) + return waitOp.emitError("wait_devnum cannot appear without waitOperands"); + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -75,6 +75,7 @@ } // ----- + // expected-error@+1 {{at least one value must be present in hostOperands or deviceOperands}} acc.update @@ -98,3 +99,15 @@ %value = alloc() : memref<10xf32> // expected-error@+1 {{wait attribute cannot appear with waitOperands}} acc.update wait(%cst: index) host(%value: memref<10xf32>) attributes {wait} + +// ----- + +%cst = constant 1 : index +// expected-error@+1 {{wait_devnum cannot appear without waitOperands}} +acc.wait wait_devnum(%cst: index) + +// ----- + +%cst = constant 1 : index +// expected-error@+1 {{async attribute cannot appear with asyncOperand}} +acc.wait async(%cst: index) attributes {async} diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -554,3 +554,41 @@ // CHECK: acc.update host([[ARGA]] : memref<10xf32>) device([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) attributes {async} // CHECK: acc.update host([[ARGA]] : memref<10xf32>) device([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) attributes {wait} // CHECK: acc.update host([[ARGA]] : memref<10xf32>) device([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) attributes {ifPresent} + +// ----- + +%i64Value = constant 1 : i64 +%i32Value = constant 1 : i32 +%idxValue = constant 1 : index +%ifCond = constant true +acc.wait +acc.wait(%i64Value: i64) +acc.wait(%i32Value: i32) +acc.wait(%idxValue: index) +acc.wait(%i32Value, %idxValue : i32, index) +acc.wait async(%i64Value: i64) +acc.wait async(%i32Value: i32) +acc.wait async(%idxValue: index) +acc.wait(%i32Value: i32) async(%idxValue: index) +acc.wait(%i64Value: i64) wait_devnum(%i32Value: i32) +acc.wait attributes {async} +acc.wait(%i64Value: i64) async(%idxValue: index) wait_devnum(%i32Value: i32) +acc.wait if(%ifCond) + +// CHECK: [[I64VALUE:%.*]] = constant 1 : i64 +// CHECK: [[I32VALUE:%.*]] = constant 1 : i32 +// CHECK: [[IDXVALUE:%.*]] = constant 1 : index +// CHECK: [[IFCOND:%.*]] = constant true +// CHECK: acc.wait +// CHECK: acc.wait([[I64VALUE]] : i64) +// CHECK: acc.wait([[I32VALUE]] : i32) +// CHECK: acc.wait([[IDXVALUE]] : index) +// CHECK: acc.wait([[I32VALUE]], [[IDXVALUE]] : i32, index) +// CHECK: acc.wait async([[I64VALUE]] : i64) +// CHECK: acc.wait async([[I32VALUE]] : i32) +// CHECK: acc.wait async([[IDXVALUE]] : index) +// CHECK: acc.wait([[I32VALUE]] : i32) async([[IDXVALUE]] : index) +// CHECK: acc.wait([[I64VALUE]] : i64) wait_devnum([[I32VALUE]] : i32) +// CHECK: acc.wait attributes {async} +// CHECK: acc.wait([[I64VALUE]] : i64) async([[IDXVALUE]] : index) wait_devnum([[I32VALUE]] : i32) +// CHECK: acc.wait if([[IFCOND]])