diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -218,6 +218,66 @@ let assemblyFormat = "attr-dict"; } +//===----------------------------------------------------------------------===// +// 2.7 teams Construct +//===----------------------------------------------------------------------===// +def TeamsOp : OpenMP_Op<"teams", [ + AttrSizedOperandSegments, RecursiveMemoryEffects, + ReductionClauseInterface]> { + let summary = "teams construct"; + let description = [{ + The teams construct defines a region of code that triggers the creation of a + league of teams. Once created, the number of teams remains constant for the + duration of its code region. + + The optional $num_teams_upper and $num_teams_lower specify the limit on the + number of teams to be created. If only the upper bound is specified, it acts + as if the lower bound was set to the same value. It is not supported to set + $num_teams_lower if $num_teams_upper is not specified. They define a closed + range, where both the lower and upper bounds are included. + + If the $if_expr is present and it evaluates to `false`, the number of teams + created is one. + + The optional $thread_limit specifies the limit on the number of threads. + + The $allocators_vars and $allocate_vars parameters are a variadic list of + values that specify the memory allocator to be used to obtain storage for + private values. + }]; + + let arguments = (ins Optional:$num_teams_lower, + Optional:$num_teams_upper, + Optional:$if_expr, + Optional:$thread_limit, + Variadic:$allocate_vars, + Variadic:$allocators_vars, + Variadic:$reduction_vars, + OptionalAttr:$reductions); + + let regions = (region AnyRegion:$region); + + let assemblyFormat = [{ + oilist( + `num_teams` `(` ( $num_teams_lower^ `:` type($num_teams_lower) )? `to` + $num_teams_upper `:` type($num_teams_upper) `)` + | `if` `(` $if_expr `)` + | `thread_limit` `(` $thread_limit `:` type($thread_limit) `)` + | `reduction` `(` + custom( + $reduction_vars, type($reduction_vars), $reductions + ) `)` + | `allocate` `(` + custom( + $allocate_vars, type($allocate_vars), + $allocators_vars, type($allocators_vars) + ) `)` + ) $region attr-dict + }]; + + let hasVerifier = 1; +} + def OMP_ScheduleModNone : I32EnumAttrCase<"none", 0>; def OMP_ScheduleModMonotonic : I32EnumAttrCase<"monotonic", 1>; def OMP_ScheduleModNonmonotonic : I32EnumAttrCase<"nonmonotonic", 2>; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -864,6 +864,46 @@ return verifyReductionVarList(*this, getReductions(), getReductionVars()); } +//===----------------------------------------------------------------------===// +// TeamsOp +//===----------------------------------------------------------------------===// + +LogicalResult TeamsOp::verify() { + Operation *op = getOperation(); + // Check parent region + if (auto target = dyn_cast_or_null(op->getParentOp())) { + WalkResult result = + target.getRegion().walk([&](Operation *nestedOp) { + if (nestedOp != op && !isa(nestedOp)) + return WalkResult::interrupt(); + // Only process operations nested directly inside of the omp.target + // region, not the contents of any regions nested inside + return WalkResult::skip(); + }); + if (result.wasInterrupted()) + return emitError("if nested inside of omp.target, it must not contain " + "any other operations"); + } + + // Check for num_teams clause restrictions + if (auto lowerb = getNumTeamsLower()) { + auto upperb = getNumTeamsUpper(); + if (!upperb) + return emitError("expected num_teams upper bound to be defined if the " + "lower bound is defined"); + if (lowerb.getType() != upperb.getType()) + return emitError( + "expected num_teams upper bound and lower bound to be the same type"); + } + + // Check for allocate clause restrictions + if (getAllocateVars().size() != getAllocatorsVars().size()) + return emitError( + "expected equal sizes for allocate and allocator variables"); + + return verifyReductionVarList(*this, getReductions(), getReductionVars()); +} + //===----------------------------------------------------------------------===// // Verifier for SectionsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1103,6 +1103,50 @@ // ----- +func.func @omp_teams_parent() { + omp.target { + %0 = arith.constant 0.0 : f32 + // expected-error @below {{if nested inside of omp.target, it must not contain any other operations}} + omp.teams { + omp.terminator + } + omp.terminator + } + return +} + +// ----- + +func.func @omp_teams_allocate(%data_var : memref) { + // expected-error @below {{expected equal sizes for allocate and allocator variables}} + "omp.teams" (%data_var) ({ + omp.terminator + }) {operand_segment_sizes = array} : (memref) -> () + return +} + +// ----- + +func.func @omp_teams_num_teams1(%lb : i32) { + // expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}} + "omp.teams" (%lb) ({ + omp.terminator + }) {operand_segment_sizes = array} : (i32) -> () + return +} + +// ----- + +func.func @omp_teams_num_teams2(%lb : i32, %ub : i16) { + // expected-error @below {{expected num_teams upper bound and lower bound to be the same type}} + omp.teams num_teams(%lb : i32 to %ub : i16) { + omp.terminator + } + return +} + +// ----- + func.func @omp_sections(%data_var : memref) -> () { // expected-error @below {{expected equal sizes for allocate and allocator variables}} "omp.sections" (%data_var) ({ diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -616,6 +616,69 @@ return } +// CHECK-LABEL: omp_teams +func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32, + %data_var : memref) -> () { + // Test nesting inside of target region. + omp.target { + // CHECK: omp.teams + omp.teams { + %0 = arith.constant 1 : i32 + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.terminator + omp.terminator + } + + // Test num teams. + // CHECK: omp.teams num_teams(%{{.+}} : i32 to %{{.+}} : i32) + omp.teams num_teams(%lb : i32 to %ub : i32) { + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.teams num_teams( to %{{.+}} : i32) + omp.teams num_teams(to %ub : i32) { + // CHECK: omp.terminator + omp.terminator + } + + // Test if. + // CHECK: omp.teams if(%{{.+}}) + omp.teams if(%if_cond) { + // CHECK: omp.terminator + omp.terminator + } + + // Test thread limit. + // CHECK: omp.teams thread_limit(%{{.+}} : i32) + omp.teams thread_limit(%num_threads : i32) { + // CHECK: omp.terminator + omp.terminator + } + + // Test reduction. + %c1 = arith.constant 1 : i32 + %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr + // CHECK: omp.teams reduction(@add_f32 -> %{{.+}} : !llvm.ptr) { + omp.teams reduction(@add_f32 -> %0 : !llvm.ptr) { + %1 = arith.constant 2.0 : f32 + // CHECK: omp.reduction %{{.+}}, %{{.+}} + omp.reduction %1, %0 : f32, !llvm.ptr + // CHECK: omp.terminator + omp.terminator + } + + // Test allocate. + // CHECK: omp.teams allocate(%{{.+}} : memref -> %{{.+}} : memref) + omp.teams allocate(%data_var : memref -> %data_var : memref) { + // CHECK: omp.terminator + omp.terminator + } + + return +} + // CHECK-LABEL: func @sections_reduction func.func @sections_reduction() { %c1 = arith.constant 1 : i32