diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -451,6 +451,8 @@ ) $region attr-dict-with-keyword }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -536,6 +538,8 @@ ) $region attr-dict-with-keyword }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -620,6 +624,8 @@ ) $region attr-dict-with-keyword }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -257,6 +257,21 @@ // ParallelOp //===----------------------------------------------------------------------===// +/// Check dataOperands for acc.parallel, acc.serial and acc.kernels. +template +static LogicalResult checkDataOperands(Op op, + const mlir::ValueRange &operands) { + for (mlir::Value operand : operands) + if (!mlir::isa( + operand.getDefiningOp())) + return op.emitError( + "expect data entry/exit operation or acc.getdeviceptr " + "as defining op"); + return success(); +} + unsigned ParallelOp::getNumDataOperands() { return getReductionOperands().size() + getCopyOperands().size() + getCopyinOperands().size() + getCopyinReadonlyOperands().size() + @@ -278,6 +293,10 @@ return getOperand(getWaitOperands().size() + numOptional + i); } +LogicalResult acc::ParallelOp::verify() { + return checkDataOperands(*this, getDataClauseOperands()); +} + //===----------------------------------------------------------------------===// // SerialOp //===----------------------------------------------------------------------===// @@ -300,6 +319,10 @@ return getOperand(getWaitOperands().size() + numOptional + i); } +LogicalResult acc::SerialOp::verify() { + return checkDataOperands(*this, getDataClauseOperands()); +} + //===----------------------------------------------------------------------===// // KernelsOp //===----------------------------------------------------------------------===// @@ -320,6 +343,10 @@ return getOperand(getWaitOperands().size() + numOptional + i); } +LogicalResult acc::KernelsOp::verify() { + return checkDataOperands(*this, getDataClauseOperands()); +} + //===----------------------------------------------------------------------===// // LoopOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -205,3 +205,26 @@ // expected-error@+1 {{operand #0 must be integer or index, but got 'f32'}} %1 = acc.bounds lowerbound(%0 : f32) +// ----- + +%value = memref.alloc() : memref<10xf32> +// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}} +acc.parallel dataOperands(%value : memref<10xf32>) { + acc.yield +} + +// ----- + +%value = memref.alloc() : memref<10xf32> +// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}} +acc.serial dataOperands(%value : memref<10xf32>) { + acc.yield +} + +// ----- + +%value = memref.alloc() : memref<10xf32> +// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}} +acc.kernels dataOperands(%value : memref<10xf32>) { + acc.yield +}