diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -477,6 +477,8 @@ ) $region attr-dict-with-keyword }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -562,6 +564,8 @@ ) $region attr-dict-with-keyword }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -646,6 +650,8 @@ ) $region attr-dict-with-keyword }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -285,6 +285,21 @@ // ParallelOp //===----------------------------------------------------------------------===// +/// Check dataOperands for acc.parallel, acc.serial and acc.kernels. +template +static LogicalResult checkDataOperands(Op op, + const mlir::ValueRange &operands) { + for (mlir::Value operand : operands) + if (!mlir::isa( + operand.getDefiningOp())) + return op.emitError( + "expect data entry/exit operation or acc.getdeviceptr " + "as defining op"); + return success(); +} + unsigned ParallelOp::getNumDataOperands() { return getReductionOperands().size() + getCopyOperands().size() + getCopyinOperands().size() + getCopyinReadonlyOperands().size() + @@ -306,6 +321,10 @@ return getOperand(getWaitOperands().size() + numOptional + i); } +LogicalResult acc::ParallelOp::verify() { + return checkDataOperands(*this, getDataClauseOperands()); +} + //===----------------------------------------------------------------------===// // SerialOp //===----------------------------------------------------------------------===// @@ -328,6 +347,10 @@ return getOperand(getWaitOperands().size() + numOptional + i); } +LogicalResult acc::SerialOp::verify() { + return checkDataOperands(*this, getDataClauseOperands()); +} + //===----------------------------------------------------------------------===// // KernelsOp //===----------------------------------------------------------------------===// @@ -348,6 +371,10 @@ return getOperand(getWaitOperands().size() + numOptional + i); } +LogicalResult acc::KernelsOp::verify() { + return checkDataOperands(*this, getDataClauseOperands()); +} + //===----------------------------------------------------------------------===// // LoopOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -224,3 +224,27 @@ %value = memref.alloc() : memref<10xf32> // expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}} acc.update dataOperands(%value : memref<10xf32>) + +// ----- + +%value = memref.alloc() : memref<10xf32> +// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}} +acc.parallel dataOperands(%value : memref<10xf32>) { + acc.yield +} + +// ----- + +%value = memref.alloc() : memref<10xf32> +// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}} +acc.serial dataOperands(%value : memref<10xf32>) { + acc.yield +} + +// ----- + +%value = memref.alloc() : memref<10xf32> +// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}} +acc.kernels dataOperands(%value : memref<10xf32>) { + acc.yield +}