diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -1511,16 +1511,16 @@ ``` }]; - let arguments = (ins Variadic:$deviceTypeOperands, + let arguments = (ins Optional:$deviceType, Optional:$defaultAsync, - Optional:$deviceNumOperand, + Optional:$deviceNum, Optional:$ifCond); let assemblyFormat = [{ oilist( - `device_type` `(` $deviceTypeOperands `:` type($deviceTypeOperands) `)` + `device_type` `(` $deviceType `:` type($deviceType) `)` | `default_async` `(` $defaultAsync `:` type($defaultAsync) `)` - | `device_num` `(` $deviceNumOperand `:` type($deviceNumOperand) `)` + | `device_num` `(` $deviceNum `:` type($deviceNum) `)` | `if` `(` $ifCond `)` ) attr-dict-with-keyword }]; diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -1133,8 +1133,7 @@ while ((currOp = currOp->getParentOp())) if (isComputeOperation(currOp)) return emitOpError("cannot be nested in a compute operation"); - if (getDeviceTypeOperands().empty() && !getDefaultAsync() && - !getDeviceNumOperand()) + if (!getDeviceType() && !getDefaultAsync() && !getDeviceNum()) return emitOpError("at least one default_async, device_num, or device_type " "operand must appear"); return success(); diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -1704,12 +1704,11 @@ %idxValue = arith.constant 1 : index %ifCond = arith.constant true acc.set device_type(%i32Value : i32) -acc.set device_type(%i32Value, %i32Value2 : i32, i32) +acc.set device_type(%i32Value : i32) acc.set device_num(%i64Value : i64) acc.set device_num(%i32Value : i32) acc.set device_num(%idxValue : index) acc.set device_num(%idxValue : index) if(%ifCond) -acc.set device_num(%idxValue : index) if(%ifCond) acc.set default_async(%i32Value : i32) // CHECK: [[I64VALUE:%.*]] = arith.constant 1 : i64 @@ -1718,10 +1717,9 @@ // CHECK: [[IDXVALUE:%.*]] = arith.constant 1 : index // CHECK: [[IFCOND:%.*]] = arith.constant true // CHECK: acc.set device_type([[I32VALUE]] : i32) -// CHECK: acc.set device_type([[I32VALUE]], [[I32VALUE2]] : i32, i32) +// CHECK: acc.set device_type([[I32VALUE]] : i32) // CHECK: acc.set device_num([[I64VALUE]] : i64) // CHECK: acc.set device_num([[I32VALUE]] : i32) // CHECK: acc.set device_num([[IDXVALUE]] : index) // CHECK: acc.set device_num([[IDXVALUE]] : index) if([[IFCOND]]) -// CHECK: acc.set device_num([[IDXVALUE]] : index) if([[IFCOND]]) // CHECK: acc.set default_async([[I32VALUE]] : i32)