diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -548,6 +548,62 @@ let hasVerifier = 1; } +def TaskGroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments, + ReductionClauseInterface, + AutomaticAllocationScope]> { + let summary = "taskgroup construct"; + let description = [{ + The taskgroup construct specifies a wait on completion of child tasks of the + current task and their descendent tasks. + + When a thread encounters a taskgroup construct, it starts executing the + region. All child tasks generated in the taskgroup region and all of their + descendants that bind to the same parallel region as the taskgroup region + are part of the taskgroup set associated with the taskgroup region. There is + an implicit task scheduling point at the end of the taskgroup region. The + current task is suspended at the task scheduling point until all tasks in + the taskgroup set complete execution. + + The `task_reduction` clause specifies a reduction among tasks. For each list + item, the number of copies is unspecified. Any copies associated with the + reduction are initialized before they are accessed by the tasks + participating in the reduction. After the end of the region, the original + list item contains the result of the reduction. + + The `allocators_vars` and `allocate_vars` parameters are a variadic list of + values that specify the memory allocator to be used to obtain storage for + private values. + }]; + + let arguments = (ins Variadic:$task_reduction_vars, + OptionalAttr:$task_reductions, + Variadic:$allocate_vars, + Variadic:$allocators_vars); + + let regions = (region AnyRegion:$region); + + let assemblyFormat = [{ + oilist(`task_reduction` `(` + custom( + $task_reduction_vars, type($task_reduction_vars), $task_reductions + ) `)` + |`allocate` `(` + custom( + $allocate_vars, type($allocate_vars), + $allocators_vars, type($allocators_vars) + ) `)` + ) $region attr-dict + }]; + + let extraClassDeclaration = [{ + /// Returns the reduction variables + operand_range getReductionVars() { return task_reduction_vars(); } + }]; + + let hasVerifier = 1; + +} + //===----------------------------------------------------------------------===// // 2.10.4 taskyield Construct //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -728,6 +728,14 @@ return verifyReductionVarList(*this, in_reductions(), in_reduction_vars()); } +//===----------------------------------------------------------------------===// +// TaskGroupOp +//===----------------------------------------------------------------------===// +LogicalResult TaskGroupOp::verify() { + return verifyReductionVarList(*this, task_reductions(), + task_reduction_vars()); +} + //===----------------------------------------------------------------------===// // WsLoopOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1493,3 +1493,38 @@ } return } + +// CHECK-LABEL: @omp_taskgroup_no_tasks +func.func @omp_taskgroup_no_tasks() -> () { + + // CHECK: omp.taskgroup + omp.taskgroup { + // CHECK: "test.foo"() : () -> () + "test.foo"() : () -> () + // CHECK: omp.terminator + omp.terminator + } + return +} + +// CHECK-LABEL: @omp_taskgroup_multiple_tasks +func.func @omp_taskgroup_multiple_tasks() -> () { + // CHECK: omp.taskgroup + omp.taskgroup { + // CHECK: omp.task + omp.task { + "test.foo"() : () -> () + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.task + omp.task { + "test.foo"() : () -> () + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.terminator + omp.terminator + } + return +}