diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -1494,6 +1494,39 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// 2.14.3. Set +//===----------------------------------------------------------------------===// + +def OpenACC_SetOp : OpenACC_Op<"set", [AttrSizedOperandSegments]> { + let summary = "set operation"; + + let description = [{ + The "acc.set" operation represents the OpenACC set directive. + + Example: + + ```mlir + acc.set device_num(%dev1 : i32) + ``` + }]; + + let arguments = (ins Variadic:$deviceTypeOperands, + Optional:$defaultAsync, + Optional:$deviceNumOperand, + Optional:$ifCond); + + let assemblyFormat = [{ + oilist( + `device_type` `(` $deviceTypeOperands `:` type($deviceTypeOperands) `)` + | `default_async` `(` $defaultAsync `:` type($defaultAsync) `)` + | `device_num` `(` $deviceNumOperand `:` type($deviceNumOperand) `)` + | `if` `(` $ifCond `)` + ) attr-dict-with-keyword + }]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // 2.14.4. Update Directive //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -1124,6 +1124,22 @@ return success(); } +//===----------------------------------------------------------------------===// +// SetOp +//===----------------------------------------------------------------------===// + +LogicalResult acc::SetOp::verify() { + Operation *currOp = *this; + while ((currOp = currOp->getParentOp())) + if (isComputeOperation(currOp)) + return emitOpError("cannot be nested in a compute operation"); + if (getDeviceTypeOperands().empty() && !getDefaultAsync() && + !getDeviceNumOperand()) + return emitOpError("at least one default_async, device_num, or device_type " + "operand must appear"); + return success(); +} + //===----------------------------------------------------------------------===// // UpdateOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -493,3 +493,17 @@ // expected-error@+1 {{num_gangs expects a maximum of 3 values}} acc.parallel num_gangs(%i64value, %i64value, %i64value, %i64value : i64, i64, i64, i64) { } + +// ----- + +%i64value = arith.constant 1 : i64 +acc.parallel { +// expected-error@+1 {{'acc.set' op cannot be nested in a compute operation}} + acc.set device_type(%i64value : i64) + acc.yield +} + +// ----- + +// expected-error@+1 {{'acc.set' op at least one default_async, device_num, or device_type operand must appear}} +acc.set diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -1695,3 +1695,33 @@ // CHECK-LABEL: func.func @compute3 // CHECK: acc.declare dataOperands( + +// ----- + +%i64Value = arith.constant 1 : i64 +%i32Value = arith.constant 1 : i32 +%i32Value2 = arith.constant 2 : i32 +%idxValue = arith.constant 1 : index +%ifCond = arith.constant true +acc.set device_type(%i32Value : i32) +acc.set device_type(%i32Value, %i32Value2 : i32, i32) +acc.set device_num(%i64Value : i64) +acc.set device_num(%i32Value : i32) +acc.set device_num(%idxValue : index) +acc.set device_num(%idxValue : index) if(%ifCond) +acc.set device_num(%idxValue : index) if(%ifCond) +acc.set default_async(%i32Value : i32) + +// CHECK: [[I64VALUE:%.*]] = arith.constant 1 : i64 +// CHECK: [[I32VALUE:%.*]] = arith.constant 1 : i32 +// CHECK: [[I32VALUE2:%.*]] = arith.constant 2 : i32 +// CHECK: [[IDXVALUE:%.*]] = arith.constant 1 : index +// CHECK: [[IFCOND:%.*]] = arith.constant true +// CHECK: acc.set device_type([[I32VALUE]] : i32) +// CHECK: acc.set device_type([[I32VALUE]], [[I32VALUE2]] : i32, i32) +// CHECK: acc.set device_num([[I64VALUE]] : i64) +// CHECK: acc.set device_num([[I32VALUE]] : i32) +// CHECK: acc.set device_num([[IDXVALUE]] : index) +// CHECK: acc.set device_num([[IDXVALUE]] : index) if([[IFCOND]]) +// CHECK: acc.set device_num([[IDXVALUE]] : index) if([[IFCOND]]) +// CHECK: acc.set default_async([[I32VALUE]] : i32)