diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -315,4 +315,44 @@ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } +//===----------------------------------------------------------------------===// +// 2.16.3. Wait Directive +//===----------------------------------------------------------------------===// + +def OpenACC_WaitOp : OpenACC_Op<"wait", [AttrSizedOperandSegments]> { + let summary = "wait operation"; + + let description = [{ + The "acc.wait" operation represents the OpenACC wait executable + directive. + + Example: + + ```mlir + acc.wait(%value1: index) + acc.wait() async(%async1: i32) + ``` + }]; + + let arguments = (ins Variadic:$waitOperands, + Optional:$asyncOperand, + Optional:$waitDevnum, + UnitAttr:$async, + Optional:$ifCond); + + let extraClassDeclaration = [{ + static StringRef getAsyncKeyword() { return "async"; } + static StringRef getWaitDevnumKeyword() { return "wait_devnum"; } + static StringRef getWaitKeyword() { return "wait"; } + static StringRef getIfKeyword() { return "if"; } + }]; + + let assemblyFormat = [{ + ( `(` $waitOperands^ `:` type($waitOperands) `)` )? + ( `async` `(` $asyncOperand^ `:` type($asyncOperand) `)` )? + ( `wait_devnum` `(` $waitDevnum^ `:` type($waitDevnum) `)` )? + ( `if` `(` $ifCond^ `)` )? attr-dict-with-keyword + }]; +} + #endif // OPENACC_OPS diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -43,12 +43,9 @@ } static ParseResult -parseOperandList(OpAsmParser &parser, StringRef keyword, +parseOperandList(OpAsmParser &parser, SmallVectorImpl &args, SmallVectorImpl &argTypes, OperationState &result) { - if (failed(parser.parseOptionalKeyword(keyword))) - return success(); - if (failed(parser.parseLParen())) return failure(); @@ -74,15 +71,30 @@ result.operands); } -static void printOperandList(Operation::operand_range operands, - StringRef listName, OpAsmPrinter &printer) { +static ParseResult +parseOperandList(OpAsmParser &parser, StringRef keyword, + SmallVectorImpl &args, + SmallVectorImpl &argTypes, OperationState &result) { + if (failed(parser.parseOptionalKeyword(keyword))) + return success(); + return parseOperandList(parser, args, argTypes, result); +} + +static void printOperandList(const Operation::operand_range &operands, + OpAsmPrinter &printer) { + printer << "("; + llvm::interleaveComma(operands, printer, [&](Value op) { + printer << op << ": " << op.getType(); + }); + printer << ")"; +} + +static void printOperandList(const Operation::operand_range &operands, + StringRef listName, OpAsmPrinter &printer) { if (operands.size() > 0) { - printer << " " << listName << "("; - llvm::interleaveComma(operands, printer, [&](Value op) { - printer << op << ": " << op.getType(); - }); - printer << ")"; + printer << " " << listName; + printOperandList(operands, printer); } } @@ -804,5 +816,21 @@ return success(); } +static LogicalResult verify(acc::WaitOp waitOp) { + // The async attribute represent the async clause without value. Therefore the + // attribute and operand cannot appear at the same time. + if (waitOp.asyncOperand() && waitOp.async()) { + waitOp.emitError("async attribute cannot appear with async operand"); + } + + if (waitOp.waitDevnum() && waitOp.waitOperands().size() == 0) { + waitOp.emitError(WaitOp::getWaitDevnumKeyword() + + " cannot appear without " + WaitOp::getWaitKeyword() + + " operands"); + } + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -68,3 +68,13 @@ } attributes {auto_, seq} // ----- + +%cst = constant 1 : index +// expected-error@+1 {{wait_devnum cannot appear without wait operands}} +acc.wait wait_devnum(%cst: index) + +// ----- + +%cst = constant 1 : index +// expected-error@+1 {{async attribute cannot appear with async operand}} +acc.wait async(%cst: index) attributes {async} diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -510,3 +510,41 @@ // CHECK-NEXT: } // CHECK: acc.data present([[ARGA]]: memref<10xf32>) copyin([[ARGB]]: memref<10xf32>) copyout([[ARGC]]: memref<10x10xf32>) { // CHECK-NEXT: } + +// ----- + +%i64Value = constant 1 : i64 +%i32Value = constant 1 : i32 +%idxValue = constant 1 : index +%ifCond = constant true +acc.wait +acc.wait(%i64Value: i64) +acc.wait(%i32Value: i32) +acc.wait(%idxValue: index) +acc.wait(%i32Value, %idxValue : i32, index) +acc.wait async(%i64Value: i64) +acc.wait async(%i32Value: i32) +acc.wait async(%idxValue: index) +acc.wait(%i32Value: i32) async(%idxValue: index) +acc.wait(%i64Value: i64) wait_devnum(%i32Value: i32) +acc.wait attributes {async} +acc.wait(%i64Value: i64) async(%idxValue: index) wait_devnum(%i32Value: i32) +acc.wait if(%ifCond) + +// CHECK: [[I64VALUE:%.*]] = constant 1 : i64 +// CHECK: [[I32VALUE:%.*]] = constant 1 : i32 +// CHECK: [[IDXVALUE:%.*]] = constant 1 : index +// CHECK: [[IFCOND:%.*]] = constant true +// CHECK: acc.wait +// CHECK: acc.wait([[I64VALUE]] : i64) +// CHECK: acc.wait([[I32VALUE]] : i32) +// CHECK: acc.wait([[IDXVALUE]] : index) +// CHECK: acc.wait([[I32VALUE]], [[IDXVALUE]] : i32, index) +// CHECK: acc.wait async([[I64VALUE]] : i64) +// CHECK: acc.wait async([[I32VALUE]] : i32) +// CHECK: acc.wait async([[IDXVALUE]] : index) +// CHECK: acc.wait([[I32VALUE]] : i32) async([[IDXVALUE]] : index) +// CHECK: acc.wait([[I64VALUE]] : i64) wait_devnum([[I32VALUE]] : i32) +// CHECK: acc.wait attributes {async} +// CHECK: acc.wait([[I64VALUE]] : i64) async([[IDXVALUE]] : index) wait_devnum([[I32VALUE]] : i32) +// CHECK: acc.wait if([[IFCOND]])