diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -315,4 +315,37 @@ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } +//===----------------------------------------------------------------------===// +// 2.16.3. Wait Directive +//===----------------------------------------------------------------------===// + +def OpenACC_WaitOp : OpenACC_Op<"wait", [AttrSizedOperandSegments]> { + let summary = "wait operation"; + + let description = [{ + The "acc.wait" operation represents the OpenACC wait executable + directive. + + Example: + + ```mlir + acc.wait(%value1: index) + acc.wait() async(%async1: i32) + ``` + }]; + + let arguments = (ins Variadic:$waitOperands, + Optional:$asyncOperand, + Optional:$waitDevnum, + UnitAttr:$async, + Optional:$ifCond); + + let extraClassDeclaration = [{ + static StringRef getAsyncKeyword() { return "async"; } + static StringRef getWaitDevnumKeyword() { return "wait_devnum"; } + static StringRef getWaitKeyword() { return "wait"; } + static StringRef getIfKeyword() { return "if"; } + }]; +} + #endif // OPENACC_OPS diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -43,12 +43,9 @@ } static ParseResult -parseOperandList(OpAsmParser &parser, StringRef keyword, +parseOperandList(OpAsmParser &parser, SmallVectorImpl &args, SmallVectorImpl &argTypes, OperationState &result) { - if (failed(parser.parseOptionalKeyword(keyword))) - return success(); - if (failed(parser.parseLParen())) return failure(); @@ -74,15 +71,30 @@ result.operands); } -static void printOperandList(Operation::operand_range operands, - StringRef listName, OpAsmPrinter &printer) { +static ParseResult +parseOperandList(OpAsmParser &parser, StringRef keyword, + SmallVectorImpl &args, + SmallVectorImpl &argTypes, OperationState &result) { + if (failed(parser.parseOptionalKeyword(keyword))) + return success(); + + return parseOperandList(parser, args, argTypes, result); +} + +static void printOperandList(const Operation::operand_range &operands, + OpAsmPrinter &printer) { + printer << "("; + llvm::interleaveComma(operands, printer, [&](Value op) { + printer << op << ": " << op.getType(); + }); + printer << ")"; +} +static void printOperandList(const Operation::operand_range &operands, + StringRef listName, OpAsmPrinter &printer) { if (operands.size() > 0) { - printer << " " << listName << "("; - llvm::interleaveComma(operands, printer, [&](Value op) { - printer << op << ": " << op.getType(); - }); - printer << ")"; + printer << " " << listName; + printOperandList(operands, printer); } } @@ -795,5 +807,91 @@ return success(); } +/// Parse acc.wait operation +/// operation := `acc.wait` (`(` value-list `)`)? (`async` `(` value-list `)`)? +/// (`wait_devnum` `(` value `)`)? (`if` `(` value `)` )? +/// attr-dict? +static ParseResult parseWaitOp(OpAsmParser &parser, OperationState &result) { + Builder &builder = parser.getBuilder(); + SmallVector asyncOperands, waitOperands, + hostOperands, deviceOperands; + SmallVector asyncOperandTypes, waitOperandTypes, hostOperandTypes, + deviceOperandTypes; + OptionalParseResult async, waitDevnum; + OpAsmParser::OperandType ifCond; + Type i1Type = builder.getI1Type(); + bool hasIfCond = false; + + // (value-list)? + if (failed(parseOperandList(parser, waitOperands, waitOperandTypes, result))) + return failure(); + + // async(value-list)? + async = + parseOptionalOperandAndType(parser, WaitOp::getAsyncKeyword(), result); + if (async.hasValue() && failed(*async)) + return failure(); + + // wait_devnum(value)? + waitDevnum = parseOptionalOperandAndType( + parser, WaitOp::getWaitDevnumKeyword(), result); + if (waitDevnum.hasValue() && failed(*waitDevnum)) + return failure(); + + // if()? + if (failed(parseOptionalOperand(parser, WaitOp::getIfKeyword(), ifCond, + i1Type, hasIfCond, result))) + return failure(); + + result.addAttribute(WaitOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {static_cast(waitOperands.size()), + static_cast(async.hasValue() ? 1 : 0), + static_cast(waitDevnum.hasValue() ? 1 : 0), + static_cast(hasIfCond ? 1 : 0)})); + + // Additional attributes + if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &printer, WaitOp &op) { + printer << WaitOp::getOperationName(); + + // wait()? + printOperandList(op.waitOperands(), printer); + + // async()? + if (Value asyncOperand = op.asyncOperand()) + printer << " " << WaitOp::getAsyncKeyword() << "(" << asyncOperand << ": " + << asyncOperand.getType() << ")"; + + // wait_devnum()? + if (Value waitDevnum = op.waitDevnum()) + printer << " " << WaitOp::getWaitDevnumKeyword() << "(" << waitDevnum + << ": " << waitDevnum.getType() << ")"; + + printer.printOptionalAttrDictWithKeyword( + op.getAttrs(), {WaitOp::getOperandSegmentSizeAttr()}); +} + +static LogicalResult verify(acc::WaitOp waitOp) { + // The async attribute represent the async clause without value. Therefore the + // attribute and operand cannot appear at the same time. + if (waitOp.asyncOperand() && waitOp.async()) { + waitOp.emitError("async attribute cannot appear with async operand"); + } + + if (waitOp.waitDevnum() && waitOp.waitOperands().size() == 0) { + waitOp.emitError(WaitOp::getWaitDevnumKeyword() + + " cannot appear without " + WaitOp::getWaitKeyword() + + " operands"); + } + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -68,3 +68,13 @@ } attributes {auto_, seq} // ----- + +%cst = constant 1 : index +// expected-error@+1 {{wait_devnum cannot appear without wait operands}} +acc.wait() wait_devnum(%cst: index) + +// ----- + +%cst = constant 1 : index +// expected-error@+1 {{async attribute cannot appear with async operand}} +acc.wait() async(%cst: index) attributes {async} \ No newline at end of file diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -510,3 +510,37 @@ // CHECK-NEXT: } // CHECK: acc.data present([[ARGA]]: memref<10xf32>) copyin([[ARGB]]: memref<10xf32>) copyout([[ARGC]]: memref<10x10xf32>) { // CHECK-NEXT: } + +// ----- + +%i64Value = constant 1 : i64 +%i32Value = constant 1 : i32 +%idxValue = constant 1 : index +acc.wait() +acc.wait(%i64Value: i64) +acc.wait(%i32Value: i32) +acc.wait(%idxValue: index) +acc.wait(%i32Value: i32, %idxValue: index) +acc.wait() async(%i64Value: i64) +acc.wait() async(%i32Value: i32) +acc.wait() async(%idxValue: index) +acc.wait(%i32Value: i32) async(%idxValue: index) +acc.wait(%i64Value: i64) wait_devnum(%i32Value: i32) +acc.wait() attributes {async} +acc.wait(%i64Value: i64) async(%idxValue: index) wait_devnum(%i32Value: i32) + +// CHECK: [[I64VALUE:%.*]] = constant 1 : i64 +// CHECK: [[I32VALUE:%.*]] = constant 1 : i32 +// CHECK: [[IDXVALUE:%.*]] = constant 1 : index +// CHECK: acc.wait() +// CHECK: acc.wait([[I64VALUE]]: i64) +// CHECK: acc.wait([[I32VALUE]]: i32) +// CHECK: acc.wait([[IDXVALUE]]: index) +// CHECK: acc.wait([[I32VALUE]]: i32, [[IDXVALUE]]: index) +// CHECK: acc.wait() async([[I64VALUE]]: i64) +// CHECK: acc.wait() async([[I32VALUE]]: i32) +// CHECK: acc.wait() async([[IDXVALUE]]: index) +// CHECK: acc.wait([[I32VALUE]]: i32) async([[IDXVALUE]]: index) +// CHECK: acc.wait([[I64VALUE]]: i64) wait_devnum([[I32VALUE]]: i32) +// CHECK: acc.wait() attributes {async} +// CHECK: acc.wait([[I64VALUE]]: i64) async([[IDXVALUE]]: index) wait_devnum([[I32VALUE]]: i32)