diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h @@ -335,9 +335,11 @@ uint32_t ThreadLimitClause[3]) const; /// The number of threads \p NumThreads can be adjusted by this method. + /// \p IsNumThreadsFromUser is true is \p NumThreads is defined by user via + /// thread_limit clause. uint64_t getNumBlocks(GenericDeviceTy &GenericDevice, uint32_t BlockLimitClause[3], uint64_t LoopTripCount, - uint32_t &NumThreads) const; + uint32_t &NumThreads, bool IsNumThreadsFromUser) const; /// Indicate if the kernel works in Generic SPMD, Generic or SPMD mode. bool isGenericSPMDMode() const { diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp @@ -374,8 +374,9 @@ KernelArgs.NumArgs, Args, Ptrs); uint32_t NumThreads = getNumThreads(GenericDevice, KernelArgs.ThreadLimit); - uint64_t NumBlocks = getNumBlocks(GenericDevice, KernelArgs.NumTeams, - KernelArgs.Tripcount, NumThreads); + uint64_t NumBlocks = + getNumBlocks(GenericDevice, KernelArgs.NumTeams, KernelArgs.Tripcount, + NumThreads, KernelArgs.ThreadLimit[0] > 0); if (auto Err = printLaunchInfo(GenericDevice, KernelArgs, NumThreads, NumBlocks)) @@ -418,7 +419,8 @@ uint64_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice, uint32_t NumTeamsClause[3], uint64_t LoopTripCount, - uint32_t &NumThreads) const { + uint32_t &NumThreads, + bool IsNumThreadsFromUser) const { assert(NumTeamsClause[1] == 0 && NumTeamsClause[2] == 0 && "Multi dimensional launch not supported yet."); @@ -443,7 +445,8 @@ // Honor the thread_limit clause; only lower the number of threads. [[maybe_unused]] auto OldNumThreads = NumThreads; - if (LoopTripCount >= DefaultNumBlocks * NumThreads) { + if (LoopTripCount >= DefaultNumBlocks * NumThreads || + IsNumThreadsFromUser) { // Enough parallelism for teams and threads. TripCountNumBlocks = ((LoopTripCount - 1) / NumThreads) + 1; assert(TripCountNumBlocks >= DefaultNumBlocks &&