diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -229,6 +229,48 @@ let assemblyFormat = "attr-dict"; } +//===----------------------------------------------------------------------===// +// 2.6.6 Enter Data Directive +//===----------------------------------------------------------------------===// + +def OpenACC_EnterDataOp : OpenACC_Op<"enter_data", [AttrSizedOperandSegments]> { + let summary = "enter data operation"; + + let description = [{ + The "acc.enter_data" operation represents the OpenACC enter data directive. + + Example: + + ```mlir + acc.enter_data create(%d1 : memref<10xf32>) attributes {async} + ``` + }]; + + let arguments = (ins Optional:$ifCond, + Optional:$asyncOperand, + UnitAttr:$async, + Optional:$waitDevnum, + Variadic:$waitOperands, + UnitAttr:$wait, + Variadic:$copyinOperands, + Variadic:$createOperands, + Variadic:$createZeroOperands, + Variadic:$attachOperands); + + let assemblyFormat = [{ + ( `if` `(` $ifCond^ `)` )? + ( `async` `(` $asyncOperand^ `:` type($asyncOperand) `)` )? + ( `wait_devnum` `(` $waitDevnum^ `:` type($waitDevnum) `)` )? + ( `wait` `(` $waitOperands^ `:` type($waitOperands) `)` )? + ( `copyin` `(` $copyinOperands^ `:` type($copyinOperands) `)` )? + ( `create` `(` $createOperands^ `:` type($createOperands) `)` )? + ( `create_zero` `(` $createZeroOperands^ `:` + type($createZeroOperands) `)` )? + ( `attach` `(` $attachOperands^ `:` type($attachOperands) `)` )? + attr-dict-with-keyword + }]; +} + //===----------------------------------------------------------------------===// // 2.6.6 Exit Data Directive //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -682,6 +682,36 @@ return success(); } +//===----------------------------------------------------------------------===// +// DataEnterOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(acc::EnterDataOp op) { + // 2.6.6. Data Enter Directive restriction + // At least one copyin, create, or attach clause must appear on an enter data + // directive. + if (op.copyinOperands().empty() && op.createOperands().empty() && + op.createZeroOperands().empty() && op.attachOperands().empty()) + return op.emitError( + "at least one operand in copyin, create, " + "create_zero or attach must appear on the enter data operation"); + + // The async attribute represent the async clause without value. Therefore the + // attribute and operand cannot appear at the same time. + if (op.asyncOperand() && op.async()) + return op.emitError("async attribute cannot appear with asyncOperand"); + + // The wait attribute represent the wait clause without values. Therefore the + // attribute and operands cannot appear at the same time. + if (!op.waitOperands().empty() && op.wait()) + return op.emitError("wait attribute cannot appear with waitOperands"); + + if (op.waitDevnum() && op.waitOperands().empty()) + return op.emitError("wait_devnum cannot appear without waitOperands"); + + return success(); +} + //===----------------------------------------------------------------------===// // InitOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -168,14 +168,33 @@ // ----- +%cst = constant 1 : index +%value = alloc() : memref<10xf32> +// expected-error@+1 {{wait_devnum cannot appear without waitOperands}} +acc.exit_data wait_devnum(%cst: index) delete(%value : memref<10xf32>) + +// ----- + +// expected-error@+1 {{at least one operand in copyin, create, create_zero or attach must appear on the enter data operation}} +acc.enter_data attributes {async} + +// ----- + +%cst = constant 1 : index +%value = alloc() : memref<10xf32> +// expected-error@+1 {{async attribute cannot appear with asyncOperand}} +acc.enter_data async(%cst: index) create(%value : memref<10xf32>) attributes {async} + +// ----- + %cst = constant 1 : index %value = alloc() : memref<10xf32> // expected-error@+1 {{wait attribute cannot appear with waitOperands}} -acc.exit_data wait(%cst: index) delete(%value : memref<10xf32>) attributes {wait} +acc.enter_data wait(%cst: index) create(%value : memref<10xf32>) attributes {wait} // ----- %cst = constant 1 : index %value = alloc() : memref<10xf32> // expected-error@+1 {{wait_devnum cannot appear without waitOperands}} -acc.exit_data wait_devnum(%cst: index) delete(%value : memref<10xf32>) +acc.enter_data wait_devnum(%cst: index) create(%value : memref<10xf32>) diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -684,3 +684,37 @@ // CHECK: acc.exit_data async([[I64VALUE]] : i64) copyout([[ARGA]] : memref<10xf32>) // CHECK: acc.exit_data if([[IFCOND]]) copyout([[ARGA]] : memref<10xf32>) // CHECK: acc.exit_data wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) copyout([[ARGA]] : memref<10xf32>) +// ----- + + +func @testenterdataop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { + %ifCond = constant true + %i64Value = constant 1 : i64 + %i32Value = constant 1 : i32 + %idxValue = constant 1 : index + + acc.enter_data copyin(%a : memref<10xf32>) + acc.enter_data create(%a : memref<10xf32>) create_zero(%b, %c : memref<10xf32>, memref<10x10xf32>) + acc.enter_data attach(%a : memref<10xf32>) + acc.enter_data copyin(%a : memref<10xf32>) attributes {async} + acc.enter_data create(%a : memref<10xf32>) attributes {wait} + acc.enter_data async(%i64Value : i64) copyin(%a : memref<10xf32>) + acc.enter_data if(%ifCond) copyin(%a : memref<10xf32>) + acc.enter_data wait_devnum(%i64Value: i64) wait(%i32Value, %idxValue : i32, index) copyin(%a : memref<10xf32>) + + return +} + +// CHECK: func @testenterdataop([[ARGA:%.*]]: memref<10xf32>, [[ARGB:%.*]]: memref<10xf32>, [[ARGC:%.*]]: memref<10x10xf32>) { +// CHECK: [[IFCOND1:%.*]] = constant true +// CHECK: [[I64VALUE:%.*]] = constant 1 : i64 +// CHECK: [[I32VALUE:%.*]] = constant 1 : i32 +// CHECK: [[IDXVALUE:%.*]] = constant 1 : index +// CHECK: acc.enter_data copyin([[ARGA]] : memref<10xf32>) +// CHECK: acc.enter_data create([[ARGA]] : memref<10xf32>) create_zero([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) +// CHECK: acc.enter_data attach([[ARGA]] : memref<10xf32>) +// CHECK: acc.enter_data copyin([[ARGA]] : memref<10xf32>) attributes {async} +// CHECK: acc.enter_data create([[ARGA]] : memref<10xf32>) attributes {wait} +// CHECK: acc.enter_data async([[I64VALUE]] : i64) copyin([[ARGA]] : memref<10xf32>) +// CHECK: acc.enter_data if([[IFCOND]]) copyin([[ARGA]] : memref<10xf32>) +// CHECK: acc.enter_data wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) copyin([[ARGA]] : memref<10xf32>)