diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -505,6 +505,9 @@ /// State to track what kernel entries can reach the associated function. BooleanStateWithPtrSetVector ReachingKernelEntries; + /// State to track what parallel levels the associated function can be at. + BooleanStateWithSetVector ParallelLevels; + /// Abstract State interface ///{ @@ -2742,6 +2745,8 @@ // Add itself to the reaching kernel and set IsKernelEntry. ReachingKernelEntries.insert(Fn); IsKernelEntry = true; + // Kernel entry is at level 0, which means not in parallel region. + ParallelLevels.insert(0); OMPInformationCache::RuntimeFunctionInfo &InitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; @@ -3227,8 +3232,10 @@ CheckRWInst, *this, UsedAssumedInformationInCheckRWInst)) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - if (!IsKernelEntry) + if (!IsKernelEntry) { updateReachingKernelEntries(A); + updateParallelLevels(A); + } // Callback to check a call instruction. auto CheckCallInst = [&](Instruction &I) { @@ -3277,6 +3284,36 @@ AllCallSitesKnown)) ReachingKernelEntries.indicatePessimisticFixpoint(); } + + /// Update info regarding parallel levels. + void updateParallelLevels(Attributor &A) { + auto PredCallSite = [&](AbstractCallSite ACS) { + Function *Caller = ACS.getInstruction()->getFunction(); + + assert(Caller && "Caller is nullptr"); + + auto &CAA = + A.getOrCreateAAFor(IRPosition::function(*Caller)); + if (CAA.ParallelLevels.isValidState()) { + // TODO: Check if caller is __kmpc_parallel_51. If yes, the parallel + // level will be the function of __kmpc_parallel_51 plus 1. + ParallelLevels ^= CAA.ParallelLevels; + return true; + } + + // We lost track of the caller of the associated function, any kernel + // could reach now. + ParallelLevels.indicatePessimisticFixpoint(); + + return true; + }; + + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(PredCallSite, *this, + true /* RequireAllCallSites */, + AllCallSitesKnown)) + ParallelLevels.indicatePessimisticFixpoint(); + } }; /// The call site kernel info abstract attribute, basically, what can we say @@ -3509,6 +3546,9 @@ case OMPRTL___kmpc_is_spmd_exec_mode: Changed |= foldIsSPMDExecMode(A); break; + case OMPRTL___kmpc_parallel_level: + Changed |= foldParallelLevel(A); + break; default: llvm_unreachable("Unhandled OpenMP runtime function!"); } @@ -3594,6 +3634,35 @@ : ChangeStatus::CHANGED; } + /// Fold __kmpc_parallel_level into a constant if possible. + ChangeStatus foldParallelLevel(Attributor &A) { + Optional SimplifiedValueBefore = SimplifiedValue; + + auto &CallerKernelInfoAA = A.getAAFor( + *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); + + if (!CallerKernelInfoAA.ParallelLevels.isValidState()) + return indicatePessimisticFixpoint(); + + // If we cannot deduct parallel region of the caller, we set SimplifiedValue + // to nullptr to indicate the associated function call cannot be folded for + // now. + if (CallerKernelInfoAA.ParallelLevels.empty()) + SimplifiedValue = nullptr; + + // If the caller is at multiple parallel levels, the associated function + // call cannot be folded anymore. + // TODO: I feel this will not happen as we will invalid + // CallerKernelInfoAA.ParallelLevels when updating AAKernelInfoFunction, + // therefore this case will be handled by checking valid state above + // directly. + if (CallerKernelInfoAA.ParallelLevels.size() > 1) + return indicatePessimisticFixpoint(); + + return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + /// An optional value the associated value is assumed to fold to. That is, we /// assume the associated value (which is a call) can be replaced by this /// simplified value. @@ -3631,6 +3700,18 @@ /* UpdateAfterInit */ false); return false; }); + + auto &ParallelLevelRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_level]; + ParallelLevelRFI.foreachUse(SCC, [&](Use &U, Function &) { + CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &ParallelLevelRFI); + if (!CI) + return false; + A.getOrCreateAAFor( + IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, + DepClassTy::NONE, /* ForceUpdate */ false, + /* UpdateAfterInit */ false); + return false; + }); } // Create CallSite AA for all Getters.