diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -45,10 +45,10 @@ } def IsDeviceAttr : OpenMP_Attr<"IsDevice", "isdevice"> { - let parameters = (ins + let parameters = (ins "bool":$is_device ); - + let assemblyFormat = "`<` struct(params) `>`"; } @@ -57,14 +57,14 @@ //===----------------------------------------------------------------------===// def FlagsAttr : OpenMP_Attr<"Flags", "flags"> { - let parameters = (ins + let parameters = (ins DefaultValuedParameter<"uint32_t", "0">:$debug_kind, DefaultValuedParameter<"bool", "false">:$assume_teams_oversubscription, DefaultValuedParameter<"bool", "false">:$assume_threads_oversubscription, DefaultValuedParameter<"bool", "false">:$assume_no_thread_state, DefaultValuedParameter<"bool", "false">:$assume_no_nested_parallelism ); - + let assemblyFormat = "`<` struct(params) `>`"; } @@ -1065,24 +1065,29 @@ The optional $nowait elliminates the implicit barrier so the parent task can make progress even if the target task is not yet completed. - TODO: map, is_device_ptr, depend, defaultmap, in_reduction + TODO: is_device_ptr, depend, defaultmap, in_reduction }]; let arguments = (ins Optional:$if_expr, Optional:$device, Optional:$thread_limit, - UnitAttr:$nowait); + UnitAttr:$nowait, + Variadic:$map_operands, + OptionalAttr:$map_types); let regions = (region AnyRegion:$region); let assemblyFormat = [{ oilist( `if` `(` $if_expr `)` - | `device` `(` $device `:` type($device) `)` - | `thread_limit` `(` $thread_limit `:` type($thread_limit) `)` - | `nowait` $nowait + | `device` `(` $device `:` type($device) `)` + | `thread_limit` `(` $thread_limit `:` type($thread_limit) `)` + | `nowait` $nowait + | `map` `(` custom($map_operands, type($map_operands), $map_types) `)` ) $region attr-dict }]; + + let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -622,7 +622,7 @@ } //===----------------------------------------------------------------------===// -// Parser, printer and verifier for Target Data +// Parser, printer and verifier for Target //===----------------------------------------------------------------------===// /// Parses a Map Clause. /// @@ -756,7 +756,7 @@ } static LogicalResult verifyMapClause(Operation *op, OperandRange map_operands, - ArrayAttr map_types) { + std::optional map_types) { // Helper function to get bitwise AND of `value` and 'flag' auto bitAnd = [](int64_t value, llvm::omp::OpenMPOffloadMappingFlags flag) -> bool { @@ -765,10 +765,13 @@ std::underlying_type_t>( flag); }; - if (map_operands.size() != map_types.size()) + if (!map_types.has_value()) + return success(); + + if (map_operands.size() != map_types->size()) return failure(); - for (const auto &mapTypeOp : map_types) { + for (const auto &mapTypeOp : *map_types) { int64_t mapTypeBits = 0x00; if (!mapTypeOp.isa()) @@ -783,12 +786,14 @@ bool del = bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); - if (isa(op) && del) - return failure(); + if ((isa(op) || isa(op)) && del) + return emitError(op->getLoc(), + "to, from, tofrom and alloc map types are permitted"); if (isa(op) && (from || del)) - return failure(); + return emitError(op->getLoc(), "to and alloc map types are permitted"); if (isa(op) && to) - return failure(); + return emitError(op->getLoc(), + "from, release and delete map types are permitted"); } return success(); @@ -806,6 +811,10 @@ return verifyMapClause(*this, getMapOperands(), getMapTypes()); } +LogicalResult TargetOp::verify() { + return verifyMapClause(*this, getMapOperands(), getMapTypes()); +} + //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1560,4 +1560,36 @@ return } +// ----- + +func.func @omp_target(%map1: memref) { + // expected-error @below {{to, from, tofrom and alloc map types are permitted}} + omp.target map((delete -> %map1 : memref)){} + return +} + +// ----- + +func.func @omp_target_data(%map1: memref) { + // expected-error @below {{to, from, tofrom and alloc map types are permitted}} + omp.target_data map((delete -> %map1 : memref)){} + return +} + +// ----- + +func.func @omp_target_enter_data(%map1: memref) { + // expected-error @below {{to and alloc map types are permitted}} + omp.target_enter_data map((from -> %map1 : memref)){} + return +} + +// ----- + +func.func @omp_target_exit_data(%map1: memref) { + // expected-error @below {{from, release and delete map types are permitted}} + omp.target_exit_data map((to -> %map1 : memref)){} + return +} + llvm.mlir.global internal @_QFsubEx() : i32 diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -480,14 +480,21 @@ } // CHECK-LABEL: omp_target -func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32) -> () { +func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1: memref, %map2: memref) -> () { // Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait. // CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait "omp.target"(%if_cond, %device, %num_threads) ({ // CHECK: omp.terminator omp.terminator - }) {nowait, operand_segment_sizes = array} : ( i1, si32, i32 ) -> () + }) {nowait, operand_segment_sizes = array} : ( i1, si32, i32 ) -> () + + // Test with optional map clause. + // CHECK: omp.target map((tofrom -> %{{.*}} : memref), (alloc -> %{{.*}} : memref)) { + omp.target map((tofrom -> %map1 : memref), (alloc -> %map2 : memref)){} + + // CHECK: omp.target map((to -> %{{.*}} : memref), (always, from -> %{{.*}} : memref)) { + omp.target map((to -> %map1 : memref), (always, from -> %map2 : memref)){} // CHECK: omp.barrier omp.barrier