diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -434,20 +434,19 @@ }; template -struct BooleanStateWithPtrSetVector : public BooleanState { - - bool contains(Ty *Elem) const { return Set.contains(Elem); } - bool insert(Ty *Elem) { +struct BooleanStateWithSetVector : public BooleanState { + bool contains(const Ty &Elem) const { return Set.contains(Elem); } + bool insert(const Ty &Elem) { if (InsertInvalidates) BooleanState::indicatePessimisticFixpoint(); return Set.insert(Elem); } - Ty *operator[](int Idx) const { return Set[Idx]; } - bool operator==(const BooleanStateWithPtrSetVector &RHS) const { + const Ty &operator[](int Idx) const { return Set[Idx]; } + bool operator==(const BooleanStateWithSetVector &RHS) const { return BooleanState::operator==(RHS) && Set == RHS.Set; } - bool operator!=(const BooleanStateWithPtrSetVector &RHS) const { + bool operator!=(const BooleanStateWithSetVector &RHS) const { return !(*this == RHS); } @@ -455,8 +454,7 @@ size_t size() const { return Set.size(); } /// "Clamp" this state with \p RHS. - BooleanStateWithPtrSetVector & - operator^=(const BooleanStateWithPtrSetVector &RHS) { + BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) { BooleanState::operator^=(RHS); Set.insert(RHS.Set.begin(), RHS.Set.end()); return *this; @@ -464,7 +462,7 @@ private: /// A set to keep track of elements. - SetVector Set; + SetVector Set; public: typename decltype(Set)::iterator begin() { return Set.begin(); } @@ -473,6 +471,10 @@ typename decltype(Set)::const_iterator end() const { return Set.end(); } }; +template +using BooleanStateWithPtrSetVector = + BooleanStateWithSetVector; + struct KernelInfoState : AbstractState { /// Flag to track if we reached a fixpoint. bool IsAtFixpoint = false; @@ -503,6 +505,9 @@ /// State to track what kernel entries can reach the associated function. BooleanStateWithPtrSetVector ReachingKernelEntries; + /// State to track what parallel levels the associated function can be at. + BooleanStateWithSetVector ParallelLevels; + /// Abstract State interface ///{ @@ -2740,6 +2745,8 @@ // Add itself to the reaching kernel and set IsKernelEntry. ReachingKernelEntries.insert(Fn); IsKernelEntry = true; + // Kernel entry is at level 0, which means not in parallel region. + ParallelLevels.insert(0); OMPInformationCache::RuntimeFunctionInfo &InitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; @@ -3225,8 +3232,10 @@ CheckRWInst, *this, UsedAssumedInformationInCheckRWInst)) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - if (!IsKernelEntry) + if (!IsKernelEntry) { updateReachingKernelEntries(A); + updateParallelLevels(A); + } // Callback to check a call instruction. auto CheckCallInst = [&](Instruction &I) { @@ -3275,6 +3284,36 @@ AllCallSitesKnown)) ReachingKernelEntries.indicatePessimisticFixpoint(); } + + /// Update info regarding parallel levels. + void updateParallelLevels(Attributor &A) { + auto PredCallSite = [&](AbstractCallSite ACS) { + Function *Caller = ACS.getInstruction()->getFunction(); + + assert(Caller && "Caller is nullptr"); + + auto &CAA = + A.getOrCreateAAFor(IRPosition::function(*Caller)); + if (CAA.ParallelLevels.isValidState()) { + // TODO: Check if caller is __kmpc_parallel_51. If yes, the parallel + // level will be the function of __kmpc_parallel_51 plus 1. + ParallelLevels ^= CAA.ParallelLevels; + return true; + } + + // We lost track of the caller of the associated function, any kernel + // could reach now. + ParallelLevels.indicatePessimisticFixpoint(); + + return true; + }; + + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(PredCallSite, *this, + true /* RequireAllCallSites */, + AllCallSitesKnown)) + ParallelLevels.indicatePessimisticFixpoint(); + } }; /// The call site kernel info abstract attribute, basically, what can we say @@ -3507,6 +3546,9 @@ case OMPRTL___kmpc_is_spmd_exec_mode: Changed |= foldIsSPMDExecMode(A); break; + case OMPRTL___kmpc_parallel_level: + Changed |= foldParallelLevel(A); + break; default: llvm_unreachable("Unhandled OpenMP runtime function!"); } @@ -3592,6 +3634,35 @@ : ChangeStatus::CHANGED; } + /// Fold __kmpc_parallel_level into a constant if possible. + ChangeStatus foldParallelLevel(Attributor &A) { + Optional SimplifiedValueBefore = SimplifiedValue; + + auto &CallerKernelInfoAA = A.getAAFor( + *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); + + if (!CallerKernelInfoAA.ParallelLevels.isValidState()) + return indicatePessimisticFixpoint(); + + // If we cannot deduct parallel region of the caller, we set SimplifiedValue + // to nullptr to indicate the associated function call cannot be folded for + // now. + if (CallerKernelInfoAA.ParallelLevels.empty()) + SimplifiedValue = nullptr; + + // If the caller is at multiple parallel levels, the associated function + // call cannot be folded anymore. + // TODO: I feel this will not happen as we will invalid + // CallerKernelInfoAA.ParallelLevels when updating AAKernelInfoFunction, + // therefore this case will be handled by checking valid state above + // directly. + if (CallerKernelInfoAA.ParallelLevels.size() > 1) + return indicatePessimisticFixpoint(); + + return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + /// An optional value the associated value is assumed to fold to. That is, we /// assume the associated value (which is a call) can be replaced by this /// simplified value. @@ -3629,6 +3700,18 @@ /* UpdateAfterInit */ false); return false; }); + + auto &ParallelLevelRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_level]; + ParallelLevelRFI.foreachUse(SCC, [&](Use &U, Function &) { + CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &ParallelLevelRFI); + if (!CI) + return false; + A.getOrCreateAAFor( + IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, + DepClassTy::NONE, /* ForceUpdate */ false, + /* UpdateAfterInit */ false); + return false; + }); } // Create CallSite AA for all Getters.