diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h --- a/clang/lib/CodeGen/CGOpenMPRuntime.h +++ b/clang/lib/CodeGen/CGOpenMPRuntime.h @@ -1430,6 +1430,14 @@ virtual void emitNumTeamsClause(CodeGenFunction &CGF, const Expr *NumTeams, const Expr *ThreadLimit, SourceLocation Loc); + /// Emits call to void __kmpc_set_thread_limit(ident_t *loc, kmp_int32 + /// global_tid, kmp_int32 thread_limit) to generate code for + /// thread_limit clause on target directive + /// \param ThreadLimit An integer expression of threads. + virtual void emitThreadLimitClause(CodeGenFunction &CGF, + const Expr *ThreadLimit, + SourceLocation Loc); + /// Struct that keeps all the relevant information that should be kept /// throughout a 'target data' region. class TargetDataInfo : public llvm::OpenMPIRBuilder::TargetDataInfo { diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -9788,9 +9788,12 @@ assert((OffloadingMandatory || OutlinedFn) && "Invalid outlined function!"); - const bool RequiresOuterTask = D.hasClausesOfKind() || - D.hasClausesOfKind() || - D.hasClausesOfKind(); + const bool RequiresOuterTask = + D.hasClausesOfKind() || + D.hasClausesOfKind() || + D.hasClausesOfKind() || + (CGM.getLangOpts().OpenMP >= 51 && D.getDirectiveKind() == OMPD_target && + D.hasClausesOfKind()); llvm::SmallVector CapturedVars; const CapturedStmt &CS = *D.getCapturedStmt(OMPD_target); auto &&ArgsCodegen = [&CS, &CapturedVars](CodeGenFunction &CGF, @@ -10610,6 +10613,24 @@ PushNumTeamsArgs); } +void CGOpenMPRuntime::emitThreadLimitClause(CodeGenFunction &CGF, + const Expr *ThreadLimit, + SourceLocation Loc) { + llvm::Value *RTLoc = emitUpdateLocation(CGF, Loc); + llvm::Value *ThreadLimitVal = + ThreadLimit + ? CGF.Builder.CreateIntCast(CGF.EmitScalarExpr(ThreadLimit), + CGF.CGM.Int32Ty, /* isSigned = */ true) + : CGF.Builder.getInt32(0); + + // Build call __kmpc_set_thread_limit(&loc, global_tid, thread_limit) + llvm::Value *ThreadLimitArgs[] = {RTLoc, getThreadID(CGF, Loc), + ThreadLimitVal}; + CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction( + CGM.getModule(), OMPRTL___kmpc_set_thread_limit), + ThreadLimitArgs); +} + void CGOpenMPRuntime::emitTargetDataCalls( CodeGenFunction &CGF, const OMPExecutableDirective &D, const Expr *IfCond, const Expr *Device, const RegionCodeGenTy &CodeGen, diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -5138,6 +5138,16 @@ Action.Enter(CGF); OMPLexicalScope LexScope(CGF, S, OMPD_task, /*EmitPreInitStmt=*/false); + if (CGF.CGM.getLangOpts().OpenMP >= 51 && + S.getDirectiveKind() == OMPD_target && + S.getSingleClause()) { + // Emit __kmpc_set_thread_limit() to set the thread_limit for the task + // enclosing this target region. This will indirectly set the thread_limit + // for every applicable construct within target region. + CGF.CGM.getOpenMPRuntime().emitThreadLimitClause( + CGF, S.getSingleClause()->getThreadLimit(), + S.getBeginLoc()); + } BodyGen(CGF); }; llvm::Function *OutlinedFn = CGM.getOpenMPRuntime().emitTaskOutlinedFunction( diff --git a/clang/test/OpenMP/target_codegen.cpp b/clang/test/OpenMP/target_codegen.cpp --- a/clang/test/OpenMP/target_codegen.cpp +++ b/clang/test/OpenMP/target_codegen.cpp @@ -846,7 +846,8 @@ // OMP51: store {{.*}} [[TL]], {{.*}} [[CEA:%.*]] // OMP51: load {{.*}} [[CEA]] // OMP51: [[CE:%.*]] = load {{.*}} [[CEA]] -// OMP51: call i32 @__tgt_target_kernel({{.*}}, i64 -1, i32 -1, i32 [[CE]], +// OMP51: call i8* @__kmpc_omp_task_alloc({{.*@.omp_task_entry.*}}) +// OMP51: call i32 [[OMP_TASK_ENTRY]] #pragma omp target thread_limit(TargetTL) #pragma omp teams @@ -854,8 +855,8 @@ // OMP51: [[TL:%.*]] = load {{.*}} %TargetTL.addr // OMP51: store {{.*}} [[TL]], {{.*}} [[CEA:%.*]] // OMP51: load {{.*}} [[CEA]] -// OMP51: [[CE:%.*]] = load {{.*}} [[CEA]] -// OMP51: call i32 @__tgt_target_kernel({{.*}}, i64 -1, i32 0, i32 [[CE]], +// OMP51: call i8* @__kmpc_omp_task_alloc({{.*@.omp_task_entry.*}}) +// OMP51: call i32 [[OMP_TASK_ENTRY]] #pragma omp target #pragma omp teams thread_limit(TeamsTL) @@ -869,10 +870,25 @@ {} // OMP51: load {{.*}} %TeamsTL.addr // OMP51: [[TeamsL:%.*]] = load {{.*}} %TeamsTL.addr -// OMP51: call i32 @__tgt_target_kernel({{.*}}, i64 -1, i32 0, i32 [[TeamsL]], +// OMP51: call i8* @__kmpc_omp_task_alloc({{.*@.omp_task_entry.*}}) +// OMP51: call i32 [[OMP_TASK_ENTRY]] } #endif +// Check that the offloading functions are called after setting thread_limit in the task entry functions + +// OMP51: define internal {{.*}}i32 [[OMP_TASK_ENTRY:@.+]](i32 {{.*}}%0, [[KMP_TASK_T_WITH_PRIVATES]].1* noalias noundef %1) +// OMP51: call void @__kmpc_set_thread_limit(%struct.ident_t* @{{.+}}, i32 %{{.+}}, i32 %{{.+}}) +// OMP51: call i32 @__tgt_target_kernel({{.*}}, i64 -1, i32 -1, + +// OMP51: define internal {{.*}}i32 [[OMP_TASK_ENTRY:@.+]](i32 {{.*}}%0, [[KMP_TASK_T_WITH_PRIVATES]].4* noalias noundef %1) +// OMP51: call void @__kmpc_set_thread_limit(%struct.ident_t* @{{.+}}, i32 %{{.+}}, i32 %{{.+}}) +// OMP51: call i32 @__tgt_target_kernel({{.*}}, i64 -1, i32 0, + +// OMP51: define internal {{.*}}i32 [[OMP_TASK_ENTRY:@.+]](i32 {{.*}}%0, [[KMP_TASK_T_WITH_PRIVATES]].7* noalias noundef %1) +// OMP51: call void @__kmpc_set_thread_limit(%struct.ident_t* @{{.+}}, i32 %{{.+}}, i32 %{{.+}}) +// OMP51: call i32 @__tgt_target_kernel({{.*}}, i64 -1, i32 0, + // CHECK: define internal void @.omp_offloading.requires_reg() // CHECK: call void @__tgt_register_requires(i64 1) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -382,6 +382,7 @@ __OMP_RTL(__kmpc_fork_teams, true, Void, IdentPtr, Int32, ParallelTaskPtr) __OMP_RTL(__kmpc_push_num_teams, false, Void, IdentPtr, Int32, Int32, Int32) +__OMP_RTL(__kmpc_set_thread_limit, false, Void, IdentPtr, Int32, Int32) __OMP_RTL(__kmpc_copyprivate, false, Void, IdentPtr, Int32, SizeTy, VoidPtr, CopyFunctionPtr, Int32) @@ -904,6 +905,8 @@ ParamAttrs(ReadOnlyPtrAttrs, SExt, ReadOnlyPtrAttrs)) __OMP_RTL_ATTRS(__kmpc_push_num_teams, InaccessibleArgOnlyAttrs, AttributeSet(), ParamAttrs(ReadOnlyPtrAttrs, SExt, SExt, SExt)) +__OMP_RTL_ATTRS(__kmpc_set_thread_limit, InaccessibleArgOnlyAttrs, AttributeSet(), + ParamAttrs(ReadOnlyPtrAttrs, SExt, SExt)) __OMP_RTL_ATTRS(__kmpc_copyprivate, DefaultAttrs, AttributeSet(), ParamAttrs(ReadOnlyPtrAttrs, SExt, SizeTyExt, diff --git a/openmp/runtime/src/kmp.h b/openmp/runtime/src/kmp.h --- a/openmp/runtime/src/kmp.h +++ b/openmp/runtime/src/kmp.h @@ -2074,6 +2074,7 @@ int nproc; /* internal control for #threads for next parallel region (per thread) */ int thread_limit; /* internal control for thread-limit-var */ + int task_thread_limit; /* internal control for thread-limit-var of a task*/ int max_active_levels; /* internal control for max_active_levels */ kmp_r_sched_t sched; /* internal control for runtime schedule {sched,chunk} pair */ @@ -3302,6 +3303,7 @@ extern int __kmp_max_nth; // maximum total number of concurrently-existing threads in a contention group extern int __kmp_cg_max_nth; +extern int __kmp_task_max_nth; // max threads used in a task extern int __kmp_teams_max_nth; // max threads used in a teams construct extern int __kmp_threads_capacity; /* capacity of the arrays __kmp_threads and __kmp_root */ @@ -4244,6 +4246,8 @@ KMP_EXPORT void __kmpc_push_num_teams(ident_t *loc, kmp_int32 global_tid, kmp_int32 num_teams, kmp_int32 num_threads); +KMP_EXPORT void __kmpc_set_thread_limit(ident_t *loc, kmp_int32 global_tid, + kmp_int32 thread_limit); /* Function for OpenMP 5.1 num_teams clause */ KMP_EXPORT void __kmpc_push_num_teams_51(ident_t *loc, kmp_int32 global_tid, kmp_int32 num_teams_lb, diff --git a/openmp/runtime/src/kmp_csupport.cpp b/openmp/runtime/src/kmp_csupport.cpp --- a/openmp/runtime/src/kmp_csupport.cpp +++ b/openmp/runtime/src/kmp_csupport.cpp @@ -381,6 +381,24 @@ __kmp_push_num_teams(loc, global_tid, num_teams, num_threads); } +/*! +@ingroup PARALLEL +@param loc source location information +@param global_tid global thread number +@param thread_limit limit on number of threads which can be created within the +current task + +Set the thread_limit for the current task +This call is there to support `thread_limit` clause on the `target` construct +*/ +void __kmpc_set_thread_limit(ident_t *loc, kmp_int32 global_tid, + kmp_int32 thread_limit) { + __kmp_assert_valid_gtid(global_tid); + kmp_info_t *thread = __kmp_threads[global_tid]; + if (thread_limit > 0) + thread->th.th_current_task->td_icvs.task_thread_limit = thread_limit; +} + /*! @ingroup PARALLEL @param loc source location information diff --git a/openmp/runtime/src/kmp_ftn_entry.h b/openmp/runtime/src/kmp_ftn_entry.h --- a/openmp/runtime/src/kmp_ftn_entry.h +++ b/openmp/runtime/src/kmp_ftn_entry.h @@ -802,7 +802,12 @@ gtid = __kmp_entry_gtid(); thread = __kmp_threads[gtid]; - return thread->th.th_current_task->td_icvs.thread_limit; + // If thread_limit for the target task is defined, return that instead of the + // regular task thread_limit + if (int thread_limit = thread->th.th_current_task->td_icvs.task_thread_limit) + return thread_limit; + else + return thread->th.th_current_task->td_icvs.thread_limit; #endif } diff --git a/openmp/runtime/src/kmp_global.cpp b/openmp/runtime/src/kmp_global.cpp --- a/openmp/runtime/src/kmp_global.cpp +++ b/openmp/runtime/src/kmp_global.cpp @@ -125,6 +125,7 @@ int __kmp_sys_max_nth = KMP_MAX_NTH; int __kmp_max_nth = 0; int __kmp_cg_max_nth = 0; +int __kmp_task_max_nth = 0; int __kmp_teams_max_nth = 0; int __kmp_threads_capacity = 0; int __kmp_dflt_team_nth = 0; diff --git a/openmp/runtime/src/kmp_runtime.cpp b/openmp/runtime/src/kmp_runtime.cpp --- a/openmp/runtime/src/kmp_runtime.cpp +++ b/openmp/runtime/src/kmp_runtime.cpp @@ -1867,6 +1867,7 @@ int nthreads; int master_active; int master_set_numthreads; + int task_thread_limit = 0; int level; int active_level; int teams_level; @@ -1905,6 +1906,8 @@ root = master_th->th.th_root; master_active = root->r.r_active; master_set_numthreads = master_th->th.th_set_nproc; + task_thread_limit = + master_th->th.th_current_task->td_icvs.task_thread_limit; #if OMPT_SUPPORT ompt_data_t ompt_parallel_data = ompt_data_none; @@ -1995,6 +1998,11 @@ ? master_set_numthreads // TODO: get nproc directly from current task : get__nproc_2(parent_team, master_tid); + // Use the thread_limit set for the current target task if exists, else go + // with the deduced nthreads + nthreads = task_thread_limit > 0 && task_thread_limit < nthreads + ? task_thread_limit + : nthreads; // Check if we need to take forkjoin lock? (no need for serialized // parallel out of teams construct). if (nthreads > 1) { @@ -3286,6 +3294,8 @@ // next parallel region (per thread) // (use a max ub on value if __kmp_parallel_initialize not called yet) __kmp_cg_max_nth, // int thread_limit; + __kmp_task_max_nth, // int task_thread_limit; // to set the thread_limit + // on task. This is used in the case of target thread_limit __kmp_dflt_max_active_levels, // int max_active_levels; //internal control // for max_active_levels r_sched, // kmp_r_sched_t sched; //internal control for runtime schedule diff --git a/openmp/runtime/test/target/target_thread_limit.cpp b/openmp/runtime/test/target/target_thread_limit.cpp new file mode 100644 --- /dev/null +++ b/openmp/runtime/test/target/target_thread_limit.cpp @@ -0,0 +1,81 @@ +// RUN: %libomp-cxx-compile -fopenmp-version=51 +// RUN: %libomp-run | FileCheck %s --check-prefix OMP51 + +#include +#include + +void foo() { +#pragma omp parallel num_threads(10) + { printf("\ntarget: foo(): parallel num_threads(10)"); } +} + +int main(void) { + + int tl = 4; + printf("\nmain: thread_limit = %d", omp_get_thread_limit()); + // OMP51: main: thread_limit = {{[0-9]+}} + +#pragma omp target thread_limit(tl) + { + printf("\ntarget: thread_limit = %d", omp_get_thread_limit()); +// OMP51: target: thread_limit = 4 +// check whether thread_limit is honoured +#pragma omp parallel + { printf("\ntarget: parallel"); } +// OMP51: target: parallel +// OMP51: target: parallel +// OMP51: target: parallel +// OMP51: target: parallel + +// check whether num_threads is honoured +#pragma omp parallel num_threads(2) + { printf("\ntarget: parallel num_threads(2)"); } +// OMP51: target: parallel num_threads(2) +// OMP51: target: parallel num_threads(2) + +// check whether thread_limit is honoured when there is a conflicting +// num_threads +#pragma omp parallel num_threads(10) + { printf("\ntarget: parallel num_threads(10)"); } + // OMP51: target: parallel num_threads(10) + // OMP51: target: parallel num_threads(10) + // OMP51: target: parallel num_threads(10) + // OMP51: target: parallel num_threads(10) + + // check whether threads are limited across functions + foo(); + // OMP51: target: foo(): parallel num_threads(10) + // OMP51: target: foo(): parallel num_threads(10) + // OMP51: target: foo(): parallel num_threads(10) + // OMP51: target: foo(): parallel num_threads(10) + } + +// checking consecutive target regions with different thread_limits +#pragma omp target thread_limit(3) + { + printf("\nsecond target: thread_limit = %d", omp_get_thread_limit()); +// OMP51: second target: thread_limit = 3 +#pragma omp parallel + { printf("\nsecond target: parallel"); } + // OMP51: second target: parallel + // OMP51: second target: parallel + // OMP51: second target: parallel + } + +// confirm that thread_limit's effects are limited to target region + printf("\nmain: thread_limit = %d", omp_get_thread_limit()); +// OMP51: main: thread_limit = {{[0-9]+}} +#pragma omp parallel num_threads(10) + { printf("\nmain: parallel num_threads(10)"); } + // OMP51: main: parallel num_threads(10) + // OMP51: main: parallel num_threads(10) + // OMP51: main: parallel num_threads(10) + // OMP51: main: parallel num_threads(10) + // OMP51: main: parallel num_threads(10) + // OMP51: main: parallel num_threads(10) + // OMP51: main: parallel num_threads(10) + // OMP51: main: parallel num_threads(10) + // OMP51: main: parallel num_threads(10) + // OMP51: main: parallel num_threads(10) + return 0; +}